test_agent.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import asyncio
  2. import os
  3. import shutil
  4. import subprocess
  5. import pytest
  6. from openhands.controller.state.state import State
  7. from openhands.core.config import AppConfig, SandboxConfig, load_from_env
  8. from openhands.core.main import run_controller
  9. from openhands.core.schema import AgentState
  10. from openhands.events.action import (
  11. AgentFinishAction,
  12. AgentRejectAction,
  13. )
  14. from openhands.events.observation.browse import BrowserOutputObservation
  15. from openhands.events.observation.delegate import AgentDelegateObservation
  16. from openhands.runtime import get_runtime_cls
  17. TEST_RUNTIME = os.getenv('TEST_RUNTIME')
  18. assert TEST_RUNTIME in ['eventstream', 'remote']
  19. _ = get_runtime_cls(TEST_RUNTIME) # make sure it does not raise an error
  20. CONFIG = AppConfig(
  21. max_iterations=int(os.getenv('MAX_ITERATIONS', 20)),
  22. max_budget_per_task=int(os.getenv('MAX_BUDGET_PER_TASK', 15)),
  23. runtime=TEST_RUNTIME,
  24. default_agent=os.getenv('DEFAULT_AGENT'),
  25. workspace_base=os.getenv('WORKSPACE_BASE'),
  26. workspace_mount_path=os.getenv('WORKSPACE_MOUNT_PATH'),
  27. sandbox=SandboxConfig(
  28. use_host_network=True,
  29. ),
  30. )
  31. load_from_env(CONFIG, os.environ)
  32. print('\nPaths used:')
  33. print(f'workspace_base: {CONFIG.workspace_base}')
  34. print(f'workspace_mount_path: {CONFIG.workspace_mount_path}')
  35. print(f'workspace_mount_path_in_sandbox: {CONFIG.workspace_mount_path_in_sandbox}')
  36. # Check if running in WSL environment
  37. if 'WSL_DISTRO_NAME' in os.environ:
  38. if (
  39. CONFIG.workspace_base
  40. and CONFIG.workspace_mount_path
  41. and CONFIG.workspace_base != CONFIG.workspace_mount_path
  42. ):
  43. print(
  44. '\n**********\nWARNING: if WORKSPACE_MOUNT_PATH is set differently to'
  45. '\nWORKSPACE_BASE some file operation tests may fail!\n**********\n'
  46. )
  47. def get_number_of_prompts(test_name: str):
  48. mock_dir = os.path.join(
  49. os.environ['SCRIPT_DIR'],
  50. 'mock',
  51. f'{TEST_RUNTIME}_runtime',
  52. os.environ['DEFAULT_AGENT'],
  53. test_name,
  54. )
  55. prompt_files = [file for file in os.listdir(mock_dir) if file.startswith('prompt_')]
  56. return len(prompt_files)
  57. def validate_final_state(final_state: State | None, test_name: str):
  58. regen = os.getenv('FORCE_REGENERATE', False).lower() in ['true', '1', 'yes']
  59. assert final_state is not None
  60. assert final_state.agent_state == AgentState.STOPPED
  61. if not regen:
  62. assert final_state.last_error is None
  63. # number of LLM conversations should be the same as number of prompt/response
  64. # log files under mock/[runtime]/[agent]/[test_name] folder. If not, it means there are
  65. # redundant prompt/response log files checked into the repository.
  66. num_of_conversations = get_number_of_prompts(test_name)
  67. assert num_of_conversations > 0
  68. # we mock the cost of every conversation to be 1 USD
  69. assert int(final_state.metrics.accumulated_cost) == num_of_conversations
  70. if final_state.history.has_delegation():
  71. assert final_state.iteration > final_state.local_iteration
  72. else:
  73. assert final_state.local_iteration == final_state.iteration
  74. assert final_state.iteration > 0
  75. @pytest.mark.skipif(
  76. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  77. reason='BrowsingAgent is a specialized agent',
  78. )
  79. @pytest.mark.skipif(
  80. (
  81. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  82. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  83. ),
  84. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  85. )
  86. @pytest.mark.skipif(
  87. os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
  88. reason='Manager agent is not capable of finishing this in reasonable steps yet',
  89. )
  90. def test_write_simple_script(current_test_name: str) -> None:
  91. task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
  92. final_state: State | None = asyncio.run(
  93. run_controller(CONFIG, task, exit_on_message=True)
  94. )
  95. validate_final_state(final_state, current_test_name)
  96. # Verify the script file exists
  97. assert CONFIG.workspace_base is not None
  98. script_path = os.path.join(CONFIG.workspace_base, 'hello.sh')
  99. assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
  100. # Run the script and capture the output
  101. result = subprocess.run(['bash', script_path], capture_output=True, text=True)
  102. # Verify the output from the script
  103. assert (
  104. result.stdout.strip() == 'hello'
  105. ), f'Expected output "hello", but got "{result.stdout.strip()}"'
  106. @pytest.mark.skipif(
  107. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  108. reason='BrowsingAgent is a specialized agent',
  109. )
  110. @pytest.mark.skipif(
  111. (
  112. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  113. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  114. ),
  115. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  116. )
  117. @pytest.mark.skipif(
  118. os.getenv('DEFAULT_AGENT') == 'PlannerAgent',
  119. reason='We only keep basic tests for PlannerAgent',
  120. )
  121. def test_edits(current_test_name: str):
  122. # Copy workspace artifacts to workspace_base location
  123. source_dir = os.path.join(os.path.dirname(__file__), 'workspace/test_edits/')
  124. files = os.listdir(source_dir)
  125. for file in files:
  126. dest_file = os.path.join(CONFIG.workspace_base, file)
  127. if os.path.exists(dest_file):
  128. os.remove(dest_file)
  129. shutil.copy(os.path.join(source_dir, file), dest_file)
  130. # Execute the task
  131. task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
  132. final_state: State | None = asyncio.run(
  133. run_controller(CONFIG, task, exit_on_message=True)
  134. )
  135. validate_final_state(final_state, current_test_name)
  136. # Verify bad.txt has been fixed
  137. text = """This is a stupid typo.
  138. Really?
  139. No more typos!
  140. Enjoy!
  141. """
  142. with open(os.path.join(CONFIG.workspace_base, 'bad.txt'), 'r') as f:
  143. content = f.read()
  144. assert content.strip() == text.strip()
  145. @pytest.mark.skipif(
  146. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  147. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  148. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  149. )
  150. def test_ipython(current_test_name: str):
  151. # Execute the task
  152. task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
  153. final_state: State | None = asyncio.run(
  154. run_controller(CONFIG, task, exit_on_message=True)
  155. )
  156. validate_final_state(final_state, current_test_name)
  157. # Verify the file exists
  158. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  159. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  160. # Verify the file contains the expected content
  161. with open(file_path, 'r') as f:
  162. content = f.read()
  163. assert (
  164. content.strip() == 'hello world'
  165. ), f'Expected content "hello world", but got "{content.strip()}"'
  166. @pytest.mark.skipif(
  167. os.getenv('DEFAULT_AGENT') != 'ManagerAgent',
  168. reason='Currently, only ManagerAgent supports task rejection',
  169. )
  170. def test_simple_task_rejection(current_test_name: str):
  171. # Give an impossible task to do: cannot write a commit message because
  172. # the workspace is not a git repo
  173. task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
  174. final_state: State | None = asyncio.run(
  175. run_controller(CONFIG, task, exit_on_message=True)
  176. )
  177. validate_final_state(final_state, current_test_name)
  178. assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
  179. @pytest.mark.skipif(
  180. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  181. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  182. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  183. )
  184. def test_ipython_module(current_test_name: str):
  185. # Execute the task
  186. task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
  187. final_state: State | None = asyncio.run(
  188. run_controller(CONFIG, task, exit_on_message=True)
  189. )
  190. validate_final_state(final_state, current_test_name)
  191. # Verify the file exists
  192. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  193. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  194. # Verify the file contains the expected content
  195. with open(file_path, 'r') as f:
  196. content = f.read()
  197. print(content)
  198. assert (
  199. content.strip().split(' ')[-1] == '1.0.9'
  200. ), f'Expected content "1.0.9", but got "{content.strip()}"'
  201. @pytest.mark.skipif(
  202. os.getenv('DEFAULT_AGENT') != 'BrowsingAgent'
  203. and os.getenv('DEFAULT_AGENT') != 'CodeActAgent',
  204. reason='currently only BrowsingAgent and CodeActAgent are capable of searching the internet',
  205. )
  206. def test_browse_internet(current_test_name: str):
  207. # Execute the task
  208. task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
  209. final_state: State | None = asyncio.run(
  210. run_controller(CONFIG, task, exit_on_message=True)
  211. )
  212. validate_final_state(final_state, current_test_name)
  213. # last action
  214. last_action = final_state.history.get_last_action()
  215. assert isinstance(last_action, AgentFinishAction)
  216. # last observation
  217. last_observation = final_state.history.get_last_observation()
  218. assert isinstance(
  219. last_observation, (BrowserOutputObservation, AgentDelegateObservation)
  220. )
  221. if isinstance(last_observation, BrowserOutputObservation):
  222. assert 'OpenHands is all you need!' in last_observation.content
  223. elif isinstance(last_observation, AgentDelegateObservation):
  224. assert 'OpenHands is all you need!' in last_observation.outputs['content']