test_agent.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import asyncio
  2. import os
  3. import shutil
  4. import subprocess
  5. import pytest
  6. from opendevin.controller.state.state import State
  7. from opendevin.core.config import AppConfig, SandboxConfig, load_from_env
  8. from opendevin.core.main import run_controller
  9. from opendevin.core.schema import AgentState
  10. from opendevin.events.action import (
  11. AgentFinishAction,
  12. AgentRejectAction,
  13. )
  14. from opendevin.events.observation.browse import BrowserOutputObservation
  15. from opendevin.events.observation.delegate import AgentDelegateObservation
  16. from opendevin.runtime import get_runtime_cls
  17. TEST_RUNTIME = os.getenv('TEST_RUNTIME')
  18. assert TEST_RUNTIME in ['eventstream', 'server']
  19. _ = get_runtime_cls(TEST_RUNTIME) # make sure it does not raise an error
  20. CONFIG = AppConfig(
  21. max_iterations=int(os.getenv('MAX_ITERATIONS', 15)),
  22. max_budget_per_task=int(os.getenv('MAX_BUDGET_PER_TASK', 15)),
  23. runtime=TEST_RUNTIME,
  24. default_agent=os.getenv('DEFAULT_AGENT'),
  25. workspace_base=os.getenv('WORKSPACE_BASE'),
  26. workspace_mount_path=os.getenv('WORKSPACE_MOUNT_PATH'),
  27. sandbox=SandboxConfig(
  28. use_host_network=True,
  29. ),
  30. )
  31. load_from_env(CONFIG, os.environ)
  32. print('\nPaths used:')
  33. print(f'workspace_base: {CONFIG.workspace_base}')
  34. print(f'workspace_mount_path: {CONFIG.workspace_mount_path}')
  35. print(f'workspace_mount_path_in_sandbox: {CONFIG.workspace_mount_path_in_sandbox}')
  36. print(f'CONFIG: {CONFIG}')
  37. def get_number_of_prompts(test_name: str):
  38. mock_dir = os.path.join(
  39. os.environ['SCRIPT_DIR'],
  40. 'mock',
  41. f'{TEST_RUNTIME}_runtime',
  42. os.environ['DEFAULT_AGENT'],
  43. test_name,
  44. )
  45. prompt_files = [file for file in os.listdir(mock_dir) if file.startswith('prompt_')]
  46. return len(prompt_files)
  47. def validate_final_state(final_state: State | None, test_name: str):
  48. assert final_state is not None
  49. assert final_state.agent_state == AgentState.STOPPED
  50. assert final_state.last_error is None
  51. # number of LLM conversations should be the same as number of prompt/response
  52. # log files under mock/[agent]/[test_name] folder. If not, it means there are
  53. # redundant prompt/response log files checked into the repository.
  54. num_of_conversations = get_number_of_prompts(test_name)
  55. assert num_of_conversations > 0
  56. # we mock the cost of every conversation to be 1 USD
  57. assert final_state.metrics.accumulated_cost == num_of_conversations
  58. if final_state.history.has_delegation():
  59. assert final_state.iteration > final_state.local_iteration
  60. else:
  61. assert final_state.local_iteration == final_state.iteration
  62. assert final_state.iteration > 0
  63. @pytest.mark.skipif(
  64. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  65. reason='BrowsingAgent is a specialized agent',
  66. )
  67. @pytest.mark.skipif(
  68. (
  69. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  70. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  71. ),
  72. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  73. )
  74. @pytest.mark.skipif(
  75. os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
  76. reason='Manager agent is not capable of finishing this in reasonable steps yet',
  77. )
  78. def test_write_simple_script(current_test_name: str) -> None:
  79. task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
  80. final_state: State | None = asyncio.run(
  81. run_controller(CONFIG, task, exit_on_message=True)
  82. )
  83. validate_final_state(final_state, current_test_name)
  84. # Verify the script file exists
  85. assert CONFIG.workspace_base is not None
  86. script_path = os.path.join(CONFIG.workspace_base, 'hello.sh')
  87. assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
  88. # Run the script and capture the output
  89. result = subprocess.run(['bash', script_path], capture_output=True, text=True)
  90. # Verify the output from the script
  91. assert (
  92. result.stdout.strip() == 'hello'
  93. ), f'Expected output "hello", but got "{result.stdout.strip()}"'
  94. @pytest.mark.skipif(
  95. os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
  96. reason='BrowsingAgent is a specialized agent',
  97. )
  98. @pytest.mark.skipif(
  99. (
  100. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  101. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  102. ),
  103. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  104. )
  105. @pytest.mark.skipif(
  106. os.getenv('DEFAULT_AGENT') == 'PlannerAgent',
  107. reason='We only keep basic tests for PlannerAgent',
  108. )
  109. def test_edits(current_test_name: str):
  110. # Copy workspace artifacts to workspace_base location
  111. source_dir = os.path.join(os.path.dirname(__file__), 'workspace/test_edits/')
  112. files = os.listdir(source_dir)
  113. for file in files:
  114. dest_file = os.path.join(CONFIG.workspace_base, file)
  115. if os.path.exists(dest_file):
  116. os.remove(dest_file)
  117. shutil.copy(os.path.join(source_dir, file), dest_file)
  118. # Execute the task
  119. task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
  120. final_state: State | None = asyncio.run(
  121. run_controller(CONFIG, task, exit_on_message=True)
  122. )
  123. validate_final_state(final_state, current_test_name)
  124. # Verify bad.txt has been fixed
  125. text = """This is a stupid typo.
  126. Really?
  127. No more typos!
  128. Enjoy!
  129. """
  130. with open(os.path.join(CONFIG.workspace_base, 'bad.txt'), 'r') as f:
  131. content = f.read()
  132. assert content.strip() == text.strip()
  133. @pytest.mark.skipif(
  134. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  135. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  136. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  137. )
  138. def test_ipython(current_test_name: str):
  139. # Execute the task
  140. task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
  141. final_state: State | None = asyncio.run(
  142. run_controller(CONFIG, task, exit_on_message=True)
  143. )
  144. validate_final_state(final_state, current_test_name)
  145. # Verify the file exists
  146. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  147. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  148. # Verify the file contains the expected content
  149. with open(file_path, 'r') as f:
  150. content = f.read()
  151. assert (
  152. content.strip() == 'hello world'
  153. ), f'Expected content "hello world", but got "{content.strip()}"'
  154. @pytest.mark.skipif(
  155. os.getenv('DEFAULT_AGENT') != 'ManagerAgent',
  156. reason='Currently, only ManagerAgent supports task rejection',
  157. )
  158. def test_simple_task_rejection(current_test_name: str):
  159. # Give an impossible task to do: cannot write a commit message because
  160. # the workspace is not a git repo
  161. task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
  162. final_state: State | None = asyncio.run(
  163. run_controller(CONFIG, task, exit_on_message=True)
  164. )
  165. validate_final_state(final_state, current_test_name)
  166. assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
  167. @pytest.mark.skipif(
  168. os.getenv('DEFAULT_AGENT') != 'CodeActAgent'
  169. and os.getenv('DEFAULT_AGENT') != 'CodeActSWEAgent',
  170. reason='currently only CodeActAgent and CodeActSWEAgent have IPython (Jupyter) execution by default',
  171. )
  172. def test_ipython_module(current_test_name: str):
  173. # Execute the task
  174. task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
  175. final_state: State | None = asyncio.run(
  176. run_controller(CONFIG, task, exit_on_message=True)
  177. )
  178. validate_final_state(final_state, current_test_name)
  179. # Verify the file exists
  180. file_path = os.path.join(CONFIG.workspace_base, 'test.txt')
  181. assert os.path.exists(file_path), 'The file "test.txt" does not exist'
  182. # Verify the file contains the expected content
  183. with open(file_path, 'r') as f:
  184. content = f.read()
  185. print(content)
  186. assert (
  187. content.strip().split(' ')[-1] == '1.0.9'
  188. ), f'Expected content "1.0.9", but got "{content.strip()}"'
  189. @pytest.mark.skipif(
  190. os.getenv('DEFAULT_AGENT') != 'BrowsingAgent'
  191. and os.getenv('DEFAULT_AGENT') != 'CodeActAgent',
  192. reason='currently only BrowsingAgent and CodeActAgent are capable of searching the internet',
  193. )
  194. @pytest.mark.skipif(
  195. (
  196. os.getenv('DEFAULT_AGENT') == 'CodeActAgent'
  197. or os.getenv('DEFAULT_AGENT') == 'CodeActSWEAgent'
  198. ),
  199. reason='CodeActAgent/CodeActSWEAgent only supports ssh sandbox which is stateful',
  200. )
  201. def test_browse_internet(http_server, current_test_name: str):
  202. # Execute the task
  203. task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
  204. final_state: State | None = asyncio.run(
  205. run_controller(CONFIG, task, exit_on_message=True)
  206. )
  207. validate_final_state(final_state, current_test_name)
  208. # last action
  209. last_action = final_state.history.get_last_action()
  210. assert isinstance(last_action, AgentFinishAction)
  211. # last observation
  212. last_observation = final_state.history.get_last_observation()
  213. assert isinstance(
  214. last_observation, (BrowserOutputObservation, AgentDelegateObservation)
  215. )
  216. if isinstance(last_observation, BrowserOutputObservation):
  217. assert 'OpenDevin is all you need!' in last_observation.content
  218. elif isinstance(last_observation, AgentDelegateObservation):
  219. assert 'OpenDevin is all you need!' in last_observation.outputs['content']