Răsfoiți Sursa

fix: Allow evaluation benchmarks to pass image urls in run_controller() instead of simply passing strings (#4100)

Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
Aditya Bharat Soni 1 an în urmă
părinte
comite
0809d26f4d

+ 2 - 1
evaluation/EDA/run_infer.py

@@ -22,6 +22,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
+from openhands.events.action import MessageAction
 
 game = None
 
@@ -122,7 +123,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
                 metadata.agent_class

+ 1 - 1
evaluation/agent_bench/run_infer.py

@@ -217,7 +217,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
         )

+ 2 - 2
evaluation/aider_bench/run_infer.py

@@ -30,7 +30,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
-from openhands.events.action import CmdRunAction
+from openhands.events.action import CmdRunAction, MessageAction
 from openhands.events.observation import CmdOutputObservation
 from openhands.runtime.runtime import Runtime
 
@@ -211,7 +211,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
         )

+ 2 - 2
evaluation/biocoder/run_infer.py

@@ -27,7 +27,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
-from openhands.events.action import CmdRunAction
+from openhands.events.action import CmdRunAction, MessageAction
 from openhands.events.observation import CmdOutputObservation
 from openhands.runtime.runtime import Runtime
 
@@ -285,7 +285,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
                 metadata.agent_class

+ 1 - 1
evaluation/bird/run_infer.py

@@ -409,7 +409,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
                 metadata.agent_class
             ],

+ 2 - 1
evaluation/browsing_delegation/run_infer.py

@@ -23,6 +23,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
+from openhands.events.action import MessageAction
 
 # Only CodeActAgent can delegate to BrowsingAgent
 SUPPORTED_AGENT_CLS = {'CodeActAgent'}
@@ -76,7 +77,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
         )
     )

+ 1 - 1
evaluation/gaia/run_infer.py

@@ -148,7 +148,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
                 metadata.agent_class

+ 2 - 1
evaluation/gorilla/run_infer.py

@@ -24,6 +24,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
+from openhands.events.action import MessageAction
 
 AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
     'CodeActAgent': codeact_user_response,
@@ -83,7 +84,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
                 metadata.agent_class

+ 1 - 1
evaluation/gpqa/run_infer.py

@@ -219,7 +219,7 @@ Ok now its time to start solving the question. Good luck!
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
                 metadata.agent_class

+ 2 - 2
evaluation/humanevalfix/run_infer.py

@@ -35,7 +35,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
-from openhands.events.action import CmdRunAction
+from openhands.events.action import CmdRunAction, MessageAction
 from openhands.events.observation import CmdOutputObservation
 from openhands.runtime.runtime import Runtime
 
@@ -237,7 +237,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
                 metadata.agent_class

+ 1 - 1
evaluation/logic_reasoning/run_infer.py

@@ -211,7 +211,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
                 metadata.agent_class

+ 3 - 2
evaluation/miniwob/run_infer.py

@@ -128,11 +128,12 @@ def process_instance(
 
     runtime = create_runtime(config, sid=env_id)
     task_str = initialize_runtime(runtime)
-
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=task_str,  # take output from initialize_runtime
+            initial_user_action=MessageAction(
+                content=task_str
+            ),  # take output from initialize_runtime
             runtime=runtime,
         )
     )

+ 2 - 1
evaluation/mint/run_infer.py

@@ -29,6 +29,7 @@ from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
 from openhands.events.action import (
     CmdRunAction,
+    MessageAction,
 )
 from openhands.events.observation import CmdOutputObservation
 from openhands.runtime.runtime import Runtime
@@ -180,7 +181,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=fake_user_response_fn,
         )

+ 2 - 2
evaluation/ml_bench/run_infer.py

@@ -39,7 +39,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
-from openhands.events.action import CmdRunAction
+from openhands.events.action import CmdRunAction, MessageAction
 from openhands.events.observation import CmdOutputObservation
 from openhands.runtime.runtime import Runtime
 
@@ -242,7 +242,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
                 metadata.agent_class

+ 3 - 2
evaluation/swe_bench/run_infer.py

@@ -29,7 +29,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
-from openhands.events.action import CmdRunAction
+from openhands.events.action import CmdRunAction, MessageAction
 from openhands.events.observation import CmdOutputObservation, ErrorObservation
 from openhands.events.serialization.event import event_to_dict
 from openhands.runtime.runtime import Runtime
