test_agent.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import asyncio
  2. import os
  3. import shutil
  4. import subprocess
  5. import pytest
  6. from openhands.controller.state.state import State
  7. from openhands.core.config import load_app_config
  8. from openhands.core.main import run_controller
  9. from openhands.core.schema import AgentState
  10. from openhands.events.action import AgentFinishAction, AgentRejectAction, MessageAction
  11. from openhands.events.observation.browse import BrowserOutputObservation
  12. from openhands.events.observation.delegate import AgentDelegateObservation
  13. from openhands.runtime import get_runtime_cls
  14. TEST_RUNTIME = os.getenv('TEST_RUNTIME')
  15. assert TEST_RUNTIME in ['eventstream', 'remote']
  16. _ = get_runtime_cls(TEST_RUNTIME) # make sure it does not raise an error
  17. CONFIG = load_app_config()
  18. CONFIG.max_iterations = int(os.getenv('MAX_ITERATIONS', 20))
  19. CONFIG.max_budget_per_task = int(os.getenv('MAX_BUDGET_PER_TASK', 15))
  20. CONFIG.runtime = TEST_RUNTIME
  21. CONFIG.default_agent = os.getenv('DEFAULT_AGENT')
  22. CONFIG.workspace_base = os.getenv('WORKSPACE_BASE')
  23. CONFIG.workspace_mount_path = os.getenv('WORKSPACE_MOUNT_PATH')
  24. CONFIG.workspace_mount_path_in_sandbox = os.getenv(
  25. 'WORKSPACE_MOUNT_PATH_IN_SANDBOX', '/workspace'
  26. )
  27. CONFIG.sandbox.use_host_network = True
  28. print('\nPaths used:')
  29. print(f'workspace_base: {CONFIG.workspace_base}')
  30. print(f'workspace_mount_path: {CONFIG.workspace_mount_path}')
  31. print(f'workspace_mount_path_in_sandbox: {CONFIG.workspace_mount_path_in_sandbox}')
  32. def get_number_of_prompts(test_name: str):
  33. mock_dir = os.path.join(
  34. os.environ['SCRIPT_DIR'],
  35. 'mock',
  36. f'{TEST_RUNTIME}_runtime',
  37. os.environ['DEFAULT_AGENT'],
  38. test_name,
  39. )
  40. prompt_files = [file for file in os.listdir(mock_dir) if file.startswith('prompt_')]
  41. return len(prompt_files)
  42. def validate_final_state(final_state: State | None, test_name: str):
  43. regen = os.getenv('FORCE_REGENERATE', False).lower() in ['true', '1', 'yes']
  44. assert final_state is not None
  45. assert final_state.agent_state == AgentState.STOPPED
  46. if not regen:
  47. assert final_state.last_error is None
  48. # number of LLM conversations should be the same as number of prompt/response
  49. # log files under mock/[runtime]/[agent]/[test_name] folder. If not, it means there are
  50. # redundant prompt/response log files checked into the repository.
  51. num_of_conversations = get_number_of_prompts(test_name)
  52. assert num_of_conversations > 0
  53. # we mock the cost of every conversation to be 1 USD
  54. # assert int(final_state.metrics.accumulated_cost) == num_of_conversations
  55. if final_state.history.has_delegation():
  56. assert final_state.iteration > final_state.local_iteration
  57. else:
  58. assert final_state.local_iteration == final_state.iteration
  59. assert final_state.iteration > 0
  60. @pytest.mark.skipif(
  61. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  62. reason='BrowsingAgent is a specialized agent',
  63. )
  64. @pytest.mark.skipif(
  65. (
  66. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  67. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  68. ),
  69. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  70. )
  71. @pytest.mark.skipif(
  72. os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
  73. reason='Manager agent is not capable of finishing this in reasonable steps yet',
  74. )
  75. def test_write_simple_script(current_test_name: str) -> None:
  76. task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
  77. final_state: State | None = asyncio.run(
  78. run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
  79. )
  80. validate_final_state(final_state, current_test_name)
  81. # Verify the script file exists
  82. assert CONFIG.workspace_base is not None
  83. script_path = os.path.join(CONFIG.workspace_base, 'hello.sh')
  84. assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
  85. # Run the script and capture the output
  86. result = subprocess.run(['bash', script_path], capture_output=True, text=True)
  87. # Verify the output from the script
  88. assert (
  89. result.stdout.strip() == 'hello'
  90. ), f'Expected output "hello", but got "{result.stdout.strip()}"'
  91. @pytest.mark.skipif(
  92. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  93. reason='BrowsingAgent is a specialized agent',
  94. )
  95. @pytest.mark.skipif(
  96. (
  97. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  98. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  99. ),
  100. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  101. )
  102. @pytest.mark.skipif(
  103. os.getenv('DEFAULT_AGENT') == 'PlannerAgent',
  104. reason='We only keep basic tests for PlannerAgent',
  105. )
  106. def test_edits(current_test_name: str):
  107. # Copy workspace artifacts to workspace_base location
  108. source_dir = os.path.join(os.path.dirname(__file__), 'workspace/test_edits/')
  109. files = os.listdir(source_dir)
  110. for file in files:
  111. dest_file = os.path.join(CONFIG.workspace_base, file)
  112. if os.path.exists(dest_file):
  113. os.remove(dest_file)
  114. shutil.copy(os.path.join(source_dir, file), dest_file)
  115. # Execute the task
  116. task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
  117. final_state: State | None = asyncio.run(
  118. run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
  119. )
  120. validate_final_state(final_state, current_test_name)
  121. # Verify bad.txt has been fixed
  122. text = """This is a stupid typo.
  123. Really?
  124. No more typos!
  125. Enjoy!
  126. """
  127. with open(os.path.join(CONFIG.workspace_base, 'bad.txt'), 'r') as f:
  128. content = f.read()
  129. assert content.strip() == text.strip()
  130. @pytest.mark.skipif(
  131. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  132. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  133. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  134. )
  135. def test_ipython(current_test_name: str):
  136. # Execute the task
  137. task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
  138. final_state: State | None = asyncio.run(
  139. run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
  140. )
  141. validate_final_state(final_state, current_test_name)
  142. # Verify the file exists
  143. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  144. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  145. # Verify the file contains the expected content
  146. with open(file_path, 'r') as f:
  147. content = f.read()
  148. assert (
  149. content.strip() == 'hello world'
  150. ), f'Expected content "hello world", but got "{content.strip()}"'
  151. @pytest.mark.skipif(
  152. os.getenv('DEFAULT_AGENT') != 'ManagerAgent',
  153. reason='Currently, only ManagerAgent supports task rejection',
  154. )
  155. def test_simple_task_rejection(current_test_name: str):
  156. # Give an impossible task to do: cannot write a commit message because
  157. # the workspace is not a git repo
  158. task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
  159. final_state: State | None = asyncio.run(
  160. run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
  161. )
  162. validate_final_state(final_state, current_test_name)
  163. assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
  164. @pytest.mark.skipif(
  165. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  166. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  167. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  168. )
  169. def test_ipython_module(current_test_name: str):
  170. # Execute the task
  171. task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
  172. final_state: State | None = asyncio.run(
  173. run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
  174. )
  175. validate_final_state(final_state, current_test_name)
  176. # Verify the file exists
  177. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  178. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  179. # Verify the file contains the expected content
  180. with open(file_path, 'r') as f:
  181. content = f.read()
  182. print(content)
  183. assert (
  184. content.strip().split(' ')[-1] == '1.0.9'
  185. ), f'Expected content "1.0.9", but got "{content.strip()}"'
  186. @pytest.mark.skipif(
  187. os.getenv('DEFAULT_AGENT') != 'BrowsingAgent'
  188. and os.getenv('DEFAULT_AGENT') != 'CodeActAgent',
  189. reason='currently only BrowsingAgent and CodeActAgent are capable of searching the internet',
  190. )
  191. def test_browse_internet(current_test_name: str):
  192. # Execute the task
  193. task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
  194. final_state: State | None = asyncio.run(
  195. run_controller(CONFIG, MessageAction(content=task), exit_on_message=True)
  196. )
  197. validate_final_state(final_state, current_test_name)
  198. # last action
  199. last_action = final_state.history.get_last_action()
  200. assert isinstance(last_action, AgentFinishAction)
  201. # last observation
  202. last_observation = final_state.history.get_last_observation()
  203. assert isinstance(
  204. last_observation, (BrowserOutputObservation, AgentDelegateObservation)
  205. )
  206. if isinstance(last_observation, BrowserOutputObservation):
  207. assert 'OpenHands is all you need!' in last_observation.content
  208. elif isinstance(last_observation, AgentDelegateObservation):
  209. assert 'OpenHands is all you need!' in last_observation.outputs['content']