cli.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import asyncio
  2. import logging
  3. import sys
  4. from typing import Type
  5. from uuid import uuid4
  6. from termcolor import colored
  7. import openhands.agenthub # noqa F401 (we import this to get the agents registered)
  8. from openhands import __version__
  9. from openhands.controller import AgentController
  10. from openhands.controller.agent import Agent
  11. from openhands.core.config import (
  12. AppConfig,
  13. get_parser,
  14. load_app_config,
  15. )
  16. from openhands.core.logger import openhands_logger as logger
  17. from openhands.core.loop import run_agent_until_done
  18. from openhands.core.schema import AgentState
  19. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  20. from openhands.events.action import (
  21. Action,
  22. ActionConfirmationStatus,
  23. ChangeAgentStateAction,
  24. CmdRunAction,
  25. FileEditAction,
  26. MessageAction,
  27. )
  28. from openhands.events.event import Event
  29. from openhands.events.observation import (
  30. AgentStateChangedObservation,
  31. CmdOutputObservation,
  32. FileEditObservation,
  33. NullObservation,
  34. )
  35. from openhands.llm.llm import LLM
  36. from openhands.runtime import get_runtime_cls
  37. from openhands.runtime.base import Runtime
  38. from openhands.security import SecurityAnalyzer, options
  39. from openhands.storage import get_file_store
  40. def display_message(message: str):
  41. print(colored('🤖 ' + message + '\n', 'yellow'))
  42. def display_command(command: str):
  43. print('❯ ' + colored(command + '\n', 'green'))
  44. def display_confirmation(confirmation_state: ActionConfirmationStatus):
  45. if confirmation_state == ActionConfirmationStatus.CONFIRMED:
  46. print(colored('✅ ' + confirmation_state + '\n', 'green'))
  47. elif confirmation_state == ActionConfirmationStatus.REJECTED:
  48. print(colored('❌ ' + confirmation_state + '\n', 'red'))
  49. else:
  50. print(colored('⏳ ' + confirmation_state + '\n', 'yellow'))
  51. def display_command_output(output: str):
  52. lines = output.split('\n')
  53. for line in lines:
  54. if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
  55. # TODO: clean this up once we clean up terminal output
  56. continue
  57. print(colored(line, 'blue'))
  58. print('\n')
  59. def display_file_edit(event: FileEditAction | FileEditObservation):
  60. print(colored(str(event), 'green'))
  61. def display_event(event: Event, config: AppConfig):
  62. if isinstance(event, Action):
  63. if hasattr(event, 'thought'):
  64. display_message(event.thought)
  65. if isinstance(event, MessageAction):
  66. if event.source == EventSource.AGENT:
  67. display_message(event.content)
  68. if isinstance(event, CmdRunAction):
  69. display_command(event.command)
  70. if isinstance(event, CmdOutputObservation):
  71. display_command_output(event.content)
  72. if isinstance(event, FileEditAction):
  73. display_file_edit(event)
  74. if isinstance(event, FileEditObservation):
  75. display_file_edit(event)
  76. if hasattr(event, 'confirmation_state') and config.security.confirmation_mode:
  77. display_confirmation(event.confirmation_state)
  78. async def main():
  79. """Runs the agent in CLI mode"""
  80. parser = get_parser()
  81. # Add the version argument
  82. parser.add_argument(
  83. '-v',
  84. '--version',
  85. action='version',
  86. version=f'{__version__}',
  87. help='Show the version number and exit',
  88. default=None,
  89. )
  90. args = parser.parse_args()
  91. if args.version:
  92. print(f'OpenHands version: {__version__}')
  93. return
  94. logger.setLevel(logging.WARNING)
  95. config = load_app_config(config_file=args.config_file)
  96. sid = 'cli'
  97. agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
  98. agent_config = config.get_agent_config(config.default_agent)
  99. llm_config = config.get_llm_config_from_agent(config.default_agent)
  100. agent = agent_cls(
  101. llm=LLM(config=llm_config),
  102. config=agent_config,
  103. )
  104. file_store = get_file_store(config.file_store, config.file_store_path)
  105. event_stream = EventStream(sid, file_store)
  106. runtime_cls = get_runtime_cls(config.runtime)
  107. runtime: Runtime = runtime_cls( # noqa: F841
  108. config=config,
  109. event_stream=event_stream,
  110. sid=sid,
  111. plugins=agent_cls.sandbox_plugins,
  112. headless_mode=True,
  113. )
  114. if config.security.security_analyzer:
  115. options.SecurityAnalyzers.get(
  116. config.security.security_analyzer, SecurityAnalyzer
  117. )(event_stream)
  118. controller = AgentController(
  119. agent=agent,
  120. max_iterations=config.max_iterations,
  121. max_budget_per_task=config.max_budget_per_task,
  122. agent_to_llm_config=config.get_agent_to_llm_config_map(),
  123. event_stream=event_stream,
  124. confirmation_mode=config.security.confirmation_mode,
  125. )
  126. async def prompt_for_next_task():
  127. # Run input() in a thread pool to avoid blocking the event loop
  128. loop = asyncio.get_event_loop()
  129. next_message = await loop.run_in_executor(
  130. None, lambda: input('How can I help? >> ')
  131. )
  132. if not next_message.strip():
  133. await prompt_for_next_task()
  134. if next_message == 'exit':
  135. event_stream.add_event(
  136. ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT
  137. )
  138. return
  139. action = MessageAction(content=next_message)
  140. event_stream.add_event(action, EventSource.USER)
  141. async def prompt_for_user_confirmation():
  142. loop = asyncio.get_event_loop()
  143. user_confirmation = await loop.run_in_executor(
  144. None, lambda: input('Confirm action (possible security risk)? (y/n) >> ')
  145. )
  146. return user_confirmation.lower() == 'y'
  147. async def on_event(event: Event):
  148. display_event(event, config)
  149. if isinstance(event, AgentStateChangedObservation):
  150. if event.agent_state in [
  151. AgentState.AWAITING_USER_INPUT,
  152. AgentState.FINISHED,
  153. ]:
  154. await prompt_for_next_task()
  155. if (
  156. isinstance(event, NullObservation)
  157. and controller.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION
  158. ):
  159. user_confirmed = await prompt_for_user_confirmation()
  160. if user_confirmed:
  161. event_stream.add_event(
  162. ChangeAgentStateAction(AgentState.USER_CONFIRMED), EventSource.USER
  163. )
  164. else:
  165. event_stream.add_event(
  166. ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER
  167. )
  168. event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
  169. await runtime.connect()
  170. asyncio.create_task(prompt_for_next_task())
  171. await run_agent_until_done(
  172. controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
  173. )
  174. if __name__ == '__main__':
  175. loop = asyncio.new_event_loop()
  176. asyncio.set_event_loop(loop)
  177. try:
  178. loop.run_until_complete(main())
  179. except KeyboardInterrupt:
  180. print('Received keyboard interrupt, shutting down...')
  181. except ConnectionRefusedError as e:
  182. print(f'Connection refused: {e}')
  183. sys.exit(1)
  184. except Exception as e:
  185. print(f'An error occurred: {e}')
  186. sys.exit(1)
  187. finally:
  188. try:
  189. # Cancel all running tasks
  190. pending = asyncio.all_tasks(loop)
  191. for task in pending:
  192. task.cancel()
  193. # Wait for all tasks to complete with a timeout
  194. loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
  195. loop.close()
  196. except Exception as e:
  197. print(f'Error during cleanup: {e}')
  198. sys.exit(1)