瀏覽代碼

Add event synchronously (#2700)

* add to event stream sync

* remove async from tests
Engel Nyst 1 年之前
父節點
當前提交
0b8d357bef

+ 4 - 6
opendevin/controller/agent_controller.py

@@ -122,7 +122,7 @@ class AgentController:
         self.state.last_error = message
         if exception:
             self.state.last_error += f': {exception}'
-        await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
+        self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
 
     async def add_history(self, action: Action, observation: Observation):
         if isinstance(action, NullAction) and isinstance(observation, NullObservation):
@@ -211,7 +211,7 @@ class AgentController:
         if new_state == AgentState.STOPPED or new_state == AgentState.ERROR:
             self.reset_task()
 
-        await self.event_stream.add_event(
+        self.event_stream.add_event(
             AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
         )
 
@@ -221,8 +221,6 @@ class AgentController:
 
     def get_agent_state(self):
         """Returns the current state of the agent task."""
-        if self.delegate is not None:
-            return self.delegate.get_agent_state()
         return self.state.agent_state
 
     async def start_delegate(self, action: AgentDelegateAction):
@@ -301,7 +299,7 @@ class AgentController:
                 # clean up delegate status
                 self.delegate = None
                 self.delegateAction = None
-                await self.event_stream.add_event(obs, EventSource.AGENT)
+                self.event_stream.add_event(obs, EventSource.AGENT)
             return
 
         logger.info(
@@ -358,7 +356,7 @@ class AgentController:
             await self.add_history(action, NullObservation(''))
 
         if not isinstance(action, NullAction):
-            await self.event_stream.add_event(action, EventSource.AGENT)
+            self.event_stream.add_event(action, EventSource.AGENT)
 
         await self.update_state_after_step()
 

+ 2 - 0
opendevin/core/logger.py

@@ -114,6 +114,8 @@ def get_console_handler():
     """
     console_handler = logging.StreamHandler()
     console_handler.setLevel(logging.INFO)
+    if config.debug:
+        console_handler.setLevel(logging.DEBUG)
     console_handler.setFormatter(console_formatter)
     return console_handler
 

+ 4 - 4
opendevin/core/main.py

@@ -100,7 +100,7 @@ async def run_agent_controller(
     # start event is a MessageAction with the task, either resumed or new
     if config.enable_cli_session and initial_state is not None:
         # we're resuming the previous session
-        await event_stream.add_event(
+        event_stream.add_event(
             MessageAction(
                 content="Let's get back on track. If you experienced errors before, do NOT resume your task. Ask me about it."
             ),
@@ -108,7 +108,7 @@ async def run_agent_controller(
         )
     elif initial_state is None:
         # init with the provided task
-        await event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
+        event_stream.add_event(MessageAction(content=task_str), EventSource.USER)
 
     async def on_event(event: Event):
         if isinstance(event, AgentStateChangedObservation):
@@ -120,10 +120,10 @@ async def run_agent_controller(
                 else:
                     message = fake_user_response_fn(controller.get_state())
                 action = MessageAction(content=message)
-                await event_stream.add_event(action, EventSource.USER)
+                event_stream.add_event(action, EventSource.USER)
 
     event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
-    while controller.get_agent_state() not in [
+    while controller.state.agent_state not in [
         AgentState.FINISHED,
         AgentState.REJECTED,
         AgentState.ERROR,

+ 8 - 9
opendevin/events/stream.py

@@ -1,5 +1,6 @@
 import asyncio
 import json
+import threading
 from datetime import datetime
 from enum import Enum
 from typing import Callable, Iterable
@@ -25,7 +26,7 @@ class EventStream:
     # when there are agent delegates
     _subscribers: dict[str, list[Callable]]
     _cur_id: int
-    _lock: asyncio.Lock
+    _lock: threading.Lock
     _file_store: FileStore
 
     def __init__(self, sid: str):
@@ -33,7 +34,7 @@ class EventStream:
         self._file_store = get_file_store()
         self._subscribers = {}
         self._cur_id = 0
-        self._lock = asyncio.Lock()
+        self._lock = threading.Lock()
         self._reinitialize_from_file_store()
 
     def _reinitialize_from_file_store(self):
@@ -93,12 +94,11 @@ class EventStream:
             if len(self._subscribers[id]) == 0:
                 del self._subscribers[id]
 
-    # TODO: make this not async
-    async def add_event(self, event: Event, source: EventSource):
-        logger.debug(f'Adding event {event} from {source}')
-        async with self._lock:
-            event._id = self._cur_id  # type: ignore[attr-defined]
+    def add_event(self, event: Event, source: EventSource):
+        with self._lock:
+            event._id = self._cur_id  # type: ignore [attr-defined]
             self._cur_id += 1
+        logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
         event._timestamp = datetime.now()  # type: ignore[attr-defined]
         event._source = source  # type: ignore[attr-defined]
         data = event_to_dict(event)
@@ -108,5 +108,4 @@ class EventStream:
             )
         for stack in self._subscribers.values():
             callback = stack[-1]
-            logger.debug(f'Notifying subscriber {callback} of event {event}')
-            await callback(event)
+            asyncio.create_task(callback(event))

+ 2 - 2
opendevin/runtime/runtime.py

@@ -114,7 +114,7 @@ class Runtime:
             observation = await self.run_action(event)
             observation._cause = event.id  # type: ignore[attr-defined]
             source = event.source if event.source else EventSource.AGENT
-            await self.event_stream.add_event(observation, source)
+            self.event_stream.add_event(observation, source)
 
     async def run_action(self, action: Action) -> Observation:
         """
@@ -149,7 +149,7 @@ class Runtime:
         for _id, cmd in self.sandbox.background_commands.items():
             output = cmd.read_logs()
             if output:
-                await self.event_stream.add_event(
+                self.event_stream.add_event(
                     CmdOutputObservation(
                         content=output, command_id=_id, command=cmd.command
                     ),

+ 4 - 4
opendevin/server/session/session.py

@@ -61,10 +61,10 @@ class Session:
             logger.exception('Error in loop_recv: %s', e)
 
     async def _initialize_agent(self, data: dict):
-        await self.agent_session.event_stream.add_event(
+        self.agent_session.event_stream.add_event(
             ChangeAgentStateAction(AgentState.LOADING), EventSource.USER
         )
-        await self.agent_session.event_stream.add_event(
+        self.agent_session.event_stream.add_event(
             AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
         )
         try:
@@ -75,7 +75,7 @@ class Session:
                 f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..'
             )
             return
-        await self.agent_session.event_stream.add_event(
+        self.agent_session.event_stream.add_event(
             ChangeAgentStateAction(AgentState.INIT), EventSource.USER
         )
 
@@ -102,7 +102,7 @@ class Session:
             await self._initialize_agent(data)
             return
         event = event_from_dict(data.copy())
-        await self.agent_session.event_stream.add_event(event, EventSource.USER)
+        self.agent_session.event_stream.add_event(event, EventSource.USER)
 
     async def send(self, data: dict[str, object]) -> bool:
         try:

+ 7 - 12
tests/unit/test_event_stream.py

@@ -1,7 +1,5 @@
 import json
 
-import pytest
-
 from opendevin.events import EventSource, EventStream
 from opendevin.events.action import NullAction
 from opendevin.events.observation import NullObservation
@@ -11,17 +9,15 @@ def collect_events(stream):
     return [event for event in stream.get_events()]
 
 
-@pytest.mark.asyncio
-async def test_basic_flow():
+def test_basic_flow():
     stream = EventStream('abc')
-    await stream.add_event(NullAction(), EventSource.AGENT)
+    stream.add_event(NullAction(), EventSource.AGENT)
     assert len(collect_events(stream)) == 1
 
 
-@pytest.mark.asyncio
-async def test_stream_storage():
+def test_stream_storage():
     stream = EventStream('def')
-    await stream.add_event(NullObservation(''), EventSource.AGENT)
+    stream.add_event(NullObservation(''), EventSource.AGENT)
     assert len(collect_events(stream)) == 1
     content = stream._file_store.read('sessions/def/events/0.json')
     assert content is not None
@@ -38,11 +34,10 @@ async def test_stream_storage():
     }
 
 
-@pytest.mark.asyncio
-async def test_rehydration():
+def test_rehydration():
     stream1 = EventStream('es1')
-    await stream1.add_event(NullObservation('obs1'), EventSource.AGENT)
-    await stream1.add_event(NullObservation('obs2'), EventSource.AGENT)
+    stream1.add_event(NullObservation('obs1'), EventSource.AGENT)
+    stream1.add_event(NullObservation('obs2'), EventSource.AGENT)
     assert len(collect_events(stream1)) == 2
 
     stream2 = EventStream('es2')