conftest.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import io
  2. import os
  3. import re
  4. import shutil
  5. import socket
  6. import subprocess
  7. import tempfile
  8. from functools import partial
  9. from http.server import HTTPServer, SimpleHTTPRequestHandler
  10. import pytest
  11. from litellm import completion
  12. from openhands.llm.llm import message_separator
  13. script_dir = os.environ.get('SCRIPT_DIR')
  14. project_root = os.environ.get('PROJECT_ROOT')
  15. workspace_path = os.environ.get('WORKSPACE_BASE')
  16. test_runtime = os.environ.get('TEST_RUNTIME')
  17. MOCK_ROOT_DIR = os.path.join(
  18. script_dir,
  19. 'mock',
  20. f'{test_runtime}_runtime',
  21. os.environ.get('DEFAULT_AGENT'),
  22. )
  23. assert script_dir is not None, 'SCRIPT_DIR environment variable is not set'
  24. assert project_root is not None, 'PROJECT_ROOT environment variable is not set'
  25. assert workspace_path is not None, 'WORKSPACE_BASE environment variable is not set'
  26. assert test_runtime is not None, 'TEST_RUNTIME environment variable is not set'
  27. class SecretExit(Exception):
  28. pass
  29. @pytest.hookimpl(tryfirst=True)
  30. def pytest_exception_interact(node, call, report):
  31. if isinstance(call.excinfo.value, SecretExit):
  32. report.outcome = 'failed'
  33. report.longrepr = (
  34. 'SecretExit: Exiting due to an error without revealing secrets.'
  35. )
  36. call.excinfo = None
  37. def filter_out_symbols(input):
  38. # remove shell hostname patterns (e.g., will change between each run)
  39. # openhands@379c7fce40b4:/workspace $
  40. input = re.sub(r'(openhands|root)@.*(:/.*)', r'\1[DUMMY_HOSTNAME]\2', input)
  41. # mask the specific part in a poetry path
  42. input = re.sub(
  43. r'(/open[a-z]{5}/poetry/open[a-z]{5}-)[a-zA-Z0-9-]+(-py3\.\d+/bin/python)',
  44. r'\1[DUMMY_STRING]\2',
  45. input,
  46. )
  47. # handle size param
  48. input = re.sub(r' size=\d+ ', ' size=[DUMMY_SIZE] ', input)
  49. # handle sha256 hashes
  50. # sha256=4ecf8be428f55981e2a188f510ba5f9022bed88f5fb404d7d949f44382201e3d
  51. input = re.sub(r'sha256=[a-z0-9]+', 'sha256=[DUMMY_HASH]', input)
  52. # remove newlines and whitespace
  53. input = re.sub(r'\\n|\\r\\n|\\r|\s+', '', input)
  54. # remove all non-alphanumeric characters
  55. input = re.sub(r'[^a-zA-Z0-9]', '', input)
  56. return input
  57. def get_log_id(prompt_log_name):
  58. match = re.search(r'prompt_(\d+).log', prompt_log_name)
  59. if match:
  60. return match.group(1)
  61. def _format_messages(messages):
  62. message_str = ''
  63. for message in messages:
  64. if isinstance(message, str):
  65. message_str += message_separator + message if message_str else message
  66. elif isinstance(message, dict):
  67. if isinstance(message['content'], list):
  68. for m in message['content']:
  69. if isinstance(m, str):
  70. message_str += message_separator + m if message_str else m
  71. elif isinstance(m, dict) and m['type'] == 'text':
  72. message_str += (
  73. message_separator + m['text'] if message_str else m['text']
  74. )
  75. elif isinstance(message['content'], str):
  76. message_str += (
  77. message_separator + message['content']
  78. if message_str
  79. else message['content']
  80. )
  81. return message_str
  82. def apply_prompt_and_get_mock_response(
  83. test_name: str, messages: str, id: int
  84. ) -> str | None:
  85. """Apply the mock prompt, and find mock response based on id.
  86. If there is no matching response file, return None.
  87. Note: this function blindly replaces existing prompt file with the given
  88. input without checking the contents.
  89. """
  90. mock_dir = os.path.join(MOCK_ROOT_DIR, test_name)
  91. prompt_file_path = os.path.join(mock_dir, f'prompt_{"{0:03}".format(id)}.log')
  92. resp_file_path = os.path.join(mock_dir, f'response_{"{0:03}".format(id)}.log')
  93. try:
  94. # load response
  95. with open(resp_file_path, 'r') as resp_file:
  96. response = resp_file.read()
  97. # apply prompt
  98. with open(prompt_file_path, 'w') as prompt_file:
  99. prompt_file.write(messages)
  100. prompt_file.write('\n')
  101. return response
  102. except FileNotFoundError:
  103. return None
  104. def get_mock_response(test_name: str, messages: str, id: int) -> str:
  105. """Find mock response based on prompt. Prompts are stored under nested
  106. folders under mock folder. If prompt_{id}.log matches,
  107. then the mock response we're looking for is at response_{id}.log.
  108. Note: we filter out all non-alphanumerical characters, otherwise we would
  109. see surprising mismatches caused by linters and minor discrepancies between
  110. different platforms.
  111. We could have done a slightly more efficient string match with the same time
  112. complexity (early-out upon first character mismatch), but it is unnecessary
  113. for tests. Empirically, different prompts of the same task usually only
  114. differ near the end of file, so the comparison would be more efficient if
  115. we start from the end of the file, but again, that is unnecessary and only
  116. makes test code harder to understand.
  117. """
  118. mock_dir = os.path.join(MOCK_ROOT_DIR, test_name)
  119. prompt = filter_out_symbols(messages)
  120. prompt_file_path = os.path.join(mock_dir, f'prompt_{"{0:03}".format(id)}.log')
  121. resp_file_path = os.path.join(mock_dir, f'response_{"{0:03}".format(id)}.log')
  122. # Open the prompt file and compare its contents
  123. with open(prompt_file_path, 'r') as f:
  124. file_content = filter_out_symbols(f.read())
  125. if file_content.strip() == prompt.strip():
  126. # Read the response file and return its content
  127. with open(resp_file_path, 'r') as resp_file:
  128. return resp_file.read()
  129. else:
  130. # print the mismatched lines
  131. print('Mismatched Prompt File path', prompt_file_path)
  132. print('---' * 10)
  133. # Create a temporary file to store messages
  134. with tempfile.NamedTemporaryFile(
  135. delete=False, mode='w', encoding='utf-8'
  136. ) as tmp_file:
  137. tmp_file_path = tmp_file.name
  138. tmp_file.write(messages)
  139. try:
  140. # Use diff command to compare files and capture the output
  141. result = subprocess.run(
  142. ['diff', '-u', prompt_file_path, tmp_file_path],
  143. capture_output=True,
  144. text=True,
  145. )
  146. if result.returncode != 0:
  147. print('Diff:')
  148. print(result.stdout)
  149. else:
  150. print('No differences found.')
  151. finally:
  152. # Clean up the temporary file
  153. os.remove(tmp_file_path)
  154. print('---' * 10)
  155. def mock_user_response(*args, test_name, **kwargs):
  156. """The agent will ask for user input using `input()` when calling `asyncio.run(main(task))`.
  157. This function mocks the user input by providing the response from the mock response file.
  158. It will read the `user_responses.log` file in the test directory and set as
  159. STDIN input for the agent to read.
  160. """
  161. user_response_file = os.path.join(
  162. script_dir,
  163. 'mock',
  164. os.environ.get('DEFAULT_AGENT'),
  165. test_name,
  166. 'user_responses.log',
  167. )
  168. if not os.path.exists(user_response_file):
  169. return ''
  170. with open(user_response_file, 'r') as f:
  171. ret = f.read().rstrip()
  172. ret += '\n'
  173. return ret
  174. def mock_completion(*args, test_name, **kwargs):
  175. global cur_id
  176. messages = kwargs['messages']
  177. message_str = _format_messages(messages) # text only
  178. # this assumes all response_(*).log filenames are in numerical order, starting from one
  179. cur_id += 1
  180. if os.environ.get('FORCE_APPLY_PROMPTS') == 'true':
  181. mock_response = apply_prompt_and_get_mock_response(
  182. test_name, message_str, cur_id
  183. )
  184. else:
  185. mock_response = get_mock_response(test_name, message_str, cur_id)
  186. if mock_response is None:
  187. raise SecretExit('\n\n***** Mock response for prompt is not found *****\n')
  188. response = completion(**kwargs, mock_response=mock_response)
  189. return response
  190. @pytest.fixture
  191. def current_test_name(request):
  192. return request.node.name
  193. @pytest.fixture(autouse=True)
  194. def patch_completion(monkeypatch, request):
  195. test_name = request.node.name
  196. # Mock LLM completion
  197. monkeypatch.setattr(
  198. 'openhands.llm.llm.litellm_completion',
  199. partial(mock_completion, test_name=test_name),
  200. )
  201. # Mock LLM completion cost (1 USD per conversation)
  202. monkeypatch.setattr(
  203. 'openhands.llm.llm.litellm_completion_cost',
  204. lambda completion_response, **extra_kwargs: 1,
  205. )
  206. # Mock LLMConfig to disable vision support
  207. monkeypatch.setattr(
  208. 'openhands.llm.llm.LLM.vision_is_active',
  209. lambda self: False,
  210. )
  211. # Mock user input (only for tests that have user_responses.log)
  212. user_responses_str = mock_user_response(test_name=test_name)
  213. if user_responses_str:
  214. user_responses = io.StringIO(user_responses_str)
  215. monkeypatch.setattr('sys.stdin', user_responses)
  216. class MultiAddressServer(HTTPServer):
  217. def server_bind(self):
  218. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  219. self.socket.bind(self.server_address)
  220. class LoggingHTTPRequestHandler(SimpleHTTPRequestHandler):
  221. def log_message(self, format, *args):
  222. print(
  223. f'Request received: {self.address_string()} - {self.log_date_time_string()} - {format % args}'
  224. )
  225. def set_up():
  226. global cur_id
  227. cur_id = 0
  228. assert workspace_path is not None, 'workspace_path is not set'
  229. # Remove and recreate the workspace_path
  230. if os.path.exists(workspace_path):
  231. shutil.rmtree(workspace_path)
  232. os.makedirs(workspace_path)
  233. @pytest.fixture(autouse=True)
  234. def resource_setup():
  235. try:
  236. original_cwd = os.getcwd()
  237. except FileNotFoundError:
  238. print(
  239. '[DEBUG] Original working directory does not exist. Using /tmp as fallback.'
  240. )
  241. original_cwd = '/tmp'
  242. os.chdir('/tmp')
  243. try:
  244. set_up()
  245. yield
  246. finally:
  247. try:
  248. print(f'[DEBUG] Final working directory: {os.getcwd()}')
  249. except FileNotFoundError:
  250. print('[DEBUG] Final working directory does not exist')
  251. if os.path.exists(workspace_path):
  252. shutil.rmtree(workspace_path)
  253. os.makedirs(workspace_path, exist_ok=True)
  254. # Try to change back to the original directory
  255. try:
  256. os.chdir(original_cwd)
  257. print(f'[DEBUG] Changed back to original directory: {original_cwd}')
  258. except Exception:
  259. os.chdir('/tmp')