conftest.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import re
  2. import os
  3. from functools import partial
  4. import pytest
  5. from litellm import completion
  6. script_dir = os.path.dirname(os.path.realpath(__file__))
  7. def filter_out_symbols(input):
  8. return ' '.join([char for char in input if char.isalpha()])
  9. def get_log_id(prompt_log_name):
  10. match = re.search(r'prompt_(\d+).log', prompt_log_name)
  11. if match:
  12. return match.group(1)
  13. def get_mock_response(test_name, messages):
  14. """
  15. Find mock response based on prompt. Prompts are stored under nested
  16. folders under mock folder. If prompt_{id}.log matches,
  17. then the mock response we're looking for is at response_{id}.log.
  18. Note: we filter out all non alpha-numerical characters, otherwise we would
  19. see surprising mismatches caused by linters and minor discrepancies between
  20. different platforms.
  21. We could have done a slightly more efficient string match with the same time
  22. complexity (early-out upon first character mismatch), but it is unnecessary
  23. for tests. Empirically, different prompts of the same task usually only
  24. differ near the end of file, so the comparison would be more efficient if
  25. we start from the end of the file, but again, that is unnecessary and only
  26. makes test code harder to understand.
  27. """
  28. mock_dir = os.path.join(script_dir, 'mock', os.environ.get('AGENT'), test_name)
  29. prompt = filter_out_symbols(messages)
  30. for root, _, files in os.walk(mock_dir):
  31. for file in files:
  32. if file.startswith('prompt_') and file.endswith('.log'):
  33. file_path = os.path.join(root, file)
  34. # Open the prompt file and compare its contents
  35. with open(file_path, 'r') as f:
  36. file_content = filter_out_symbols(f.read())
  37. if file_content == prompt:
  38. # If a match is found, construct the corresponding response file path
  39. log_id = get_log_id(file_path)
  40. resp_file_path = os.path.join(root, f'response_{log_id}.log')
  41. # Read the response file and return its content
  42. with open(resp_file_path, 'r') as resp_file:
  43. return resp_file.read()
  44. def mock_completion(*args, test_name, **kwargs):
  45. messages = kwargs['messages']
  46. message_str = ''
  47. for message in messages:
  48. message_str += message['content']
  49. mock_response = get_mock_response(test_name, message_str)
  50. assert mock_response is not None, 'Mock response for prompt is not found'
  51. response = completion(**kwargs, mock_response=mock_response)
  52. return response
  53. @pytest.fixture(autouse=True)
  54. def patch_completion(monkeypatch, request):
  55. test_name = request.node.name
  56. monkeypatch.setattr('opendevin.llm.llm.litellm_completion', partial(mock_completion, test_name=test_name))