Ver código fonte

[Arch proposal] ENVIRONMENT event source (#4584)

Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
Engel Nyst 1 ano atrás
pai
commit
0687608feb

+ 5 - 3
openhands/controller/agent_controller.py

@@ -156,7 +156,7 @@ class AgentController:
         if exception is not None and isinstance(exception, litellm.AuthenticationError):
             detail = 'Please check your credentials. Is your API key correct?'
         self.event_stream.add_event(
-            ErrorObservation(f'{message}:{detail}'), EventSource.USER
+            ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT
         )
 
     async def start_step_loop(self):
@@ -346,7 +346,8 @@ class AgentController:
 
         self.state.agent_state = new_state
         self.event_stream.add_event(
-            AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
+            AgentStateChangedObservation('', self.state.agent_state),
+            EventSource.ENVIRONMENT,
         )
 
         if new_state == AgentState.INIT and self.state.resume_state:
@@ -423,7 +424,8 @@ class AgentController:
         if self._is_stuck():
             # This need to go BEFORE report_error to sync metrics
             self.event_stream.add_event(
-                FatalErrorObservation('Agent got stuck in a loop'), EventSource.USER
+                FatalErrorObservation('Agent got stuck in a loop'),
+                EventSource.ENVIRONMENT,
             )
             return
 

+ 2 - 2
openhands/core/cli.py

@@ -61,7 +61,7 @@ def display_event(event: Event):
         if hasattr(event, 'thought'):
             display_message(event.thought)
     if isinstance(event, MessageAction):
-        if event.source != EventSource.USER:
+        if event.source == EventSource.AGENT:
             display_message(event.content)
     if isinstance(event, CmdRunAction):
         display_command(event.command)
@@ -131,7 +131,7 @@ async def main():
         next_message = input('How can I help? >> ')
         if next_message == 'exit':
             event_stream.add_event(
-                ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
+                ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
             )
             return
         action = MessageAction(content=next_message)

+ 2 - 0
openhands/core/message.py

@@ -49,6 +49,8 @@ class ImageContent(Content):
 
 
 class Message(BaseModel):
+    # NOTE: this is not the same as EventSource
+    # These are the roles in the LLM's APIs
     role: Literal['user', 'system', 'assistant', 'tool']
     content: list[TextContent | ImageContent] = Field(default_factory=list)
     cache_enabled: bool = False

+ 1 - 0
openhands/events/event.py

@@ -9,6 +9,7 @@ from openhands.llm.metrics import Metrics
 class EventSource(str, Enum):
     AGENT = 'agent'
     USER = 'user'
+    ENVIRONMENT = 'environment'
 
 
 @dataclass

+ 2 - 0
openhands/runtime/base.py

@@ -136,6 +136,8 @@ class Runtime(FileEditRuntimeMixin):
             )
             observation._cause = event.id  # type: ignore[attr-defined]
             observation.tool_call_metadata = event.tool_call_metadata
+
+            # this might be unnecessary, since source should be set by the event stream when we're here
             source = event.source if event.source else EventSource.AGENT
             await self.event_stream.async_add_event(observation, source)  # type: ignore[arg-type]
 

+ 1 - 0
openhands/security/invariant/analyzer.py

@@ -147,6 +147,7 @@ class InvariantAnalyzer(SecurityAnalyzer):
         new_event = action_from_dict(
             {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
         )
+        # we should confirm only on agent actions
         event_source = event.source if event.source else EventSource.AGENT
         await call_sync_from_async(self.event_stream.add_event, new_event, event_source)
 

+ 1 - 1
openhands/server/session/agent_session.py

@@ -118,7 +118,7 @@ class AgentSession:
             agent_configs=agent_configs,
         )
         self.event_stream.add_event(
-            ChangeAgentStateAction(AgentState.INIT), EventSource.USER
+            ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
         )
         if self.controller:
             self.controller.agent_task = self.controller.start_step_loop()

+ 13 - 5
openhands/server/session/session.py

