|
|
@@ -1,4 +1,5 @@
|
|
|
import asyncio
|
|
|
+import hashlib
|
|
|
import sys
|
|
|
import uuid
|
|
|
from typing import Callable, Type
|
|
|
@@ -47,9 +48,13 @@ async def create_runtime(
|
|
|
sid: The session id.
|
|
|
runtime_tools_config: (will be deprecated) The runtime tools config.
|
|
|
"""
|
|
|
+ # if sid is provided on the command line, use it as the name of the event stream
|
|
|
+ # otherwise generate it on the basis of the configured jwt_secret
|
|
|
+ # we can do this better, this is just so that the sid is retrieved when we want to restore the session
|
|
|
+ session_id = sid or generate_sid(config)
|
|
|
+
|
|
|
# set up the event stream
|
|
|
file_store = get_file_store(config.file_store, config.file_store_path)
|
|
|
- session_id = 'main' + ('_' + sid if sid else str(uuid.uuid4()))
|
|
|
event_stream = EventStream(session_id, file_store)
|
|
|
|
|
|
# agent class
|
|
|
@@ -72,6 +77,7 @@ async def create_runtime(
|
|
|
async def run_controller(
|
|
|
config: AppConfig,
|
|
|
task_str: str,
|
|
|
+ sid: str | None = None,
|
|
|
runtime: Runtime | None = None,
|
|
|
agent: Agent | None = None,
|
|
|
exit_on_message: bool = False,
|
|
|
@@ -100,15 +106,18 @@ async def run_controller(
|
|
|
config=agent_config,
|
|
|
)
|
|
|
|
|
|
+ # make sure the session id is set
|
|
|
+ sid = sid or generate_sid(config)
|
|
|
+
|
|
|
if runtime is None:
|
|
|
- runtime = await create_runtime(config)
|
|
|
+ runtime = await create_runtime(config, sid=sid)
|
|
|
|
|
|
event_stream = runtime.event_stream
|
|
|
# restore cli session if enabled
|
|
|
initial_state = None
|
|
|
if config.enable_cli_session:
|
|
|
try:
|
|
|
- logger.info('Restoring agent state from cli session')
|
|
|
+ logger.info(f'Restoring agent state from cli session {event_stream.sid}')
|
|
|
initial_state = State.restore_from_session(
|
|
|
event_stream.sid, event_stream.file_store
|
|
|
)
|
|
|
@@ -179,6 +188,15 @@ async def run_controller(
|
|
|
return state
|
|
|
|
|
|
|
|
|
+def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
|
|
|
+ """Generate a session id based on the session name and the jwt secret."""
|
|
|
+ session_name = session_name or str(uuid.uuid4())
|
|
|
+ jwt_secret = config.jwt_secret
|
|
|
+
|
|
|
+ hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest()
|
|
|
+ return f'{session_name}_{hash_str[:16]}'
|
|
|
+
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
args = parse_arguments()
|
|
|
|
|
|
@@ -207,6 +225,10 @@ if __name__ == '__main__':
|
|
|
# Set default agent
|
|
|
config.default_agent = args.agent_cls
|
|
|
|
|
|
+ # Set session name
|
|
|
+ session_name = args.name
|
|
|
+ sid = generate_sid(config, session_name)
|
|
|
+
|
|
|
# if max budget per task is not sent on the command line, use the config value
|
|
|
if args.max_budget_per_task is not None:
|
|
|
config.max_budget_per_task = args.max_budget_per_task
|
|
|
@@ -217,5 +239,6 @@ if __name__ == '__main__':
|
|
|
run_controller(
|
|
|
config=config,
|
|
|
task_str=task_str,
|
|
|
+ sid=sid,
|
|
|
)
|
|
|
)
|