conftest.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import io
  2. import os
  3. import re
  4. import subprocess
  5. import tempfile
  6. from functools import partial
  7. from http.server import HTTPServer, SimpleHTTPRequestHandler
  8. from threading import Thread
  9. import pytest
  10. from litellm import completion
  11. from opendevin.llm.llm import message_separator
  12. script_dir = os.path.dirname(os.path.realpath(__file__))
  13. workspace_path = os.getenv('WORKSPACE_BASE')
  14. class SecretExit(Exception):
  15. pass
  16. @pytest.hookimpl(tryfirst=True)
  17. def pytest_exception_interact(node, call, report):
  18. if isinstance(call.excinfo.value, SecretExit):
  19. report.outcome = 'failed'
  20. report.longrepr = (
  21. 'SecretExit: Exiting due to an error without revealing secrets.'
  22. )
  23. call.excinfo = None
  24. def filter_out_symbols(input):
  25. input = re.sub(r'\\n|\\r\\n|\\r|\s+', '', input)
  26. return input
  27. def get_log_id(prompt_log_name):
  28. match = re.search(r'prompt_(\d+).log', prompt_log_name)
  29. if match:
  30. return match.group(1)
  31. def apply_prompt_and_get_mock_response(test_name: str, messages: str, id: int) -> str:
  32. """Apply the mock prompt, and find mock response based on id.
  33. If there is no matching response file, return None.
  34. Note: this function blindly replaces existing prompt file with the given
  35. input without checking the contents.
  36. """
  37. mock_dir = os.path.join(
  38. script_dir, 'mock', os.environ.get('DEFAULT_AGENT'), test_name
  39. )
  40. prompt_file_path = os.path.join(mock_dir, f'prompt_{"{0:03}".format(id)}.log')
  41. resp_file_path = os.path.join(mock_dir, f'response_{"{0:03}".format(id)}.log')
  42. try:
  43. # load response
  44. with open(resp_file_path, 'r') as resp_file:
  45. response = resp_file.read()
  46. # apply prompt
  47. with open(prompt_file_path, 'w') as prompt_file:
  48. prompt_file.write(messages)
  49. return response
  50. except FileNotFoundError:
  51. return None
  52. def get_mock_response(test_name: str, messages: str, id: int) -> str:
  53. """Find mock response based on prompt. Prompts are stored under nested
  54. folders under mock folder. If prompt_{id}.log matches,
  55. then the mock response we're looking for is at response_{id}.log.
  56. Note: we filter out all non-alphanumerical characters, otherwise we would
  57. see surprising mismatches caused by linters and minor discrepancies between
  58. different platforms.
  59. We could have done a slightly more efficient string match with the same time
  60. complexity (early-out upon first character mismatch), but it is unnecessary
  61. for tests. Empirically, different prompts of the same task usually only
  62. differ near the end of file, so the comparison would be more efficient if
  63. we start from the end of the file, but again, that is unnecessary and only
  64. makes test code harder to understand.
  65. """
  66. prompt = filter_out_symbols(messages)
  67. mock_dir = os.path.join(
  68. script_dir, 'mock', os.environ.get('DEFAULT_AGENT'), test_name
  69. )
  70. prompt_file_path = os.path.join(mock_dir, f'prompt_{"{0:03}".format(id)}.log')
  71. resp_file_path = os.path.join(mock_dir, f'response_{"{0:03}".format(id)}.log')
  72. # Open the prompt file and compare its contents
  73. with open(prompt_file_path, 'r') as f:
  74. file_content = filter_out_symbols(f.read())
  75. if file_content == prompt:
  76. # Read the response file and return its content
  77. with open(resp_file_path, 'r') as resp_file:
  78. return resp_file.read()
  79. else:
  80. # print the mismatched lines
  81. print('Mismatched Prompt File path', prompt_file_path)
  82. print('---' * 10)
  83. # Create a temporary file to store messages
  84. with tempfile.NamedTemporaryFile(
  85. delete=False, mode='w', encoding='utf-8'
  86. ) as tmp_file:
  87. tmp_file_path = tmp_file.name
  88. tmp_file.write(messages)
  89. try:
  90. # Use diff command to compare files and capture the output
  91. result = subprocess.run(
  92. ['diff', '-u', prompt_file_path, tmp_file_path],
  93. capture_output=True,
  94. text=True,
  95. )
  96. if result.returncode != 0:
  97. print('Diff:')
  98. print(result.stdout)
  99. else:
  100. print('No differences found.')
  101. finally:
  102. # Clean up the temporary file
  103. os.remove(tmp_file_path)
  104. print('---' * 10)
  105. def mock_user_response(*args, test_name, **kwargs):
  106. """The agent will ask for user input using `input()` when calling `asyncio.run(main(task))`.
  107. This function mocks the user input by providing the response from the mock response file.
  108. It will read the `user_responses.log` file in the test directory and set as
  109. STDIN input for the agent to read.
  110. """
  111. user_response_file = os.path.join(
  112. script_dir,
  113. 'mock',
  114. os.environ.get('DEFAULT_AGENT'),
  115. test_name,
  116. 'user_responses.log',
  117. )
  118. if not os.path.exists(user_response_file):
  119. return ''
  120. with open(user_response_file, 'r') as f:
  121. ret = f.read().rstrip()
  122. ret += '\n'
  123. return ret
  124. def mock_completion(*args, test_name, **kwargs):
  125. global cur_id
  126. messages = kwargs['messages']
  127. message_str = ''
  128. for message in messages:
  129. message_str += message_separator + message['content']
  130. # this assumes all response_(*).log filenames are in numerical order, starting from one
  131. cur_id += 1
  132. if os.environ.get('FORCE_APPLY_PROMPTS') == 'true':
  133. mock_response = apply_prompt_and_get_mock_response(
  134. test_name, message_str, cur_id
  135. )
  136. else:
  137. mock_response = get_mock_response(test_name, message_str, cur_id)
  138. if mock_response is None:
  139. raise SecretExit('Mock response for prompt is not found')
  140. response = completion(**kwargs, mock_response=mock_response)
  141. return response
  142. @pytest.fixture(autouse=True)
  143. def patch_completion(monkeypatch, request):
  144. test_name = request.node.name
  145. # Mock LLM completion
  146. monkeypatch.setattr(
  147. 'opendevin.llm.llm.litellm_completion',
  148. partial(mock_completion, test_name=test_name),
  149. )
  150. # Mock user input (only for tests that have user_responses.log)
  151. user_responses_str = mock_user_response(test_name=test_name)
  152. if user_responses_str:
  153. user_responses = io.StringIO(user_responses_str)
  154. monkeypatch.setattr('sys.stdin', user_responses)
  155. @pytest.fixture
  156. def http_server():
  157. web_dir = os.path.join(os.path.dirname(__file__), 'static')
  158. os.chdir(web_dir)
  159. handler = SimpleHTTPRequestHandler
  160. # Start the server
  161. server = HTTPServer(('localhost', 8000), handler)
  162. thread = Thread(target=server.serve_forever)
  163. thread.setDaemon(True)
  164. thread.start()
  165. yield server
  166. # Stop the server
  167. server.shutdown()
  168. thread.join()
  169. def set_up():
  170. global cur_id
  171. cur_id = 0
  172. assert workspace_path is not None
  173. if os.path.exists(workspace_path):
  174. for file in os.listdir(workspace_path):
  175. os.remove(os.path.join(workspace_path, file))
  176. @pytest.fixture(autouse=True)
  177. def resource_setup():
  178. set_up()
  179. if not os.path.exists(workspace_path):
  180. os.makedirs(workspace_path)
  181. # Yield to test execution
  182. yield