agent.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from typing import TypedDict, Union
  2. from openhands.controller.agent import Agent
  3. from openhands.controller.state.state import State
  4. from openhands.core.config import AgentConfig
  5. from openhands.core.schema import AgentState
  6. from openhands.events.action import (
  7. Action,
  8. AddTaskAction,
  9. AgentFinishAction,
  10. AgentRejectAction,
  11. BrowseInteractiveAction,
  12. BrowseURLAction,
  13. CmdRunAction,
  14. FileReadAction,
  15. FileWriteAction,
  16. MessageAction,
  17. ModifyTaskAction,
  18. )
  19. from openhands.events.observation import (
  20. AgentStateChangedObservation,
  21. CmdOutputObservation,
  22. FileReadObservation,
  23. FileWriteObservation,
  24. NullObservation,
  25. Observation,
  26. )
  27. from openhands.events.serialization.event import event_to_dict
  28. from openhands.llm.llm import LLM
  29. """
  30. FIXME: There are a few problems this surfaced
  31. * FileWrites seem to add an unintended newline at the end of the file
  32. * Browser not working
  33. """
  34. ActionObs = TypedDict(
  35. 'ActionObs', {'action': Action, 'observations': list[Observation]}
  36. )
  37. class DummyAgent(Agent):
  38. VERSION = '1.0'
  39. """
  40. The DummyAgent is used for e2e testing. It just sends the same set of actions deterministically,
  41. without making any LLM calls.
  42. """
  43. def __init__(self, llm: LLM, config: AgentConfig):
  44. super().__init__(llm, config)
  45. self.steps: list[ActionObs] = [
  46. {
  47. 'action': AddTaskAction(
  48. parent='None', goal='check the current directory'
  49. ),
  50. 'observations': [],
  51. },
  52. {
  53. 'action': AddTaskAction(parent='0', goal='run ls'),
  54. 'observations': [],
  55. },
  56. {
  57. 'action': ModifyTaskAction(task_id='0', state='in_progress'),
  58. 'observations': [],
  59. },
  60. {
  61. 'action': MessageAction('Time to get started!'),
  62. 'observations': [],
  63. },
  64. {
  65. 'action': CmdRunAction(command='echo "foo"'),
  66. 'observations': [
  67. CmdOutputObservation(
  68. 'foo', command_id=-1, command='echo "foo"', exit_code=0
  69. )
  70. ],
  71. },
  72. {
  73. 'action': FileWriteAction(
  74. content='echo "Hello, World!"', path='hello.sh'
  75. ),
  76. 'observations': [
  77. FileWriteObservation(
  78. content='echo "Hello, World!"', path='hello.sh'
  79. )
  80. ],
  81. },
  82. {
  83. 'action': FileReadAction(path='hello.sh'),
  84. 'observations': [
  85. FileReadObservation('echo "Hello, World!"\n', path='hello.sh')
  86. ],
  87. },
  88. {
  89. 'action': CmdRunAction(command='bash hello.sh'),
  90. 'observations': [
  91. CmdOutputObservation(
  92. 'bash: hello.sh: No such file or directory',
  93. command_id=-1,
  94. command='bash workspace/hello.sh',
  95. exit_code=127,
  96. )
  97. ],
  98. },
  99. {
  100. 'action': BrowseURLAction(url='https://google.com'),
  101. 'observations': [
  102. # BrowserOutputObservation('<html><body>Simulated Google page</body></html>',url='https://google.com',screenshot=''),
  103. ],
  104. },
  105. {
  106. 'action': BrowseInteractiveAction(
  107. browser_actions='goto("https://google.com")'
  108. ),
  109. 'observations': [
  110. # BrowserOutputObservation('<html><body>Simulated Google page after interaction</body></html>',url='https://google.com',screenshot=''),
  111. ],
  112. },
  113. {
  114. 'action': AgentRejectAction(),
  115. 'observations': [NullObservation('')],
  116. },
  117. {
  118. 'action': AgentFinishAction(
  119. outputs={}, thought='Task completed', action='finish'
  120. ),
  121. 'observations': [AgentStateChangedObservation('', AgentState.FINISHED)],
  122. },
  123. ]
  124. def step(self, state: State) -> Action:
  125. if state.iteration >= len(self.steps):
  126. return AgentFinishAction()
  127. current_step = self.steps[state.iteration]
  128. action = current_step['action']
  129. # If the action is AddTaskAction or ModifyTaskAction, update the parent ID or task_id
  130. if isinstance(action, AddTaskAction):
  131. if action.parent == 'None':
  132. action.parent = '' # Root task has no parent
  133. elif action.parent == '0':
  134. action.parent = state.root_task.id
  135. elif action.parent.startswith('0.'):
  136. action.parent = f'{state.root_task.id}{action.parent[1:]}'
  137. elif isinstance(action, ModifyTaskAction):
  138. if action.task_id == '0':
  139. action.task_id = state.root_task.id
  140. elif action.task_id.startswith('0.'):
  141. action.task_id = f'{state.root_task.id}{action.task_id[1:]}'
  142. # Ensure the task_id doesn't start with a dot
  143. if action.task_id.startswith('.'):
  144. action.task_id = action.task_id[1:]
  145. elif isinstance(action, (BrowseURLAction, BrowseInteractiveAction)):
  146. try:
  147. return self.simulate_browser_action(action)
  148. except (
  149. Exception
  150. ): # This could be a specific exception for browser unavailability
  151. return self.handle_browser_unavailable(action)
  152. if state.iteration > 0:
  153. prev_step = self.steps[state.iteration - 1]
  154. if 'observations' in prev_step and prev_step['observations']:
  155. expected_observations = prev_step['observations']
  156. hist_events = state.history.get_last_events(len(expected_observations))
  157. if len(hist_events) < len(expected_observations):
  158. print(
  159. f'Warning: Expected {len(expected_observations)} observations, but got {len(hist_events)}'
  160. )
  161. for i in range(min(len(expected_observations), len(hist_events))):
  162. hist_obs = event_to_dict(hist_events[i])
  163. expected_obs = event_to_dict(expected_observations[i])
  164. # Remove dynamic fields for comparison
  165. for obs in [hist_obs, expected_obs]:
  166. obs.pop('id', None)
  167. obs.pop('timestamp', None)
  168. obs.pop('cause', None)
  169. obs.pop('source', None)
  170. if 'extras' in obs:
  171. obs['extras'].pop('command_id', None)
  172. if hist_obs != expected_obs:
  173. print(
  174. f'Warning: Observation mismatch. Expected {expected_obs}, got {hist_obs}'
  175. )
  176. return action
  177. def simulate_browser_action(
  178. self, action: Union[BrowseURLAction, BrowseInteractiveAction]
  179. ) -> Action:
  180. # Instead of simulating, we'll reject the browser action
  181. return self.handle_browser_unavailable(action)
  182. def handle_browser_unavailable(
  183. self, action: Union[BrowseURLAction, BrowseInteractiveAction]
  184. ) -> Action:
  185. # Create a message action to inform that browsing is not available
  186. message = 'Browser actions are not available in the DummyAgent environment.'
  187. if isinstance(action, BrowseURLAction):
  188. message += f' Unable to browse URL: {action.url}'
  189. elif isinstance(action, BrowseInteractiveAction):
  190. message += (
  191. f' Unable to perform interactive browsing: {action.browser_actions}'
  192. )
  193. return MessageAction(content=message)