__init__.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import asyncio
  2. from typing import List, Callable, Tuple
  3. import traceback
  4. from opendevin.state import State
  5. from opendevin.agent import Agent
  6. from opendevin.action import (
  7. Action,
  8. NullAction,
  9. FileReadAction,
  10. FileWriteAction,
  11. AgentFinishAction,
  12. )
  13. from opendevin.observation import (
  14. Observation,
  15. AgentErrorObservation,
  16. NullObservation
  17. )
  18. from .command_manager import CommandManager
  19. def print_with_indent(text: str):
  20. print("\t"+text.replace("\n","\n\t"), flush=True)
  21. class AgentController:
  22. def __init__(
  23. self,
  24. agent: Agent,
  25. workdir: str,
  26. max_iterations: int = 100,
  27. callbacks: List[Callable] = [],
  28. ):
  29. self.agent = agent
  30. self.max_iterations = max_iterations
  31. self.workdir = workdir
  32. self.command_manager = CommandManager(workdir)
  33. self.callbacks = callbacks
  34. self.state_updated_info: List[Tuple[Action, Observation]] = []
  35. def get_current_state(self) -> State:
  36. # update observations & actions
  37. state = State(
  38. background_commands_obs=self.command_manager.get_background_obs(),
  39. updated_info=self.state_updated_info,
  40. )
  41. self.state_updated_info = []
  42. return state
  43. def add_history(self, action: Action, observation: Observation):
  44. if not isinstance(action, Action):
  45. raise ValueError("action must be an instance of Action")
  46. if not isinstance(observation, Observation):
  47. raise ValueError("observation must be an instance of Observation")
  48. self.state_updated_info.append((action, observation))
  49. async def start_loop(self, task_instruction: str):
  50. finished = False
  51. self.agent.instruction = task_instruction
  52. for i in range(self.max_iterations):
  53. try:
  54. finished = await self.step(i)
  55. except Exception as e:
  56. print("Error in loop", e, flush=True)
  57. traceback.print_exc()
  58. break
  59. if finished:
  60. break
  61. if not finished:
  62. print("Exited before finishing", flush=True)
  63. async def step(self, i: int):
  64. print("\n\n==============", flush=True)
  65. print("STEP", i, flush=True)
  66. log_obs = self.command_manager.get_background_obs()
  67. for obs in log_obs:
  68. self.add_history(NullAction(), obs)
  69. await self._run_callbacks(obs)
  70. print_with_indent("\nBACKGROUND LOG:\n%s" % obs)
  71. state: State = self.get_current_state()
  72. action: Action = NullAction()
  73. observation: Observation = NullObservation("")
  74. try:
  75. action = self.agent.step(state)
  76. print_with_indent("\nACTION:\n%s" % action)
  77. except Exception as e:
  78. observation = AgentErrorObservation(str(e))
  79. print_with_indent("\nAGENT ERROR:\n%s" % observation)
  80. traceback.print_exc()
  81. await self._run_callbacks(action)
  82. if isinstance(action, AgentFinishAction):
  83. print_with_indent("\nFINISHED")
  84. return True
  85. if isinstance(action, (FileReadAction, FileWriteAction)):
  86. action_cls = action.__class__
  87. _kwargs = action.__dict__
  88. _kwargs["base_path"] = self.workdir
  89. action = action_cls(**_kwargs)
  90. print(action, flush=True)
  91. if action.executable:
  92. try:
  93. observation = action.run(self)
  94. except Exception as e:
  95. observation = AgentErrorObservation(str(e))
  96. print_with_indent("\nACTION RUN ERROR:\n%s" % observation)
  97. traceback.print_exc()
  98. if not isinstance(observation, NullObservation):
  99. print_with_indent("\nOBSERVATION:\n%s" % observation)
  100. self.add_history(action, observation)
  101. await self._run_callbacks(observation)
  102. async def _run_callbacks(self, event):
  103. if event is None:
  104. return
  105. for callback in self.callbacks:
  106. idx = self.callbacks.index(callback)
  107. try:
  108. callback(event)
  109. except Exception as e:
  110. print("Callback error:" + str(idx), e, flush=True)
  111. pass
  112. await asyncio.sleep(0.001) # Give back control for a tick, so we can await in callbacks