|
|
@@ -2,6 +2,7 @@ import asyncio
|
|
|
import json
|
|
|
import time
|
|
|
from dataclasses import dataclass, field
|
|
|
+from uuid import uuid4
|
|
|
|
|
|
import socketio
|
|
|
|
|
|
@@ -27,6 +28,14 @@ class ConversationDoesNotExistError(Exception):
|
|
|
pass
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class _SessionIsRunningCheck:
|
|
|
+ request_id: str
|
|
|
+ request_sids: list[str]
|
|
|
+ running_sids: set[str] = field(default_factory=set)
|
|
|
+ flag: asyncio.Event = field(default_factory=asyncio.Event)
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
class SessionManager:
|
|
|
sio: socketio.AsyncServer
|
|
|
@@ -36,7 +45,9 @@ class SessionManager:
|
|
|
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
|
|
|
_last_alive_timestamps: dict[str, float] = field(default_factory=dict)
|
|
|
_redis_listen_task: asyncio.Task | None = None
|
|
|
- _session_is_running_flags: dict[str, asyncio.Event] = field(default_factory=dict)
|
|
|
+ _session_is_running_checks: dict[str, _SessionIsRunningCheck] = field(
|
|
|
+ default_factory=dict
|
|
|
+ )
|
|
|
_active_conversations: dict[str, tuple[Conversation, int]] = field(
|
|
|
default_factory=dict
|
|
|
)
|
|
|
@@ -97,27 +108,41 @@ class SessionManager:
|
|
|
async def _process_message(self, message: dict):
|
|
|
data = json.loads(message['data'])
|
|
|
logger.debug(f'got_published_message:{message}')
|
|
|
- sid = data['sid']
|
|
|
message_type = data['message_type']
|
|
|
if message_type == 'event':
|
|
|
+ sid = data['sid']
|
|
|
session = self._local_agent_loops_by_sid.get(sid)
|
|
|
if session:
|
|
|
await session.dispatch(data['data'])
|
|
|
elif message_type == 'is_session_running':
|
|
|
# Another node in the cluster is asking if the current node is running the session given.
|
|
|
- session = self._local_agent_loops_by_sid.get(sid)
|
|
|
- if session:
|
|
|
+ request_id = data['request_id']
|
|
|
+ sids = [
|
|
|
+ sid for sid in data['sids'] if sid in self._local_agent_loops_by_sid
|
|
|
+ ]
|
|
|
+ if sids:
|
|
|
await self._get_redis_client().publish(
|
|
|
'oh_event',
|
|
|
- json.dumps({'sid': sid, 'message_type': 'session_is_running'}),
|
|
|
+ json.dumps(
|
|
|
+ {
|
|
|
+ 'request_id': request_id,
|
|
|
+ 'sids': sids,
|
|
|
+ 'message_type': 'session_is_running',
|
|
|
+ }
|
|
|
+ ),
|
|
|
)
|
|
|
elif message_type == 'session_is_running':
|
|
|
- self._last_alive_timestamps[sid] = time.time()
|
|
|
- flag = self._session_is_running_flags.get(sid)
|
|
|
- if flag:
|
|
|
- flag.set()
|
|
|
+ request_id = data['request_id']
|
|
|
+ for sid in data['sids']:
|
|
|
+ self._last_alive_timestamps[sid] = time.time()
|
|
|
+ check = self._session_is_running_checks.get(request_id)
|
|
|
+ if check:
|
|
|
+ check.running_sids.update(data['sids'])
|
|
|
+ if len(check.request_sids) == len(check.running_sids):
|
|
|
+ check.flag.set()
|
|
|
elif message_type == 'has_remote_connections_query':
|
|
|
# Another node in the cluster is asking if the current node is connected to a session
|
|
|
+ sid = data['sid']
|
|
|
required = sid in self.local_connection_id_to_session_id.values()
|
|
|
if required:
|
|
|
await self._get_redis_client().publish(
|
|
|
@@ -127,12 +152,14 @@ class SessionManager:
|
|
|
),
|
|
|
)
|
|
|
elif message_type == 'has_remote_connections_response':
|
|
|
+ sid = data['sid']
|
|
|
flag = self._has_remote_connections_flags.get(sid)
|
|
|
if flag:
|
|
|
flag.set()
|
|
|
elif message_type == 'session_closing':
|
|
|
# Session closing event - We only get this in the event of graceful shutdown,
|
|
|
# which can't be guaranteed - nodes can simply vanish unexpectedly!
|
|
|
+ sid = data['sid']
|
|
|
logger.debug(f'session_closing:{sid}')
|
|
|
for (
|
|
|
connection_id,
|
|
|
@@ -234,33 +261,47 @@ class SessionManager:
|
|
|
logger.warning('error_cleaning_detached_conversations', exc_info=True)
|
|
|
await asyncio.sleep(_CLEANUP_EXCEPTION_WAIT_TIME)
|
|
|
|
|
|
- async def _is_agent_loop_running(self, sid: str) -> bool:
|
|
|
- if await self._is_agent_loop_running_locally(sid):
|
|
|
+ async def get_agent_loop_running(self, sids: set[str]) -> set[str]:
|
|
|
+ running_sids = set(sid for sid in sids if sid in self._local_agent_loops_by_sid)
|
|
|
+ check_cluster_sids = [sid for sid in sids if sid not in running_sids]
|
|
|
+ running_cluster_sids = await self.get_agent_loop_running_in_cluster(
|
|
|
+ check_cluster_sids
|
|
|
+ )
|
|
|
+ running_sids.union(running_cluster_sids)
|
|
|
+ return running_sids
|
|
|
+
|
|
|
+ async def is_agent_loop_running(self, sid: str) -> bool:
|
|
|
+ if await self.is_agent_loop_running_locally(sid):
|
|
|
return True
|
|
|
- if await self._is_agent_loop_running_in_cluster(sid):
|
|
|
+ if await self.is_agent_loop_running_in_cluster(sid):
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
- async def _is_agent_loop_running_locally(self, sid: str) -> bool:
|
|
|
- if self._local_agent_loops_by_sid.get(sid, None):
|
|
|
- return True
|
|
|
- return False
|
|
|
+ async def is_agent_loop_running_locally(self, sid: str) -> bool:
|
|
|
+ return sid in self._local_agent_loops_by_sid
|
|
|
+
|
|
|
+ async def is_agent_loop_running_in_cluster(self, sid: str) -> bool:
|
|
|
+ running_sids = await self.get_agent_loop_running_in_cluster([sid])
|
|
|
+ return bool(running_sids)
|
|
|
|
|
|
- async def _is_agent_loop_running_in_cluster(self, sid: str) -> bool:
|
|
|
+ async def get_agent_loop_running_in_cluster(self, sids: list[str]) -> set[str]:
|
|
|
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
|
|
|
redis_client = self._get_redis_client()
|
|
|
if not redis_client:
|
|
|
- return False
|
|
|
+ return set()
|
|
|
|
|
|
flag = asyncio.Event()
|
|
|
- self._session_is_running_flags[sid] = flag
|
|
|
+ request_id = str(uuid4())
|
|
|
+ check = _SessionIsRunningCheck(request_id=request_id, request_sids=sids)
|
|
|
+ self._session_is_running_checks[request_id] = check
|
|
|
try:
|
|
|
- logger.debug(f'publish:is_session_running:{sid}')
|
|
|
+ logger.debug(f'publish:is_session_running:{sids}')
|
|
|
await redis_client.publish(
|
|
|
'oh_event',
|
|
|
json.dumps(
|
|
|
{
|
|
|
- 'sid': sid,
|
|
|
+ 'request_id': request_id,
|
|
|
+ 'sids': sids,
|
|
|
'message_type': 'is_session_running',
|
|
|
}
|
|
|
),
|
|
|
@@ -268,13 +309,12 @@ class SessionManager:
|
|
|
async with asyncio.timeout(_REDIS_POLL_TIMEOUT):
|
|
|
await flag.wait()
|
|
|
|
|
|
- result = flag.is_set()
|
|
|
- return result
|
|
|
+ return check.running_sids
|
|
|
except TimeoutError:
|
|
|
# Nobody replied in time
|
|
|
- return False
|
|
|
+ return check.running_sids
|
|
|
finally:
|
|
|
- self._session_is_running_flags.pop(sid, None)
|
|
|
+ self._session_is_running_checks.pop(request_id, None)
|
|
|
|
|
|
async def _has_remote_connections(self, sid: str) -> bool:
|
|
|
"""As the rest of the cluster if they still want this session running. Wait a for a short timeout for a reply"""
|
|
|
@@ -307,7 +347,7 @@ class SessionManager:
|
|
|
) -> EventStream:
|
|
|
logger.info(f'maybe_start_agent_loop:{sid}')
|
|
|
session: Session | None = None
|
|
|
- if not await self._is_agent_loop_running(sid):
|
|
|
+ if not await self.is_agent_loop_running(sid):
|
|
|
logger.info(f'start_agent_loop:{sid}')
|
|
|
session = Session(
|
|
|
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
|
|
|
@@ -328,7 +368,7 @@ class SessionManager:
|
|
|
logger.info(f'found_local_agent_loop:{sid}')
|
|
|
return session.agent_session.event_stream
|
|
|
|
|
|
- if await self._is_agent_loop_running_in_cluster(sid):
|
|
|
+ if await self.is_agent_loop_running_in_cluster(sid):
|
|
|
logger.info(f'found_remote_agent_loop:{sid}')
|
|
|
return EventStream(sid, self.file_store)
|
|
|
|
|
|
@@ -352,7 +392,7 @@ class SessionManager:
|
|
|
next_alive_check = last_alive_at + _CHECK_ALIVE_INTERVAL
|
|
|
if (
|
|
|
next_alive_check > time.time()
|
|
|
- or await self._is_agent_loop_running_in_cluster(sid)
|
|
|
+ or await self.is_agent_loop_running_in_cluster(sid)
|
|
|
):
|
|
|
# Send the event to the other pod
|
|
|
await redis_client.publish(
|