listen_socket.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from fastapi import status
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.core.schema.action import ActionType
  4. from openhands.events.action import (
  5. NullAction,
  6. )
  7. from openhands.events.observation import (
  8. NullObservation,
  9. )
  10. from openhands.events.observation.agent import AgentStateChangedObservation
  11. from openhands.events.serialization import event_to_dict
  12. from openhands.events.stream import AsyncEventStreamWrapper
  13. from openhands.server.auth import get_sid_from_token, sign_token
  14. from openhands.server.github_utils import authenticate_github_user
  15. from openhands.server.session.session_init_data import SessionInitData
  16. from openhands.server.shared import config, session_manager, sio
  17. @sio.event
  18. async def connect(connection_id: str, environ):
  19. logger.info(f'sio:connect: {connection_id}')
  20. @sio.event
  21. async def oh_action(connection_id: str, data: dict):
  22. # If it's an init, we do it here.
  23. action = data.get('action', '')
  24. if action == ActionType.INIT:
  25. token = data.pop('token', None)
  26. github_token = data.pop('github_token', None)
  27. latest_event_id = int(data.pop('latest_event_id', -1))
  28. kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
  29. session_init_data = SessionInitData(**kwargs)
  30. session_init_data.github_token = github_token
  31. session_init_data.selected_repository = data.get('selected_repository', None)
  32. await init_connection(
  33. connection_id, token, github_token, session_init_data, latest_event_id
  34. )
  35. return
  36. logger.info(f'sio:oh_action:{connection_id}')
  37. await session_manager.send_to_event_stream(connection_id, data)
  38. async def init_connection(
  39. connection_id: str,
  40. token: str | None,
  41. gh_token: str | None,
  42. session_init_data: SessionInitData,
  43. latest_event_id: int,
  44. ):
  45. if not await authenticate_github_user(gh_token):
  46. raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
  47. if token:
  48. sid = get_sid_from_token(token, config.jwt_secret)
  49. if sid == '':
  50. await sio.emit('oh_event', {'error': 'Invalid token', 'error_code': 401})
  51. return
  52. logger.info(f'Existing session: {sid}')
  53. else:
  54. sid = connection_id
  55. logger.info(f'New session: {sid}')
  56. token = sign_token({'sid': sid}, config.jwt_secret)
  57. await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
  58. # The session in question should exist, but may not actually be running locally...
  59. event_stream = await session_manager.init_or_join_session(
  60. sid, connection_id, session_init_data
  61. )
  62. # Send events
  63. agent_state_changed = None
  64. async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
  65. async for event in async_stream:
  66. if isinstance(
  67. event,
  68. (
  69. NullAction,
  70. NullObservation,
  71. ),
  72. ):
  73. continue
  74. elif isinstance(event, AgentStateChangedObservation):
  75. if event.agent_state == 'init':
  76. await sio.emit('oh_event', event_to_dict(event), to=connection_id)
  77. else:
  78. agent_state_changed = event
  79. continue
  80. await sio.emit('oh_event', event_to_dict(event), to=connection_id)
  81. if agent_state_changed:
  82. await sio.emit('oh_event', event_to_dict(agent_state_changed), to=connection_id)
  83. @sio.event
  84. async def disconnect(connection_id: str):
  85. logger.info(f'sio:disconnect:{connection_id}')
  86. await session_manager.disconnect_from_session(connection_id)