socket.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from fastapi import status
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.core.schema.action import ActionType
  4. from openhands.events.action import (
  5. NullAction,
  6. )
  7. from openhands.events.observation import (
  8. NullObservation,
  9. )
  10. from openhands.events.serialization import event_to_dict
  11. from openhands.events.stream import AsyncEventStreamWrapper
  12. from openhands.server.auth import get_sid_from_token, sign_token
  13. from openhands.server.github_utils import authenticate_github_user
  14. from openhands.server.shared import config, session_manager, sio
  15. @sio.event
  16. async def connect(connection_id: str, environ):
  17. logger.info(f'sio:connect: {connection_id}')
  18. @sio.event
  19. async def oh_action(connection_id: str, data: dict):
  20. # If it's an init, we do it here.
  21. action = data.get('action', '')
  22. if action == ActionType.INIT:
  23. await init_connection(connection_id, data)
  24. return
  25. logger.info(f'sio:oh_action:{connection_id}')
  26. await session_manager.send_to_event_stream(connection_id, data)
  27. async def init_connection(connection_id: str, data: dict):
  28. gh_token = data.pop('github_token', None)
  29. if not await authenticate_github_user(gh_token):
  30. raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
  31. token = data.pop('token', None)
  32. if token:
  33. sid = get_sid_from_token(token, config.jwt_secret)
  34. if sid == '':
  35. await sio.send({'error': 'Invalid token', 'error_code': 401})
  36. return
  37. logger.info(f'Existing session: {sid}')
  38. else:
  39. sid = connection_id
  40. logger.info(f'New session: {sid}')
  41. token = sign_token({'sid': sid}, config.jwt_secret)
  42. await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
  43. latest_event_id = int(data.pop('latest_event_id', -1))
  44. # The session in question should exist, but may not actually be running locally...
  45. event_stream = await session_manager.init_or_join_session(sid, connection_id, data)
  46. # Send events
  47. async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
  48. async for event in async_stream:
  49. if isinstance(
  50. event,
  51. (
  52. NullAction,
  53. NullObservation,
  54. ),
  55. ):
  56. continue
  57. await sio.emit('oh_event', event_to_dict(event), to=connection_id)
  58. @sio.event
  59. async def disconnect(connection_id: str):
  60. logger.info(f'sio:disconnect:{connection_id}')
  61. await session_manager.disconnect_from_session(connection_id)