main.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import asyncio
  2. import hashlib
  3. import sys
  4. import uuid
  5. from typing import Callable, Protocol, Type
  6. import openhands.agenthub # noqa F401 (we import this to get the agents registered)
  7. from openhands.controller import AgentController
  8. from openhands.controller.agent import Agent
  9. from openhands.controller.state.state import State
  10. from openhands.core.config import (
  11. AppConfig,
  12. get_llm_config_arg,
  13. load_app_config,
  14. parse_arguments,
  15. )
  16. from openhands.core.logger import openhands_logger as logger
  17. from openhands.core.schema import AgentState
  18. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  19. from openhands.events.action import MessageAction
  20. from openhands.events.action.action import Action
  21. from openhands.events.event import Event
  22. from openhands.events.observation import AgentStateChangedObservation
  23. from openhands.llm.llm import LLM
  24. from openhands.runtime import get_runtime_cls
  25. from openhands.runtime.runtime import Runtime
  26. from openhands.storage import get_file_store
  27. class FakeUserResponseFunc(Protocol):
  28. def __call__(
  29. self,
  30. state: State,
  31. encapsulate_solution: bool = ...,
  32. try_parse: Callable[[Action], str] = ...,
  33. ) -> str: ...
  34. def read_task_from_file(file_path: str) -> str:
  35. """Read task from the specified file."""
  36. with open(file_path, 'r', encoding='utf-8') as file:
  37. return file.read()
  38. def read_task_from_stdin() -> str:
  39. """Read task from stdin."""
  40. return sys.stdin.read()
  41. def create_runtime(
  42. config: AppConfig,
  43. sid: str | None = None,
  44. ) -> Runtime:
  45. """Create a runtime for the agent to run on.
  46. config: The app config.
  47. sid: The session id.
  48. """
  49. # if sid is provided on the command line, use it as the name of the event stream
  50. # otherwise generate it on the basis of the configured jwt_secret
  51. # we can do this better, this is just so that the sid is retrieved when we want to restore the session
  52. session_id = sid or generate_sid(config)
  53. # set up the event stream
  54. file_store = get_file_store(config.file_store, config.file_store_path)
  55. event_stream = EventStream(session_id, file_store)
  56. # agent class
  57. agent_cls = openhands.agenthub.Agent.get_cls(config.default_agent)
  58. # runtime and tools
  59. runtime_cls = get_runtime_cls(config.runtime)
  60. logger.info(f'Initializing runtime: {runtime_cls.__name__}')
  61. runtime: Runtime = runtime_cls(
  62. config=config,
  63. event_stream=event_stream,
  64. sid=session_id,
  65. plugins=agent_cls.sandbox_plugins,
  66. )
  67. return runtime
  68. async def run_controller(
  69. config: AppConfig,
  70. initial_user_action: Action,
  71. sid: str | None = None,
  72. runtime: Runtime | None = None,
  73. agent: Agent | None = None,
  74. exit_on_message: bool = False,
  75. fake_user_response_fn: FakeUserResponseFunc | None = None,
  76. headless_mode: bool = True,
  77. ) -> State | None:
  78. """Main coroutine to run the agent controller with task input flexibility.
  79. It's only used when you launch openhands backend directly via cmdline.
  80. Args:
  81. config: The app config.
  82. initial_user_action: An Action object containing initial user input
  83. runtime: (optional) A runtime for the agent to run on.
  84. agent: (optional) A agent to run.
  85. exit_on_message: quit if agent asks for a message from user (optional)
  86. fake_user_response_fn: An optional function that receives the current state
  87. (could be None) and returns a fake user response.
  88. headless_mode: Whether the agent is run in headless mode.
  89. """
  90. # Create the agent
  91. if agent is None:
  92. agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
  93. agent_config = config.get_agent_config(config.default_agent)
  94. llm_config = config.get_llm_config_from_agent(config.default_agent)
  95. agent = agent_cls(
  96. llm=LLM(config=llm_config),
  97. config=agent_config,
  98. )
  99. # make sure the session id is set
  100. sid = sid or generate_sid(config)
  101. if runtime is None:
  102. runtime = create_runtime(config, sid=sid)
  103. event_stream = runtime.event_stream
  104. # restore cli session if enabled
  105. initial_state = None
  106. if config.enable_cli_session:
  107. try:
  108. logger.info(f'Restoring agent state from cli session {event_stream.sid}')
  109. initial_state = State.restore_from_session(
  110. event_stream.sid, event_stream.file_store
  111. )
  112. except Exception as e:
  113. logger.info(f'Error restoring state: {e}')
  114. # init controller with this initial state
  115. controller = AgentController(
  116. agent=agent,
  117. max_iterations=config.max_iterations,
  118. max_budget_per_task=config.max_budget_per_task,
  119. agent_to_llm_config=config.get_agent_to_llm_config_map(),
  120. event_stream=event_stream,
  121. initial_state=initial_state,
  122. headless_mode=headless_mode,
  123. )
  124. if controller is not None:
  125. controller.agent_task = asyncio.create_task(controller.start_step_loop())
  126. assert isinstance(
  127. initial_user_action, Action
  128. ), f'initial user actions must be an Action, got {type(initial_user_action)}'
  129. # Logging
  130. logger.info(
  131. f'Agent Controller Initialized: Running agent {agent.name}, model '
  132. f'{agent.llm.config.model}, with actions: {initial_user_action}'
  133. )
  134. # start event is a MessageAction with the task, either resumed or new
  135. if config.enable_cli_session and initial_state is not None:
  136. # we're resuming the previous session
  137. event_stream.add_event(
  138. MessageAction(
  139. content=(
  140. "Let's get back on track. If you experienced errors before, do "
  141. 'NOT resume your task. Ask me about it.'
  142. ),
  143. ),
  144. EventSource.USER,
  145. )
  146. elif initial_state is None:
  147. # init with the provided actions
  148. event_stream.add_event(initial_user_action, EventSource.USER)
  149. async def on_event(event: Event):
  150. if isinstance(event, AgentStateChangedObservation):
  151. if event.agent_state == AgentState.AWAITING_USER_INPUT:
  152. if exit_on_message:
  153. message = '/exit'
  154. elif fake_user_response_fn is None:
  155. message = input('Request user input >> ')
  156. else:
  157. message = fake_user_response_fn(controller.get_state())
  158. action = MessageAction(content=message)
  159. event_stream.add_event(action, EventSource.USER)
  160. event_stream.subscribe(EventStreamSubscriber.MAIN, on_event)
  161. while controller.state.agent_state not in [
  162. AgentState.FINISHED,
  163. AgentState.REJECTED,
  164. AgentState.ERROR,
  165. AgentState.PAUSED,
  166. AgentState.STOPPED,
  167. ]:
  168. await asyncio.sleep(1) # Give back control for a tick, so the agent can run
  169. # save session when we're about to close
  170. if config.enable_cli_session:
  171. end_state = controller.get_state()
  172. end_state.save_to_session(event_stream.sid, event_stream.file_store)
  173. # close when done
  174. await controller.close()
  175. state = controller.get_state()
  176. return state
  177. def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
  178. """Generate a session id based on the session name and the jwt secret."""
  179. session_name = session_name or str(uuid.uuid4())
  180. jwt_secret = config.jwt_secret
  181. hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest()
  182. return f'{session_name}_{hash_str[:16]}'
  183. if __name__ == '__main__':
  184. args = parse_arguments()
  185. # Determine the task
  186. if args.file:
  187. task_str = read_task_from_file(args.file)
  188. elif args.task:
  189. task_str = args.task
  190. elif not sys.stdin.isatty():
  191. task_str = read_task_from_stdin()
  192. else:
  193. raise ValueError('No task provided. Please specify a task through -t, -f.')
  194. initial_user_action: MessageAction = MessageAction(content=task_str)
  195. # Load the app config
  196. # this will load config from config.toml in the current directory
  197. # as well as from the environment variables
  198. config = load_app_config(config_file=args.config_file)
  199. # Override default LLM configs ([llm] section in config.toml)
  200. if args.llm_config:
  201. llm_config = get_llm_config_arg(args.llm_config)
  202. if llm_config is None:
  203. raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
  204. config.set_llm_config(llm_config)
  205. # Set default agent
  206. config.default_agent = args.agent_cls
  207. # Set session name
  208. session_name = args.name
  209. sid = generate_sid(config, session_name)
  210. # if max budget per task is not sent on the command line, use the config value
  211. if args.max_budget_per_task is not None:
  212. config.max_budget_per_task = args.max_budget_per_task
  213. if args.max_iterations is not None:
  214. config.max_iterations = args.max_iterations
  215. asyncio.run(
  216. run_controller(
  217. config=config,
  218. initial_user_action=initial_user_action,
  219. sid=sid,
  220. )
  221. )