@@ -73,10 +73,11 @@ class Session:
 
     async def _initialize_agent(self, data: dict):
         self.agent_session.event_stream.add_event(
-            ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
+            ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
         )
         self.agent_session.event_stream.add_event(
-            AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
+            AgentStateChangedObservation('', AgentState.LOADING),
+            EventSource.ENVIRONMENT,
         )
         # Extract the agent-relevant arguments from the request
         args = {key: value for key, value in data.get('args', {}).items()}
@@ -138,12 +139,19 @@ class Session:
             return
         if event.source == EventSource.AGENT:
             await self.send(event_to_dict(event))
-        elif event.source == EventSource.USER and isinstance(
+        # NOTE: ipython observations are not sent here currently
+        elif event.source == EventSource.ENVIRONMENT and isinstance(
             event, CmdOutputObservation
         ):
-            await self.send(event_to_dict(event))
+            # feedback from the environment to agent actions is understood as agent events by the UI
+            event_dict = event_to_dict(event)
+            event_dict['source'] = EventSource.AGENT
+            await self.send(event_dict)
         elif isinstance(event, ErrorObservation):
-            await self.send(event_to_dict(event))
+            # send error events as agent events to the UI
+            event_dict = event_to_dict(event)
+            event_dict['source'] = EventSource.AGENT
+            await self.send(event_dict)
 
     async def dispatch(self, data: dict):
         action = data.get('action', '')

+ 3 - 1
tests/unit/test_agent_controller.py

@@ -207,7 +207,9 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream):
                 'Non fatal error here to trigger loop'
             )
             non_fatal_error_obs._cause = event.id
-            await event_stream.async_add_event(non_fatal_error_obs, EventSource.USER)
+            await event_stream.async_add_event(
+                non_fatal_error_obs, EventSource.ENVIRONMENT
+            )
 
     event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event)
     runtime.event_stream = event_stream

+ 37 - 37
tests/unit/test_is_stuck.py

@@ -80,7 +80,7 @@ class TestStuckDetector:
                 code=code_snippet,
             )
             ipython_observation._cause = ipython_action._id
-            event_stream.add_event(ipython_observation, EventSource.USER)
+            event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
 
     def _impl_unterminated_string_error_events(
         self, event_stream: EventStream, random_line: bool, incidents: int = 4
@@ -96,7 +96,7 @@ class TestStuckDetector:
                 code=code_snippet,
             )
             ipython_observation._cause = ipython_action._id
-            event_stream.add_event(ipython_observation, EventSource.USER)
+            event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT)
 
     def test_history_too_short(
         self, stuck_detector: StuckDetector, event_stream: EventStream
@@ -106,7 +106,7 @@ class TestStuckDetector:
         observation = NullObservation(content='')
         observation._cause = message_action.id
         event_stream.add_event(message_action, EventSource.USER)
-        event_stream.add_event(observation, EventSource.USER)
+        event_stream.add_event(observation, EventSource.ENVIRONMENT)
 
         cmd_action = CmdRunAction(command='ls')
         event_stream.add_event(cmd_action, EventSource.AGENT)
@@ -114,7 +114,7 @@ class TestStuckDetector:
             command_id=1, command='ls', content='file1.txt\nfile2.txt'
         )
         cmd_observation._cause = cmd_action._id
-        event_stream.add_event(cmd_observation, EventSource.USER)
+        event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT)
 
         # stuck_detector.state.history.set_event_stream(event_stream)
 
@@ -131,7 +131,7 @@ class TestStuckDetector:
 
         # 2 events
         event_stream.add_event(hello_action, EventSource.USER)
-        event_stream.add_event(hello_observation, EventSource.USER)
+        event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
 
         cmd_action_1 = CmdRunAction(command='ls')
         event_stream.add_event(cmd_action_1, EventSource.AGENT)
@@ -139,7 +139,7 @@ class TestStuckDetector:
             content='', command='ls', command_id=cmd_action_1._id
         )
         cmd_observation_1._cause = cmd_action_1._id
-        event_stream.add_event(cmd_observation_1, EventSource.USER)
+        event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
         # 4 events
 
         cmd_action_2 = CmdRunAction(command='ls')
