main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import asyncio
  2. import hashlib
  3. import json
  4. import os
  5. import sys
  6. import uuid
  7. from typing import Callable, Protocol, Type
  8. import openhands.agenthub # noqa F401 (we import this to get the agents registered)
  9. from openhands.controller import AgentController
  10. from openhands.controller.agent import Agent
  11. from openhands.controller.state.state import State
  12. from openhands.core.config import (
  13. AppConfig,
  14. get_llm_config_arg,
  15. load_app_config,
  16. parse_arguments,
  17. )
  18. from openhands.core.logger import openhands_logger as logger
  19. from openhands.core.loop import run_agent_until_done
  20. from openhands.core.schema import AgentState
  21. from openhands.events import EventSource, EventStream, EventStreamSubscriber
  22. from openhands.events.action import MessageAction
  23. from openhands.events.action.action import Action
  24. from openhands.events.event import Event
  25. from openhands.events.observation import AgentStateChangedObservation
  26. from openhands.events.serialization.event import event_to_trajectory
  27. from openhands.llm.llm import LLM
  28. from openhands.runtime import get_runtime_cls
  29. from openhands.runtime.base import Runtime
  30. from openhands.storage import get_file_store
  31. class FakeUserResponseFunc(Protocol):
  32. def __call__(
  33. self,
  34. state: State,
  35. encapsulate_solution: bool = False,
  36. try_parse: Callable[[Action | None], str] | None = None,
  37. ) -> str: ...
  38. def read_task_from_file(file_path: str) -> str:
  39. """Read task from the specified file."""
  40. with open(file_path, 'r', encoding='utf-8') as file:
  41. return file.read()
  42. def read_task_from_stdin() -> str:
  43. """Read task from stdin."""
  44. return sys.stdin.read()
  45. def create_runtime(
  46. config: AppConfig,
  47. sid: str | None = None,
  48. headless_mode: bool = True,
  49. ) -> Runtime:
  50. """Create a runtime for the agent to run on.
  51. config: The app config.
  52. sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
  53. Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
  54. headless_mode: Whether the agent is run in headless mode. `create_runtime` is typically called within evaluation scripts,
  55. where we don't want to have the VSCode UI open, so it defaults to True.
  56. """
  57. # if sid is provided on the command line, use it as the name of the event stream
  58. # otherwise generate it on the basis of the configured jwt_secret
  59. # we can do this better, this is just so that the sid is retrieved when we want to restore the session
  60. session_id = sid or generate_sid(config)
  61. # set up the event stream
  62. file_store = get_file_store(config.file_store, config.file_store_path)
  63. event_stream = EventStream(session_id, file_store)
  64. # agent class
  65. agent_cls = openhands.agenthub.Agent.get_cls(config.default_agent)
  66. # runtime and tools
  67. runtime_cls = get_runtime_cls(config.runtime)
  68. logger.debug(f'Initializing runtime: {runtime_cls.__name__}')
  69. runtime: Runtime = runtime_cls(
  70. config=config,
  71. event_stream=event_stream,
  72. sid=session_id,
  73. plugins=agent_cls.sandbox_plugins,
  74. headless_mode=headless_mode,
  75. )
  76. return runtime
  77. async def run_controller(
  78. config: AppConfig,
  79. initial_user_action: Action,
  80. sid: str | None = None,
  81. runtime: Runtime | None = None,
  82. agent: Agent | None = None,
  83. exit_on_message: bool = False,
  84. fake_user_response_fn: FakeUserResponseFunc | None = None,
  85. headless_mode: bool = True,
  86. ) -> State | None:
  87. """Main coroutine to run the agent controller with task input flexibility.
  88. It's only used when you launch openhands backend directly via cmdline.
  89. Args:
  90. config: The app config.
  91. initial_user_action: An Action object containing initial user input
  92. sid: (optional) The session id. IMPORTANT: please don't set this unless you know what you're doing.
  93. Set it to incompatible value will cause unexpected behavior on RemoteRuntime.
  94. runtime: (optional) A runtime for the agent to run on.
  95. agent: (optional) A agent to run.
  96. exit_on_message: quit if agent asks for a message from user (optional)
  97. fake_user_response_fn: An optional function that receives the current state
  98. (could be None) and returns a fake user response.
  99. headless_mode: Whether the agent is run in headless mode.
  100. """
  101. # Create the agent
  102. if agent is None:
  103. agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
  104. agent_config = config.get_agent_config(config.default_agent)
  105. llm_config = config.get_llm_config_from_agent(config.default_agent)
  106. agent = agent_cls(
  107. llm=LLM(config=llm_config),
  108. config=agent_config,
  109. )
  110. # make sure the session id is set
  111. sid = sid or generate_sid(config)
  112. if runtime is None:
  113. runtime = create_runtime(config, sid=sid, headless_mode=headless_mode)
  114. await runtime.connect()
  115. event_stream = runtime.event_stream
  116. # restore cli session if available
  117. initial_state = None
  118. try:
  119. logger.debug(
  120. f'Trying to restore agent state from cli session {event_stream.sid} if available'
  121. )
  122. initial_state = State.restore_from_session(
  123. event_stream.sid, event_stream.file_store
  124. )
  125. except Exception as e:
  126. logger.debug(f'Cannot restore agent state: {e}')
  127. # init controller with this initial state
  128. controller = AgentController(
  129. agent=agent,
  130. max_iterations=config.max_iterations,
  131. max_budget_per_task=config.max_budget_per_task,
  132. agent_to_llm_config=config.get_agent_to_llm_config_map(),
  133. event_stream=event_stream,
  134. initial_state=initial_state,
  135. headless_mode=headless_mode,
  136. )
  137. assert isinstance(
  138. initial_user_action, Action
  139. ), f'initial user actions must be an Action, got {type(initial_user_action)}'
  140. # Logging
  141. logger.debug(
  142. f'Agent Controller Initialized: Running agent {agent.name}, model '
  143. f'{agent.llm.config.model}, with actions: {initial_user_action}'
  144. )
  145. # start event is a MessageAction with the task, either resumed or new
  146. if initial_state is not None:
  147. # we're resuming the previous session
  148. event_stream.add_event(
  149. MessageAction(
  150. content=(
  151. "Let's get back on track. If you experienced errors before, do "
  152. 'NOT resume your task. Ask me about it.'
  153. ),
  154. ),
  155. EventSource.USER,
  156. )
  157. else:
  158. # init with the provided actions
  159. event_stream.add_event(initial_user_action, EventSource.USER)
  160. async def on_event(event: Event):
  161. if isinstance(event, AgentStateChangedObservation):
  162. if event.agent_state == AgentState.AWAITING_USER_INPUT:
  163. if exit_on_message:
  164. message = '/exit'
  165. elif fake_user_response_fn is None:
  166. # read until EOF (Ctrl+D on Unix, Ctrl+Z on Windows)
  167. print('Request user input (press Ctrl+D/Z when done) >> ')
  168. message = sys.stdin.read().rstrip()
  169. else:
  170. message = fake_user_response_fn(controller.get_state())
  171. action = MessageAction(content=message)
  172. event_stream.add_event(action, EventSource.USER)
  173. event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)
  174. end_states = [
  175. AgentState.FINISHED,
  176. AgentState.REJECTED,
  177. AgentState.ERROR,
  178. AgentState.PAUSED,
  179. AgentState.STOPPED,
  180. ]
  181. try:
  182. await run_agent_until_done(controller, runtime, end_states)
  183. except Exception as e:
  184. logger.error(f'Exception in main loop: {e}')
  185. # save session when we're about to close
  186. if config.file_store is not None and config.file_store != 'memory':
  187. end_state = controller.get_state()
  188. # NOTE: the saved state does not include delegates events
  189. end_state.save_to_session(event_stream.sid, event_stream.file_store)
  190. state = controller.get_state()
  191. # save trajectories if applicable
  192. if config.trajectories_path is not None:
  193. # if trajectories_path is a folder, use session id as file name
  194. if os.path.isdir(config.trajectories_path):
  195. file_path = os.path.join(config.trajectories_path, sid + '.json')
  196. else:
  197. file_path = config.trajectories_path
  198. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  199. histories = [event_to_trajectory(event) for event in state.history]
  200. with open(file_path, 'w') as f:
  201. json.dump(histories, f)
  202. return state
  203. def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
  204. """Generate a session id based on the session name and the jwt secret."""
  205. session_name = session_name or str(uuid.uuid4())
  206. jwt_secret = config.jwt_secret
  207. hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest()
  208. return f'{session_name}-{hash_str[:16]}'
  209. def auto_continue_response(
  210. state: State,
  211. encapsulate_solution: bool = False,
  212. try_parse: Callable[[Action | None], str] | None = None,
  213. ) -> str:
  214. """Default function to generate user responses.
  215. Returns 'continue' to tell the agent to proceed without asking for more input.
  216. """
  217. return 'continue'
  218. if __name__ == '__main__':
  219. args = parse_arguments()
  220. # Determine the task
  221. if args.file:
  222. task_str = read_task_from_file(args.file)
  223. elif args.task:
  224. task_str = args.task
  225. elif not sys.stdin.isatty():
  226. task_str = read_task_from_stdin()
  227. else:
  228. raise ValueError('No task provided. Please specify a task through -t, -f.')
  229. initial_user_action: MessageAction = MessageAction(content=task_str)
  230. # Load the app config
  231. # this will load config from config.toml in the current directory
  232. # as well as from the environment variables
  233. config = load_app_config(config_file=args.config_file)
  234. # Override default LLM configs ([llm] section in config.toml)
  235. if args.llm_config:
  236. llm_config = get_llm_config_arg(args.llm_config)
  237. if llm_config is None:
  238. raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
  239. config.set_llm_config(llm_config)
  240. # Set default agent
  241. config.default_agent = args.agent_cls
  242. # Set session name
  243. session_name = args.name
  244. sid = generate_sid(config, session_name)
  245. # if max budget per task is not sent on the command line, use the config value
  246. if args.max_budget_per_task is not None:
  247. config.max_budget_per_task = args.max_budget_per_task
  248. if args.max_iterations is not None:
  249. config.max_iterations = args.max_iterations
  250. asyncio.run(
  251. run_controller(
  252. config=config,
  253. initial_user_action=initial_user_action,
  254. sid=sid,
  255. fake_user_response_fn=None
  256. if args.no_auto_continue
  257. else auto_continue_response,
  258. )
  259. )