|
|
@@ -28,6 +28,17 @@ print(f'workspace_mount_path: {workspace_mount_path}')
|
|
|
print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
|
|
|
|
|
|
|
|
|
+def validate_final_state(final_state: State | None):
|
|
|
+ assert final_state is not None
|
|
|
+ assert final_state.agent_state == AgentState.STOPPED
|
|
|
+ assert final_state.last_error is None
|
|
|
+ if final_state.history.has_delegation():
|
|
|
+ assert final_state.iteration > final_state.local_iteration
|
|
|
+ else:
|
|
|
+ assert final_state.local_iteration == final_state.iteration
|
|
|
+ assert final_state.iteration > 0
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.skipif(
|
|
|
os.getenv('DEFAULT_AGENT') == 'BrowsingAgent',
|
|
|
reason='BrowsingAgent is a specialized agent',
|
|
|
@@ -112,8 +123,7 @@ def test_edits():
|
|
|
final_state: State | None = asyncio.run(
|
|
|
run_agent_controller(agent, task, exit_on_message=True)
|
|
|
)
|
|
|
- assert final_state.agent_state == AgentState.STOPPED
|
|
|
- assert final_state.last_error is None
|
|
|
+ validate_final_state(final_state)
|
|
|
|
|
|
# Verify bad.txt has been fixed
|
|
|
text = """This is a stupid typo.
|
|
|
@@ -146,8 +156,7 @@ def test_ipython():
|
|
|
final_state: State | None = asyncio.run(
|
|
|
run_agent_controller(agent, task, exit_on_message=True)
|
|
|
)
|
|
|
- assert final_state.agent_state == AgentState.STOPPED
|
|
|
- assert final_state.last_error is None
|
|
|
+ validate_final_state(final_state)
|
|
|
|
|
|
# Verify the file exists
|
|
|
file_path = os.path.join(workspace_base, 'test.txt')
|
|
|
@@ -179,8 +188,7 @@ def test_simple_task_rejection():
|
|
|
# the workspace is not a git repo
|
|
|
task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
|
|
|
final_state: State | None = asyncio.run(run_agent_controller(agent, task))
|
|
|
- assert final_state.agent_state == AgentState.STOPPED
|
|
|
- assert final_state.last_error is None
|
|
|
+ validate_final_state(final_state)
|
|
|
assert isinstance(final_state.history.get_last_action(), AgentRejectAction)
|
|
|
|
|
|
|
|
|
@@ -204,8 +212,7 @@ def test_ipython_module():
|
|
|
final_state: State | None = asyncio.run(
|
|
|
run_agent_controller(agent, task, exit_on_message=True)
|
|
|
)
|
|
|
- assert final_state.agent_state == AgentState.STOPPED
|
|
|
- assert final_state.last_error is None
|
|
|
+ validate_final_state(final_state)
|
|
|
|
|
|
# Verify the file exists
|
|
|
file_path = os.path.join(workspace_base, 'test.txt')
|
|
|
@@ -244,8 +251,7 @@ def test_browse_internet(http_server):
|
|
|
final_state: State | None = asyncio.run(
|
|
|
run_agent_controller(agent, task, exit_on_message=True)
|
|
|
)
|
|
|
- assert final_state.agent_state == AgentState.STOPPED
|
|
|
- assert final_state.last_error is None
|
|
|
+ validate_final_state(final_state)
|
|
|
|
|
|
# last action
|
|
|
last_action = final_state.history.get_last_action()
|