listen.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. import asyncio
  2. import io
  3. import os
  4. import re
  5. import tempfile
  6. import uuid
  7. import warnings
  8. from contextlib import asynccontextmanager
  9. import requests
  10. from pathspec import PathSpec
  11. from pathspec.patterns import GitWildMatchPattern
  12. from openhands.security.options import SecurityAnalyzers
  13. from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
  14. from openhands.storage import get_file_store
  15. from openhands.utils.async_utils import call_sync_from_async
  16. with warnings.catch_warnings():
  17. warnings.simplefilter('ignore')
  18. import litellm
  19. from dotenv import load_dotenv
  20. from fastapi import (
  21. FastAPI,
  22. HTTPException,
  23. Request,
  24. UploadFile,
  25. WebSocket,
  26. status,
  27. )
  28. from fastapi.middleware.cors import CORSMiddleware
  29. from fastapi.responses import JSONResponse, StreamingResponse
  30. from fastapi.security import HTTPBearer
  31. from fastapi.staticfiles import StaticFiles
  32. from pydantic import BaseModel
  33. from starlette.middleware.base import BaseHTTPMiddleware
  34. import openhands.agenthub # noqa F401 (we import this to get the agents registered)
  35. from openhands.controller.agent import Agent
  36. from openhands.core.config import LLMConfig, load_app_config
  37. from openhands.core.logger import openhands_logger as logger
  38. from openhands.events.action import (
  39. ChangeAgentStateAction,
  40. FileReadAction,
  41. FileWriteAction,
  42. NullAction,
  43. )
  44. from openhands.events.observation import (
  45. AgentStateChangedObservation,
  46. ErrorObservation,
  47. FileReadObservation,
  48. FileWriteObservation,
  49. NullObservation,
  50. )
  51. from openhands.events.serialization import event_to_dict
  52. from openhands.llm import bedrock
  53. from openhands.runtime.base import Runtime
  54. from openhands.server.auth import get_sid_from_token, sign_token
  55. from openhands.server.session import SessionManager
  56. load_dotenv()
  57. config = load_app_config()
  58. file_store = get_file_store(config.file_store, config.file_store_path)
  59. session_manager = SessionManager(config, file_store)
  60. GITHUB_CLIENT_ID = os.getenv('GITHUB_CLIENT_ID', '').strip()
  61. GITHUB_CLIENT_SECRET = os.getenv('GITHUB_CLIENT_SECRET', '').strip()
  62. @asynccontextmanager
  63. async def lifespan(app: FastAPI):
  64. global session_manager
  65. async with session_manager:
  66. yield
  67. app = FastAPI(lifespan=lifespan)
  68. app.add_middleware(
  69. CORSMiddleware,
  70. allow_origins=['http://localhost:3001', 'http://127.0.0.1:3001'],
  71. allow_credentials=True,
  72. allow_methods=['*'],
  73. allow_headers=['*'],
  74. )
  75. class NoCacheMiddleware(BaseHTTPMiddleware):
  76. """
  77. Middleware to disable caching for all routes by adding appropriate headers
  78. """
  79. async def dispatch(self, request, call_next):
  80. response = await call_next(request)
  81. if not request.url.path.startswith('/assets'):
  82. response.headers['Cache-Control'] = (
  83. 'no-cache, no-store, must-revalidate, max-age=0'
  84. )
  85. response.headers['Pragma'] = 'no-cache'
  86. response.headers['Expires'] = '0'
  87. return response
  88. app.add_middleware(NoCacheMiddleware)
  89. security_scheme = HTTPBearer()
  90. def load_file_upload_config() -> tuple[int, bool, list[str]]:
  91. """Load file upload configuration from the config object.
  92. This function retrieves the file upload settings from the global config object.
  93. It handles the following settings:
  94. - Maximum file size for uploads
  95. - Whether to restrict file types
  96. - List of allowed file extensions
  97. It also performs sanity checks on the values to ensure they are valid and safe.
  98. Returns:
  99. tuple: A tuple containing:
  100. - max_file_size_mb (int): Maximum file size in MB. 0 means no limit.
  101. - restrict_file_types (bool): Whether file type restrictions are enabled.
  102. - allowed_extensions (set): Set of allowed file extensions.
  103. """
  104. # Retrieve values from config
  105. max_file_size_mb = config.file_uploads_max_file_size_mb
  106. restrict_file_types = config.file_uploads_restrict_file_types
  107. allowed_extensions = config.file_uploads_allowed_extensions
  108. # Sanity check for max_file_size_mb
  109. if not isinstance(max_file_size_mb, int) or max_file_size_mb < 0:
  110. logger.warning(
  111. f'Invalid max_file_size_mb: {max_file_size_mb}. Setting to 0 (no limit).'
  112. )
  113. max_file_size_mb = 0
  114. # Sanity check for allowed_extensions
  115. if not isinstance(allowed_extensions, (list, set)) or not allowed_extensions:
  116. logger.warning(
  117. f'Invalid allowed_extensions: {allowed_extensions}. Setting to [".*"].'
  118. )
  119. allowed_extensions = ['.*']
  120. else:
  121. # Ensure all extensions start with a dot and are lowercase
  122. allowed_extensions = [
  123. ext.lower() if ext.startswith('.') else f'.{ext.lower()}'
  124. for ext in allowed_extensions
  125. ]
  126. # If restrictions are disabled, allow all
  127. if not restrict_file_types:
  128. allowed_extensions = ['.*']
  129. logger.debug(
  130. f'File upload config: max_size={max_file_size_mb}MB, '
  131. f'restrict_types={restrict_file_types}, '
  132. f'allowed_extensions={allowed_extensions}'
  133. )
  134. return max_file_size_mb, restrict_file_types, allowed_extensions
  135. # Load configuration
  136. MAX_FILE_SIZE_MB, RESTRICT_FILE_TYPES, ALLOWED_EXTENSIONS = load_file_upload_config()
  137. def is_extension_allowed(filename):
  138. """Check if the file extension is allowed based on the current configuration.
  139. This function supports wildcards and files without extensions.
  140. The check is case-insensitive for extensions.
  141. Args:
  142. filename (str): The name of the file to check.
  143. Returns:
  144. bool: True if the file extension is allowed, False otherwise.
  145. """
  146. if not RESTRICT_FILE_TYPES:
  147. return True
  148. file_ext = os.path.splitext(filename)[1].lower() # Convert to lowercase
  149. return (
  150. '.*' in ALLOWED_EXTENSIONS
  151. or file_ext in (ext.lower() for ext in ALLOWED_EXTENSIONS)
  152. or (file_ext == '' and '.' in ALLOWED_EXTENSIONS)
  153. )
  154. @app.middleware('http')
  155. async def attach_session(request: Request, call_next):
  156. """Middleware to attach session information to the request.
  157. This middleware checks for the Authorization header, validates the token,
  158. and attaches the corresponding session to the request state.
  159. Args:
  160. request (Request): The incoming request object.
  161. call_next (Callable): The next middleware or route handler in the chain.
  162. Returns:
  163. Response: The response from the next middleware or route handler.
  164. """
  165. non_authed_paths = [
  166. '/api/options/',
  167. '/api/github/callback',
  168. '/api/authenticate',
  169. ]
  170. if any(
  171. request.url.path.startswith(path) for path in non_authed_paths
  172. ) or not request.url.path.startswith('/api/'):
  173. response = await call_next(request)
  174. return response
  175. # Bypass authentication for OPTIONS requests (preflight)
  176. if request.method == 'OPTIONS':
  177. response = await call_next(request)
  178. return response
  179. # For all other methods, validate the Authorization header
  180. if not request.headers.get('Authorization'):
  181. logger.warning('Missing Authorization header')
  182. return JSONResponse(
  183. status_code=status.HTTP_401_UNAUTHORIZED,
  184. content={'error': 'Missing Authorization header'},
  185. )
  186. auth_token = request.headers.get('Authorization')
  187. if 'Bearer' in auth_token:
  188. auth_token = auth_token.split('Bearer')[1].strip()
  189. request.state.sid = get_sid_from_token(auth_token, config.jwt_secret)
  190. if request.state.sid == '':
  191. logger.warning('Invalid token')
  192. return JSONResponse(
  193. status_code=status.HTTP_401_UNAUTHORIZED,
  194. content={'error': 'Invalid token'},
  195. )
  196. request.state.conversation = await session_manager.attach_to_conversation(
  197. request.state.sid
  198. )
  199. if request.state.conversation is None:
  200. return JSONResponse(
  201. status_code=status.HTTP_404_NOT_FOUND,
  202. content={'error': 'Session not found'},
  203. )
  204. response = await call_next(request)
  205. return response
  206. @app.websocket('/ws')
  207. async def websocket_endpoint(websocket: WebSocket):
  208. """WebSocket endpoint for receiving events from the client (i.e., the browser).
  209. Once connected, the client can send various actions:
  210. - Initialize the agent:
  211. session management, and event streaming.
  212. ```json
  213. {"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
  214. Args:
  215. ```
  216. websocket (WebSocket): The WebSocket connection object.
  217. - Start a new development task:
  218. ```json
  219. {"action": "start", "args": {"task": "write a bash script that prints hello"}}
  220. ```
  221. - Send a message:
  222. ```json
  223. {"action": "message", "args": {"content": "Hello, how are you?", "images_urls": ["base64_url1", "base64_url2"]}}
  224. ```
  225. - Write contents to a file:
  226. ```json
  227. {"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
  228. ```
  229. - Read the contents of a file:
  230. ```json
  231. {"action": "read", "args": {"path": "./greetings.txt"}}
  232. ```
  233. - Run a command:
  234. ```json
  235. {"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
  236. ```
  237. - Run an IPython command:
  238. ```json
  239. {"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
  240. ```
  241. - Open a web page:
  242. ```json
  243. {"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
  244. ```
  245. - Add a task to the root_task:
  246. ```json
  247. {"action": "add_task", "args": {"task": "Implement feature X"}}
  248. ```
  249. - Update a task in the root_task:
  250. ```json
  251. {"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
  252. ```
  253. - Change the agent's state:
  254. ```json
  255. {"action": "change_agent_state", "args": {"state": "paused"}}
  256. ```
  257. - Finish the task:
  258. ```json
  259. {"action": "finish", "args": {}}
  260. ```
  261. """
  262. await asyncio.wait_for(websocket.accept(), 10)
  263. if websocket.query_params.get('token'):
  264. token = websocket.query_params.get('token')
  265. sid = get_sid_from_token(token, config.jwt_secret)
  266. if sid == '':
  267. await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
  268. await websocket.close()
  269. return
  270. else:
  271. sid = str(uuid.uuid4())
  272. token = sign_token({'sid': sid}, config.jwt_secret)
  273. session = session_manager.add_or_restart_session(sid, websocket)
  274. await websocket.send_json({'token': token, 'status': 'ok'})
  275. latest_event_id = -1
  276. if websocket.query_params.get('latest_event_id'):
  277. latest_event_id = int(websocket.query_params.get('latest_event_id'))
  278. for event in session.agent_session.event_stream.get_events(
  279. start_id=latest_event_id + 1
  280. ):
  281. if isinstance(
  282. event,
  283. (
  284. NullAction,
  285. NullObservation,
  286. ChangeAgentStateAction,
  287. AgentStateChangedObservation,
  288. ),
  289. ):
  290. continue
  291. await websocket.send_json(event_to_dict(event))
  292. await session.loop_recv()
  293. @app.get('/api/options/models')
  294. async def get_litellm_models() -> list[str]:
  295. """
  296. Get all models supported by LiteLLM.
  297. This function combines models from litellm and Bedrock, removing any
  298. error-prone Bedrock models.
  299. To get the models:
  300. ```sh
  301. curl http://localhost:3000/api/litellm-models
  302. ```
  303. Returns:
  304. list: A sorted list of unique model names.
  305. """
  306. litellm_model_list = litellm.model_list + list(litellm.model_cost.keys())
  307. litellm_model_list_without_bedrock = bedrock.remove_error_modelId(
  308. litellm_model_list
  309. )
  310. # TODO: for bedrock, this is using the default config
  311. llm_config: LLMConfig = config.get_llm_config()
  312. bedrock_model_list = []
  313. if (
  314. llm_config.aws_region_name
  315. and llm_config.aws_access_key_id
  316. and llm_config.aws_secret_access_key
  317. ):
  318. bedrock_model_list = bedrock.list_foundation_models(
  319. llm_config.aws_region_name,
  320. llm_config.aws_access_key_id,
  321. llm_config.aws_secret_access_key,
  322. )
  323. model_list = litellm_model_list_without_bedrock + bedrock_model_list
  324. for llm_config in config.llms.values():
  325. ollama_base_url = llm_config.ollama_base_url
  326. if llm_config.model.startswith('ollama'):
  327. if not ollama_base_url:
  328. ollama_base_url = llm_config.base_url
  329. if ollama_base_url:
  330. ollama_url = ollama_base_url.strip('/') + '/api/tags'
  331. try:
  332. ollama_models_list = requests.get(ollama_url, timeout=3).json()[
  333. 'models'
  334. ]
  335. for model in ollama_models_list:
  336. model_list.append('ollama/' + model['name'])
  337. break
  338. except requests.exceptions.RequestException as e:
  339. logger.error(f'Error getting OLLAMA models: {e}', exc_info=True)
  340. return list(sorted(set(model_list)))
  341. @app.get('/api/options/agents')
  342. async def get_agents():
  343. """Get all agents supported by LiteLLM.
  344. To get the agents:
  345. ```sh
  346. curl http://localhost:3000/api/agents
  347. ```
  348. Returns:
  349. list: A sorted list of agent names.
  350. """
  351. agents = sorted(Agent.list_agents())
  352. return agents
  353. @app.get('/api/options/security-analyzers')
  354. async def get_security_analyzers():
  355. """Get all supported security analyzers.
  356. To get the security analyzers:
  357. ```sh
  358. curl http://localhost:3000/api/security-analyzers
  359. ```
  360. Returns:
  361. list: A sorted list of security analyzer names.
  362. """
  363. return sorted(SecurityAnalyzers.keys())
  364. FILES_TO_IGNORE = [
  365. '.git/',
  366. '.DS_Store',
  367. 'node_modules/',
  368. '__pycache__/',
  369. ]
  370. @app.get('/api/list-files')
  371. async def list_files(request: Request, path: str | None = None):
  372. """List files in the specified path.
  373. This function retrieves a list of files from the agent's runtime file store,
  374. excluding certain system and hidden files/directories.
  375. To list files:
  376. ```sh
  377. curl http://localhost:3000/api/list-files
  378. ```
  379. Args:
  380. request (Request): The incoming request object.
  381. path (str, optional): The path to list files from. Defaults to None.
  382. Returns:
  383. list: A list of file names in the specified path.
  384. Raises:
  385. HTTPException: If there's an error listing the files.
  386. """
  387. if not request.state.conversation.runtime:
  388. return JSONResponse(
  389. status_code=status.HTTP_404_NOT_FOUND,
  390. content={'error': 'Runtime not yet initialized'},
  391. )
  392. runtime: Runtime = request.state.conversation.runtime
  393. file_list = await asyncio.create_task(
  394. call_sync_from_async(runtime.list_files, path)
  395. )
  396. if path:
  397. file_list = [os.path.join(path, f) for f in file_list]
  398. file_list = [f for f in file_list if f not in FILES_TO_IGNORE]
  399. def filter_for_gitignore(file_list, base_path):
  400. gitignore_path = os.path.join(base_path, '.gitignore')
  401. try:
  402. read_action = FileReadAction(gitignore_path)
  403. observation = runtime.run_action(read_action)
  404. spec = PathSpec.from_lines(
  405. GitWildMatchPattern, observation.content.splitlines()
  406. )
  407. except Exception as e:
  408. print(e)
  409. return file_list
  410. file_list = [entry for entry in file_list if not spec.match_file(entry)]
  411. return file_list
  412. file_list = filter_for_gitignore(file_list, '')
  413. return file_list
  414. @app.get('/api/select-file')
  415. async def select_file(file: str, request: Request):
  416. """Retrieve the content of a specified file.
  417. To select a file:
  418. ```sh
  419. curl http://localhost:3000/api/select-file?file=<file_path>
  420. ```
  421. Args:
  422. file (str): The path of the file to be retrieved.
  423. Expect path to be absolute inside the runtime.
  424. request (Request): The incoming request object.
  425. Returns:
  426. dict: A dictionary containing the file content.
  427. Raises:
  428. HTTPException: If there's an error opening the file.
  429. """
  430. runtime: Runtime = request.state.conversation.runtime
  431. file = os.path.join(runtime.config.workspace_mount_path_in_sandbox, file)
  432. read_action = FileReadAction(file)
  433. observation = await call_sync_from_async(runtime.run_action, read_action)
  434. if isinstance(observation, FileReadObservation):
  435. content = observation.content
  436. return {'code': content}
  437. elif isinstance(observation, ErrorObservation):
  438. logger.error(f'Error opening file {file}: {observation}', exc_info=False)
  439. return JSONResponse(
  440. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  441. content={'error': f'Error opening file: {observation}'},
  442. )
  443. def sanitize_filename(filename):
  444. """Sanitize the filename to prevent directory traversal"""
  445. # Remove any directory components
  446. filename = os.path.basename(filename)
  447. # Remove any non-alphanumeric characters except for .-_
  448. filename = re.sub(r'[^\w\-_\.]', '', filename)
  449. # Limit the filename length
  450. max_length = 255
  451. if len(filename) > max_length:
  452. name, ext = os.path.splitext(filename)
  453. filename = name[: max_length - len(ext)] + ext
  454. return filename
  455. @app.post('/api/upload-files')
  456. async def upload_file(request: Request, files: list[UploadFile]):
  457. """Upload a list of files to the workspace.
  458. To upload a files:
  459. ```sh
  460. curl -X POST -F "file=@<file_path1>" -F "file=@<file_path2>" http://localhost:3000/api/upload-files
  461. ```
  462. Args:
  463. request (Request): The incoming request object.
  464. files (list[UploadFile]): A list of files to be uploaded.
  465. Returns:
  466. dict: A message indicating the success of the upload operation.
  467. Raises:
  468. HTTPException: If there's an error saving the files.
  469. """
  470. try:
  471. uploaded_files = []
  472. skipped_files = []
  473. for file in files:
  474. safe_filename = sanitize_filename(file.filename)
  475. file_contents = await file.read()
  476. if (
  477. MAX_FILE_SIZE_MB > 0
  478. and len(file_contents) > MAX_FILE_SIZE_MB * 1024 * 1024
  479. ):
  480. skipped_files.append(
  481. {
  482. 'name': safe_filename,
  483. 'reason': f'Exceeds maximum size limit of {MAX_FILE_SIZE_MB}MB',
  484. }
  485. )
  486. continue
  487. if not is_extension_allowed(safe_filename):
  488. skipped_files.append(
  489. {'name': safe_filename, 'reason': 'File type not allowed'}
  490. )
  491. continue
  492. # copy the file to the runtime
  493. with tempfile.TemporaryDirectory() as tmp_dir:
  494. tmp_file_path = os.path.join(tmp_dir, safe_filename)
  495. with open(tmp_file_path, 'wb') as tmp_file:
  496. tmp_file.write(file_contents)
  497. tmp_file.flush()
  498. runtime: Runtime = request.state.conversation.runtime
  499. runtime.copy_to(
  500. tmp_file_path, runtime.config.workspace_mount_path_in_sandbox
  501. )
  502. uploaded_files.append(safe_filename)
  503. response_content = {
  504. 'message': 'File upload process completed',
  505. 'uploaded_files': uploaded_files,
  506. 'skipped_files': skipped_files,
  507. }
  508. if not uploaded_files and skipped_files:
  509. return JSONResponse(
  510. status_code=status.HTTP_400_BAD_REQUEST,
  511. content={
  512. **response_content,
  513. 'error': 'No files were uploaded successfully',
  514. },
  515. )
  516. return JSONResponse(status_code=status.HTTP_200_OK, content=response_content)
  517. except Exception as e:
  518. logger.error(f'Error during file upload: {e}', exc_info=True)
  519. return JSONResponse(
  520. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  521. content={
  522. 'error': f'Error during file upload: {str(e)}',
  523. 'uploaded_files': [],
  524. 'skipped_files': [],
  525. },
  526. )
  527. @app.post('/api/submit-feedback')
  528. async def submit_feedback(request: Request, feedback: FeedbackDataModel):
  529. """Submit user feedback.
  530. This function stores the provided feedback data.
  531. To submit feedback:
  532. ```sh
  533. curl -X POST -F "email=test@example.com" -F "token=abc" -F "feedback=positive" -F "permissions=private" -F "trajectory={}" http://localhost:3000/api/submit-feedback
  534. ```
  535. Args:
  536. request (Request): The incoming request object.
  537. feedback (FeedbackDataModel): The feedback data to be stored.
  538. Returns:
  539. dict: The stored feedback data.
  540. Raises:
  541. HTTPException: If there's an error submitting the feedback.
  542. """
  543. # Assuming the storage service is already configured in the backend
  544. # and there is a function to handle the storage.
  545. try:
  546. feedback_data = store_feedback(feedback)
  547. return JSONResponse(status_code=200, content=feedback_data)
  548. except Exception as e:
  549. logger.error(f'Error submitting feedback: {e}')
  550. return JSONResponse(
  551. status_code=500, content={'error': 'Failed to submit feedback'}
  552. )
  553. @app.get('/api/defaults')
  554. async def appconfig_defaults():
  555. """Retrieve the default configuration settings.
  556. To get the default configurations:
  557. ```sh
  558. curl http://localhost:3000/api/defaults
  559. ```
  560. Returns:
  561. dict: The default configuration settings.
  562. """
  563. return config.defaults_dict
  564. @app.post('/api/save-file')
  565. async def save_file(request: Request):
  566. """Save a file to the agent's runtime file store.
  567. This endpoint allows saving a file when the agent is in a paused, finished,
  568. or awaiting user input state. It checks the agent's state before proceeding
  569. with the file save operation.
  570. Args:
  571. request (Request): The incoming FastAPI request object.
  572. Returns:
  573. JSONResponse: A JSON response indicating the success of the operation.
  574. Raises:
  575. HTTPException:
  576. - 403 error if the agent is not in an allowed state for editing.
  577. - 400 error if the file path or content is missing.
  578. - 500 error if there's an unexpected error during the save operation.
  579. """
  580. try:
  581. # Extract file path and content from the request
  582. data = await request.json()
  583. file_path = data.get('filePath')
  584. content = data.get('content')
  585. # Validate the presence of required data
  586. if not file_path or content is None:
  587. raise HTTPException(status_code=400, detail='Missing filePath or content')
  588. # Save the file to the agent's runtime file store
  589. runtime: Runtime = request.state.conversation.runtime
  590. file_path = os.path.join(
  591. runtime.config.workspace_mount_path_in_sandbox, file_path
  592. )
  593. write_action = FileWriteAction(file_path, content)
  594. observation = await call_sync_from_async(runtime.run_action, write_action)
  595. if isinstance(observation, FileWriteObservation):
  596. return JSONResponse(
  597. status_code=200, content={'message': 'File saved successfully'}
  598. )
  599. elif isinstance(observation, ErrorObservation):
  600. return JSONResponse(
  601. status_code=500,
  602. content={'error': f'Failed to save file: {observation}'},
  603. )
  604. else:
  605. return JSONResponse(
  606. status_code=500,
  607. content={'error': f'Unexpected observation: {observation}'},
  608. )
  609. except Exception as e:
  610. # Log the error and return a 500 response
  611. logger.error(f'Error saving file: {e}', exc_info=True)
  612. raise HTTPException(status_code=500, detail=f'Error saving file: {e}')
  613. @app.route('/api/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
  614. async def security_api(request: Request):
  615. """Catch-all route for security analyzer API requests.
  616. Each request is handled directly to the security analyzer.
  617. Args:
  618. request (Request): The incoming FastAPI request object.
  619. Returns:
  620. Any: The response from the security analyzer.
  621. Raises:
  622. HTTPException: If the security analyzer is not initialized.
  623. """
  624. if not request.state.conversation.security_analyzer:
  625. raise HTTPException(status_code=404, detail='Security analyzer not initialized')
  626. return await request.state.conversation.security_analyzer.handle_api_request(
  627. request
  628. )
  629. @app.get('/api/zip-directory')
  630. async def zip_current_workspace(request: Request):
  631. try:
  632. logger.info('Zipping workspace')
  633. runtime: Runtime = request.state.conversation.runtime
  634. path = runtime.config.workspace_mount_path_in_sandbox
  635. zip_file_bytes = await call_sync_from_async(runtime.copy_from, path)
  636. zip_stream = io.BytesIO(zip_file_bytes) # Wrap to behave like a file stream
  637. response = StreamingResponse(
  638. zip_stream,
  639. media_type='application/x-zip-compressed',
  640. headers={'Content-Disposition': 'attachment; filename=workspace.zip'},
  641. )
  642. return response
  643. except Exception as e:
  644. logger.error(f'Error zipping workspace: {e}', exc_info=True)
  645. raise HTTPException(
  646. status_code=500,
  647. detail='Failed to zip workspace',
  648. )
  649. class AuthCode(BaseModel):
  650. code: str
  651. @app.post('/api/github/callback')
  652. def github_callback(auth_code: AuthCode):
  653. # Prepare data for the token exchange request
  654. data = {
  655. 'client_id': GITHUB_CLIENT_ID,
  656. 'client_secret': GITHUB_CLIENT_SECRET,
  657. 'code': auth_code.code,
  658. }
  659. logger.info('Exchanging code for GitHub token')
  660. headers = {'Accept': 'application/json'}
  661. response = requests.post(
  662. 'https://github.com/login/oauth/access_token', data=data, headers=headers
  663. )
  664. if response.status_code != 200:
  665. logger.error(f'Failed to exchange code for token: {response.text}')
  666. return JSONResponse(
  667. status_code=status.HTTP_400_BAD_REQUEST,
  668. content={'error': 'Failed to exchange code for token'},
  669. )
  670. token_response = response.json()
  671. if 'access_token' not in token_response:
  672. return JSONResponse(
  673. status_code=status.HTTP_400_BAD_REQUEST,
  674. content={'error': 'No access token in response'},
  675. )
  676. return JSONResponse(
  677. status_code=status.HTTP_200_OK,
  678. content={'access_token': token_response['access_token']},
  679. )
  680. class User(BaseModel):
  681. login: str # GitHub login handle
  682. @app.post('/api/authenticate')
  683. def authenticate(user: User | None = None):
  684. waitlist = os.getenv('GITHUB_USER_LIST_FILE')
  685. # Only check if waitlist is provided
  686. if waitlist is not None:
  687. try:
  688. with open(waitlist, 'r') as f:
  689. users = f.read().splitlines()
  690. if user is None or user.login not in users:
  691. return JSONResponse(
  692. status_code=status.HTTP_403_FORBIDDEN,
  693. content={'error': 'User not on waitlist'},
  694. )
  695. except FileNotFoundError:
  696. return JSONResponse(
  697. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  698. content={'error': 'Waitlist file not found'},
  699. )
  700. return JSONResponse(
  701. status_code=status.HTTP_200_OK, content={'message': 'User authenticated'}
  702. )
  703. class SPAStaticFiles(StaticFiles):
  704. async def get_response(self, path: str, scope):
  705. try:
  706. return await super().get_response(path, scope)
  707. except Exception:
  708. # FIXME: just making this HTTPException doesn't work for some reason
  709. return await super().get_response('index.html', scope)
  710. app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')