@@ -148,13 +148,13 @@ class TestStuckDetector:
             content='', command='ls', command_id=cmd_action_2._id
         )
         cmd_observation_2._cause = cmd_action_2._id
-        event_stream.add_event(cmd_observation_2, EventSource.USER)
+        event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
         # 6 events
 
         # random user message just because we can
         message_null_observation = NullObservation(content='')
         event_stream.add_event(message_action, EventSource.USER)
-        event_stream.add_event(message_null_observation, EventSource.USER)
+        event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
         # 8 events
 
         assert stuck_detector.is_stuck() is False
@@ -166,7 +166,7 @@ class TestStuckDetector:
             content='', command='ls', command_id=cmd_action_3._id
         )
         cmd_observation_3._cause = cmd_action_3._id
-        event_stream.add_event(cmd_observation_3, EventSource.USER)
+        event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
         # 10 events
 
         assert len(collect_events(event_stream)) == 10
@@ -191,7 +191,7 @@ class TestStuckDetector:
             content='', command='ls', command_id=cmd_action_4._id
         )
         cmd_observation_4._cause = cmd_action_4._id
-        event_stream.add_event(cmd_observation_4, EventSource.USER)
+        event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT)
         # 12 events
 
         assert len(collect_events(event_stream)) == 12
@@ -223,14 +223,14 @@ class TestStuckDetector:
         hello_observation = NullObservation(content='')
         event_stream.add_event(hello_action, EventSource.USER)
         hello_observation._cause = hello_action._id
-        event_stream.add_event(hello_observation, EventSource.USER)
+        event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
         # 2 events
 
         cmd_action_1 = CmdRunAction(command='invalid_command')
         event_stream.add_event(cmd_action_1, EventSource.AGENT)
         error_observation_1 = ErrorObservation(content='Command not found')
         error_observation_1._cause = cmd_action_1._id
-        event_stream.add_event(error_observation_1, EventSource.USER)
+        event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT)
         # 4 events
 
         cmd_action_2 = CmdRunAction(command='invalid_command')
@@ -239,26 +239,26 @@ class TestStuckDetector:
             content='Command still not found or another error'
         )
         error_observation_2._cause = cmd_action_2._id
-        event_stream.add_event(error_observation_2, EventSource.USER)
+        event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT)
         # 6 events
 
         message_null_observation = NullObservation(content='')
         event_stream.add_event(message_action, EventSource.USER)
-        event_stream.add_event(message_null_observation, EventSource.USER)
+        event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
         # 8 events
 
         cmd_action_3 = CmdRunAction(command='invalid_command')
         event_stream.add_event(cmd_action_3, EventSource.AGENT)
         error_observation_3 = ErrorObservation(content='Different error')
         error_observation_3._cause = cmd_action_3._id
-        event_stream.add_event(error_observation_3, EventSource.USER)
+        event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT)
         # 10 events
 
         cmd_action_4 = CmdRunAction(command='invalid_command')
         event_stream.add_event(cmd_action_4, EventSource.AGENT)
         error_observation_4 = ErrorObservation(content='Command not found')
         error_observation_4._cause = cmd_action_4._id
-        event_stream.add_event(error_observation_4, EventSource.USER)
+        event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT)
         # 12 events
 
         with patch('logging.Logger.warning') as mock_warning:
@@ -366,7 +366,7 @@ class TestStuckDetector:
             code='print("hello',
         )
         ipython_observation_1._cause = ipython_action_1._id
-        event_stream.add_event(ipython_observation_1, EventSource.USER)
+        event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT)
 
         ipython_action_2 = IPythonRunCellAction(code='print("hello')
         event_stream.add_event(ipython_action_2, EventSource.AGENT)
@@ -375,7 +375,7 @@ class TestStuckDetector:
             code='print("hello',
         )
         ipython_observation_2._cause = ipython_action_2._id
-        event_stream.add_event(ipython_observation_2, EventSource.USER)
+        event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT)
 
         ipython_action_3 = IPythonRunCellAction(code='print("hello')
         event_stream.add_event(ipython_action_3, EventSource.AGENT)
