|
|
@@ -1,4 +1,5 @@
|
|
|
import logging
|
|
|
+import random
|
|
|
from unittest.mock import Mock, patch
|
|
|
|
|
|
import pytest
|
|
|
@@ -27,6 +28,9 @@ def collect_events(stream):
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
|
|
+jupyter_line_1 = '\n[Jupyter current working directory:'
|
|
|
+jupyter_line_2 = '\n[Jupyter Python interpreter:'
|
|
|
+
|
|
|
|
|
|
@pytest.fixture
|
|
|
def temp_dir(tmp_path_factory: TempPathFactory) -> str:
|
|
|
@@ -46,11 +50,44 @@ class TestStuckDetector:
|
|
|
@pytest.fixture
|
|
|
def stuck_detector(self, event_stream):
|
|
|
state = State(inputs={}, max_iterations=50)
|
|
|
- # state.history = ShortTermHistory()
|
|
|
state.history.set_event_stream(event_stream)
|
|
|
|
|
|
return StuckDetector(state)
|
|
|
|
|
|
+ def _impl_syntax_error_events(
|
|
|
+ self, event_stream: EventStream, error_message: str, random_line: bool
|
|
|
+ ):
|
|
|
+ for _ in range(4):
|
|
|
+ ipython_action = IPythonRunCellAction(code='print("hello')
|
|
|
+ event_stream.add_event(ipython_action, EventSource.AGENT)
|
|
|
+ extra_number = random.randint(88, 222) if random_line else '42'
|
|
|
+ extra_line = '\n' * random.randint(2, 3) if random_line else ''
|
|
|
+ ipython_observation = IPythonRunCellObservation(
|
|
|
+ content=f' Cell In[1], line {extra_number}\n'
|
|
|
+ 'to_replace="""def largest(min_factor, max_factor):\n ^\n'
|
|
|
+ f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
|
|
|
+ code='print("hello',
|
|
|
+ )
|
|
|
+ print(ipython_observation.content)
|
|
|
+ ipython_observation._cause = ipython_action._id
|
|
|
+ event_stream.add_event(ipython_observation, EventSource.USER)
|
|
|
+
|
|
|
+ def _impl_unterminated_string_error_events(
|
|
|
+ self, event_stream: EventStream, random_line: bool
|
|
|
+ ):
|
|
|
+ for _ in range(4):
|
|
|
+ ipython_action = IPythonRunCellAction(code='print("hello')
|
|
|
+ event_stream.add_event(ipython_action, EventSource.AGENT)
|
|
|
+ line_number = str(random.randint(1, 10)) if random_line else '1'
|
|
|
+ ipython_observation = IPythonRunCellObservation(
|
|
|
+ content=f'print("hello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
|
|
|
+ + jupyter_line_1
|
|
|
+ + jupyter_line_2,
|
|
|
+ code='print("hello',
|
|
|
+ )
|
|
|
+ ipython_observation._cause = ipython_action._id
|
|
|
+ event_stream.add_event(ipython_observation, EventSource.USER)
|
|
|
+
|
|
|
def test_history_too_short(
|
|
|
self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
):
|
|
|
@@ -202,81 +239,75 @@ class TestStuckDetector:
|
|
|
'Action, ErrorObservation loop detected'
|
|
|
)
|
|
|
|
|
|
- def test_is_stuck_ipython_syntax_error(
|
|
|
+ def test_is_stuck_invalid_syntax_error(
|
|
|
self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
):
|
|
|
- ipython_action_1 = IPythonRunCellAction(code='print("hello')
|
|
|
- event_stream.add_event(ipython_action_1, EventSource.AGENT)
|
|
|
- ipython_observation_1 = IPythonRunCellObservation(
|
|
|
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
|
|
|
- code='print("hello',
|
|
|
+ self._impl_syntax_error_events(
|
|
|
+ event_stream,
|
|
|
+ error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
|
|
+ random_line=False,
|
|
|
)
|
|
|
- ipython_observation_1._cause = ipython_action_1._id
|
|
|
- event_stream.add_event(ipython_observation_1, EventSource.USER)
|
|
|
|
|
|
- ipython_action_2 = IPythonRunCellAction(code='print("hello')
|
|
|
- event_stream.add_event(ipython_action_2, EventSource.AGENT)
|
|
|
- ipython_observation_2 = IPythonRunCellObservation(
|
|
|
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)',
|
|
|
- code='print("hello',
|
|
|
+ with patch('logging.Logger.warning'):
|
|
|
+ assert stuck_detector.is_stuck() is True
|
|
|
+
|
|
|
+ def test_is_not_stuck_invalid_syntax_error(
|
|
|
+ self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
+ ):
|
|
|
+ self._impl_syntax_error_events(
|
|
|
+ event_stream,
|
|
|
+ error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
|
|
+ random_line=True,
|
|
|
)
|
|
|
- ipython_observation_2._cause = ipython_action_2._id
|
|
|
- event_stream.add_event(ipython_observation_2, EventSource.USER)
|
|
|
|
|
|
- ipython_action_3 = IPythonRunCellAction(code='print("hello')
|
|
|
- event_stream.add_event(ipython_action_3, EventSource.AGENT)
|
|
|
- ipython_observation_3 = IPythonRunCellObservation(
|
|
|
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)',
|
|
|
- code='print("hello',
|
|
|
+ with patch('logging.Logger.warning'):
|
|
|
+ assert stuck_detector.is_stuck() is False
|
|
|
+
|
|
|
+ def test_is_stuck_incomplete_input_error(
|
|
|
+ self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
+ ):
|
|
|
+ self._impl_syntax_error_events(
|
|
|
+ event_stream,
|
|
|
+ error_message='SyntaxError: incomplete input',
|
|
|
+ random_line=False,
|
|
|
)
|
|
|
- ipython_observation_3._cause = ipython_action_3._id
|
|
|
- event_stream.add_event(ipython_observation_3, EventSource.USER)
|
|
|
|
|
|
- ipython_action_4 = IPythonRunCellAction(code='print("hello')
|
|
|
- event_stream.add_event(ipython_action_4, EventSource.AGENT)
|
|
|
- ipython_observation_4 = IPythonRunCellObservation(
|
|
|
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)',
|
|
|
- code='print("hello',
|
|
|
+ with patch('logging.Logger.warning'):
|
|
|
+ assert stuck_detector.is_stuck() is True
|
|
|
+
|
|
|
+ def test_is_not_stuck_incomplete_input_error(
|
|
|
+ self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
+ ):
|
|
|
+ self._impl_syntax_error_events(
|
|
|
+ event_stream,
|
|
|
+ error_message='SyntaxError: incomplete input',
|
|
|
+ random_line=True,
|
|
|
)
|
|
|
- ipython_observation_4._cause = ipython_action_4._id
|
|
|
- event_stream.add_event(ipython_observation_4, EventSource.USER)
|
|
|
|
|
|
- # stuck_detector.state.history.set_event_stream(event_stream)
|
|
|
+ with patch('logging.Logger.warning'):
|
|
|
+ assert stuck_detector.is_stuck() is False
|
|
|
|
|
|
- last_observations = [
|
|
|
- ipython_observation_1,
|
|
|
- ipython_observation_2,
|
|
|
- ipython_observation_3,
|
|
|
- ipython_observation_4,
|
|
|
- ]
|
|
|
- for observation in last_observations:
|
|
|
- has_string = (
|
|
|
- observation.content[-100:].find(
|
|
|
- 'SyntaxError: unterminated string literal (detected at line'
|
|
|
- )
|
|
|
- != -1
|
|
|
- )
|
|
|
- assert has_string
|
|
|
-
|
|
|
- string_is_last = (
|
|
|
- len(
|
|
|
- observation.content.split(
|
|
|
- 'SyntaxError: unterminated string literal (detected at line'
|
|
|
- )[-1]
|
|
|
- )
|
|
|
- < 10
|
|
|
- )
|
|
|
- assert string_is_last
|
|
|
+ def test_is_not_stuck_ipython_unterminated_string_error(
|
|
|
+ self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
+ ):
|
|
|
+ self._impl_unterminated_string_error_events(event_stream, random_line=True)
|
|
|
|
|
|
- with patch('logging.Logger.warning') as mock_warning:
|
|
|
+ with patch('logging.Logger.warning'):
|
|
|
+ assert stuck_detector.is_stuck() is False
|
|
|
+
|
|
|
+ def test_is_stuck_ipython_unterminated_string_error(
|
|
|
+ self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
+ ):
|
|
|
+ self._impl_unterminated_string_error_events(event_stream, random_line=False)
|
|
|
+
|
|
|
+ with patch('logging.Logger.warning'):
|
|
|
assert stuck_detector.is_stuck() is True
|
|
|
- mock_warning.assert_called_once_with(
|
|
|
- 'Action, IPythonRunCellObservation loop detected'
|
|
|
- )
|
|
|
|
|
|
def test_is_stuck_ipython_syntax_error_not_at_end(
|
|
|
self, stuck_detector: StuckDetector, event_stream: EventStream
|
|
|
):
|
|
|
+ # this test is to make sure we don't get false positives
|
|
|
+ # since the "at line x" is changing in between!
|
|
|
ipython_action_1 = IPythonRunCellAction(code='print("hello')
|
|
|
event_stream.add_event(ipython_action_1, EventSource.AGENT)
|
|
|
ipython_observation_1 = IPythonRunCellObservation(
|
|
|
@@ -480,8 +511,6 @@ class TestStuckDetector:
|
|
|
event_stream.add_event(message_action_6, EventSource.AGENT)
|
|
|
message_action_6._source = EventSource.AGENT
|
|
|
|
|
|
- # stuck_detector.state.history.set_event_stream(event_stream)
|
|
|
-
|
|
|
assert stuck_detector.is_stuck()
|
|
|
|
|
|
# Add an observation event between the repeated message actions
|
|
|
@@ -502,7 +531,8 @@ class TestStuckDetector:
|
|
|
event_stream.add_event(message_action_8, EventSource.AGENT)
|
|
|
message_action_8._source = EventSource.AGENT
|
|
|
|
|
|
- assert not stuck_detector.is_stuck()
|
|
|
+ with patch('logging.Logger.warning'):
|
|
|
+ assert not stuck_detector.is_stuck()
|
|
|
|
|
|
|
|
|
class TestAgentController:
|