state.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import base64
  2. import pickle
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from opendevin.controller.state.task import RootTask
  6. from opendevin.core.logger import opendevin_logger as logger
  7. from opendevin.core.metrics import Metrics
  8. from opendevin.core.schema import AgentState
  9. from opendevin.events.action import (
  10. Action,
  11. MessageAction,
  12. )
  13. from opendevin.events.observation import (
  14. CmdOutputObservation,
  15. Observation,
  16. )
  17. from opendevin.storage import get_file_store
  18. class TRAFFIC_CONTROL_STATE(str, Enum):
  19. # default state, no rate limiting
  20. NORMAL = 'normal'
  21. # task paused due to traffic control
  22. THROTTLING = 'throttling'
  23. # traffic control is temporarily paused
  24. PAUSED = 'paused'
  25. RESUMABLE_STATES = [
  26. AgentState.RUNNING,
  27. AgentState.PAUSED,
  28. AgentState.AWAITING_USER_INPUT,
  29. AgentState.FINISHED,
  30. ]
  31. @dataclass
  32. class State:
  33. root_task: RootTask = field(default_factory=RootTask)
  34. iteration: int = 0
  35. max_iterations: int = 100
  36. background_commands_obs: list[CmdOutputObservation] = field(default_factory=list)
  37. history: list[tuple[Action, Observation]] = field(default_factory=list)
  38. updated_info: list[tuple[Action, Observation]] = field(default_factory=list)
  39. inputs: dict = field(default_factory=dict)
  40. outputs: dict = field(default_factory=dict)
  41. last_error: str | None = None
  42. agent_state: AgentState = AgentState.LOADING
  43. resume_state: AgentState | None = None
  44. traffic_control_state: TRAFFIC_CONTROL_STATE = TRAFFIC_CONTROL_STATE.NORMAL
  45. metrics: Metrics = Metrics()
  46. # root agent has level 0, and every delegate increases the level by one
  47. delegate_level: int = 0
  48. def save_to_session(self, sid: str):
  49. fs = get_file_store()
  50. pickled = pickle.dumps(self)
  51. encoded = base64.b64encode(pickled).decode('utf-8')
  52. try:
  53. fs.write(f'sessions/{sid}/agent_state.pkl', encoded)
  54. except Exception as e:
  55. logger.error(f'Failed to save state to session: {e}')
  56. raise e
  57. @staticmethod
  58. def restore_from_session(sid: str) -> 'State':
  59. fs = get_file_store()
  60. try:
  61. encoded = fs.read(f'sessions/{sid}/agent_state.pkl')
  62. pickled = base64.b64decode(encoded)
  63. state = pickle.loads(pickled)
  64. except Exception as e:
  65. logger.error(f'Failed to restore state from session: {e}')
  66. raise e
  67. if state.agent_state in RESUMABLE_STATES:
  68. state.resume_state = state.agent_state
  69. else:
  70. state.resume_state = None
  71. state.agent_state = AgentState.LOADING
  72. return state
  73. def get_current_user_intent(self):
  74. # TODO: this is used to understand the user's main goal, but it's possible
  75. # the latest message is an interruption. We should look for a space where
  76. # the agent goes to FINISHED, and then look for the next user message.
  77. for action, obs in reversed(self.history):
  78. if isinstance(action, MessageAction) and action.source == 'user':
  79. return action.content