|
|
@@ -33,7 +33,7 @@ from opendevin.events.observation import (
|
|
|
Observation,
|
|
|
)
|
|
|
from opendevin.events.stream import EventSource, EventStream, EventStreamSubscriber
|
|
|
-from opendevin.runtime import DockerSSHBox
|
|
|
+from opendevin.runtime import DockerSSHBox, Sandbox
|
|
|
from opendevin.runtime.runtime import Runtime
|
|
|
from opendevin.runtime.server.runtime import ServerRuntime
|
|
|
|
|
|
@@ -60,6 +60,8 @@ class AgentController:
|
|
|
sid: str = 'default',
|
|
|
max_iterations: int = MAX_ITERATIONS,
|
|
|
max_chars: int = MAX_CHARS,
|
|
|
+ sandbox: Optional[Sandbox] = None,
|
|
|
+ remind_iterations: bool = config.remind_iterations,
|
|
|
):
|
|
|
"""Initializes a new instance of the AgentController class.
|
|
|
|
|
|
@@ -68,6 +70,8 @@ class AgentController:
|
|
|
sid: The session ID of the agent.
|
|
|
max_iterations: The maximum number of iterations the agent can run.
|
|
|
max_chars: The maximum number of characters the agent can output.
|
|
|
+ sandbox: An optional initialized sandbox to run the agent in. If not provided, a default sandbox will be created based on config.
|
|
|
+ remind_iterations: A boolean value indicating whether to remind the agent its remaining budget of interaction.
|
|
|
"""
|
|
|
self.id = sid
|
|
|
self.agent = agent
|
|
|
@@ -76,8 +80,15 @@ class AgentController:
|
|
|
EventStreamSubscriber.AGENT_CONTROLLER, self.on_event
|
|
|
)
|
|
|
self.max_iterations = max_iterations
|
|
|
- self.runtime = ServerRuntime(self.id)
|
|
|
+
|
|
|
+ self.remind_iterations = remind_iterations
|
|
|
+ if self.remind_iterations:
|
|
|
+ logger.info(
|
|
|
+ 'Iteration reminder is ENABLED: agent will be reminded of remaining turns.'
|
|
|
+ )
|
|
|
+ self.runtime = ServerRuntime(sandbox=sandbox, sid=self.id)
|
|
|
self.max_chars = max_chars
|
|
|
+
|
|
|
# Initialize agent-required plugins for sandbox (if any)
|
|
|
self.runtime.init_sandbox_plugins(agent.sandbox_plugins)
|
|
|
|
|
|
@@ -187,7 +198,9 @@ class AgentController:
|
|
|
self.agent.reset()
|
|
|
|
|
|
async def set_agent_state_to(self, new_state: AgentState):
|
|
|
- logger.info(f'Setting agent({type(self.agent).__name__}) state from {self._agent_state} to {new_state}')
|
|
|
+ logger.info(
|
|
|
+ f'Setting agent({type(self.agent).__name__}) state from {self._agent_state} to {new_state}'
|
|
|
+ )
|
|
|
if new_state == self._agent_state:
|
|
|
return
|
|
|
|
|
|
@@ -201,7 +214,11 @@ class AgentController:
|
|
|
self._cur_step += 1
|
|
|
if self.agent_task is not None:
|
|
|
self.agent_task.cancel()
|
|
|
- elif new_state == AgentState.STOPPED or new_state == AgentState.ERROR or new_state == AgentState.FINISHED:
|
|
|
+ elif (
|
|
|
+ new_state == AgentState.STOPPED
|
|
|
+ or new_state == AgentState.ERROR
|
|
|
+ or new_state == AgentState.FINISHED
|
|
|
+ ):
|
|
|
await self.reset_task()
|
|
|
|
|
|
await self.event_stream.add_event(
|
|
|
@@ -225,6 +242,17 @@ class AgentController:
|
|
|
task = action.inputs.get('task') or ''
|
|
|
await self.delegate.setup_task(task, action.inputs)
|
|
|
|
|
|
+ def add_iteration_reminder_when_needed(self, i: int, obs: Observation):
|
|
|
+ """Add iteration reminder to the observation if needed.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ i: The current iteration number (0-indexed).
|
|
|
+ obs: The observation to add the reminder to.
|
|
|
+ """
|
|
|
+ if self.remind_iterations:
|
|
|
+ obs.content += f'\n\nENVIRONMENT REMINDER: You have {self.max_iterations - i - 1} turns left to complete the task.'
|
|
|
+ return obs
|
|
|
+
|
|
|
async def step(self, i: int) -> bool:
|
|
|
if self.state is None:
|
|
|
raise ValueError('No task to run')
|
|
|
@@ -265,6 +293,7 @@ class AgentController:
|
|
|
if isinstance(action, AgentFinishAction):
|
|
|
self.state.outputs = action.outputs # type: ignore[attr-defined]
|
|
|
logger.info(action, extra={'msg_type': 'INFO'})
|
|
|
+ await self.add_history(action, NullObservation(''))
|
|
|
return True
|
|
|
elif isinstance(action, MessageAction) and action.wait_for_response:
|
|
|
# FIXME: remove this once history is managed outside the agent controller
|
|
|
@@ -280,6 +309,7 @@ class AgentController:
|
|
|
elif not isinstance(observation, ErrorObservation):
|
|
|
observation = await self.runtime.run_action(action)
|
|
|
|
|
|
+ observation = self.add_iteration_reminder_when_needed(i, observation)
|
|
|
if not isinstance(observation, NullObservation):
|
|
|
logger.info(observation, extra={'msg_type': 'OBSERVATION'})
|
|
|
await self.add_history(action, observation)
|