| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import asyncio
- import logging
- from typing import Type
- from termcolor import colored
- import openhands.agenthub # noqa F401 (we import this to get the agents registered)
- from openhands import __version__
- from openhands.controller import AgentController
- from openhands.controller.agent import Agent
- from openhands.core.config import (
- get_parser,
- load_app_config,
- )
- from openhands.core.logger import openhands_logger as logger
- from openhands.core.schema import AgentState
- from openhands.events import EventSource, EventStream, EventStreamSubscriber
- from openhands.events.action import (
- Action,
- ChangeAgentStateAction,
- CmdRunAction,
- MessageAction,
- )
- from openhands.events.event import Event
- from openhands.events.observation import (
- AgentStateChangedObservation,
- CmdOutputObservation,
- )
- from openhands.llm.llm import LLM
- from openhands.runtime import get_runtime_cls
- from openhands.runtime.runtime import Runtime
- from openhands.storage import get_file_store
- def display_message(message: str):
- print(colored('🤖 ' + message + '\n', 'yellow'))
- def display_command(command: str):
- print('❯ ' + colored(command + '\n', 'green'))
- def display_command_output(output: str):
- lines = output.split('\n')
- for line in lines:
- if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
- # TODO: clean this up once we clean up terminal output
- continue
- print(colored(line, 'blue'))
- print('\n')
- def display_event(event: Event):
- if isinstance(event, Action):
- if hasattr(event, 'thought'):
- display_message(event.thought)
- if isinstance(event, MessageAction):
- if event.source != EventSource.USER:
- display_message(event.content)
- if isinstance(event, CmdRunAction):
- display_command(event.command)
- if isinstance(event, CmdOutputObservation):
- display_command_output(event.content)
- async def main():
- """Runs the agent in CLI mode"""
- parser = get_parser()
- # Add the version argument
- parser.add_argument(
- '-v',
- '--version',
- action='version',
- version=f'{__version__}',
- help='Show the version number and exit',
- default=None,
- )
- args = parser.parse_args()
- if args.version:
- print(f'OpenHands version: {__version__}')
- return
- logger.setLevel(logging.WARNING)
- config = load_app_config(config_file=args.config_file)
- sid = 'cli'
- agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
- agent_config = config.get_agent_config(config.default_agent)
- llm_config = config.get_llm_config_from_agent(config.default_agent)
- agent = agent_cls(
- llm=LLM(config=llm_config),
- config=agent_config,
- )
- file_store = get_file_store(config.file_store, config.file_store_path)
- event_stream = EventStream(sid, file_store)
- runtime_cls = get_runtime_cls(config.runtime)
- runtime: Runtime = runtime_cls( # noqa: F841
- config=config,
- event_stream=event_stream,
- sid=sid,
- plugins=agent_cls.sandbox_plugins,
- )
- controller = AgentController(
- agent=agent,
- max_iterations=config.max_iterations,
- max_budget_per_task=config.max_budget_per_task,
- agent_to_llm_config=config.get_agent_to_llm_config_map(),
- event_stream=event_stream,
- )
- if controller is not None:
- controller.agent_task = asyncio.create_task(controller.start_step_loop())
- async def prompt_for_next_task():
- next_message = input('How can I help? >> ')
- if next_message == 'exit':
- event_stream.add_event(
- ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
- )
- return
- action = MessageAction(content=next_message)
- event_stream.add_event(action, EventSource.USER)
- async def on_event(event: Event):
- display_event(event)
- if isinstance(event, AgentStateChangedObservation):
- if event.agent_state == AgentState.ERROR:
- print('An error occurred. Please try again.')
- if event.agent_state in [
- AgentState.AWAITING_USER_INPUT,
- AgentState.FINISHED,
- AgentState.ERROR,
- ]:
- await prompt_for_next_task()
- event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
- await prompt_for_next_task()
- while controller.state.agent_state not in [
- AgentState.STOPPED,
- ]:
- await asyncio.sleep(1) # Give back control for a tick, so the agent can run
- print('Exiting...')
- await controller.close()
- if __name__ == '__main__':
- loop = asyncio.get_event_loop()
- try:
- loop.run_until_complete(main())
- finally:
- pass
|