| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- import io
- import os
- import re
- import sys
- from functools import partial
- import pytest
- from litellm import completion
- from opendevin.llm.llm import message_separator
- script_dir = os.path.dirname(os.path.realpath(__file__))
- workspace_path = os.getenv('WORKSPACE_BASE')
- def filter_out_symbols(input):
- return ' '.join([char for char in input if char.isalnum()])
- def get_log_id(prompt_log_name):
- match = re.search(r'prompt_(\d+).log', prompt_log_name)
- if match:
- return match.group(1)
- def apply_prompt_and_get_mock_response(test_name: str, messages: str, id: int) -> str:
- """
- Apply the mock prompt, and find mock response based on id.
- If there is no matching response file, return None.
- Note: this function blindly replaces existing prompt file with the given
- input without checking the contents.
- """
- mock_dir = os.path.join(script_dir, 'mock', os.environ.get('AGENT'), test_name)
- prompt_file_path = os.path.join(mock_dir, f'prompt_{"{0:03}".format(id)}.log')
- resp_file_path = os.path.join(mock_dir, f'response_{"{0:03}".format(id)}.log')
- try:
- # load response
- with open(resp_file_path, 'r') as resp_file:
- response = resp_file.read()
- # apply prompt
- with open(prompt_file_path, 'w') as prompt_file:
- prompt_file.write(messages)
- return response
- except FileNotFoundError:
- return None
- def get_mock_response(test_name: str, messages: str):
- """
- Find mock response based on prompt. Prompts are stored under nested
- folders under mock folder. If prompt_{id}.log matches,
- then the mock response we're looking for is at response_{id}.log.
- Note: we filter out all non alpha-numerical characters, otherwise we would
- see surprising mismatches caused by linters and minor discrepancies between
- different platforms.
- We could have done a slightly more efficient string match with the same time
- complexity (early-out upon first character mismatch), but it is unnecessary
- for tests. Empirically, different prompts of the same task usually only
- differ near the end of file, so the comparison would be more efficient if
- we start from the end of the file, but again, that is unnecessary and only
- makes test code harder to understand.
- """
- mock_dir = os.path.join(script_dir, 'mock', os.environ.get('AGENT'), test_name)
- prompt = filter_out_symbols(messages)
- for root, _, files in os.walk(mock_dir):
- for file in files:
- if file.startswith('prompt_') and file.endswith('.log'):
- file_path = os.path.join(root, file)
- # Open the prompt file and compare its contents
- with open(file_path, 'r') as f:
- file_content = filter_out_symbols(f.read())
- if file_content == prompt:
- # If a match is found, construct the corresponding response file path
- log_id = get_log_id(file_path)
- resp_file_path = os.path.join(root, f'response_{log_id}.log')
- # Read the response file and return its content
- with open(resp_file_path, 'r') as resp_file:
- return resp_file.read()
- else:
- # print the mismatched lines
- print('File path', file_path)
- print('---' * 10)
- print(messages)
- print('---' * 10)
- for i, (c1, c2) in enumerate(zip(file_content, prompt)):
- if c1 != c2:
- print(
- f'Mismatch at index {i}: {c1[max(0,i-100):i+100]} vs {c2[max(0,i-100):i+100]}'
- )
- break
- def mock_user_response(*args, test_name, **kwargs):
- """The agent will ask for user input using `input()` when calling `asyncio.run(main(task))`.
- This function mocks the user input by providing the response from the mock response file.
- It will read the `user_responses.log` file in the test directory and set as
- STDIN input for the agent to read.
- """
- user_response_file = os.path.join(
- script_dir, 'mock', os.environ.get('AGENT'), test_name, 'user_responses.log'
- )
- if not os.path.exists(user_response_file):
- return ''
- with open(user_response_file, 'r') as f:
- ret = f.read().rstrip()
- ret += '\n'
- return ret
- def mock_completion(*args, test_name, **kwargs):
- global cur_id
- messages = kwargs['messages']
- message_str = ''
- for message in messages:
- message_str += message_separator + message['content']
- if os.environ.get('FORCE_APPLY_PROMPTS') == 'true':
- # this assumes all response_(*).log filenames are in numerical order, starting from one
- cur_id += 1
- mock_response = apply_prompt_and_get_mock_response(
- test_name, message_str, cur_id
- )
- else:
- mock_response = get_mock_response(test_name, message_str)
- if mock_response is None:
- print('Mock response for prompt is not found:\n\n' + message_str)
- print('Exiting...')
- sys.exit(1)
- response = completion(**kwargs, mock_response=mock_response)
- return response
- @pytest.fixture(autouse=True)
- def patch_completion(monkeypatch, request):
- test_name = request.node.name
- # Mock LLM completion
- monkeypatch.setattr(
- 'opendevin.llm.llm.litellm_completion',
- partial(mock_completion, test_name=test_name),
- )
- # Mock user input (only for tests that have user_responses.log)
- user_responses_str = mock_user_response(test_name=test_name)
- if user_responses_str:
- user_responses = io.StringIO(user_responses_str)
- monkeypatch.setattr('sys.stdin', user_responses)
- def set_up():
- global cur_id
- cur_id = 0
- assert workspace_path is not None
- if os.path.exists(workspace_path):
- for file in os.listdir(workspace_path):
- os.remove(os.path.join(workspace_path, file))
- @pytest.fixture(autouse=True)
- def resource_setup():
- set_up()
- if not os.path.exists(workspace_path):
- os.makedirs(workspace_path)
- # Yield to test execution
- yield
|