conftest.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import io
  2. import os
  3. import re
  4. import sys
  5. from functools import partial
  6. import pytest
  7. from litellm import completion
  8. from opendevin.llm.llm import message_separator
  9. script_dir = os.path.dirname(os.path.realpath(__file__))
  10. workspace_path = os.getenv('WORKSPACE_BASE')
  11. def filter_out_symbols(input):
  12. return ' '.join([char for char in input if char.isalnum()])
  13. def get_log_id(prompt_log_name):
  14. match = re.search(r'prompt_(\d+).log', prompt_log_name)
  15. if match:
  16. return match.group(1)
  17. def apply_prompt_and_get_mock_response(test_name: str, messages: str, id: int) -> str:
  18. """
  19. Apply the mock prompt, and find mock response based on id.
  20. If there is no matching response file, return None.
  21. Note: this function blindly replaces existing prompt file with the given
  22. input without checking the contents.
  23. """
  24. mock_dir = os.path.join(script_dir, 'mock', os.environ.get('AGENT'), test_name)
  25. prompt_file_path = os.path.join(mock_dir, f'prompt_{"{0:03}".format(id)}.log')
  26. resp_file_path = os.path.join(mock_dir, f'response_{"{0:03}".format(id)}.log')
  27. try:
  28. # load response
  29. with open(resp_file_path, 'r') as resp_file:
  30. response = resp_file.read()
  31. # apply prompt
  32. with open(prompt_file_path, 'w') as prompt_file:
  33. prompt_file.write(messages)
  34. return response
  35. except FileNotFoundError:
  36. return None
  37. def get_mock_response(test_name: str, messages: str):
  38. """
  39. Find mock response based on prompt. Prompts are stored under nested
  40. folders under mock folder. If prompt_{id}.log matches,
  41. then the mock response we're looking for is at response_{id}.log.
  42. Note: we filter out all non alpha-numerical characters, otherwise we would
  43. see surprising mismatches caused by linters and minor discrepancies between
  44. different platforms.
  45. We could have done a slightly more efficient string match with the same time
  46. complexity (early-out upon first character mismatch), but it is unnecessary
  47. for tests. Empirically, different prompts of the same task usually only
  48. differ near the end of file, so the comparison would be more efficient if
  49. we start from the end of the file, but again, that is unnecessary and only
  50. makes test code harder to understand.
  51. """
  52. mock_dir = os.path.join(script_dir, 'mock', os.environ.get('AGENT'), test_name)
  53. prompt = filter_out_symbols(messages)
  54. for root, _, files in os.walk(mock_dir):
  55. for file in files:
  56. if file.startswith('prompt_') and file.endswith('.log'):
  57. file_path = os.path.join(root, file)
  58. # Open the prompt file and compare its contents
  59. with open(file_path, 'r') as f:
  60. file_content = filter_out_symbols(f.read())
  61. if file_content == prompt:
  62. # If a match is found, construct the corresponding response file path
  63. log_id = get_log_id(file_path)
  64. resp_file_path = os.path.join(root, f'response_{log_id}.log')
  65. # Read the response file and return its content
  66. with open(resp_file_path, 'r') as resp_file:
  67. return resp_file.read()
  68. else:
  69. # print the mismatched lines
  70. print('File path', file_path)
  71. print('---' * 10)
  72. print(messages)
  73. print('---' * 10)
  74. for i, (c1, c2) in enumerate(zip(file_content, prompt)):
  75. if c1 != c2:
  76. print(
  77. f'Mismatch at index {i}: {c1[max(0,i-100):i+100]} vs {c2[max(0,i-100):i+100]}'
  78. )
  79. break
  80. def mock_user_response(*args, test_name, **kwargs):
  81. """The agent will ask for user input using `input()` when calling `asyncio.run(main(task))`.
  82. This function mocks the user input by providing the response from the mock response file.
  83. It will read the `user_responses.log` file in the test directory and set as
  84. STDIN input for the agent to read.
  85. """
  86. user_response_file = os.path.join(
  87. script_dir, 'mock', os.environ.get('AGENT'), test_name, 'user_responses.log'
  88. )
  89. if not os.path.exists(user_response_file):
  90. return ''
  91. with open(user_response_file, 'r') as f:
  92. ret = f.read().rstrip()
  93. ret += '\n'
  94. return ret
  95. def mock_completion(*args, test_name, **kwargs):
  96. global cur_id
  97. messages = kwargs['messages']
  98. message_str = ''
  99. for message in messages:
  100. message_str += message_separator + message['content']
  101. if os.environ.get('FORCE_APPLY_PROMPTS') == 'true':
  102. # this assumes all response_(*).log filenames are in numerical order, starting from one
  103. cur_id += 1
  104. mock_response = apply_prompt_and_get_mock_response(
  105. test_name, message_str, cur_id
  106. )
  107. else:
  108. mock_response = get_mock_response(test_name, message_str)
  109. if mock_response is None:
  110. print('Mock response for prompt is not found:\n\n' + message_str)
  111. print('Exiting...')
  112. sys.exit(1)
  113. response = completion(**kwargs, mock_response=mock_response)
  114. return response
  115. @pytest.fixture(autouse=True)
  116. def patch_completion(monkeypatch, request):
  117. test_name = request.node.name
  118. # Mock LLM completion
  119. monkeypatch.setattr(
  120. 'opendevin.llm.llm.litellm_completion',
  121. partial(mock_completion, test_name=test_name),
  122. )
  123. # Mock user input (only for tests that have user_responses.log)
  124. user_responses_str = mock_user_response(test_name=test_name)
  125. if user_responses_str:
  126. user_responses = io.StringIO(user_responses_str)
  127. monkeypatch.setattr('sys.stdin', user_responses)
  128. def set_up():
  129. global cur_id
  130. cur_id = 0
  131. assert workspace_path is not None
  132. if os.path.exists(workspace_path):
  133. for file in os.listdir(workspace_path):
  134. os.remove(os.path.join(workspace_path, file))
  135. @pytest.fixture(autouse=True)
  136. def resource_setup():
  137. set_up()
  138. if not os.path.exists(workspace_path):
  139. os.makedirs(workspace_path)
  140. # Yield to test execution
  141. yield