listen.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978
  1. import asyncio
  2. import os
  3. import re
  4. import tempfile
  5. import time
  6. import uuid
  7. import warnings
  8. import jwt
  9. import requests
  10. from pathspec import PathSpec
  11. from pathspec.patterns import GitWildMatchPattern
  12. from openhands.security.options import SecurityAnalyzers
  13. from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
  14. from openhands.server.github import (
  15. GITHUB_CLIENT_ID,
  16. GITHUB_CLIENT_SECRET,
  17. UserVerifier,
  18. authenticate_github_user,
  19. )
  20. from openhands.storage import get_file_store
  21. from openhands.utils.async_utils import call_sync_from_async
  22. with warnings.catch_warnings():
  23. warnings.simplefilter('ignore')
  24. import litellm
  25. from dotenv import load_dotenv
  26. from fastapi import (
  27. BackgroundTasks,
  28. FastAPI,
  29. HTTPException,
  30. Request,
  31. UploadFile,
  32. WebSocket,
  33. WebSocketDisconnect,
  34. status,
  35. )
  36. from fastapi.responses import FileResponse, JSONResponse
  37. from fastapi.security import HTTPBearer
  38. from fastapi.staticfiles import StaticFiles
  39. from pydantic import BaseModel
  40. import openhands.agenthub # noqa F401 (we import this to get the agents registered)
  41. from openhands.controller.agent import Agent
  42. from openhands.core.config import LLMConfig, load_app_config
  43. from openhands.core.logger import openhands_logger as logger
  44. from openhands.events.action import (
  45. ChangeAgentStateAction,
  46. FileReadAction,
  47. FileWriteAction,
  48. NullAction,
  49. )
  50. from openhands.events.observation import (
  51. AgentStateChangedObservation,
  52. ErrorObservation,
  53. FileReadObservation,
  54. FileWriteObservation,
  55. NullObservation,
  56. )
  57. from openhands.events.serialization import event_to_dict
  58. from openhands.events.stream import AsyncEventStreamWrapper
  59. from openhands.llm import bedrock
  60. from openhands.runtime.base import Runtime
  61. from openhands.server.auth.auth import get_sid_from_token, sign_token
  62. from openhands.server.middleware import (
  63. InMemoryRateLimiter,
  64. LocalhostCORSMiddleware,
  65. NoCacheMiddleware,
  66. RateLimitMiddleware,
  67. )
  68. from openhands.server.session import SessionManager
  69. load_dotenv()
  70. config = load_app_config()
  71. file_store = get_file_store(config.file_store, config.file_store_path)
  72. session_manager = SessionManager(config, file_store)
  73. app = FastAPI()
  74. app.add_middleware(
  75. LocalhostCORSMiddleware,
  76. allow_credentials=True,
  77. allow_methods=['*'],
  78. allow_headers=['*'],
  79. )
  80. app.add_middleware(NoCacheMiddleware)
  81. app.add_middleware(
  82. RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=10, seconds=1)
  83. )
  84. @app.get('/health')
  85. async def health():
  86. return 'OK'
  87. security_scheme = HTTPBearer()
  88. def load_file_upload_config() -> tuple[int, bool, list[str]]:
  89. """Load file upload configuration from the config object.
  90. This function retrieves the file upload settings from the global config object.
  91. It handles the following settings:
  92. - Maximum file size for uploads
  93. - Whether to restrict file types
  94. - List of allowed file extensions
  95. It also performs sanity checks on the values to ensure they are valid and safe.
  96. Returns:
  97. tuple: A tuple containing:
  98. - max_file_size_mb (int): Maximum file size in MB. 0 means no limit.
  99. - restrict_file_types (bool): Whether file type restrictions are enabled.
  100. - allowed_extensions (set): Set of allowed file extensions.
  101. """
  102. # Retrieve values from config
  103. max_file_size_mb = config.file_uploads_max_file_size_mb
  104. restrict_file_types = config.file_uploads_restrict_file_types
  105. allowed_extensions = config.file_uploads_allowed_extensions
  106. # Sanity check for max_file_size_mb
  107. if not isinstance(max_file_size_mb, int) or max_file_size_mb < 0:
  108. logger.warning(
  109. f'Invalid max_file_size_mb: {max_file_size_mb}. Setting to 0 (no limit).'
  110. )
  111. max_file_size_mb = 0
  112. # Sanity check for allowed_extensions
  113. if not isinstance(allowed_extensions, (list, set)) or not allowed_extensions:
  114. logger.warning(
  115. f'Invalid allowed_extensions: {allowed_extensions}. Setting to [".*"].'
  116. )
  117. allowed_extensions = ['.*']
  118. else:
  119. # Ensure all extensions start with a dot and are lowercase
  120. allowed_extensions = [
  121. ext.lower() if ext.startswith('.') else f'.{ext.lower()}'
  122. for ext in allowed_extensions
  123. ]
  124. # If restrictions are disabled, allow all
  125. if not restrict_file_types:
  126. allowed_extensions = ['.*']
  127. logger.debug(
  128. f'File upload config: max_size={max_file_size_mb}MB, '
  129. f'restrict_types={restrict_file_types}, '
  130. f'allowed_extensions={allowed_extensions}'
  131. )
  132. return max_file_size_mb, restrict_file_types, allowed_extensions
  133. # Load configuration
  134. MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config()
  135. def is_extension_allowed(filename):
  136. """Check if the file extension is allowed based on the current configuration.
  137. This function supports wildcards and files without extensions.
  138. The check is case-insensitive for extensions.
  139. Args:
  140. filename (str): The name of the file to check.
  141. Returns:
  142. bool: True if the file extension is allowed, False otherwise.
  143. """
  144. if not RESTRICT_FILE_TYPES:
  145. return True
  146. file_ext = os.path.splitext(filename)[1].lower() # Convert to lowercase
  147. return (
  148. '.*' in ALLOWED_EXTENSIONS
  149. or file_ext in (ext.lower() for ext in ALLOWED_EXTENSIONS)
  150. or (file_ext == '' and '.' in ALLOWED_EXTENSIONS)
  151. )
  152. @app.middleware('http')
  153. async def attach_session(request: Request, call_next):
  154. """Middleware to attach session information to the request.
  155. This middleware checks for the Authorization header, validates the token,
  156. and attaches the corresponding session to the request state.
  157. Args:
  158. request (Request): The incoming request object.
  159. call_next (Callable): The next middleware or route handler in the chain.
  160. Returns:
  161. Response: The response from the next middleware or route handler.
  162. """
  163. non_authed_paths = [
  164. '/api/options/',
  165. '/api/github/callback',
  166. '/api/authenticate',
  167. ]
  168. if any(
  169. request.url.path.startswith(path) for path in non_authed_paths
  170. ) or not request.url.path.startswith('/api/'):
  171. response = await call_next(request)
  172. return response
  173. # Bypass authentication for OPTIONS requests (preflight)
  174. if request.method == 'OPTIONS':
  175. response = await call_next(request)
  176. return response
  177. user_verifier = UserVerifier()
  178. if user_verifier.is_active():
  179. signed_token = request.cookies.get('github_auth')
  180. if not signed_token:
  181. return JSONResponse(
  182. status_code=status.HTTP_401_UNAUTHORIZED,
  183. content={'error': 'Not authenticated'},
  184. )
  185. try:
  186. jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
  187. except Exception as e:
  188. logger.warning(f'Invalid token: {e}')
  189. return JSONResponse(
  190. status_code=status.HTTP_401_UNAUTHORIZED,
  191. content={'error': 'Invalid token'},
  192. )
  193. if not request.headers.get('Authorization'):
  194. logger.warning('Missing Authorization header')
  195. return JSONResponse(
  196. status_code=status.HTTP_401_UNAUTHORIZED,
  197. content={'error': 'Missing Authorization header'},
  198. )
  199. auth_token = request.headers.get('Authorization')
  200. if 'Bearer' in auth_token:
  201. auth_token = auth_token.split('Bearer')[1].strip()
  202. request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
  203. if request.state.sid == '':
  204. logger.warning('Invalid token')
  205. return JSONResponse(
  206. status_code=status.HTTP_401_UNAUTHORIZED,
  207. content={'error': 'Invalid token'},
  208. )
  209. request.state.conversation = await session_manager.attach_to_conversation(
  210. request.state.sid
  211. )
  212. if not request.state.conversation:
  213. logger.error(f'Runtime not found for session: {request.state.sid}')
  214. return JSONResponse(
  215. status_code=status.HTTP_404_NOT_FOUND,
  216. content={'error': 'Session not found'},
  217. )
  218. try:
  219. response = await call_next(request)
  220. finally:
  221. await session_manager.detach_from_conversation(request.state.conversation)
  222. return response
  223. @app.websocket('/ws')
  224. async def websocket_endpoint(websocket: WebSocket):
  225. """WebSocket endpoint for receiving events from the client (i.e., the browser).
  226. Once connected, the client can send various actions:
  227. - Initialize the agent:
  228. session management, and event streaming.
  229. ```json
  230. {"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
  231. Args:
  232. ```
  233. websocket (WebSocket): The WebSocket connection object.
  234. - Start a new development task:
  235. ```json
  236. {"action": "start", "args": {"task": "write a bash script that prints hello"}}
  237. ```
  238. - Send a message:
  239. ```json
  240. {"action": "message", "args": {"content": "Hello, how are you?", "image_urls": ["base64_url1", "base64_url2"]}}
  241. ```
  242. - Write contents to a file:
  243. ```json
  244. {"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
  245. ```
  246. - Read the contents of a file:
  247. ```json
  248. {"action": "read", "args": {"path": "./greetings.txt"}}
  249. ```
  250. - Run a command:
  251. ```json
  252. {"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
  253. ```
  254. - Run an IPython command:
  255. ```json
  256. {"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
  257. ```
  258. - Open a web page:
  259. ```json
  260. {"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
  261. ```
  262. - Add a task to the root_task:
  263. ```json
  264. {"action": "add_task", "args": {"task": "Implement feature X"}}
  265. ```
  266. - Update a task in the root_task:
  267. ```json
  268. {"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
  269. ```
  270. - Change the agent's state:
  271. ```json
  272. {"action": "change_agent_state", "args": {"state": "paused"}}
  273. ```
  274. - Finish the task:
  275. ```json
  276. {"action": "finish", "args": {}}
  277. ```
  278. """
  279. # Get protocols from Sec-WebSocket-Protocol header
  280. protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
  281. # The first protocol should be our real protocol (e.g. 'openhands')
  282. # The second protocol should contain our auth token
  283. if len(protocols) < 3:
  284. logger.error('Expected 3 websocket protocols, got %d', len(protocols))
  285. await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
  286. return
  287. real_protocol = protocols[0]
  288. jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
  289. github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
  290. if not await authenticate_github_user(github_token):
  291. await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
  292. return
  293. await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
  294. if jwt_token:
  295. sid = get_sid_from_token(jwt_token, config.jwt_secret)
  296. if sid == '':
  297. await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
  298. await websocket.close()
  299. return
  300. else:
  301. sid = str(uuid.uuid4())
  302. jwt_token = sign_token({'sid': sid}, config.jwt_secret)
  303. logger.info(f'New session: {sid}')
  304. session = session_manager.add_or_restart_session(sid, websocket)
  305. await websocket.send_json({'token': jwt_token, 'status': 'ok'})
  306. latest_event_id = -1
  307. if websocket.query_params.get('latest_event_id'):
  308. try:
  309. latest_event_id = int(websocket.query_params.get('latest_event_id'))
  310. except ValueError:
  311. logger.warning(
  312. f'Invalid latest_event_id: {websocket.query_params.get("latest_event_id")}'
  313. )
  314. pass
  315. async_stream = AsyncEventStreamWrapper(
  316. session.agent_session.event_stream, latest_event_id + 1
  317. )
  318. async for event in async_stream:
  319. if isinstance(
  320. event,
  321. (
  322. NullAction,
  323. NullObservation,
  324. ChangeAgentStateAction,
  325. AgentStateChangedObservation,
  326. ),
  327. ):
  328. continue
  329. try:
  330. await websocket.send_json(event_to_dict(event))
  331. except WebSocketDisconnect:
  332. logger.warning(
  333. 'Websocket disconnected while sending event history, before loop started'
  334. )
  335. session.close()
  336. return
  337. await session.loop_recv()
  338. @app.get('/api/options/models')
  339. async def get_litellm_models() -> list[str]:
  340. """
  341. Get all models supported by LiteLLM.
  342. This function combines models from litellm and Bedrock, removing any
  343. error-prone Bedrock models.
  344. To get the models:
  345. ```sh
  346. curl http://localhost:3000/api/litellm-models
  347. ```
  348. Returns:
  349. list: A sorted list of unique model names.
  350. """
  351. litellm_model_list = litellm.model_list + list(litellm.model_cost.keys())
  352. litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
  353. litellm_model_list
  354. )
  355. # TODO: for bedrock, this is using the default config
  356. llm_config: LLMConfig = config.get_llm_config()
  357. bedrock_model_list = []
  358. if (
  359. llm_config.aws_region_name
  360. and llm_config.aws_access_key_id
  361. and llm_config.aws_secret_access_key
  362. ):
  363. bedrock_model_list = bedrock.list_foundation_models(
  364. llm_config.aws_region_name,
  365. llm_config.aws_access_key_id,
  366. llm_config.aws_secret_access_key,
  367. )
  368. model_list = litellm_model_list_without_bedrock + bedrock_model_list
  369. for llm_config in config.llms.values():
  370. ollama_base_url = llm_config.ollama_base_url
  371. if llm_config.model.startswith('ollama'):
  372. if not ollama_base_url:
  373. ollama_base_url = llm_config.base_url
  374. if ollama_base_url:
  375. ollama_url = ollama_base_url.strip('/') + '/api/tags'
  376. try:
  377. ollama_models_list = requests.get(ollama_url, timeout=3).json()[
  378. 'models'
  379. ]
  380. for model in ollama_models_list:
  381. model_list.append('ollama/' + model['name'])
  382. break
  383. except requests.exceptions.RequestException as e:
  384. logger.error(f'Error getting OLLAMA models: {e}', exc_info=True)
  385. return list(sorted(set(model_list)))
  386. @app.get('/api/options/agents')
  387. async def get_agents():
  388. """Get all agents supported by LiteLLM.
  389. To get the agents:
  390. ```sh
  391. curl http://localhost:3000/api/agents
  392. ```
  393. Returns:
  394. list: A sorted list of agent names.
  395. """
  396. agents = sorted(Agent.list_agents())
  397. return agents
  398. @app.get('/api/options/security-analyzers')
  399. async def get_security_analyzers():
  400. """Get all supported security analyzers.
  401. To get the security analyzers:
  402. ```sh
  403. curl http://localhost:3000/api/security-analyzers
  404. ```
  405. Returns:
  406. list: A sorted list of security analyzer names.
  407. """
  408. return sorted(SecurityAnalyzers.keys())
  409. FILES_TO_IGNORE = [
  410. '.git/',
  411. '.DS_Store',
  412. 'node_modules/',
  413. '__pycache__/',
  414. ]
  415. @app.get('/api/list-files')
  416. async def list_files(request: Request, path: str | None = None):
  417. """List files in the specified path.
  418. This function retrieves a list of files from the agent's runtime file store,
  419. excluding certain system and hidden files/directories.
  420. To list files:
  421. ```sh
  422. curl http://localhost:3000/api/list-files
  423. ```
  424. Args:
  425. request (Request): The incoming request object.
  426. path (str, optional): The path to list files from. Defaults to None.
  427. Returns:
  428. list: A list of file names in the specified path.
  429. Raises:
  430. HTTPException: If there's an error listing the files.
  431. """
  432. if not request.state.conversation.runtime:
  433. return JSONResponse(
  434. status_code=status.HTTP_404_NOT_FOUND,
  435. content={'error': 'Runtime not yet initialized'},
  436. )
  437. runtime: Runtime = request.state.conversation.runtime
  438. file_list = await call_sync_from_async(runtime.list_files, path)
  439. if path:
  440. file_list = [os.path.join(path, f) for f in file_list]
  441. file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
  442. async def filter_for_gitignore(file_list, base_path):
  443. gitignore_path = os.path.join(base_path, '.gitignore')
  444. try:
  445. read_action = FileReadAction(gitignore_path)
  446. observation = await call_sync_from_async(runtime.run_action, read_action)
  447. spec = PathSpec.from_lines(
  448. GitWildMatchPattern, observation.content.splitlines()
  449. )
  450. except Exception as e:
  451. logger.warning(e)
  452. return file_list
  453. file_list = [entry for entry in file_list if not spec.match_file(entry)]
  454. return file_list
  455. file_list = await filter_for_gitignore(file_list, '')
  456. return file_list
  457. @app.get('/api/select-file')
  458. async def select_file(file: str, request: Request):
  459. """Retrieve the content of a specified file.
  460. To select a file:
  461. ```sh
  462. curl http://localhost:3000/api/select-file?file=<file_path>
  463. ```
  464. Args:
  465. file (str): The path of the file to be retrieved.
  466. Expect path to be absolute inside the runtime.
  467. request (Request): The incoming request object.
  468. Returns:
  469. dict: A dictionary containing the file content.
  470. Raises:
  471. HTTPException: If there's an error opening the file.
  472. """
  473. runtime: Runtime = request.state.conversation.runtime
  474. file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
  475. read_action = FileReadAction(file)
  476. observation = await call_sync_from_async(runtime.run_action, read_action)
  477. if isinstance(observation, FileReadObservation):
  478. content = observation.content
  479. return {'code': content}
  480. elif isinstance(observation, ErrorObservation):
  481. logger.error(f'Error opening file {file}: {observation}', exc_info=False)
  482. return JSONResponse(
  483. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  484. content={'error': f'Error opening file: {observation}'},
  485. )
  486. def sanitize_filename(filename):
  487. """Sanitize the filename to prevent directory traversal"""
  488. # Remove any directory components
  489. filename = os.path.basename(filename)
  490. # Remove any non-alphanumeric characters except for .-_
  491. filename = re.sub(r'[^\w\-_\.]', '', filename)
  492. # Limit the filename length
  493. max_length = 255
  494. if len(filename) > max_length:
  495. name, ext = os.path.splitext(filename)
  496. filename = name[: max_length - len(ext)] + ext
  497. return filename
  498. @app.get('/api/conversation')
  499. async def get_remote_runtime_config(request: Request):
  500. """Retrieve the remote runtime configuration.
  501. Currently, this is the runtime ID.
  502. """
  503. runtime = request.state.conversation.runtime
  504. runtime_id = runtime.runtime_id if hasattr(runtime, 'runtime_id') else None
  505. session_id = runtime.sid if hasattr(runtime, 'sid') else None
  506. return JSONResponse(
  507. content={
  508. 'runtime_id': runtime_id,
  509. 'session_id': session_id,
  510. }
  511. )
  512. @app.post('/api/upload-files')
  513. async def upload_file(request: Request, files: list[UploadFile]):
  514. """Upload a list of files to the workspace.
  515. To upload a files:
  516. ```sh
  517. curl -X POST -F "file=@<file_path1>" -F "file=@<file_path2>" http://localhost:3000/api/upload-files
  518. ```
  519. Args:
  520. request (Request): The incoming request object.
  521. files (list[UploadFile]): A list of files to be uploaded.
  522. Returns:
  523. dict: A message indicating the success of the upload operation.
  524. Raises:
  525. HTTPException: If there's an error saving the files.
  526. """
  527. try:
  528. uploaded_files = []
  529. skipped_files = []
  530. for file in files:
  531. safe_filename = sanitize_filename(file.filename)
  532. file_contents = await file.read()
  533. if (
  534. MAX_FILE_SIZE_MB > 0
  535. and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024
  536. ):
  537. skipped_files.append(
  538. {
  539. 'name': safe_filename,
  540. 'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB',
  541. }
  542. )
  543. continue
  544. if not is_extension_allowed(safe_filename):
  545. skipped_files.append(
  546. {'name': safe_filename, 'reason': 'File type not allowed'}
  547. )
  548. continue
  549. # copy the file to the runtime
  550. with tempfile.TemporaryDirectory() as tmp_dir:
  551. tmp_file_path = os.path.join(tmp_dir, safe_filename)
  552. with open(tmp_file_path, 'wb') as tmp_file:
  553. tmp_file.write(file_contents)
  554. tmp_file.flush()
  555. runtime: Runtime = request.state.conversation.runtime
  556. runtime.copy_to(
  557. tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
  558. )
  559. uploaded_files.append(safe_filename)
  560. response_content = {
  561. 'message': 'File upload process completed',
  562. 'uploaded_files': uploaded_files,
  563. 'skipped_files': skipped_files,
  564. }
  565. if not uploaded_files and skipped_files:
  566. return JSONResponse(
  567. status_code=status.HTTP_400_BAD_REQUEST,
  568. content={
  569. **response_content,
  570. 'error': 'No files were uploaded successfully',
  571. },
  572. )
  573. return JSONResponse(status_code=status.HTTP_200_OK, content=response_content)
  574. except Exception as e:
  575. logger.error(f'Error during file upload: {e}', exc_info=True)
  576. return JSONResponse(
  577. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  578. content={
  579. 'error': f'Error during file upload: {str(e)}',
  580. 'uploaded_files': [],
  581. 'skipped_files': [],
  582. },
  583. )
  584. @app.post('/api/submit-feedback')
  585. async def submit_feedback(request: Request):
  586. """Submit user feedback.
  587. This function stores the provided feedback data.
  588. To submit feedback:
  589. ```sh
  590. curl -X POST -d '{"email": "test@example.com"}' -H "Authorization:"
  591. ```
  592. Args:
  593. request (Request): The incoming request object.
  594. feedback (FeedbackDataModel): The feedback data to be stored.
  595. Returns:
  596. dict: The stored feedback data.
  597. Raises:
  598. HTTPException: If there's an error submitting the feedback.
  599. """
  600. # Assuming the storage service is already configured in the backend
  601. # and there is a function to handle the storage.
  602. body = await request.json()
  603. async_stream = AsyncEventStreamWrapper(
  604. request.state.conversation.event_stream, filter_hidden=True
  605. )
  606. trajectory = []
  607. async for event in async_stream:
  608. trajectory.append(event_to_dict(event))
  609. feedback = FeedbackDataModel(
  610. email=body.get('email', ''),
  611. version=body.get('version', ''),
  612. permissions=body.get('permissions', 'private'),
  613. polarity=body.get('polarity', ''),
  614. feedback=body.get('polarity', ''),
  615. trajectory=trajectory,
  616. )
  617. try:
  618. feedback_data = await call_sync_from_async(store_feedback, feedback)
  619. return JSONResponse(status_code=200, content=feedback_data)
  620. except Exception as e:
  621. logger.error(f'Error submitting feedback: {e}')
  622. return JSONResponse(
  623. status_code=500, content={'error': 'Failed to submit feedback'}
  624. )
  625. @app.get('/api/defaults')
  626. async def appconfig_defaults():
  627. """Retrieve the default configuration settings.
  628. To get the default configurations:
  629. ```sh
  630. curl http://localhost:3000/api/defaults
  631. ```
  632. Returns:
  633. dict: The default configuration settings.
  634. """
  635. return config.defaults_dict
  636. @app.post('/api/save-file')
  637. async def save_file(request: Request):
  638. """Save a file to the agent's runtime file store.
  639. This endpoint allows saving a file when the agent is in a paused, finished,
  640. or awaiting user input state. It checks the agent's state before proceeding
  641. with the file save operation.
  642. Args:
  643. request (Request): The incoming FastAPI request object.
  644. Returns:
  645. JSONResponse: A JSON response indicating the success of the operation.
  646. Raises:
  647. HTTPException:
  648. - 403 error if the agent is not in an allowed state for editing.
  649. - 400 error if the file path or content is missing.
  650. - 500 error if there's an unexpected error during the save operation.
  651. """
  652. try:
  653. # Extract file path and content from the request
  654. data = await request.json()
  655. file_path = data.get('filePath')
  656. content = data.get('content')
  657. # Validate the presence of required data
  658. if not file_path or content is None:
  659. raise HTTPException(status_code=400, detail='Missing filePath or content')
  660. # Save the file to the agent's runtime file store
  661. runtime: Runtime = request.state.conversation.runtime
  662. file_path = os.path.join(
  663. runtime.config.workspace_mount_path_in_sandbox, file_path
  664. )
  665. write_action = FileWriteAction(file_path, content)
  666. observation = await call_sync_from_async(runtime.run_action, write_action)
  667. if isinstance(observation, FileWriteObservation):
  668. return JSONResponse(
  669. status_code=200, content={'message': 'File saved successfully'}
  670. )
  671. elif isinstance(observation, ErrorObservation):
  672. return JSONResponse(
  673. status_code=500,
  674. content={'error': f'Failed to save file: {observation}'},
  675. )
  676. else:
  677. return JSONResponse(
  678. status_code=500,
  679. content={'error': f'Unexpected observation: {observation}'},
  680. )
  681. except Exception as e:
  682. # Log the error and return a 500 response
  683. logger.error(f'Error saving file: {e}', exc_info=True)
  684. raise HTTPException(status_code=500, detail=f'Error saving file: {e}')
  685. @app.route('/api/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
  686. async def security_api(request: Request):
  687. """Catch-all route for security analyzer API requests.
  688. Each request is handled directly to the security analyzer.
  689. Args:
  690. request (Request): The incoming FastAPI request object.
  691. Returns:
  692. Any: The response from the security analyzer.
  693. Raises:
  694. HTTPException: If the security analyzer is not initialized.
  695. """
  696. if not request.state.conversation.security_analyzer:
  697. raise HTTPException(status_code=404, detail='Security analyzer not initialized')
  698. return await request.state.conversation.security_analyzer.handle_api_request(
  699. request
  700. )
  701. @app.get('/api/zip-directory')
  702. async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks):
  703. try:
  704. logger.debug('Zipping workspace')
  705. runtime: Runtime = request.state.conversation.runtime
  706. path = runtime.config.workspace_mount_path_in_sandbox
  707. zip_file = await call_sync_from_async(runtime.copy_from, path)
  708. response = FileResponse(
  709. path=zip_file,
  710. filename='workspace.zip',
  711. media_type='application/x-zip-compressed',
  712. )
  713. # This will execute after the response is sent (So the file is not deleted before being sent)
  714. background_tasks.add_task(zip_file.unlink)
  715. return response
  716. except Exception as e:
  717. logger.error(f'Error zipping workspace: {e}', exc_info=True)
  718. raise HTTPException(
  719. status_code=500,
  720. detail='Failed to zip workspace',
  721. )
  722. class AuthCode(BaseModel):
  723. code: str
  724. @app.post('/api/github/callback')
  725. def github_callback(auth_code: AuthCode):
  726. # Prepare data for the token exchange request
  727. data = {
  728. 'client_id': GITHUB_CLIENT_ID,
  729. 'client_secret': GITHUB_CLIENT_SECRET,
  730. 'code': auth_code.code,
  731. }
  732. logger.debug('Exchanging code for GitHub token')
  733. headers = {'Accept': 'application/json'}
  734. response = requests.post(
  735. 'https://github.com/login/oauth/access_token', data=data, headers=headers
  736. )
  737. if response.status_code != 200:
  738. logger.error(f'Failed to exchange code for token: {response.text}')
  739. return JSONResponse(
  740. status_code=status.HTTP_400_BAD_REQUEST,
  741. content={'error': 'Failed to exchange code for token'},
  742. )
  743. token_response = response.json()
  744. if 'access_token' not in token_response:
  745. return JSONResponse(
  746. status_code=status.HTTP_400_BAD_REQUEST,
  747. content={'error': 'No access token in response'},
  748. )
  749. return JSONResponse(
  750. status_code=status.HTTP_200_OK,
  751. content={'access_token': token_response['access_token']},
  752. )
  753. @app.post('/api/authenticate')
  754. async def authenticate(request: Request):
  755. token = request.headers.get('X-GitHub-Token')
  756. if not await authenticate_github_user(token):
  757. return JSONResponse(
  758. status_code=status.HTTP_401_UNAUTHORIZED,
  759. content={'error': 'Not authorized via GitHub waitlist'},
  760. )
  761. # Create a signed JWT token with 1-hour expiration
  762. cookie_data = {
  763. 'github_token': token,
  764. 'exp': int(time.time()) + 3600, # 1 hour expiration
  765. }
  766. signed_token = sign_token(cookie_data, config.jwt_secret)
  767. response = JSONResponse(
  768. status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
  769. )
  770. # Set secure cookie with signed token
  771. response.set_cookie(
  772. key='github_auth',
  773. value=signed_token,
  774. max_age=3600, # 1 hour in seconds
  775. httponly=True,
  776. secure=True,
  777. samesite='strict',
  778. )
  779. return response
  780. @app.get('/api/vscode-url')
  781. async def get_vscode_url(request: Request):
  782. """Get the VSCode URL.
  783. This endpoint allows getting the VSCode URL.
  784. Args:
  785. request (Request): The incoming FastAPI request object.
  786. Returns:
  787. JSONResponse: A JSON response indicating the success of the operation.
  788. """
  789. try:
  790. runtime: Runtime = request.state.conversation.runtime
  791. logger.debug(f'Runtime type: {type(runtime)}')
  792. logger.debug(f'Runtime VSCode URL: {runtime.vscode_url}')
  793. return JSONResponse(status_code=200, content={'vscode_url': runtime.vscode_url})
  794. except Exception as e:
  795. logger.error(f'Error getting VSCode URL: {e}', exc_info=True)
  796. return JSONResponse(
  797. status_code=500,
  798. content={
  799. 'vscode_url': None,
  800. 'error': f'Error getting VSCode URL: {e}',
  801. },
  802. )
  803. class SPAStaticFiles(StaticFiles):
  804. async def get_response(self, path: str, scope):
  805. try:
  806. return await super().get_response(path, scope)
  807. except Exception:
  808. # FIXME: just making this HTTPException doesn't work for some reason
  809. return await super().get_response('index.html', scope)
  810. app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')