|
|
@@ -55,12 +55,20 @@ class AgentController:
|
|
|
def __init__(
|
|
|
self,
|
|
|
agent: Agent,
|
|
|
- inputs: dict = {},
|
|
|
sid: str = 'default',
|
|
|
max_iterations: int = MAX_ITERATIONS,
|
|
|
max_chars: int = MAX_CHARS,
|
|
|
callbacks: List[Callable] = [],
|
|
|
):
|
|
|
+ """Initializes a new instance of the AgentController class.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ agent: The agent instance to control.
|
|
|
+ sid: The session ID of the agent.
|
|
|
+ max_iterations: The maximum number of iterations the agent can run.
|
|
|
+ max_chars: The maximum number of characters the agent can output.
|
|
|
+ callbacks: A list of callback functions to run after each action.
|
|
|
+ """
|
|
|
self.id = sid
|
|
|
self.agent = agent
|
|
|
self.max_iterations = max_iterations
|
|
|
@@ -73,8 +81,12 @@ class AgentController:
|
|
|
self.browser = BrowserEnv()
|
|
|
|
|
|
|
|
|
- if isinstance(agent, CodeActAgent) and not isinstance(self.action_manager.sandbox, DockerSSHBox):
|
|
|
- logger.warning('CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.')
|
|
|
+ if isinstance(agent, CodeActAgent) and not isinstance(
|
|
|
+ self.action_manager.sandbox, DockerSSHBox
|
|
|
+ ):
|
|
|
+ logger.warning(
|
|
|
+ 'CodeActAgent requires DockerSSHBox as sandbox! Using other sandbox that are not stateful (LocalBox, DockerExecBox) will not work properly.'
|
|
|
+ )
|
|
|
|
|
|
self._await_user_message_queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
|
|
@@ -119,7 +131,10 @@ class AgentController:
|
|
|
except Exception:
|
|
|
logger.error('Error in loop', exc_info=True)
|
|
|
await self._run_callbacks(
|
|
|
- AgentErrorObservation('Oops! Something went wrong while completing your task. You can check the logs for more info.'))
|
|
|
+ AgentErrorObservation(
|
|
|
+ 'Oops! Something went wrong while completing your task. You can check the logs for more info.'
|
|
|
+ )
|
|
|
+ )
|
|
|
await self.set_task_state_to(TaskState.STOPPED)
|
|
|
break
|
|
|
|
|
|
@@ -139,14 +154,15 @@ class AgentController:
|
|
|
|
|
|
if self._is_stuck():
|
|
|
logger.info('Loop detected, stopping task')
|
|
|
- observation = AgentErrorObservation('I got stuck into a loop, the task has stopped.')
|
|
|
+ observation = AgentErrorObservation(
|
|
|
+ 'I got stuck into a loop, the task has stopped.'
|
|
|
+ )
|
|
|
await self._run_callbacks(observation)
|
|
|
await self.set_task_state_to(TaskState.STOPPED)
|
|
|
break
|
|
|
|
|
|
async def setup_task(self, task: str, inputs: dict = {}):
|
|
|
- """Sets up the agent controller with a task.
|
|
|
- """
|
|
|
+ """Sets up the agent controller with a task."""
|
|
|
self._task_state = TaskState.RUNNING
|
|
|
await self.notify_task_state_changed()
|
|
|
self.state = State(Plan(task))
|
|
|
@@ -203,14 +219,19 @@ class AgentController:
|
|
|
self.add_history(NullAction(), message)
|
|
|
|
|
|
else:
|
|
|
- raise ValueError(f'Task (state: {self._task_state}) is not in a state to add user message')
|
|
|
+ raise ValueError(
|
|
|
+ f'Task (state: {self._task_state}) is not in a state to add user message'
|
|
|
+ )
|
|
|
|
|
|
async def wait_for_user_input(self) -> UserMessageObservation:
|
|
|
self._task_state = TaskState.AWAITING_USER_INPUT
|
|
|
await self.notify_task_state_changed()
|
|
|
# wait for the next user message
|
|
|
if len(self.callbacks) == 0:
|
|
|
- logger.info('Use STDIN to request user message as no callbacks are registered', extra={'msg_type': 'INFO'})
|
|
|
+ logger.info(
|
|
|
+ 'Use STDIN to request user message as no callbacks are registered',
|
|
|
+ extra={'msg_type': 'INFO'},
|
|
|
+ )
|
|
|
message = input('Request user input [type /exit to stop interaction] >> ')
|
|
|
user_message_observation = UserMessageObservation(message)
|
|
|
else:
|
|
|
@@ -312,22 +333,33 @@ class AgentController:
|
|
|
return self.state
|
|
|
|
|
|
def _is_stuck(self):
|
|
|
- if self.state is None or self.state.history is None or len(self.state.history) < 3:
|
|
|
+ if (
|
|
|
+ self.state is None
|
|
|
+ or self.state.history is None
|
|
|
+ or len(self.state.history) < 3
|
|
|
+ ):
|
|
|
return False
|
|
|
|
|
|
# if the last three (Action, Observation) tuples are too repetitive
|
|
|
# the agent got stuck in a loop
|
|
|
if all(
|
|
|
- [self.state.history[-i][0] == self.state.history[-3][0] for i in range(1, 3)]
|
|
|
+ [
|
|
|
+ self.state.history[-i][0] == self.state.history[-3][0]
|
|
|
+ for i in range(1, 3)
|
|
|
+ ]
|
|
|
):
|
|
|
# it repeats same action, give it a chance, but not if:
|
|
|
- if (all
|
|
|
- (isinstance(self.state.history[-i][1], NullObservation) for i in range(1, 4))):
|
|
|
+ if all(
|
|
|
+ isinstance(self.state.history[-i][1], NullObservation)
|
|
|
+ for i in range(1, 4)
|
|
|
+ ):
|
|
|
# same (Action, NullObservation): like 'think' the same thought over and over
|
|
|
logger.debug('Action, NullObservation loop detected')
|
|
|
return True
|
|
|
- elif (all
|
|
|
- (isinstance(self.state.history[-i][1], AgentErrorObservation) for i in range(1, 4))):
|
|
|
+ elif all(
|
|
|
+ isinstance(self.state.history[-i][1], AgentErrorObservation)
|
|
|
+ for i in range(1, 4)
|
|
|
+ ):
|
|
|
# (NullAction, AgentErrorObservation): errors coming from an exception
|
|
|
# (Action, AgentErrorObservation): the same action getting an error, even if not necessarily the same error
|
|
|
logger.debug('Action, AgentErrorObservation loop detected')
|