cli.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import asyncio
  2. import logging
  3. from typing import Type
  4. from termcolor import colored
  5. import openhands.agenthub # noqa F401 (we import this to get the agents registered)
  6. from openhands import __version__
  7. from openhands.controller import AgentController
  8. from openhands.controller.agent import Agent
  9. from openhands.core.config import (
  10. get_parser,
  11. load_app_config,
  12. )
  13. from openhands.core.logger import openhands_logger as logger
  14. from openhands.core.schema import AgentState
  15. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  16. from openhands.events.action import (
  17. Action,
  18. ChangeAgentStateAction,
  19. CmdRunAction,
  20. FileEditAction,
  21. MessageAction,
  22. )
  23. from openhands.events.event import Event
  24. from openhands.events.observation import (
  25. AgentStateChangedObservation,
  26. CmdOutputObservation,
  27. FileEditObservation,
  28. )
  29. from openhands.llm.llm import LLM
  30. from openhands.runtime import get_runtime_cls
  31. from openhands.runtime.base import Runtime
  32. from openhands.storage import get_file_store
  33. def display_message(message: str):
  34. print(colored('🤖 ' + message + '\n', 'yellow'))
  35. def display_command(command: str):
  36. print('❯ ' + colored(command + '\n', 'green'))
  37. def display_command_output(output: str):
  38. lines = output.split('\n')
  39. for line in lines:
  40. if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
  41. # TODO: clean this up once we clean up terminal output
  42. continue
  43. print(colored(line, 'blue'))
  44. print('\n')
  45. def display_file_edit(event: FileEditAction | FileEditObservation):
  46. print(colored(str(event), 'green'))
  47. def display_event(event: Event):
  48. if isinstance(event, Action):
  49. if hasattr(event, 'thought'):
  50. display_message(event.thought)
  51. if isinstance(event, MessageAction):
  52. if event.source != EventSource.USER:
  53. display_message(event.content)
  54. if isinstance(event, CmdRunAction):
  55. display_command(event.command)
  56. if isinstance(event, CmdOutputObservation):
  57. display_command_output(event.content)
  58. if isinstance(event, FileEditAction):
  59. display_file_edit(event)
  60. if isinstance(event, FileEditObservation):
  61. display_file_edit(event)
  62. async def main():
  63. """Runs the agent in CLI mode"""
  64. parser = get_parser()
  65. # Add the version argument
  66. parser.add_argument(
  67. '-v',
  68. '--version',
  69. action='version',
  70. version=f'{__version__}',
  71. help='Show the version number and exit',
  72. default=None,
  73. )
  74. args = parser.parse_args()
  75. if args.version:
  76. print(f'OpenHands version: {__version__}')
  77. return
  78. logger.setLevel(logging.WARNING)
  79. config = load_app_config(config_file=args.config_file)
  80. sid = 'cli'
  81. agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
  82. agent_config = config.get_agent_config(config.default_agent)
  83. llm_config = config.get_llm_config_from_agent(config.default_agent)
  84. agent = agent_cls(
  85. llm=LLM(config=llm_config),
  86. config=agent_config,
  87. )
  88. file_store = get_file_store(config.file_store, config.file_store_path)
  89. event_stream = EventStream(sid, file_store)
  90. runtime_cls = get_runtime_cls(config.runtime)
  91. runtime: Runtime = runtime_cls( # noqa: F841
  92. config=config,
  93. event_stream=event_stream,
  94. sid=sid,
  95. plugins=agent_cls.sandbox_plugins,
  96. )
  97. await runtime.connect()
  98. controller = AgentController(
  99. agent=agent,
  100. max_iterations=config.max_iterations,
  101. max_budget_per_task=config.max_budget_per_task,
  102. agent_to_llm_config=config.get_agent_to_llm_config_map(),
  103. event_stream=event_stream,
  104. )
  105. if controller is not None:
  106. controller.agent_task = asyncio.create_task(controller.start_step_loop())
  107. async def prompt_for_next_task():
  108. next_message = input('How can I help? >> ')
  109. if next_message == 'exit':
  110. event_stream.add_event(
  111. ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
  112. )
  113. return
  114. action = MessageAction(content=next_message)
  115. event_stream.add_event(action, EventSource.USER)
  116. async def on_event(event: Event):
  117. display_event(event)
  118. if isinstance(event, AgentStateChangedObservation):
  119. if event.agent_state == AgentState.ERROR:
  120. print('An error occurred. Please try again.')
  121. if event.agent_state in [
  122. AgentState.AWAITING_USER_INPUT,
  123. AgentState.FINISHED,
  124. AgentState.ERROR,
  125. ]:
  126. await prompt_for_next_task()
  127. event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
  128. await prompt_for_next_task()
  129. while controller.state.agent_state not in [
  130. AgentState.STOPPED,
  131. ]:
  132. await asyncio.sleep(1) # Give back control for a tick, so the agent can run
  133. print('Exiting...')
  134. await controller.close()
  135. if __name__ == '__main__':
  136. loop = asyncio.get_event_loop()
  137. try:
  138. loop.run_until_complete(main())
  139. finally:
  140. pass