agent_controller.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import asyncio
  2. from typing import List, Callable
  3. import traceback
  4. from opendevin.plan import Plan
  5. from opendevin.state import State
  6. from opendevin.agent import Agent
  7. from opendevin.action import (
  8. Action,
  9. NullAction,
  10. AgentFinishAction,
  11. AddTaskAction,
  12. ModifyTaskAction
  13. )
  14. from opendevin.observation import (
  15. Observation,
  16. AgentErrorObservation,
  17. NullObservation
  18. )
  19. from .command_manager import CommandManager
  20. def print_with_indent(text: str):
  21. print("\t"+text.replace("\n","\n\t"), flush=True)
  22. class AgentController:
  23. def __init__(
  24. self,
  25. agent: Agent,
  26. workdir: str,
  27. max_iterations: int = 100,
  28. callbacks: List[Callable] = [],
  29. ):
  30. self.agent = agent
  31. self.max_iterations = max_iterations
  32. self.workdir = workdir
  33. self.command_manager = CommandManager(workdir)
  34. self.callbacks = callbacks
  35. def update_state_for_step(self, i):
  36. self.state.iteration = i
  37. self.state.background_commands_obs = self.command_manager.get_background_obs()
  38. def update_state_after_step(self):
  39. self.state.updated_info = []
  40. def add_history(self, action: Action, observation: Observation):
  41. if not isinstance(action, Action):
  42. raise ValueError("action must be an instance of Action")
  43. if not isinstance(observation, Observation):
  44. raise ValueError("observation must be an instance of Observation")
  45. self.state.history.append((action, observation))
  46. self.state.updated_info.append((action, observation))
  47. async def start_loop(self, task: str):
  48. finished = False
  49. plan = Plan(task)
  50. self.state = State(plan)
  51. for i in range(self.max_iterations):
  52. try:
  53. finished = await self.step(i)
  54. except Exception as e:
  55. print("Error in loop", e, flush=True)
  56. traceback.print_exc()
  57. break
  58. if finished:
  59. break
  60. if not finished:
  61. print("Exited before finishing", flush=True)
  62. async def step(self, i: int):
  63. print("\n\n==============", flush=True)
  64. print("STEP", i, flush=True)
  65. print_with_indent("\nPLAN:\n")
  66. print_with_indent(self.state.plan.__str__())
  67. log_obs = self.command_manager.get_background_obs()
  68. for obs in log_obs:
  69. self.add_history(NullAction(), obs)
  70. await self._run_callbacks(obs)
  71. print_with_indent("\nBACKGROUND LOG:\n%s" % obs)
  72. self.update_state_for_step(i)
  73. action: Action = NullAction()
  74. observation: Observation = NullObservation("")
  75. try:
  76. action = self.agent.step(self.state)
  77. if action is None:
  78. raise ValueError("Agent must return an action")
  79. print_with_indent("\nACTION:\n%s" % action)
  80. except Exception as e:
  81. observation = AgentErrorObservation(str(e))
  82. print_with_indent("\nAGENT ERROR:\n%s" % observation)
  83. traceback.print_exc()
  84. self.update_state_after_step()
  85. await self._run_callbacks(action)
  86. finished = isinstance(action, AgentFinishAction)
  87. if finished:
  88. print_with_indent("\nFINISHED")
  89. return True
  90. if isinstance(action, AddTaskAction):
  91. try:
  92. self.state.plan.add_subtask(action.parent, action.goal, action.subtasks)
  93. except Exception as e:
  94. observation = AgentErrorObservation(str(e))
  95. print_with_indent("\nADD TASK ERROR:\n%s" % observation)
  96. traceback.print_exc()
  97. elif isinstance(action, ModifyTaskAction):
  98. try:
  99. self.state.plan.set_subtask_state(action.id, action.state)
  100. except Exception as e:
  101. observation = AgentErrorObservation(str(e))
  102. print_with_indent("\nMODIFY TASK ERROR:\n%s" % observation)
  103. traceback.print_exc()
  104. if action.executable:
  105. try:
  106. observation = action.run(self)
  107. except Exception as e:
  108. observation = AgentErrorObservation(str(e))
  109. print_with_indent("\nACTION RUN ERROR:\n%s" % observation)
  110. traceback.print_exc()
  111. if not isinstance(observation, NullObservation):
  112. print_with_indent("\nOBSERVATION:\n%s" % observation)
  113. self.add_history(action, observation)
  114. await self._run_callbacks(observation)
  115. async def _run_callbacks(self, event):
  116. if event is None:
  117. return
  118. for callback in self.callbacks:
  119. idx = self.callbacks.index(callback)
  120. try:
  121. callback(event)
  122. except Exception as e:
  123. print("Callback error:" + str(idx), e, flush=True)
  124. pass
  125. await asyncio.sleep(0.001) # Give back control for a tick, so we can await in callbacks