test_event_stream.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import json
  2. import pytest
  3. from pytest import TempPathFactory
  4. from openhands.events import EventSource, EventStream
  5. from openhands.events.action import (
  6. NullAction,
  7. )
  8. from openhands.events.observation import NullObservation
  9. from openhands.storage import get_file_store
  10. @pytest.fixture
  11. def temp_dir(tmp_path_factory: TempPathFactory) -> str:
  12. return str(tmp_path_factory.mktemp('test_event_stream'))
  13. def collect_events(stream):
  14. return [event for event in stream.get_events()]
  15. def test_basic_flow(temp_dir: str):
  16. file_store = get_file_store('local', temp_dir)
  17. event_stream = EventStream('abc', file_store)
  18. event_stream.add_event(NullAction(), EventSource.AGENT)
  19. assert len(collect_events(event_stream)) == 1
  20. def test_stream_storage(temp_dir: str):
  21. file_store = get_file_store('local', temp_dir)
  22. event_stream = EventStream('abc', file_store)
  23. event_stream.add_event(NullObservation(''), EventSource.AGENT)
  24. assert len(collect_events(event_stream)) == 1
  25. content = event_stream.file_store.read('sessions/abc/events/0.json')
  26. assert content is not None
  27. data = json.loads(content)
  28. assert 'timestamp' in data
  29. del data['timestamp']
  30. assert data == {
  31. 'id': 0,
  32. 'source': 'agent',
  33. 'observation': 'null',
  34. 'content': '',
  35. 'extras': {},
  36. 'message': 'No observation',
  37. }
  38. def test_rehydration(temp_dir: str):
  39. file_store = get_file_store('local', temp_dir)
  40. event_stream = EventStream('abc', file_store)
  41. event_stream.add_event(NullObservation('obs1'), EventSource.AGENT)
  42. event_stream.add_event(NullObservation('obs2'), EventSource.AGENT)
  43. assert len(collect_events(event_stream)) == 2
  44. stream2 = EventStream('es2', file_store)
  45. assert len(collect_events(stream2)) == 0
  46. stream1rehydrated = EventStream('abc', file_store)
  47. events = collect_events(stream1rehydrated)
  48. assert len(events) == 2
  49. assert events[0].content == 'obs1'
  50. assert events[1].content == 'obs2'
  51. def test_get_matching_events_type_filter(temp_dir: str):
  52. file_store = get_file_store('local', temp_dir)
  53. event_stream = EventStream('abc', file_store)
  54. # Add mixed event types
  55. event_stream.add_event(NullAction(), EventSource.AGENT)
  56. event_stream.add_event(NullObservation('test'), EventSource.AGENT)
  57. event_stream.add_event(NullAction(), EventSource.AGENT)
  58. # Filter by NullAction
  59. events = event_stream.get_matching_events(event_type='NullAction')
  60. assert len(events) == 2
  61. assert all(e['action'] == 'null' for e in events)
  62. # Filter by NullObservation
  63. events = event_stream.get_matching_events(event_type='NullObservation')
  64. assert len(events) == 1
  65. assert events[0]['observation'] == 'null'
  66. def test_get_matching_events_query_search(temp_dir: str):
  67. file_store = get_file_store('local', temp_dir)
  68. event_stream = EventStream('abc', file_store)
  69. event_stream.add_event(NullObservation('hello world'), EventSource.AGENT)
  70. event_stream.add_event(NullObservation('test message'), EventSource.AGENT)
  71. event_stream.add_event(NullObservation('another hello'), EventSource.AGENT)
  72. # Search for 'hello'
  73. events = event_stream.get_matching_events(query='hello')
  74. assert len(events) == 2
  75. # Search should be case-insensitive
  76. events = event_stream.get_matching_events(query='HELLO')
  77. assert len(events) == 2
  78. # Search for non-existent text
  79. events = event_stream.get_matching_events(query='nonexistent')
  80. assert len(events) == 0
  81. def test_get_matching_events_source_filter(temp_dir: str):
  82. file_store = get_file_store('local', temp_dir)
  83. event_stream = EventStream('abc', file_store)
  84. event_stream.add_event(NullObservation('test1'), EventSource.AGENT)
  85. event_stream.add_event(NullObservation('test2'), EventSource.ENVIRONMENT)
  86. event_stream.add_event(NullObservation('test3'), EventSource.AGENT)
  87. # Filter by AGENT source
  88. events = event_stream.get_matching_events(source='agent')
  89. assert len(events) == 2
  90. assert all(e['source'] == 'agent' for e in events)
  91. # Filter by ENVIRONMENT source
  92. events = event_stream.get_matching_events(source='environment')
  93. assert len(events) == 1
  94. assert events[0]['source'] == 'environment'
  95. def test_get_matching_events_pagination(temp_dir: str):
  96. file_store = get_file_store('local', temp_dir)
  97. event_stream = EventStream('abc', file_store)
  98. # Add 5 events
  99. for i in range(5):
  100. event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
  101. # Test limit
  102. events = event_stream.get_matching_events(limit=3)
  103. assert len(events) == 3
  104. # Test start_id
  105. events = event_stream.get_matching_events(start_id=2)
  106. assert len(events) == 3
  107. assert events[0]['content'] == 'test2'
  108. # Test combination of start_id and limit
  109. events = event_stream.get_matching_events(start_id=1, limit=2)
  110. assert len(events) == 2
  111. assert events[0]['content'] == 'test1'
  112. assert events[1]['content'] == 'test2'
  113. def test_get_matching_events_limit_validation(temp_dir: str):
  114. file_store = get_file_store('local', temp_dir)
  115. event_stream = EventStream('abc', file_store)
  116. # Test limit less than 1
  117. with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
  118. event_stream.get_matching_events(limit=0)
  119. # Test limit greater than 100
  120. with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
  121. event_stream.get_matching_events(limit=101)
  122. # Test valid limits work
  123. event_stream.add_event(NullObservation('test'), EventSource.AGENT)
  124. events = event_stream.get_matching_events(limit=1)
  125. assert len(events) == 1
  126. events = event_stream.get_matching_events(limit=100)
  127. assert len(events) == 1