@@ -384,7 +384,7 @@ class TestStuckDetector:
             code='print("hello',
         )
         ipython_observation_3._cause = ipython_action_3._id
-        event_stream.add_event(ipython_observation_3, EventSource.USER)
+        event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT)
 
         ipython_action_4 = IPythonRunCellAction(code='print("hello')
         event_stream.add_event(ipython_action_4, EventSource.AGENT)
@@ -393,7 +393,7 @@ class TestStuckDetector:
             code='print("hello',
         )
         ipython_observation_4._cause = ipython_action_4._id
-        event_stream.add_event(ipython_observation_4, EventSource.USER)
+        event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT)
 
         with patch('logging.Logger.warning') as mock_warning:
             assert stuck_detector.is_stuck() is False
@@ -406,7 +406,7 @@ class TestStuckDetector:
         message_action._source = EventSource.USER
         event_stream.add_event(message_action, EventSource.USER)
         message_observation = NullObservation(content='')
-        event_stream.add_event(message_observation, EventSource.USER)
+        event_stream.add_event(message_observation, EventSource.ENVIRONMENT)
 
         cmd_action_1 = CmdRunAction(command='ls')
         event_stream.add_event(cmd_action_1, EventSource.AGENT)
@@ -414,7 +414,7 @@ class TestStuckDetector:
             command_id=1, command='ls', content='file1.txt\nfile2.txt'
         )
         cmd_observation_1._cause = cmd_action_1._id
-        event_stream.add_event(cmd_observation_1, EventSource.USER)
+        event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
 
         read_action_1 = FileReadAction(path='file1.txt')
         event_stream.add_event(read_action_1, EventSource.AGENT)
@@ -422,7 +422,7 @@ class TestStuckDetector:
             content='File content', path='file1.txt'
         )
         read_observation_1._cause = read_action_1._id
-        event_stream.add_event(read_observation_1, EventSource.USER)
+        event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
 
         cmd_action_2 = CmdRunAction(command='ls')
         event_stream.add_event(cmd_action_2, EventSource.AGENT)
@@ -430,7 +430,7 @@ class TestStuckDetector:
             command_id=2, command='ls', content='file1.txt\nfile2.txt'
         )
         cmd_observation_2._cause = cmd_action_2._id
-        event_stream.add_event(cmd_observation_2, EventSource.USER)
+        event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
 
         read_action_2 = FileReadAction(path='file1.txt')
         event_stream.add_event(read_action_2, EventSource.AGENT)
@@ -438,12 +438,12 @@ class TestStuckDetector:
             content='File content', path='file1.txt'
         )
         read_observation_2._cause = read_action_2._id
-        event_stream.add_event(read_observation_2, EventSource.USER)
+        event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
 
         # one more message to break the pattern
         message_null_observation = NullObservation(content='')
         event_stream.add_event(message_action, EventSource.USER)
-        event_stream.add_event(message_null_observation, EventSource.USER)
+        event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
 
         cmd_action_3 = CmdRunAction(command='ls')
         event_stream.add_event(cmd_action_3, EventSource.AGENT)
@@ -451,7 +451,7 @@ class TestStuckDetector:
             command_id=3, command='ls', content='file1.txt\nfile2.txt'
         )
         cmd_observation_3._cause = cmd_action_3._id
-        event_stream.add_event(cmd_observation_3, EventSource.USER)
+        event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
 
         read_action_3 = FileReadAction(path='file1.txt')
         event_stream.add_event(read_action_3, EventSource.AGENT)
@@ -459,7 +459,7 @@ class TestStuckDetector:
             content='File content', path='file1.txt'
         )
         read_observation_3._cause = read_action_3._id
-        event_stream.add_event(read_observation_3, EventSource.USER)
+        event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
 
         with patch('logging.Logger.warning') as mock_warning:
             assert stuck_detector.is_stuck() is True
@@ -475,7 +475,7 @@ class TestStuckDetector:
         event_stream.add_event(hello_action, EventSource.USER)
         hello_observation = NullObservation(content='')
         hello_observation._cause = hello_action._id
