cli.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import asyncio
  2. import logging
  3. from typing import Type
  4. from termcolor import colored
  5. import agenthub # noqa F401 (we import this to get the agents registered)
  6. from openhands.controller import AgentController
  7. from openhands.controller.agent import Agent
  8. from openhands.core.config import (
  9. load_app_config,
  10. )
  11. from openhands.core.logger import openhands_logger as logger
  12. from openhands.core.schema import AgentState
  13. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  14. from openhands.events.action import (
  15. Action,
  16. ChangeAgentStateAction,
  17. CmdRunAction,
  18. MessageAction,
  19. )
  20. from openhands.events.event import Event
  21. from openhands.events.observation import (
  22. AgentStateChangedObservation,
  23. CmdOutputObservation,
  24. )
  25. from openhands.llm.llm import LLM
  26. from openhands.runtime import get_runtime_cls
  27. from openhands.runtime.runtime import Runtime
  28. from openhands.storage import get_file_store
  29. def display_message(message: str):
  30. print(colored('🤖 ' + message + '\n', 'yellow'))
  31. def display_command(command: str):
  32. print('❯ ' + colored(command + '\n', 'green'))
  33. def display_command_output(output: str):
  34. lines = output.split('\n')
  35. for line in lines:
  36. if line.startswith('[Python Interpreter') or line.startswith('openhands@'):
  37. # TODO: clean this up once we clean up terminal output
  38. continue
  39. print(colored(line, 'blue'))
  40. print('\n')
  41. def display_event(event: Event):
  42. if isinstance(event, Action):
  43. if hasattr(event, 'thought'):
  44. display_message(event.thought)
  45. if isinstance(event, MessageAction):
  46. if event.source != EventSource.USER:
  47. display_message(event.content)
  48. if isinstance(event, CmdRunAction):
  49. display_command(event.command)
  50. if isinstance(event, CmdOutputObservation):
  51. display_command_output(event.content)
  52. async def main():
  53. """Runs the agent in CLI mode"""
  54. logger.setLevel(logging.WARNING)
  55. config = load_app_config()
  56. sid = 'cli'
  57. agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
  58. agent_config = config.get_agent_config(config.default_agent)
  59. llm_config = config.get_llm_config_from_agent(config.default_agent)
  60. agent = agent_cls(
  61. llm=LLM(config=llm_config),
  62. config=agent_config,
  63. )
  64. file_store = get_file_store(config.file_store, config.file_store_path)
  65. event_stream = EventStream(sid, file_store)
  66. runtime_cls = get_runtime_cls(config.runtime)
  67. runtime: Runtime = runtime_cls(
  68. config=config,
  69. event_stream=event_stream,
  70. sid=sid,
  71. plugins=agent_cls.sandbox_plugins,
  72. )
  73. await runtime.ainit()
  74. controller = AgentController(
  75. agent=agent,
  76. max_iterations=config.max_iterations,
  77. max_budget_per_task=config.max_budget_per_task,
  78. agent_to_llm_config=config.get_agent_to_llm_config_map(),
  79. event_stream=event_stream,
  80. )
  81. async def prompt_for_next_task():
  82. next_message = input('How can I help? >> ')
  83. if next_message == 'exit':
  84. event_stream.add_event(
  85. ChangeAgentStateAction(AgentState.STOPPED), EventSource.USER
  86. )
  87. return
  88. action = MessageAction(content=next_message)
  89. event_stream.add_event(action, EventSource.USER)
  90. async def on_event(event: Event):
  91. display_event(event)
  92. if isinstance(event, AgentStateChangedObservation):
  93. if event.agent_state == AgentState.ERROR:
  94. print('An error occurred. Please try again.')
  95. if event.agent_state in [
  96. AgentState.AWAITING_USER_INPUT,
  97. AgentState.FINISHED,
  98. AgentState.ERROR,
  99. ]:
  100. await prompt_for_next_task()
  101. event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
  102. await prompt_for_next_task()
  103. while controller.state.agent_state not in [
  104. AgentState.STOPPED,
  105. ]:
  106. await asyncio.sleep(1) # Give back control for a tick, so the agent can run
  107. print('Exiting...')
  108. await controller.close()
  109. if __name__ == '__main__':
  110. loop = asyncio.get_event_loop()
  111. try:
  112. loop.run_until_complete(main())
  113. finally:
  114. pass