| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625 |
- import logging
- from unittest.mock import Mock, patch
- import pytest
- from pytest import TempPathFactory
- from openhands.controller.agent_controller import AgentController
- from openhands.controller.state.state import State
- from openhands.controller.stuck import StuckDetector
- from openhands.events.action import CmdRunAction, FileReadAction, MessageAction
- from openhands.events.action.commands import IPythonRunCellAction
- from openhands.events.observation import (
- CmdOutputObservation,
- FileReadObservation,
- )
- from openhands.events.observation.commands import IPythonRunCellObservation
- from openhands.events.observation.empty import NullObservation
- from openhands.events.observation.error import ErrorObservation
- from openhands.events.stream import EventSource, EventStream
- from openhands.storage import get_file_store
- def collect_events(stream):
- return [event for event in stream.get_events()]
- logging.basicConfig(level=logging.DEBUG)
- jupyter_line_1 = '\n[Jupyter current working directory:'
- jupyter_line_2 = '\n[Jupyter Python interpreter:'
- code_snippet = """
- edit_file_by_replace(
- 'book_store.py',
- to_replace=\"""def total(basket):
- if not basket:
- return 0
- """
- @pytest.fixture
- def temp_dir(tmp_path_factory: TempPathFactory) -> str:
- return str(tmp_path_factory.mktemp('test_is_stuck'))
- @pytest.fixture
- def event_stream(temp_dir):
- file_store = get_file_store('local', temp_dir)
- event_stream = EventStream('asdf', file_store)
- yield event_stream
- # clear after each test
- event_stream.clear()
- class TestStuckDetector:
- @pytest.fixture
- def stuck_detector(self):
- state = State(inputs={}, max_iterations=50)
- state.history = [] # Initialize history as an empty list
- return StuckDetector(state)
- def _impl_syntax_error_events(
- self,
- state: State,
- error_message: str,
- random_line: bool,
- incidents: int = 4,
- ):
- for i in range(incidents):
- ipython_action = IPythonRunCellAction(code=code_snippet)
- state.history.append(ipython_action)
- extra_number = (i + 1) * 10 if random_line else '42'
- extra_line = '\n' * (i + 1) if random_line else ''
- ipython_observation = IPythonRunCellObservation(
- content=f' Cell In[1], line {extra_number}\n'
- 'to_replace="""def largest(min_factor, max_factor):\n ^\n'
- f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2,
- code=code_snippet,
- )
- # ipython_observation._cause = ipython_action._id
- state.history.append(ipython_observation)
- def _impl_unterminated_string_error_events(
- self, state: State, random_line: bool, incidents: int = 4
- ):
- for i in range(incidents):
- ipython_action = IPythonRunCellAction(code=code_snippet)
- state.history.append(ipython_action)
- line_number = (i + 1) * 10 if random_line else '1'
- ipython_observation = IPythonRunCellObservation(
- content=f'print(" Cell In[1], line {line_number}\nhello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})'
- + jupyter_line_1
- + jupyter_line_2,
- code=code_snippet,
- )
- # ipython_observation._cause = ipython_action._
- state.history.append(ipython_observation)
- def test_history_too_short(self, stuck_detector: StuckDetector):
- state = stuck_detector.state
- message_action = MessageAction(content='Hello', wait_for_response=False)
- message_action._source = EventSource.USER
- observation = NullObservation(content='')
- # observation._cause = message_action.id
- state.history.append(message_action)
- state.history.append(observation)
- cmd_action = CmdRunAction(command='ls')
- state.history.append(cmd_action)
- cmd_observation = CmdOutputObservation(
- command_id=1, command='ls', content='file1.txt\nfile2.txt'
- )
- # cmd_observation._cause = cmd_action._id
- state.history.append(cmd_observation)
- assert stuck_detector.is_stuck(headless_mode=True) is False
- def test_interactive_mode_resets_after_user_message(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- # First add some actions that would be stuck in non-UI mode
- for i in range(4):
- cmd_action = CmdRunAction(command='ls')
- cmd_action._id = i
- state.history.append(cmd_action)
- cmd_observation = CmdOutputObservation(
- content='', command='ls', command_id=i
- )
- cmd_observation._cause = cmd_action._id
- state.history.append(cmd_observation)
- # In headless mode, this should be stuck
- assert stuck_detector.is_stuck(headless_mode=True) is True
- # with the UI, it will ALSO be stuck initially
- assert stuck_detector.is_stuck(headless_mode=False) is True
- # Add a user message
- message_action = MessageAction(content='Hello', wait_for_response=False)
- message_action._source = EventSource.USER
- state.history.append(message_action)
- # In not-headless mode, this should not be stuck because we ignore history before user message
- assert stuck_detector.is_stuck(headless_mode=False) is False
- # But in headless mode, this should be still stuck because user messages do not count
- assert stuck_detector.is_stuck(headless_mode=True) is True
- # Add two more identical actions - still not stuck because we need at least 3
- for i in range(2):
- cmd_action = CmdRunAction(command='ls')
- cmd_action._id = i + 4
- state.history.append(cmd_action)
- cmd_observation = CmdOutputObservation(
- content='', command='ls', command_id=i + 4
- )
- cmd_observation._cause = cmd_action._id
- state.history.append(cmd_observation)
- assert stuck_detector.is_stuck(headless_mode=False) is False
- # Add two more identical actions - now it should be stuck
- for i in range(2):
- cmd_action = CmdRunAction(command='ls')
- cmd_action._id = i + 6
- state.history.append(cmd_action)
- cmd_observation = CmdOutputObservation(
- content='', command='ls', command_id=i + 6
- )
- cmd_observation._cause = cmd_action._id
- state.history.append(cmd_observation)
- assert stuck_detector.is_stuck(headless_mode=False) is True
- def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector):
- state = stuck_detector.state
- message_action = MessageAction(content='Done', wait_for_response=False)
- message_action._source = EventSource.USER
- hello_action = MessageAction(content='Hello', wait_for_response=False)
- hello_observation = NullObservation('')
- # 2 events
- state.history.append(hello_action)
- state.history.append(hello_observation)
- cmd_action_1 = CmdRunAction(command='ls')
- cmd_action_1._id = 1
- state.history.append(cmd_action_1)
- cmd_observation_1 = CmdOutputObservation(content='', command='ls', command_id=1)
- cmd_observation_1._cause = cmd_action_1._id
- state.history.append(cmd_observation_1)
- # 4 events
- cmd_action_2 = CmdRunAction(command='ls')
- cmd_action_2._id = 2
- state.history.append(cmd_action_2)
- cmd_observation_2 = CmdOutputObservation(content='', command='ls', command_id=2)
- cmd_observation_2._cause = cmd_action_2._id
- state.history.append(cmd_observation_2)
- # 6 events
- # random user message just because we can
- message_null_observation = NullObservation(content='')
- state.history.append(message_action)
- state.history.append(message_null_observation)
- # 8 events
- assert stuck_detector.is_stuck(headless_mode=True) is False
- cmd_action_3 = CmdRunAction(command='ls')
- cmd_action_3._id = 3
- state.history.append(cmd_action_3)
- cmd_observation_3 = CmdOutputObservation(content='', command='ls', command_id=3)
- cmd_observation_3._cause = cmd_action_3._id
- state.history.append(cmd_observation_3)
- # 10 events
- assert len(state.history) == 10
- assert stuck_detector.is_stuck(headless_mode=True) is False
- cmd_action_4 = CmdRunAction(command='ls')
- cmd_action_4._id = 4
- state.history.append(cmd_action_4)
- cmd_observation_4 = CmdOutputObservation(content='', command='ls', command_id=4)
- cmd_observation_4._cause = cmd_action_4._id
- state.history.append(cmd_observation_4)
- # 12 events
- assert len(state.history) == 12
- with patch('logging.Logger.warning') as mock_warning:
- assert stuck_detector.is_stuck(headless_mode=True) is True
- mock_warning.assert_called_once_with('Action, Observation loop detected')
- def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
- state = stuck_detector.state
- # (action, error_observation), not necessarily the same error
- message_action = MessageAction(content='Done', wait_for_response=False)
- message_action._source = EventSource.USER
- hello_action = MessageAction(content='Hello', wait_for_response=False)
- hello_observation = NullObservation(content='')
- state.history.append(hello_action)
- # hello_observation._cause = hello_action._id
- state.history.append(hello_observation)
- # 2 events
- cmd_action_1 = CmdRunAction(command='invalid_command')
- state.history.append(cmd_action_1)
- error_observation_1 = ErrorObservation(content='Command not found')
- # error_observation_1._cause = cmd_action_1._id
- state.history.append(error_observation_1)
- # 4 events
- cmd_action_2 = CmdRunAction(command='invalid_command')
- state.history.append(cmd_action_2)
- error_observation_2 = ErrorObservation(
- content='Command still not found or another error'
- )
- # error_observation_2._cause = cmd_action_2._id
- state.history.append(error_observation_2)
- # 6 events
- message_null_observation = NullObservation(content='')
- state.history.append(message_action)
- state.history.append(message_null_observation)
- # 8 events
- cmd_action_3 = CmdRunAction(command='invalid_command')
- state.history.append(cmd_action_3)
- error_observation_3 = ErrorObservation(content='Different error')
- # error_observation_3._cause = cmd_action_3._id
- state.history.append(error_observation_3)
- # 10 events
- cmd_action_4 = CmdRunAction(command='invalid_command')
- state.history.append(cmd_action_4)
- error_observation_4 = ErrorObservation(content='Command not found')
- # error_observation_4._cause = cmd_action_4._id
- state.history.append(error_observation_4)
- # 12 events
- with patch('logging.Logger.warning') as mock_warning:
- assert stuck_detector.is_stuck(headless_mode=True) is True
- mock_warning.assert_called_once_with(
- 'Action, ErrorObservation loop detected'
- )
- def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector):
- state = stuck_detector.state
- self._impl_syntax_error_events(
- state,
- error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
- random_line=False,
- )
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is True
- def test_is_not_stuck_invalid_syntax_error_random_lines(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- self._impl_syntax_error_events(
- state,
- error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
- random_line=True,
- )
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is False
- def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- self._impl_syntax_error_events(
- state,
- error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?',
- random_line=True,
- incidents=3,
- )
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is False
- def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
- state = stuck_detector.state
- self._impl_syntax_error_events(
- state,
- error_message='SyntaxError: incomplete input',
- random_line=False,
- )
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is True
- def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
- state = stuck_detector.state
- self._impl_syntax_error_events(
- state,
- error_message='SyntaxError: incomplete input',
- random_line=True,
- )
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is False
- def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- self._impl_unterminated_string_error_events(state, random_line=True)
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is False
- def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- self._impl_unterminated_string_error_events(
- state, random_line=False, incidents=3
- )
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is False
- def test_is_stuck_ipython_unterminated_string_error(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- self._impl_unterminated_string_error_events(state, random_line=False)
- with patch('logging.Logger.warning'):
- assert stuck_detector.is_stuck(headless_mode=True) is True
- def test_is_not_stuck_ipython_syntax_error_not_at_end(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- # this test is to make sure we don't get false positives
- # since the "at line x" is changing in between!
- ipython_action_1 = IPythonRunCellAction(code='print("hello')
- state.history.append(ipython_action_1)
- ipython_observation_1 = IPythonRunCellObservation(
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output',
- code='print("hello',
- )
- # ipython_observation_1._cause = ipython_action_1._id
- state.history.append(ipython_observation_1)
- ipython_action_2 = IPythonRunCellAction(code='print("hello')
- state.history.append(ipython_action_2)
- ipython_observation_2 = IPythonRunCellObservation(
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on',
- code='print("hello',
- )
- # ipython_observation_2._cause = ipython_action_2._id
- state.history.append(ipython_observation_2)
- ipython_action_3 = IPythonRunCellAction(code='print("hello')
- state.history.append(ipython_action_3)
- ipython_observation_3 = IPythonRunCellObservation(
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough',
- code='print("hello',
- )
- # ipython_observation_3._cause = ipython_action_3._id
- state.history.append(ipython_observation_3)
- ipython_action_4 = IPythonRunCellAction(code='print("hello')
- state.history.append(ipython_action_4)
- ipython_observation_4 = IPythonRunCellObservation(
- content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output',
- code='print("hello',
- )
- # ipython_observation_4._cause = ipython_action_4._id
- state.history.append(ipython_observation_4)
- with patch('logging.Logger.warning') as mock_warning:
- assert stuck_detector.is_stuck(headless_mode=True) is False
- mock_warning.assert_not_called()
- def test_is_stuck_repeating_action_observation_pattern(
- self, stuck_detector: StuckDetector
- ):
- state = stuck_detector.state
- message_action = MessageAction(content='Come on', wait_for_response=False)
- message_action._source = EventSource.USER
- state.history.append(message_action)
- message_observation = NullObservation(content='')
- state.history.append(message_observation)
- cmd_action_1 = CmdRunAction(command='ls')
- state.history.append(cmd_action_1)
- cmd_observation_1 = CmdOutputObservation(
- command_id=1, command='ls', content='file1.txt\nfile2.txt'
- )
- # cmd_observation_1._cause = cmd_action_1._id
- state.history.append(cmd_observation_1)
- read_action_1 = FileReadAction(path='file1.txt')
- state.history.append(read_action_1)
- read_observation_1 = FileReadObservation(
- content='File content', path='file1.txt'
- )
- # read_observation_1._cause = read_action_1._id
- state.history.append(read_observation_1)
- cmd_action_2 = CmdRunAction(command='ls')
- state.history.append(cmd_action_2)
- cmd_observation_2 = CmdOutputObservation(
- command_id=2, command='ls', content='file1.txt\nfile2.txt'
- )
- # cmd_observation_2._cause = cmd_action_2._id
- state.history.append(cmd_observation_2)
- read_action_2 = FileReadAction(path='file1.txt')
- state.history.append(read_action_2)
- read_observation_2 = FileReadObservation(
- content='File content', path='file1.txt'
- )
- # read_observation_2._cause = read_action_2._id
- state.history.append(read_observation_2)
- message_action = MessageAction(content='Come on', wait_for_response=False)
- message_action._source = EventSource.USER
- state.history.append(message_action)
- message_null_observation = NullObservation(content='')
- state.history.append(message_null_observation)
- cmd_action_3 = CmdRunAction(command='ls')
- state.history.append(cmd_action_3)
- cmd_observation_3 = CmdOutputObservation(
- command_id=3, command='ls', content='file1.txt\nfile2.txt'
- )
- # cmd_observation_3._cause = cmd_action_3._id
- state.history.append(cmd_observation_3)
- read_action_3 = FileReadAction(path='file1.txt')
- state.history.append(read_action_3)
- read_observation_3 = FileReadObservation(
- content='File content', path='file1.txt'
- )
- # read_observation_3._cause = read_action_3._id
- state.history.append(read_observation_3)
- with patch('logging.Logger.warning') as mock_warning:
- assert stuck_detector.is_stuck(headless_mode=True) is True
- mock_warning.assert_called_once_with('Action, Observation pattern detected')
- def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
- state = stuck_detector.state
- message_action = MessageAction(content='Done', wait_for_response=False)
- message_action._source = EventSource.USER
- hello_action = MessageAction(content='Hello', wait_for_response=False)
- state.history.append(hello_action)
- hello_observation = NullObservation(content='')
- # hello_observation._cause = hello_action._id
- state.history.append(hello_observation)
- cmd_action_1 = CmdRunAction(command='ls')
- state.history.append(cmd_action_1)
- cmd_observation_1 = CmdOutputObservation(
- command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt'
- )
- # cmd_observation_1._cause = cmd_action_1._id
- state.history.append(cmd_observation_1)
- read_action_1 = FileReadAction(path='file1.txt')
- state.history.append(read_action_1)
- read_observation_1 = FileReadObservation(
- content='File content', path='file1.txt'
- )
- # read_observation_1._cause = read_action_1._id
- state.history.append(read_observation_1)
- cmd_action_2 = CmdRunAction(command='pwd')
- state.history.append(cmd_action_2)
- cmd_observation_2 = CmdOutputObservation(
- command_id=2, command='pwd', content='/home/user'
- )
- # cmd_observation_2._cause = cmd_action_2._id
- state.history.append(cmd_observation_2)
- read_action_2 = FileReadAction(path='file2.txt')
- state.history.append(read_action_2)
- read_observation_2 = FileReadObservation(
- content='Another file content', path='file2.txt'
- )
- # read_observation_2._cause = read_action_2._id
- state.history.append(read_observation_2)
- message_null_observation = NullObservation(content='')
- state.history.append(message_action)
- state.history.append(message_null_observation)
- cmd_action_3 = CmdRunAction(command='pwd')
- state.history.append(cmd_action_3)
- cmd_observation_3 = CmdOutputObservation(
- command_id=cmd_action_3.id, command='pwd', content='/home/user'
- )
- # cmd_observation_3._cause = cmd_action_3._id
- state.history.append(cmd_observation_3)
- read_action_3 = FileReadAction(path='file2.txt')
- state.history.append(read_action_3)
- read_observation_3 = FileReadObservation(
- content='Another file content', path='file2.txt'
- )
- # read_observation_3._cause = read_action_3._id
- state.history.append(read_observation_3)
- assert stuck_detector.is_stuck(headless_mode=True) is False
- def test_is_stuck_monologue(self, stuck_detector):
- state = stuck_detector.state
- # Add events to the history list directly
- message_action_1 = MessageAction(content='Hi there!')
- message_action_1._source = EventSource.USER
- state.history.append(message_action_1)
- message_action_2 = MessageAction(content='Hi there!')
- message_action_2._source = EventSource.AGENT
- state.history.append(message_action_2)
- message_action_3 = MessageAction(content='How are you?')
- message_action_3._source = EventSource.USER
- state.history.append(message_action_3)
- cmd_kill_action = CmdRunAction(
- command='echo 42', thought="I'm not stuck, he's stuck"
- )
- state.history.append(cmd_kill_action)
- message_action_4 = MessageAction(content="I'm doing well, thanks for asking.")
- message_action_4._source = EventSource.AGENT
- state.history.append(message_action_4)
- message_action_5 = MessageAction(content="I'm doing well, thanks for asking.")
- message_action_5._source = EventSource.AGENT
- state.history.append(message_action_5)
- message_action_6 = MessageAction(content="I'm doing well, thanks for asking.")
- message_action_6._source = EventSource.AGENT
- state.history.append(message_action_6)
- assert stuck_detector.is_stuck(headless_mode=True)
- # Add an observation event between the repeated message actions
- cmd_output_observation = CmdOutputObservation(
- content='OK, I was stuck, but no more.',
- command_id=42,
- command='storybook',
- exit_code=0,
- )
- # cmd_output_observation._cause = cmd_kill_action._id
- state.history.append(cmd_output_observation)
- message_action_7 = MessageAction(content="I'm doing well, thanks for asking.")
- message_action_7._source = EventSource.AGENT
- state.history.append(message_action_7)
- message_action_8 = MessageAction(content="I'm doing well, thanks for asking.")
- message_action_8._source = EventSource.AGENT
- state.history.append(message_action_8)
- with patch('logging.Logger.warning'):
- assert not stuck_detector.is_stuck(headless_mode=True)
- class TestAgentController:
- @pytest.fixture
- def controller(self):
- controller = Mock(spec=AgentController)
- controller._is_stuck = AgentController._is_stuck.__get__(
- controller, AgentController
- )
- controller.delegate = None
- controller.state = Mock()
- return controller
- def test_is_stuck_delegate_stuck(self, controller: AgentController):
- controller.delegate = Mock()
- controller.delegate._is_stuck.return_value = True
- assert controller._is_stuck() is True
|