9
9
10
10
from ..tracing import tracer
11
11
12
- LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat" : "OpenAI" }
13
- PROVIDER_TO_STEP_NAME = {"OpenAI" : "OpenAI Chat Completion" }
12
+ LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat" : "OpenAI" , "chat-ollama" : "Ollama" }
13
+ PROVIDER_TO_STEP_NAME = {"OpenAI" : "OpenAI Chat Completion" , "Ollama" : "Ollama Chat Completion" }
14
14
15
15
16
16
class OpenlayerHandler (BaseCallbackHandler ):
@@ -45,13 +45,16 @@ def on_chat_model_start(
45
45
) -> Any :
46
46
"""Run when Chat Model starts running."""
47
47
self .model_parameters = kwargs .get ("invocation_params" , {})
48
+ self .metadata = kwargs .get ("metadata" , {})
48
49
49
50
provider = self .model_parameters .get ("_type" , None )
50
51
if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP :
51
52
self .provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP [provider ]
52
53
self .model_parameters .pop ("_type" )
54
+ self .metadata .pop ("ls_provider" , None )
55
+ self .metadata .pop ("ls_model_type" , None )
53
56
54
- self .model = self .model_parameters .get ("model_name" , None )
57
+ self .model = self .model_parameters .get ("model_name" , None ) or self . metadata . pop ( "ls_model_name" , None )
55
58
self .output = ""
56
59
self .prompt = self ._langchain_messages_to_prompt (messages )
57
60
self .start_time = time .time ()
@@ -82,17 +85,32 @@ def on_llm_end(self, response: langchain_schema.LLMResult, **kwargs: Any) -> Any
82
85
self .end_time = time .time ()
83
86
self .latency = (self .end_time - self .start_time ) * 1000
84
87
85
- if response . llm_output and "token_usage" in response . llm_output :
86
- self .prompt_tokens = response . llm_output [ "token_usage" ]. get ( "prompt_tokens" , 0 )
87
- self .completion_tokens = response . llm_output [ "token_usage" ]. get ( "completion_tokens" , 0 )
88
- self .total_tokens = response . llm_output [ "token_usage" ]. get ( "total_tokens" , 0 )
88
+ if self . provider == "OpenAI" :
89
+ self ._openai_token_information ( response )
90
+ elif self .provider == "Ollama" :
91
+ self ._ollama_token_information ( response )
89
92
90
93
for generations in response .generations :
91
94
for generation in generations :
92
95
self .output += generation .text .replace ("\n " , " " )
93
96
94
97
self ._add_to_trace ()
95
98
99
+ def _openai_token_information (self , response : langchain_schema .LLMResult ) -> None :
100
+ """Extracts OpenAI's token information."""
101
+ if response .llm_output and "token_usage" in response .llm_output :
102
+ self .prompt_tokens = response .llm_output ["token_usage" ].get ("prompt_tokens" , 0 )
103
+ self .completion_tokens = response .llm_output ["token_usage" ].get ("completion_tokens" , 0 )
104
+ self .total_tokens = response .llm_output ["token_usage" ].get ("total_tokens" , 0 )
105
+
106
+ def _ollama_token_information (self , response : langchain_schema .LLMResult ) -> None :
107
+ """Extracts Ollama's token information."""
108
+ generation_info = response .generations [0 ][0 ].generation_info
109
+ if generation_info :
110
+ self .prompt_tokens = generation_info .get ("prompt_eval_count" , 0 )
111
+ self .completion_tokens = generation_info .get ("eval_count" , 0 )
112
+ self .total_tokens = self .prompt_tokens + self .completion_tokens
113
+
96
114
def _add_to_trace (self ) -> None :
97
115
"""Adds to the trace."""
98
116
name = PROVIDER_TO_STEP_NAME .get (self .provider , "Chat Completion Model" )
@@ -109,7 +127,7 @@ def _add_to_trace(self) -> None:
109
127
model_parameters = self .model_parameters ,
110
128
prompt_tokens = self .prompt_tokens ,
111
129
completion_tokens = self .completion_tokens ,
112
- metadata = self .metatada ,
130
+ metadata = self .metadata ,
113
131
)
114
132
115
133
def on_llm_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
0 commit comments