Răsfoiți Sursa

feat(eval): rewrite log_completions to save completions to directory (#4566)

Xingyao Wang 1 an în urmă
părinte
comite
7340b78962

+ 10 - 1
evaluation/integration_tests/run_infer.py

@@ -33,6 +33,7 @@ FAKE_RESPONSES = {
 
 def get_config(
     metadata: EvalMetadata,
+    instance_id: str,
 ) -> AppConfig:
     config = AppConfig(
         default_agent=metadata.agent_class,
@@ -49,6 +50,14 @@ def get_config(
         workspace_base=None,
         workspace_mount_path=None,
     )
+    if metadata.llm_config.log_completions:
+        metadata.llm_config.log_completions_folder = os.path.join(
+            metadata.eval_output_dir, 'llm_completions', instance_id
+        )
+        logger.info(
+            f'Logging LLM completions for instance {instance_id} to '
+            f'{metadata.llm_config.log_completions_folder}'
+        )
     config.set_llm_config(metadata.llm_config)
     return config
 
@@ -58,7 +67,7 @@ def process_instance(
     metadata: EvalMetadata,
     reset_logger: bool = True,
 ) -> EvalOutput:
-    config = get_config(metadata)
+    config = get_config(metadata, instance.instance_id)
 
     # Setup the logger properly, so you can run multi-processing to parallelize the evaluation
     if reset_logger:

+ 8 - 1
evaluation/swe_bench/run_infer.py

@@ -143,6 +143,14 @@ def get_config(
         workspace_base=None,
         workspace_mount_path=None,
     )
+    if metadata.llm_config.log_completions:
+        metadata.llm_config.log_completions_folder = os.path.join(
+            metadata.eval_output_dir, 'llm_completions', instance['instance_id']
+        )
+        logger.info(
+            f'Logging LLM completions for instance {instance["instance_id"]} to '
+            f'{metadata.llm_config.log_completions_folder}'
+        )
     config.set_llm_config(metadata.llm_config)
     return config
 
@@ -432,7 +440,6 @@ def process_instance(
         metadata=metadata,
         history=histories,
         metrics=metrics,
-        llm_completions=state.extra_data.get('llm_completions', []),
         error=state.last_error if state and state.last_error else None,
     )
     return output

+ 0 - 1
evaluation/utils/shared.py

@@ -61,7 +61,6 @@ class EvalOutput(BaseModel):
     history: (
         list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
     ) = None
-    llm_completions: list[dict[str, Any]] | None = None
     metrics: dict[str, Any] | None = None
     error: str | None = None
 

+ 0 - 4
openhands/controller/agent_controller.py

@@ -132,10 +132,6 @@ class AgentController:
     async def update_state_after_step(self):
         # update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
         self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)
-        if 'llm_completions' not in self.state.extra_data:
-            self.state.extra_data['llm_completions'] = []
-        self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
-        self.agent.llm.llm_completions.clear()
 
     async def report_error(self, message: str, exception: Exception | None = None):
         """Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.

+ 2 - 0
openhands/core/config/llm_config.py

@@ -40,6 +40,7 @@ class LLMConfig:
         disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
         caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider.
         log_completions: Whether to log LLM completions to the state.
+        log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
         draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
     """
 
@@ -73,6 +74,7 @@ class LLMConfig:
     disable_vision: bool | None = None
     caching_prompt: bool = True
     log_completions: bool = False
+    log_completions_folder: str | None = None
     draft_editor: Optional['LLMConfig'] = None
 
     def defaults_to_dict(self) -> dict:

+ 26 - 13
openhands/llm/llm.py

@@ -1,4 +1,6 @@
 import copy
+import json
+import os
 import time
 import warnings
 from functools import partial
@@ -77,11 +79,6 @@ class LLM(RetryMixin, DebugMixin):
         self.cost_metric_supported: bool = True
         self.config: LLMConfig = copy.deepcopy(config)
 
-        # list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
-        # - 'messages': list of messages
-        # - 'response': response from the LLM
-        self.llm_completions: list[dict[str, Any]] = []
-
         # litellm actually uses base Exception here for unknown model
         self.model_info: ModelInfo | None = None
         try:
@@ -95,6 +92,13 @@ class LLM(RetryMixin, DebugMixin):
         except Exception as e:
             logger.warning(f'Could not get model info for {config.model}:\n{e}')
 
+        if self.config.log_completions:
+            if self.config.log_completions_folder is None:
+                raise RuntimeError(
+                    'log_completions_folder is required when log_completions is enabled'
+                )
+            os.makedirs(self.config.log_completions_folder, exist_ok=True)
+
         # Set the max tokens in an LM-specific way if not set
         if self.config.max_input_tokens is None:
             if (
@@ -194,14 +198,24 @@ class LLM(RetryMixin, DebugMixin):
 
                 # log for evals or other scripts that need the raw completion
                 if self.config.log_completions:
-                    self.llm_completions.append(
-                        {
-                            'messages': messages,
-                            'response': resp,
-                            'timestamp': time.time(),
-                            'cost': self._completion_cost(resp),
-                        }
+                    assert self.config.log_completions_folder is not None
+                    log_file = os.path.join(
+                        self.config.log_completions_folder,
+                        # use the metric model name (for draft editor)
+                        f'{self.metrics.model_name}-{time.time()}.json',
                     )
+                    with open(log_file, 'w') as f:
+                        json.dump(
+                            {
+                                'messages': messages,
+                                'response': resp,
+                                'args': args,
+                                'kwargs': kwargs,
+                                'timestamp': time.time(),
+                                'cost': self._completion_cost(resp),
+                            },
+                            f,
+                        )
 
                 message_back: str = resp['choices'][0]['message']['content']
 
@@ -400,7 +414,6 @@ class LLM(RetryMixin, DebugMixin):
 
     def reset(self):
         self.metrics.reset()
-        self.llm_completions = []
 
     def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
         if isinstance(messages, Message):