state.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import base64
  2. import pickle
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from typing import Any
  6. from openhands.controller.state.task import RootTask
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.schema import AgentState
  9. from openhands.events.action import (
  10. MessageAction,
  11. )
  12. from openhands.events.action.agent import AgentFinishAction
  13. from openhands.events.event import Event, EventSource
  14. from openhands.llm.metrics import Metrics
  15. from openhands.storage.files import FileStore
  16. class TrafficControlState(str, Enum):
  17. # default state, no rate limiting
  18. NORMAL = 'normal'
  19. # task paused due to traffic control
  20. THROTTLING = 'throttling'
  21. # traffic control is temporarily paused
  22. PAUSED = 'paused'
  23. RESUMABLE_STATES = [
  24. AgentState.RUNNING,
  25. AgentState.PAUSED,
  26. AgentState.AWAITING_USER_INPUT,
  27. AgentState.FINISHED,
  28. ]
  29. @dataclass
  30. class State:
  31. """
  32. Represents the running state of an agent in the OpenHands system, saving data of its operation and memory.
  33. - Multi-agent/delegate state:
  34. - store the task (conversation between the agent and the user)
  35. - the subtask (conversation between an agent and the user or another agent)
  36. - global and local iterations
  37. - delegate levels for multi-agent interactions
  38. - almost stuck state
  39. - Running state of an agent:
  40. - current agent state (e.g., LOADING, RUNNING, PAUSED)
  41. - traffic control state for rate limiting
  42. - confirmation mode
  43. - the last error encountered
  44. - Data for saving and restoring the agent:
  45. - save to and restore from a session
  46. - serialize with pickle and base64
  47. - Save / restore data about message history
  48. - start and end IDs for events in agent's history
  49. - summaries and delegate summaries
  50. - Metrics:
  51. - global metrics for the current task
  52. - local metrics for the current subtask
  53. - Extra data:
  54. - additional task-specific data
  55. """
  56. root_task: RootTask = field(default_factory=RootTask)
  57. # global iteration for the current task
  58. iteration: int = 0
  59. # local iteration for the current subtask
  60. local_iteration: int = 0
  61. # max number of iterations for the current task
  62. max_iterations: int = 100
  63. confirmation_mode: bool = False
  64. history: list[Event] = field(default_factory=list)
  65. inputs: dict = field(default_factory=dict)
  66. outputs: dict = field(default_factory=dict)
  67. agent_state: AgentState = AgentState.LOADING
  68. resume_state: AgentState | None = None
  69. traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
  70. # global metrics for the current task
  71. metrics: Metrics = field(default_factory=Metrics)
  72. # local metrics for the current subtask
  73. local_metrics: Metrics = field(default_factory=Metrics)
  74. # root agent has level 0, and every delegate increases the level by one
  75. delegate_level: int = 0
  76. # start_id and end_id track the range of events in history
  77. start_id: int = -1
  78. end_id: int = -1
  79. # truncation_id tracks where to load history after context window truncation
  80. truncation_id: int = -1
  81. delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
  82. # NOTE: This will never be used by the controller, but it can be used by different
  83. # evaluation tasks to store extra data needed to track the progress/state of the task.
  84. extra_data: dict[str, Any] = field(default_factory=dict)
  85. last_error: str = ''
  86. def save_to_session(self, sid: str, file_store: FileStore):
  87. pickled = pickle.dumps(self)
  88. logger.debug(f'Saving state to session {sid}:{self.agent_state}')
  89. encoded = base64.b64encode(pickled).decode('utf-8')
  90. try:
  91. file_store.write(f'sessions/{sid}/agent_state.pkl', encoded)
  92. except Exception as e:
  93. logger.error(f'Failed to save state to session: {e}')
  94. raise e
  95. @staticmethod
  96. def restore_from_session(sid: str, file_store: FileStore) -> 'State':
  97. try:
  98. encoded = file_store.read(f'sessions/{sid}/agent_state.pkl')
  99. pickled = base64.b64decode(encoded)
  100. state = pickle.loads(pickled)
  101. except Exception as e:
  102. logger.warning(f'Could not restore state from session: {e}')
  103. raise e
  104. # update state
  105. if state.agent_state in RESUMABLE_STATES:
  106. state.resume_state = state.agent_state
  107. else:
  108. state.resume_state = None
  109. # first state after restore
  110. state.agent_state = AgentState.LOADING
  111. return state
  112. def __getstate__(self):
  113. # don't pickle history, it will be restored from the event stream
  114. state = self.__dict__.copy()
  115. state['history'] = []
  116. return state
  117. def __setstate__(self, state):
  118. self.__dict__.update(state)
  119. # make sure we always have the attribute history
  120. if not hasattr(self, 'history'):
  121. self.history = []
  122. def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
  123. """Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
  124. last_user_message = None
  125. last_user_message_image_urls: list[str] | None = []
  126. for event in reversed(self.history):
  127. if isinstance(event, MessageAction) and event.source == 'user':
  128. last_user_message = event.content
  129. last_user_message_image_urls = event.image_urls
  130. elif isinstance(event, AgentFinishAction):
  131. if last_user_message is not None:
  132. return last_user_message, None
  133. return last_user_message, last_user_message_image_urls
  134. def get_last_agent_message(self) -> MessageAction | None:
  135. for event in reversed(self.history):
  136. if isinstance(event, MessageAction) and event.source == EventSource.AGENT:
  137. return event
  138. return None
  139. def get_last_user_message(self) -> MessageAction | None:
  140. for event in reversed(self.history):
  141. if isinstance(event, MessageAction) and event.source == EventSource.USER:
  142. return event
  143. return None