Просмотр исходного кода

fix session state after resuming (#1999)

* fix state resuming

* fix session reconnection

* fix lint
Robert Brennan 1 год назад
Родитель
Сommit
ea9c785075

+ 17 - 5
frontend/src/services/session.ts

@@ -7,6 +7,8 @@ import { getSettings } from "./settings";
 class Session {
   private static _socket: WebSocket | null = null;
 
+  private static _latest_event_id: number = -1;
+
   // callbacks contain a list of callable functions
   // event: function, like:
   // open: [function1, function2]
@@ -25,11 +27,10 @@ class Session {
   private static _disconnecting = false;
 
   public static restoreOrStartNewSession() {
-    const token = getToken();
     if (Session.isConnected()) {
       Session.disconnect();
     }
-    Session._connect(token);
+    Session._connect();
   }
 
   public static startNewSession() {
@@ -44,13 +45,20 @@ class Session {
     Session.send(eventString);
   };
 
-  private static _connect(token: string = ""): void {
+  private static _connect(): void {
     if (Session.isConnected()) return;
     Session._connecting = true;
 
     const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
-    const WS_URL = `${protocol}//${window.location.host}/ws?token=${token}`;
-    Session._socket = new WebSocket(WS_URL);
+    let wsURL = `${protocol}//${window.location.host}/ws`;
+    const token = getToken();
+    if (token) {
+      wsURL += `?token=${token}`;
+      if (Session._latest_event_id !== -1) {
+        wsURL += `&latest_event_id=${Session._latest_event_id}`;
+      }
+    }
+    Session._socket = new WebSocket(wsURL);
     Session._setupSocket();
   }
 
@@ -77,10 +85,14 @@ class Session {
         return;
       }
       if (data.error && data.error_code === 401) {
+        Session._latest_event_id = -1;
         clearToken();
       } else if (data.token) {
         setToken(data.token);
       } else {
+        if (data.id !== undefined) {
+          Session._latest_event_id = data.id;
+        }
         handleAssistantMessage(data);
       }
     };

+ 5 - 0
opendevin/controller/agent_controller.py

@@ -158,6 +158,7 @@ class AgentController:
         logger.info(
             f'Setting agent({type(self.agent).__name__}) state from {self.state.agent_state} to {new_state}'
         )
+
         if new_state == self.state.agent_state:
             return
 
@@ -169,6 +170,10 @@ class AgentController:
             AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT
         )
 
+        if new_state == AgentState.INIT and self.state.resume_state:
+            await self.set_agent_state_to(self.state.resume_state)
+            self.state.resume_state = None
+
     def get_agent_state(self):
         """Returns the current state of the agent task."""
         return self.state.agent_state

+ 13 - 0
opendevin/controller/state/state.py

@@ -16,6 +16,13 @@ from opendevin.events.observation import (
 )
 from opendevin.storage import get_file_store
 
+RESUMABLE_STATES = [
+    AgentState.RUNNING,
+    AgentState.PAUSED,
+    AgentState.AWAITING_USER_INPUT,
+    AgentState.FINISHED,
+]
+
 
 @dataclass
 class State:
@@ -31,6 +38,7 @@ class State:
     outputs: dict = field(default_factory=dict)
     error: str | None = None
     agent_state: AgentState = AgentState.LOADING
+    resume_state: AgentState | None = None
     metrics: Metrics = Metrics()
 
     def save_to_session(self, sid: str):
@@ -53,6 +61,11 @@ class State:
         except Exception as e:
             logger.error(f'Failed to restore state from session: {e}')
             raise e
+        if state.agent_state in RESUMABLE_STATES:
+            state.resume_state = state.agent_state
+        else:
+            state.resume_state = None
+        state.agent_state = AgentState.LOADING
         return state
 
     def get_current_user_intent(self):

+ 10 - 9
opendevin/events/stream.py

@@ -52,15 +52,16 @@ class EventStream:
         return int(filename.split('/')[-1].split('.')[0])
 
     def get_events(self, start_id=0, end_id=None) -> Iterable[Event]:
-        try:
-            events = self._file_store.list(f'sessions/{self.sid}/events')
-        except FileNotFoundError:
-            return
-        for event_str in events:
-            id = self._get_id_from_filename(event_str)
-            if start_id <= id and (end_id is None or id <= end_id):
-                event = self.get_event(id)
-                yield event
+        event_id = start_id
+        while True:
+            if end_id is not None and event_id > end_id:
+                break
+            try:
+                event = self.get_event(event_id)
+            except FileNotFoundError:
+                break
+            yield event
+            event_id += 1
 
     def get_event(self, id: int) -> Event:
         filename = self._get_filename_for_id(id)

+ 4 - 4
opendevin/server/listen.py

@@ -154,11 +154,11 @@ async def websocket_endpoint(websocket: WebSocket):
     session = session_manager.add_or_restart_session(sid, websocket)
     await websocket.send_json({'token': token, 'status': 'ok'})
 
-    last_event_id = -1
-    if websocket.query_params.get('last_event_id'):
-        last_event_id = int(websocket.query_params.get('last_event_id'))
+    latest_event_id = -1
+    if websocket.query_params.get('latest_event_id'):
+        latest_event_id = int(websocket.query_params.get('latest_event_id'))
     for event in session.agent_session.event_stream.get_events(
-        start_id=last_event_id + 1
+        start_id=latest_event_id + 1
     ):
         if isinstance(event, NullAction) or isinstance(event, NullObservation):
             continue