|
|
@@ -219,6 +219,20 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
)
|
|
|
resp.choices[0].message = fn_call_response_message
|
|
|
|
|
|
+ message_back: str = resp['choices'][0]['message']['content'] or ''
|
|
|
+ tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
|
|
|
+ if tool_calls:
|
|
|
+ for tool_call in tool_calls:
|
|
|
+ fn_name = tool_call.function.name
|
|
|
+ fn_args = tool_call.function.arguments
|
|
|
+ message_back += f'\nFunction call: {fn_name}({fn_args})'
|
|
|
+
|
|
|
+ # log the LLM response
|
|
|
+ self.log_response(message_back)
|
|
|
+
|
|
|
+ # post-process the response first to calculate cost
|
|
|
+ cost = self._post_completion(resp)
|
|
|
+
|
|
|
# log for evals or other scripts that need the raw completion
|
|
|
if self.config.log_completions:
|
|
|
assert self.config.log_completions_folder is not None
|
|
|
@@ -228,37 +242,27 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
|
|
|
)
|
|
|
|
|
|
+ # set up the dict to be logged
|
|
|
_d = {
|
|
|
'messages': messages,
|
|
|
'response': resp,
|
|
|
'args': args,
|
|
|
'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
|
|
|
'timestamp': time.time(),
|
|
|
- 'cost': self._completion_cost(resp),
|
|
|
+ 'cost': cost,
|
|
|
}
|
|
|
+
|
|
|
+ # if non-native function calling, save messages/response separately
|
|
|
if mock_function_calling:
|
|
|
- # Overwrite response as non-fncall to be consistent with `messages``
|
|
|
+ # Overwrite response as non-fncall to be consistent with messages
|
|
|
_d['response'] = non_fncall_response
|
|
|
+
|
|
|
# Save fncall_messages/response separately
|
|
|
_d['fncall_messages'] = original_fncall_messages
|
|
|
_d['fncall_response'] = resp
|
|
|
with open(log_file, 'w') as f:
|
|
|
f.write(json.dumps(_d))
|
|
|
|
|
|
- message_back: str = resp['choices'][0]['message']['content'] or ''
|
|
|
- tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
|
|
|
- if tool_calls:
|
|
|
- for tool_call in tool_calls:
|
|
|
- fn_name = tool_call.function.name
|
|
|
- fn_args = tool_call.function.arguments
|
|
|
- message_back += f'\nFunction call: {fn_name}({fn_args})'
|
|
|
-
|
|
|
- # log the LLM response
|
|
|
- self.log_response(message_back)
|
|
|
-
|
|
|
- # post-process the response
|
|
|
- self._post_completion(resp)
|
|
|
-
|
|
|
return resp
|
|
|
except APIError as e:
|
|
|
if 'Attention Required! | Cloudflare' in str(e):
|
|
|
@@ -414,7 +418,7 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
)
|
|
|
return model_name_supported
|
|
|
|
|
|
- def _post_completion(self, response: ModelResponse) -> None:
|
|
|
+ def _post_completion(self, response: ModelResponse) -> float:
|
|
|
"""Post-process the completion response.
|
|
|
|
|
|
Logs the cost and usage stats of the completion call.
|
|
|
@@ -472,6 +476,8 @@ class LLM(RetryMixin, DebugMixin):
|
|
|
if stats:
|
|
|
logger.debug(stats)
|
|
|
|
|
|
+ return cur_cost
|
|
|
+
|
|
|
def get_token_count(self, messages) -> int:
|
|
|
"""Get the number of tokens in a list of messages.
|
|
|
|