cli.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import asyncio
  2. import logging
  3. import sys
  4. from typing import Type
  5. from uuid import uuid4
  6. from termcolor import colored
  7. import openhands.agenthub # noqa F401 (we import this to get the agents registered)
  8. from openhands import __version__
  9. from openhands.controller import AgentController
  10. from openhands.controller.agent import Agent
  11. from openhands.core.config import (
  12. get_parser,
  13. load_app_config,
  14. )
  15. from openhands.core.logger import openhands_logger as logger
  16. from openhands.core.loop import run_agent_until_done
  17. from openhands.core.schema import AgentState
  18. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  19. from openhands.events.action import (
  20. Action,
  21. ChangeAgentStateAction,
  22. CmdRunAction,
  23. FileEditAction,
  24. MessageAction,
  25. )
  26. from openhands.events.event import Event
  27. from openhands.events.observation import (
  28. AgentStateChangedObservation,
  29. CmdOutputObservation,
  30. FileEditObservation,
  31. )
  32. from openhands.llm.llm import LLM
  33. from openhands.runtime import get_runtime_cls
  34. from openhands.runtime.base import Runtime
  35. from openhands.storage import get_file_store
  36. def display_message(message: str):
  37. print(colored('🤖 ' + message + '\n', 'yellow'))
  38. def display_command(command: str):
  39. print('❯ ' + colored(command + '\n', 'green'))
  40. def display_command_output(output: str):
  41. lines = output.split('\n')
  42. for line in lines:
  43. if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
  44. # TODO: clean this up once we clean up terminal output
  45. continue
  46. print(colored(line, 'blue'))
  47. print('\n')
  48. def display_file_edit(event: FileEditAction | FileEditObservation):
  49. print(colored(str(event), 'green'))
  50. def display_event(event: Event):
  51. if isinstance(event, Action):
  52. if hasattr(event, 'thought'):
  53. display_message(event.thought)
  54. if isinstance(event, MessageAction):
  55. if event.source == EventSource.AGENT:
  56. display_message(event.content)
  57. if isinstance(event, CmdRunAction):
  58. display_command(event.command)
  59. if isinstance(event, CmdOutputObservation):
  60. display_command_output(event.content)
  61. if isinstance(event, FileEditAction):
  62. display_file_edit(event)
  63. if isinstance(event, FileEditObservation):
  64. display_file_edit(event)
  65. async def main():
  66. """Runs the agent in CLI mode"""
  67. parser = get_parser()
  68. # Add the version argument
  69. parser.add_argument(
  70. '-v',
  71. '--version',
  72. action='version',
  73. version=f'{__version__}',
  74. help='Show the version number and exit',
  75. default=None,
  76. )
  77. args = parser.parse_args()
  78. if args.version:
  79. print(f'OpenHands version: {__version__}')
  80. return
  81. logger.setLevel(logging.WARNING)
  82. config = load_app_config(config_file=args.config_file)
  83. sid = 'cli'
  84. agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
  85. agent_config = config.get_agent_config(config.default_agent)
  86. llm_config = config.get_llm_config_from_agent(config.default_agent)
  87. agent = agent_cls(
  88. llm=LLM(config=llm_config),
  89. config=agent_config,
  90. )
  91. file_store = get_file_store(config.file_store, config.file_store_path)
  92. event_stream = EventStream(sid, file_store)
  93. runtime_cls = get_runtime_cls(config.runtime)
  94. runtime: Runtime = runtime_cls( # noqa: F841
  95. config=config,
  96. event_stream=event_stream,
  97. sid=sid,
  98. plugins=agent_cls.sandbox_plugins,
  99. )
  100. controller = AgentController(
  101. agent=agent,
  102. max_iterations=config.max_iterations,
  103. max_budget_per_task=config.max_budget_per_task,
  104. agent_to_llm_config=config.get_agent_to_llm_config_map(),
  105. event_stream=event_stream,
  106. )
  107. async def prompt_for_next_task():
  108. # Run input() in a thread pool to avoid blocking the event loop
  109. loop = asyncio.get_event_loop()
  110. next_message = await loop.run_in_executor(
  111. None, lambda: input('How can I help? >> ')
  112. )
  113. if not next_message.strip():
  114. await prompt_for_next_task()
  115. if next_message == 'exit':
  116. event_stream.add_event(
  117. ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
  118. )
  119. return
  120. action = MessageAction(content=next_message)
  121. event_stream.add_event(action, EventSource.USER)
  122. async def on_event(event: Event):
  123. display_event(event)
  124. if isinstance(event, AgentStateChangedObservation):
  125. if event.agent_state in [
  126. AgentState.AWAITING_USER_INPUT,
  127. AgentState.FINISHED,
  128. ]:
  129. await prompt_for_next_task()
  130. event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
  131. await runtime.connect()
  132. asyncio.create_task(prompt_for_next_task())
  133. await run_agent_until_done(
  134. controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
  135. )
  136. if __name__ == '__main__':
  137. loop = asyncio.new_event_loop()
  138. asyncio.set_event_loop(loop)
  139. try:
  140. loop.run_until_complete(main())
  141. except KeyboardInterrupt:
  142. print('Received keyboard interrupt, shutting down...')
  143. except ConnectionRefusedError as e:
  144. print(f'Connection refused: {e}')
  145. sys.exit(1)
  146. except Exception as e:
  147. print(f'An error occurred: {e}')
  148. sys.exit(1)
  149. finally:
  150. try:
  151. # Cancel all running tasks
  152. pending = asyncio.all_tasks(loop)
  153. for task in pending:
  154. task.cancel()
  155. # Wait for all tasks to complete with a timeout
  156. loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
  157. loop.close()
  158. except Exception as e:
  159. print(f'Error during cleanup: {e}')
  160. sys.exit(1)