@@ -365,6 +365,7 @@ def process_instance(
         logger.info(f'Starting evaluation for instance {instance.instance_id}.')
 
     runtime = create_runtime(config, sid=instance.instance_id)
+
     try:
         initialize_runtime(runtime, instance)
 
@@ -374,7 +375,7 @@ def process_instance(
         state: State | None = asyncio.run(
             run_controller(
                 config=config,
-                task_str=instruction,
+                initial_user_action=MessageAction(content=instruction),
                 runtime=runtime,
                 fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
                     metadata.agent_class

+ 2 - 2
evaluation/toolqa/run_infer.py

@@ -23,7 +23,7 @@ from openhands.core.config import (
 )
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.main import create_runtime, run_controller
-from openhands.events.action import CmdRunAction
+from openhands.events.action import CmdRunAction, MessageAction
 from openhands.events.observation import CmdOutputObservation
 from openhands.runtime.runtime import Runtime
 
@@ -109,7 +109,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=instruction,
+            initial_user_action=MessageAction(content=instruction),
             runtime=runtime,
             fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
                 metadata.agent_class

+ 1 - 1
evaluation/webarena/run_infer.py

@@ -148,7 +148,7 @@ def process_instance(
     state: State | None = asyncio.run(
         run_controller(
             config=config,
-            task_str=task_str,
+            initial_user_action=MessageAction(content=task_str),
             runtime=runtime,
         )
     )

+ 10 - 8
openhands/core/main.py

@@ -83,7 +83,7 @@ def create_runtime(
 
 async def run_controller(
     config: AppConfig,
-    task_str: str,
+    initial_user_action: Action,
     sid: str | None = None,
     runtime: Runtime | None = None,
     agent: Agent | None = None,
@@ -96,7 +96,7 @@ async def run_controller(
 
     Args:
         config: The app config.
-        task_str: The task to run. It can be a string.
+        initial_user_action: An Action object containing initial user input
         runtime: (optional) A runtime for the agent to run on.
         agent: (optional) A agent to run.
         exit_on_message: quit if agent asks for a message from user (optional)
@@ -146,11 +146,13 @@ async def run_controller(
     if controller is not None:
         controller.agent_task = asyncio.create_task(controller.start_step_loop())
 
-    assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
+    assert isinstance(
+        initial_user_action, Action
+    ), f'initial user actions must be an Action, got {type(initial_user_action)}'
     # Logging
     logger.info(
         f'Agent Controller Initialized: Running agent {agent.name}, model '
-        f'{agent.llm.config.model}, with task: "{task_str}"'
+        f'{agent.llm.config.model}, with actions: {initial_user_action}'
     )
 
     # start event is a MessageAction with the task, either resumed or new
@@ -166,8 +168,8 @@ async def run_controller(
             EventSource.USER,
         )
     elif initial_state is None:
-        # init with the provided task
-        event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
+        # init with the provided actions
+        event_stream.add_event(initial_user_action, EventSource.USER)
 
     async def on_event(event: Event):
         if isinstance(event, AgentStateChangedObservation):
@@ -224,7 +226,7 @@ if __name__ == '__main__':
         task_str = read_task_from_stdin()
     else:
         raise ValueError('No task provided. Please specify a task through -t, -f.')
-
+    initial_user_action: MessageAction = MessageAction(content=task_str)
     # Load the app config
     # this will load config from config.toml in the current directory
     # as well as from the environment variables
@@ -253,7 +255,7 @@ if __name__ == '__main__':
     asyncio.run(
         run_controller(
             config=config,
-            task_str=task_str,
+            initial_user_action=initial_user_action,
             sid=sid,
         )
     )

+ 7 - 10
tests/integration/test_agent.py

@@ -9,10 +9,7 @@ from openhands.controller.state.state import State
 from openhands.core.config import load_app_config
 from openhands.core.main import run_controller
 from openhands.core.schema import AgentState
-from openhands.events.action import (
-    AgentFinishAction,
-    AgentRejectAction,
-)
+from openhands.events.action import AgentFinishAction, AgentRejectAction, MessageAction
 from openhands.events.observation.browse import BrowserOutputObservation
 from openhands.events.observation.delegate import AgentDelegateObservation
 from openhands.runtime import get_runtime_cls
@@ -90,7 +87,7 @@ def test_write_simple_script(current_test_name: str) -> None:
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
 
     final_state: State | None = asyncio.run(
-        run_controller(CONFIG, task, exit_on_message=True)
+        run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
     )
     validate_final_state(final_state, current_test_name)
 
@@ -136,7 +133,7 @@ def test_edits(current_test_name: str):
     # Execute the task
     task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
     final_state: State | None = asyncio.run(
-        run_controller(CONFIG, task, exit_on_message=True)
+        run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
     )
     validate_final_state(final_state, current_test_name)
 
@@ -160,7 +157,7 @@ def test_ipython(current_test_name: str):
     # Execute the task
     task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
     final_state: State | None = asyncio.run(
-        run_controller(CONFIG, task, exit_on_message=True)
+        run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
     )
     validate_final_state(final_state, current_test_name)
 
@@ -185,7 +182,7 @@ def test_simple_task_rejection(current_test_name: str):
     # the workspace is not a git repo
     task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
     final_state: State | None = asyncio.run(
-        run_controller(CONFIG, task, exit_on_message=True)
+        run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
     )
     validate_final_state(final_state, current_test_name)
     assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
@@ -200,7 +197,7 @@ def test_ipython_module(current_test_name: str):
     # Execute the task
     task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
     final_state: State | None = asyncio.run(
-        run_controller(CONFIG, task, exit_on_message=True)
+        run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
     )
     validate_final_state(final_state, current_test_name)
 
@@ -226,7 +223,7 @@ def test_browse_internet(current_test_name: str):
     # Execute the task
     task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
     final_state: State | None = asyncio.run(
-        run_controller(CONFIG, task, exit_on_message=True)
+        run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
     )
     validate_final_state(final_state, current_test_name)