test_agent.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import asyncio
  2. import os
  3. import shutil
  4. import subprocess
  5. import pytest
  6. from opendevin.controller.state.state import State
  7. from opendevin.core.config import AppConfig, SandboxConfig, load_from_env
  8. from opendevin.core.main import run_controller
  9. from opendevin.core.schema import AgentState
  10. from opendevin.events.action import (
  11. AgentFinishAction,
  12. AgentRejectAction,
  13. )
  14. from opendevin.events.observation.browse import BrowserOutputObservation
  15. from opendevin.events.observation.delegate import AgentDelegateObservation
  16. from opendevin.runtime import get_runtime_cls
  17. TEST_RUNTIME = os.getenv('TEST_RUNTIME')
  18. assert TEST_RUNTIME in ['eventstream', 'server']
  19. _ = get_runtime_cls(TEST_RUNTIME) # make sure it does not raise an error
  20. CONFIG = AppConfig(
  21. max_iterations=int(os.getenv('MAX_ITERATIONS', 15)),
  22. max_budget_per_task=int(os.getenv('MAX_BUDGET_PER_TASK', 15)),
  23. runtime=TEST_RUNTIME,
  24. default_agent=os.getenv('DEFAULT_AGENT'),
  25. workspace_base=os.getenv('WORKSPACE_BASE'),
  26. workspace_mount_path=os.getenv('WORKSPACE_MOUNT_PATH'),
  27. sandbox=SandboxConfig(
  28. box_type=os.getenv('SANDBOX_BOX_TYPE', 'ssh'),
  29. use_host_network=True,
  30. ),
  31. )
  32. load_from_env(CONFIG, os.environ)
  33. print('\nPaths used:')
  34. print(f'workspace_base: {CONFIG.workspace_base}')
  35. print(f'workspace_mount_path: {CONFIG.workspace_mount_path}')
  36. print(f'workspace_mount_path_in_sandbox: {CONFIG.workspace_mount_path_in_sandbox}')
  37. print(f'CONFIG: {CONFIG}')
  38. def get_number_of_prompts(test_name: str):
  39. mock_dir = os.path.join(
  40. os.environ['SCRIPT_DIR'],
  41. 'mock',
  42. f'{TEST_RUNTIME}_runtime',
  43. os.environ['DEFAULT_AGENT'],
  44. test_name,
  45. )
  46. prompt_files = [file for file in os.listdir(mock_dir) if file.startswith('prompt_')]
  47. return len(prompt_files)
  48. def validate_final_state(final_state: State | None, test_name: str):
  49. assert final_state is not None
  50. assert final_state.agent_state == AgentState.STOPPED
  51. assert final_state.last_error is None
  52. # number of LLM conversations should be the same as number of prompt/response
  53. # log files under mock/[agent]/[test_name] folder. If not, it means there are
  54. # redundant prompt/response log files checked into the repository.
  55. num_of_conversations = get_number_of_prompts(test_name)
  56. assert num_of_conversations > 0
  57. # we mock the cost of every conversation to be 1 USD
  58. assert final_state.metrics.accumulated_cost == num_of_conversations
  59. if final_state.history.has_delegation():
  60. assert final_state.iteration > final_state.local_iteration
  61. else:
  62. assert final_state.local_iteration == final_state.iteration
  63. assert final_state.iteration > 0
  64. @pytest.mark.skipif(
  65. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  66. reason='BrowsingAgent is a specialized agent',
  67. )
  68. @pytest.mark.skipif(
  69. (
  70. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  71. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  72. )
  73. and os.getenv('SANDBOX_BOX_TYPE', '').lower() != 'ssh',
  74. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  75. )
  76. @pytest.mark.skipif(
  77. os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
  78. reason='Manager agent is not capable of finishing this in reasonable steps yet',
  79. )
  80. def test_write_simple_script(current_test_name: str) -> None:
  81. task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
  82. final_state: State | None = asyncio.run(
  83. run_controller(CONFIG, task, exit_on_message=True)
  84. )
  85. validate_final_state(final_state, current_test_name)
  86. # Verify the script file exists
  87. assert CONFIG.workspace_base is not None
  88. script_path = os.path.join(CONFIG.workspace_base, 'hello.sh')
  89. assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
  90. # Run the script and capture the output
  91. result = subprocess.run(['bash', script_path], capture_output=True, text=True)
  92. # Verify the output from the script
  93. assert (
  94. result.stdout.strip() == 'hello'
  95. ), f'Expected output "hello", but got "{result.stdout.strip()}"'
  96. @pytest.mark.skipif(
  97. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  98. reason='BrowsingAgent is a specialized agent',
  99. )
  100. @pytest.mark.skipif(
  101. (
  102. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  103. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  104. )
  105. and os.getenv('SANDBOX_BOX_TYPE', '').lower() != 'ssh',
  106. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  107. )
  108. @pytest.mark.skipif(
  109. os.getenv('DEFAULT_AGENT') == 'PlannerAgent',
  110. reason='We only keep basic tests for PlannerAgent',
  111. )
  112. @pytest.mark.skipif(
  113. os.getenv('SANDBOX_BOX_TYPE') == 'local',
  114. reason='local sandbox shows environment-dependent absolute path for pwd command',
  115. )
  116. def test_edits(current_test_name: str):
  117. # Copy workspace artifacts to workspace_base location
  118. source_dir = os.path.join(os.path.dirname(__file__), 'workspace/test_edits/')
  119. files = os.listdir(source_dir)
  120. for file in files:
  121. dest_file = os.path.join(CONFIG.workspace_base, file)
  122. if os.path.exists(dest_file):
  123. os.remove(dest_file)
  124. shutil.copy(os.path.join(source_dir, file), dest_file)
  125. # Execute the task
  126. task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
  127. final_state: State | None = asyncio.run(
  128. run_controller(CONFIG, task, exit_on_message=True)
  129. )
  130. validate_final_state(final_state, current_test_name)
  131. # Verify bad.txt has been fixed
  132. text = """This is a stupid typo.
  133. Really?
  134. No more typos!
  135. Enjoy!
  136. """
  137. with open(os.path.join(CONFIG.workspace_base, 'bad.txt'), 'r') as f:
  138. content = f.read()
  139. assert content.strip() == text.strip()
  140. @pytest.mark.skipif(
  141. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  142. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  143. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  144. )
  145. @pytest.mark.skipif(
  146. os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
  147. reason='Currently, only ssh sandbox supports stateful tasks',
  148. )
  149. def test_ipython(current_test_name: str):
  150. # Execute the task
  151. task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
  152. final_state: State | None = asyncio.run(
  153. run_controller(CONFIG, task, exit_on_message=True)
  154. )
  155. validate_final_state(final_state, current_test_name)
  156. # Verify the file exists
  157. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  158. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  159. # Verify the file contains the expected content
  160. with open(file_path, 'r') as f:
  161. content = f.read()
  162. assert (
  163. content.strip() == 'hello world'
  164. ), f'Expected content "hello world", but got "{content.strip()}"'
  165. @pytest.mark.skipif(
  166. os.getenv('DEFAULT_AGENT') != 'ManagerAgent',
  167. reason='Currently, only ManagerAgent supports task rejection',
  168. )
  169. @pytest.mark.skipif(
  170. os.getenv('SANDBOX_BOX_TYPE') == 'local',
  171. reason='FIXME: local sandbox does not capture stderr',
  172. )
  173. def test_simple_task_rejection(current_test_name: str):
  174. # Give an impossible task to do: cannot write a commit message because
  175. # the workspace is not a git repo
  176. task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
  177. final_state: State | None = asyncio.run(
  178. run_controller(CONFIG, task, exit_on_message=True)
  179. )
  180. validate_final_state(final_state, current_test_name)
  181. assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
  182. @pytest.mark.skipif(
  183. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  184. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  185. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  186. )
  187. @pytest.mark.skipif(
  188. os.getenv('SANDBOX_BOX_TYPE') != 'ssh',
  189. reason='Currently, only ssh sandbox supports stateful tasks',
  190. )
  191. def test_ipython_module(current_test_name: str):
  192. # Execute the task
  193. task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
  194. final_state: State | None = asyncio.run(
  195. run_controller(CONFIG, task, exit_on_message=True)
  196. )
  197. validate_final_state(final_state, current_test_name)
  198. # Verify the file exists
  199. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  200. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  201. # Verify the file contains the expected content
  202. with open(file_path, 'r') as f:
  203. content = f.read()
  204. print(content)
  205. assert (
  206. content.strip().split(' ')[-1] == '1.0.9'
  207. ), f'Expected content "1.0.9", but got "{content.strip()}"'
  208. @pytest.mark.skipif(
  209. os.getenv('DEFAULT_AGENT') != 'BrowsingAgent'
  210. and os.getenv('DEFAULT_AGENT') != 'CodeActAgent',
  211. reason='currently only BrowsingAgent and CodeActAgent are capable of searching the internet',
  212. )
  213. @pytest.mark.skipif(
  214. (
  215. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  216. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  217. )
  218. and os.getenv('SANDBOX_BOX_TYPE', '').lower() != 'ssh',
  219. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  220. )
  221. def test_browse_internet(http_server, current_test_name: str):
  222. # Execute the task
  223. task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
  224. final_state: State | None = asyncio.run(
  225. run_controller(CONFIG, task, exit_on_message=True)
  226. )
  227. validate_final_state(final_state, current_test_name)
  228. # last action
  229. last_action = final_state.history.get_last_action()
  230. assert isinstance(last_action, AgentFinishAction)
  231. # last observation
  232. last_observation = final_state.history.get_last_observation()
  233. assert isinstance(
  234. last_observation, (BrowserOutputObservation, AgentDelegateObservation)
  235. )
  236. if isinstance(last_observation, BrowserOutputObservation):
  237. assert 'OpenDevin is all you need!' in last_observation.content
  238. elif isinstance(last_observation, AgentDelegateObservation):
  239. assert 'OpenDevin is all you need!' in last_observation.outputs['content']