浏览代码

Add event search endpoint with pagination and filtering (#4688)

Co-authored-by: AI Assistant <assistant@example.com>
tofarr 1 年之前
父节点
当前提交
be6ca4a3ce
共有 3 个文件被更改,包括 253 次插入0 次删除
  1. 89 0
      openhands/events/stream.py
  2. 60 0
      openhands/server/listen.py
  3. 104 0
      tests/unit/test_event_stream.py

+ 89 - 0
openhands/events/stream.py

@@ -211,6 +211,95 @@ class EventStream:
             if event.source == source:
                 yield event
 
+    def _should_filter_event(
+        self,
+        event,
+        query: str | None = None,
+        event_type: str | None = None,
+        source: str | None = None,
+        start_date: str | None = None,
+        end_date: str | None = None,
+    ) -> bool:
+        """Check if an event should be filtered out based on the given criteria.
+
+        Args:
+            event: The event to check
+            query (str, optional): Text to search for in event content
+            event_type (str, optional): Filter by event type (e.g., "FileReadAction")
+            source (str, optional): Filter by event source
+            start_date (str, optional): Filter events after this date (ISO format)
+            end_date (str, optional): Filter events before this date (ISO format)
+
+        Returns:
+            bool: True if the event should be filtered out, False if it matches all criteria
+        """
+        if event_type and not event.__class__.__name__ == event_type:
+            return True
+
+        if source and not event.source.value == source:
+            return True
+
+        if start_date and event.timestamp < start_date:
+            return True
+
+        if end_date and event.timestamp > end_date:
+            return True
+
+        # Text search in event content if query provided
+        if query:
+            event_dict = event_to_dict(event)
+            event_str = str(event_dict).lower()
+            if query.lower() not in event_str:
+                return True
+
+        return False
+
+    def get_matching_events(
+        self,
+        query: str | None = None,
+        event_type: str | None = None,
+        source: str | None = None,
+        start_date: str | None = None,
+        end_date: str | None = None,
+        start_id: int = 0,
+        limit: int = 100,
+    ) -> list:
+        """Get matching events from the event stream based on filters.
+
+        Args:
+            query (str, optional): Text to search for in event content
+            event_type (str, optional): Filter by event type (e.g., "FileReadAction")
+            source (str, optional): Filter by event source
+            start_date (str, optional): Filter events after this date (ISO format)
+            end_date (str, optional): Filter events before this date (ISO format)
+            start_id (int): Starting ID in the event stream. Defaults to 0
+            limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 100
+
+        Returns:
+            list: List of matching events (as dicts)
+
+        Raises:
+            ValueError: If limit is less than 1 or greater than 100
+        """
+        if limit < 1 or limit > 100:
+            raise ValueError('Limit must be between 1 and 100')
+
+        matching_events: list = []
+
+        for event in self.get_events(start_id=start_id):
+            if self._should_filter_event(
+                event, query, event_type, source, start_date, end_date
+            ):
+                continue
+
+            matching_events.append(event_to_dict(event))
+
+            # Stop if we have enough events
+            if len(matching_events) >= limit:
+                break
+
+        return matching_events
+
     def clear(self):
         self.file_store.delete(f'sessions/{self.sid}')
         self._cur_id = 0

+ 60 - 0
openhands/server/listen.py

@@ -279,6 +279,66 @@ async def attach_session(request: Request, call_next):
     return response
 
 
+@app.get('/api/events/search')
+async def search_events(
+    request: Request,
+    query: str | None = None,
+    start_id: int = 0,
+    limit: int = 20,
+    event_type: str | None = None,
+    source: str | None = None,
+    start_date: str | None = None,
+    end_date: str | None = None,
+):
+    """Search through the event stream with filtering and pagination.
+
+    Args:
+        request (Request): The incoming request object
+        query (str, optional): Text to search for in event content
+        start_id (int): Starting ID in the event stream. Defaults to 0
+        limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 20
+        event_type (str, optional): Filter by event type (e.g., "FileReadAction")
+        source (str, optional): Filter by event source
+        start_date (str, optional): Filter events after this date (ISO format)
+        end_date (str, optional): Filter events before this date (ISO format)
+
+    Returns:
+        dict: Dictionary containing:
+            - events: List of matching events
+            - has_more: Whether there are more matching events after this batch
+
+    Raises:
+        HTTPException: If conversation is not found
+        ValueError: If limit is less than 1 or greater than 100
+    """
+    if not request.state.conversation:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail='Conversation not found'
+        )
+
+    # Get matching events from the stream
+    event_stream = request.state.conversation.event_stream
+    matching_events = event_stream.get_matching_events(
+        query=query,
+        event_type=event_type,
+        source=source,
+        start_date=start_date,
+        end_date=end_date,
+        start_id=start_id,
+        limit=limit + 1,  # Get one extra to check if there are more
+    )
+
+    # Check if there are more events
+    has_more = len(matching_events) > limit
+    if has_more:
+        matching_events = matching_events[:limit]  # Remove the extra event
+
+    return {
+        'events': matching_events,
+        'has_more': has_more,
+    }
+
+
 @app.get('/api/options/models')
 async def get_litellm_models() -> list[str]:
     """

+ 104 - 0
tests/unit/test_event_stream.py

@@ -62,3 +62,107 @@ def test_rehydration(temp_dir: str):
     assert len(events) == 2
     assert events[0].content == 'obs1'
     assert events[1].content == 'obs2'
+
+
+def test_get_matching_events_type_filter(temp_dir: str):
+    file_store = get_file_store('local', temp_dir)
+    event_stream = EventStream('abc', file_store)
+
+    # Add mixed event types
+    event_stream.add_event(NullAction(), EventSource.AGENT)
+    event_stream.add_event(NullObservation('test'), EventSource.AGENT)
+    event_stream.add_event(NullAction(), EventSource.AGENT)
+
+    # Filter by NullAction
+    events = event_stream.get_matching_events(event_type='NullAction')
+    assert len(events) == 2
+    assert all(e['action'] == 'null' for e in events)
+
+    # Filter by NullObservation
+    events = event_stream.get_matching_events(event_type='NullObservation')
+    assert len(events) == 1
+    assert events[0]['observation'] == 'null'
+
+
+def test_get_matching_events_query_search(temp_dir: str):
+    file_store = get_file_store('local', temp_dir)
+    event_stream = EventStream('abc', file_store)
+
+    event_stream.add_event(NullObservation('hello world'), EventSource.AGENT)
+    event_stream.add_event(NullObservation('test message'), EventSource.AGENT)
+    event_stream.add_event(NullObservation('another hello'), EventSource.AGENT)
+
+    # Search for 'hello'
+    events = event_stream.get_matching_events(query='hello')
+    assert len(events) == 2
+
+    # Search should be case-insensitive
+    events = event_stream.get_matching_events(query='HELLO')
+    assert len(events) == 2
+
+    # Search for non-existent text
+    events = event_stream.get_matching_events(query='nonexistent')
+    assert len(events) == 0
+
+
+def test_get_matching_events_source_filter(temp_dir: str):
+    file_store = get_file_store('local', temp_dir)
+    event_stream = EventStream('abc', file_store)
+
+    event_stream.add_event(NullObservation('test1'), EventSource.AGENT)
+    event_stream.add_event(NullObservation('test2'), EventSource.ENVIRONMENT)
+    event_stream.add_event(NullObservation('test3'), EventSource.AGENT)
+
+    # Filter by AGENT source
+    events = event_stream.get_matching_events(source='agent')
+    assert len(events) == 2
+    assert all(e['source'] == 'agent' for e in events)
+
+    # Filter by ENVIRONMENT source
+    events = event_stream.get_matching_events(source='environment')
+    assert len(events) == 1
+    assert events[0]['source'] == 'environment'
+
+
+def test_get_matching_events_pagination(temp_dir: str):
+    file_store = get_file_store('local', temp_dir)
+    event_stream = EventStream('abc', file_store)
+
+    # Add 5 events
+    for i in range(5):
+        event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
+
+    # Test limit
+    events = event_stream.get_matching_events(limit=3)
+    assert len(events) == 3
+
+    # Test start_id
+    events = event_stream.get_matching_events(start_id=2)
+    assert len(events) == 3
+    assert events[0]['content'] == 'test2'
+
+    # Test combination of start_id and limit
+    events = event_stream.get_matching_events(start_id=1, limit=2)
+    assert len(events) == 2
+    assert events[0]['content'] == 'test1'
+    assert events[1]['content'] == 'test2'
+
+
+def test_get_matching_events_limit_validation(temp_dir: str):
+    file_store = get_file_store('local', temp_dir)
+    event_stream = EventStream('abc', file_store)
+
+    # Test limit less than 1
+    with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
+        event_stream.get_matching_events(limit=0)
+
+    # Test limit greater than 100
+    with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
+        event_stream.get_matching_events(limit=101)
+
+    # Test valid limits work
+    event_stream.add_event(NullObservation('test'), EventSource.AGENT)
+    events = event_stream.get_matching_events(limit=1)
+    assert len(events) == 1
+    events = event_stream.get_matching_events(limit=100)
+    assert len(events) == 1