Pārlūkot izejas kodu

Fix issue #5383: [Bug]: LLM Cost is added to the `metrics` twice (#5396)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
OpenHands 1 gadu atpakaļ
vecāks
revīzija
794408cd31
1 mainītis faili ar 23 papildinājumiem un 17 dzēšanām
  1. 23 17
      openhands/llm/llm.py

+ 23 - 17
openhands/llm/llm.py

@@ -219,6 +219,20 @@ class LLM(RetryMixin, DebugMixin):
                         )
                     resp.choices[0].message = fn_call_response_message
 
+                message_back: str = resp['choices'][0]['message']['content'] or ''
+                tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
+                if tool_calls:
+                    for tool_call in tool_calls:
+                        fn_name = tool_call.function.name
+                        fn_args = tool_call.function.arguments
+                        message_back += f'\nFunction call: {fn_name}({fn_args})'
+
+                # log the LLM response
+                self.log_response(message_back)
+
+                # post-process the response first to calculate cost
+                cost = self._post_completion(resp)
+
                 # log for evals or other scripts that need the raw completion
                 if self.config.log_completions:
                     assert self.config.log_completions_folder is not None
@@ -228,37 +242,27 @@ class LLM(RetryMixin, DebugMixin):
                         f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
                     )
 
+                    # set up the dict to be logged
                     _d = {
                         'messages': messages,
                         'response': resp,
                         'args': args,
                         'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
                         'timestamp': time.time(),
-                        'cost': self._completion_cost(resp),
+                        'cost': cost,
                     }
+
+                    # if non-native function calling, save messages/response separately
                     if mock_function_calling:
-                        # Overwrite response as non-fncall to be consistent with `messages``
+                        # Overwrite response as non-fncall to be consistent with messages
                         _d['response'] = non_fncall_response
+
                         # Save fncall_messages/response separately
                         _d['fncall_messages'] = original_fncall_messages
                         _d['fncall_response'] = resp
                     with open(log_file, 'w') as f:
                         f.write(json.dumps(_d))
 
-                message_back: str = resp['choices'][0]['message']['content'] or ''
-                tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
-                if tool_calls:
-                    for tool_call in tool_calls:
-                        fn_name = tool_call.function.name
-                        fn_args = tool_call.function.arguments
-                        message_back += f'\nFunction call: {fn_name}({fn_args})'
-
-                # log the LLM response
-                self.log_response(message_back)
-
-                # post-process the response
-                self._post_completion(resp)
-
                 return resp
             except APIError as e:
                 if 'Attention Required! | Cloudflare' in str(e):
@@ -414,7 +418,7 @@ class LLM(RetryMixin, DebugMixin):
         )
         return model_name_supported
 
-    def _post_completion(self, response: ModelResponse) -> None:
+    def _post_completion(self, response: ModelResponse) -> float:
         """Post-process the completion response.
 
         Logs the cost and usage stats of the completion call.
@@ -472,6 +476,8 @@ class LLM(RetryMixin, DebugMixin):
         if stats:
             logger.debug(stats)
 
+        return cur_cost
+
     def get_token_count(self, messages) -> int:
         """Get the number of tokens in a list of messages.