-        event_stream.add_event(hello_observation, EventSource.USER)
+        event_stream.add_event(hello_observation, EventSource.ENVIRONMENT)
 
         cmd_action_1 = CmdRunAction(command='ls')
         event_stream.add_event(cmd_action_1, EventSource.AGENT)
@@ -483,7 +483,7 @@ class TestStuckDetector:
             command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
         )
         cmd_observation_1._cause = cmd_action_1._id
-        event_stream.add_event(cmd_observation_1, EventSource.USER)
+        event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
 
         read_action_1 = FileReadAction(path='file1.txt')
         event_stream.add_event(read_action_1, EventSource.AGENT)
@@ -491,7 +491,7 @@ class TestStuckDetector:
             content='File content', path='file1.txt'
         )
         read_observation_1._cause = read_action_1._id
-        event_stream.add_event(read_observation_1, EventSource.USER)
+        event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT)
 
         cmd_action_2 = CmdRunAction(command='pwd')
         event_stream.add_event(cmd_action_2, EventSource.AGENT)
@@ -499,7 +499,7 @@ class TestStuckDetector:
             command_id=2, command='pwd', content='/home/user'
         )
         cmd_observation_2._cause = cmd_action_2._id
-        event_stream.add_event(cmd_observation_2, EventSource.USER)
+        event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
 
         read_action_2 = FileReadAction(path='file2.txt')
         event_stream.add_event(read_action_2, EventSource.AGENT)
@@ -507,11 +507,11 @@ class TestStuckDetector:
             content='Another file content', path='file2.txt'
         )
         read_observation_2._cause = read_action_2._id
-        event_stream.add_event(read_observation_2, EventSource.USER)
+        event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT)
 
         message_null_observation = NullObservation(content='')
         event_stream.add_event(message_action, EventSource.USER)
-        event_stream.add_event(message_null_observation, EventSource.USER)
+        event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT)
 
         cmd_action_3 = CmdRunAction(command='pwd')
         event_stream.add_event(cmd_action_3, EventSource.AGENT)
@@ -519,7 +519,7 @@ class TestStuckDetector:
             command_id=cmd_action_3.id, command='pwd', content='/home/user'
         )
         cmd_observation_3._cause = cmd_action_3._id
-        event_stream.add_event(cmd_observation_3, EventSource.USER)
+        event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT)
 
         read_action_3 = FileReadAction(path='file2.txt')
         event_stream.add_event(read_action_3, EventSource.AGENT)
@@ -527,7 +527,7 @@ class TestStuckDetector:
             content='Another file content', path='file2.txt'
         )
         read_observation_3._cause = read_action_3._id
-        event_stream.add_event(read_observation_3, EventSource.USER)
+        event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT)
 
         assert stuck_detector.is_stuck() is False
 
@@ -572,7 +572,7 @@ class TestStuckDetector:
             exit_code=0,
         )
         cmd_output_observation._cause = cmd_kill_action._id
-        event_stream.add_event(cmd_output_observation, EventSource.USER)
+        event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT)
 
         message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
         event_stream.add_event(message_action_7, EventSource.AGENT)

+ 1 - 1
tests/unit/test_memory.py

@@ -88,7 +88,7 @@ def _create_observation_event(observation: str) -> Event:
     event = Event()
     event._id = -1
     event._timestamp = datetime.now(timezone.utc).isoformat()
-    event._source = EventSource.USER
+    event._source = EventSource.ENVIRONMENT
     event.observation = observation
     return event
 

+ 2 - 2
tests/unit/test_prompt_caching.py

@@ -155,7 +155,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
         command='ls -l',
         exit_code=0,
     )
-    mock_event_stream.add_event(cmd_observation_1, EventSource.USER)
+    mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT)
 
     message_action_2 = MessageAction("Now, let's create a new directory.")
     mock_event_stream.add_event(message_action_2, EventSource.AGENT)
@@ -169,7 +169,7 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream):
         command='mkdir new_directory',
         exit_code=0,
     )
-    mock_event_stream.add_event(cmd_observation_2, EventSource.USER)
+    mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT)
 
     codeact_agent.reset()
     messages = codeact_agent._get_messages(