state.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import base64
  2. import pickle
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from typing import Any
  6. from openhands.controller.state.task import RootTask
  7. from openhands.core.logger import openhands_logger as logger
  8. from openhands.core.metrics import Metrics
  9. from openhands.core.schema import AgentState
  10. from openhands.events.action import (
  11. MessageAction,
  12. )
  13. from openhands.events.action.agent import AgentFinishAction
  14. from openhands.memory.history import ShortTermHistory
  15. from openhands.storage.files import FileStore
  16. class TrafficControlState(str, Enum):
  17. # default state, no rate limiting
  18. NORMAL = 'normal'
  19. # task paused due to traffic control
  20. THROTTLING = 'throttling'
  21. # traffic control is temporarily paused
  22. PAUSED = 'paused'
  23. RESUMABLE_STATES = [
  24. AgentState.RUNNING,
  25. AgentState.PAUSED,
  26. AgentState.AWAITING_USER_INPUT,
  27. AgentState.FINISHED,
  28. ]
  29. @dataclass
  30. class State:
  31. """
  32. Represents the running state of an agent in the OpenHands system, saving data of its operation and memory.
  33. - Multi-agent/delegate state:
  34. - store the task (conversation between the agent and the user)
  35. - the subtask (conversation between an agent and the user or another agent)
  36. - global and local iterations
  37. - delegate levels for multi-agent interactions
  38. - almost stuck state
  39. - Running state of an agent:
  40. - current agent state (e.g., LOADING, RUNNING, PAUSED)
  41. - traffic control state for rate limiting
  42. - confirmation mode
  43. - the last error encountered
  44. - Data for saving and restoring the agent:
  45. - save to and restore from a session
  46. - serialize with pickle and base64
  47. - Save / restore data about message history
  48. - start and end IDs for events in agent's history
  49. - summaries and delegate summaries
  50. - Metrics:
  51. - global metrics for the current task
  52. - local metrics for the current subtask
  53. - Extra data:
  54. - additional task-specific data
  55. """
  56. root_task: RootTask = field(default_factory=RootTask)
  57. # global iteration for the current task
  58. iteration: int = 0
  59. # local iteration for the current subtask
  60. local_iteration: int = 0
  61. # max number of iterations for the current task
  62. max_iterations: int = 100
  63. confirmation_mode: bool = False
  64. history: ShortTermHistory = field(default_factory=ShortTermHistory)
  65. inputs: dict = field(default_factory=dict)
  66. outputs: dict = field(default_factory=dict)
  67. last_error: str | None = None
  68. agent_state: AgentState = AgentState.LOADING
  69. resume_state: AgentState | None = None
  70. traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
  71. # global metrics for the current task
  72. metrics: Metrics = field(default_factory=Metrics)
  73. # local metrics for the current subtask
  74. local_metrics: Metrics = field(default_factory=Metrics)
  75. # root agent has level 0, and every delegate increases the level by one
  76. delegate_level: int = 0
  77. # start_id and end_id track the range of events in history
  78. start_id: int = -1
  79. end_id: int = -1
  80. almost_stuck: int = 0
  81. # NOTE: This will never be used by the controller, but it can be used by different
  82. # evaluation tasks to store extra data needed to track the progress/state of the task.
  83. extra_data: dict[str, Any] = field(default_factory=dict)
  84. def save_to_session(self, sid: str, file_store: FileStore):
  85. pickled = pickle.dumps(self)
  86. logger.debug(f'Saving state to session {sid}:{self.agent_state}')
  87. encoded = base64.b64encode(pickled).decode('utf-8')
  88. try:
  89. file_store.write(f'sessions/{sid}/agent_state.pkl', encoded)
  90. except Exception as e:
  91. logger.error(f'Failed to save state to session: {e}')
  92. raise e
  93. @staticmethod
  94. def restore_from_session(sid: str, file_store: FileStore) -> 'State':
  95. try:
  96. encoded = file_store.read(f'sessions/{sid}/agent_state.pkl')
  97. pickled = base64.b64decode(encoded)
  98. state = pickle.loads(pickled)
  99. except Exception as e:
  100. logger.error(f'Failed to restore state from session: {e}')
  101. raise e
  102. # update state
  103. if state.agent_state in RESUMABLE_STATES:
  104. state.resume_state = state.agent_state
  105. else:
  106. state.resume_state = None
  107. # don't carry last_error anymore after restore
  108. state.last_error = None
  109. # first state after restore
  110. state.agent_state = AgentState.LOADING
  111. return state
  112. def __getstate__(self):
  113. state = self.__dict__.copy()
  114. # save the relevant data from recent history
  115. # so that we can restore it when the state is restored
  116. if 'history' in state:
  117. state['start_id'] = state['history'].start_id
  118. state['end_id'] = state['history'].end_id
  119. # don't save history object itself
  120. state.pop('history', None)
  121. return state
  122. def __setstate__(self, state):
  123. self.__dict__.update(state)
  124. # recreate the history object
  125. if not hasattr(self, 'history'):
  126. self.history = ShortTermHistory()
  127. # restore the relevant data in history from the state
  128. self.history.start_id = self.start_id
  129. self.history.end_id = self.end_id
  130. # remove the restored data from the state if any
  131. def get_current_user_intent(self):
  132. """Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
  133. last_user_message = None
  134. last_user_message_image_urls: list[str] | None = []
  135. for event in self.history.get_events(reverse=True):
  136. if isinstance(event, MessageAction) and event.source == 'user':
  137. last_user_message = event.content
  138. last_user_message_image_urls = event.images_urls
  139. elif isinstance(event, AgentFinishAction):
  140. if last_user_message is not None:
  141. return last_user_message
  142. return last_user_message, last_user_message_image_urls