history.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from typing import ClassVar, Iterable
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.events.action.action import Action
  4. from openhands.events.action.agent import (
  5. AgentDelegateAction,
  6. ChangeAgentStateAction,
  7. )
  8. from openhands.events.action.empty import NullAction
  9. from openhands.events.action.message import MessageAction
  10. from openhands.events.event import Event, EventSource
  11. from openhands.events.observation.agent import AgentStateChangedObservation
  12. from openhands.events.observation.delegate import AgentDelegateObservation
  13. from openhands.events.observation.empty import NullObservation
  14. from openhands.events.observation.error import FatalErrorObservation
  15. from openhands.events.observation.observation import Observation
  16. from openhands.events.serialization.event import event_to_dict
  17. from openhands.events.stream import EventStream
  18. from openhands.events.utils import get_pairs_from_events
  19. class ShortTermHistory(list[Event]):
  20. """A list of events that represents the short-term memory of the agent.
  21. This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
  22. """
  23. start_id: int
  24. end_id: int
  25. _event_stream: EventStream
  26. delegates: dict[tuple[int, int], tuple[str, str]]
  27. filter_out: ClassVar[tuple[type[Event], ...]] = (
  28. NullAction,
  29. NullObservation,
  30. ChangeAgentStateAction,
  31. AgentStateChangedObservation,
  32. FatalErrorObservation,
  33. )
  34. def __init__(self):
  35. super().__init__()
  36. self.start_id = -1
  37. self.end_id = -1
  38. self.delegates = {}
  39. def set_event_stream(self, event_stream: EventStream):
  40. self._event_stream = event_stream
  41. def get_events_as_list(self, include_delegates: bool = False) -> list[Event]:
  42. """Return the history as a list of Event objects."""
  43. return list(self.get_events(include_delegates=include_delegates))
  44. def get_events(
  45. self,
  46. reverse: bool = False,
  47. include_delegates: bool = False,
  48. include_hidden=False,
  49. ) -> Iterable[Event]:
  50. """Return the events as a stream of Event objects."""
  51. # TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
  52. # or even if it is, because currently we don't add it to the summary
  53. # iterate from start_id to end_id, or reverse
  54. start_id = self.start_id if self.start_id != -1 else 0
  55. end_id = (
  56. self.end_id
  57. if self.end_id != -1
  58. else self._event_stream.get_latest_event_id()
  59. )
  60. for event in self._event_stream.get_events(
  61. start_id=start_id,
  62. end_id=end_id,
  63. reverse=reverse,
  64. filter_out_type=self.filter_out,
  65. ):
  66. if not include_hidden and hasattr(event, 'hidden') and event.hidden:
  67. continue
  68. # TODO add summaries
  69. # and filter out events that were included in a summary
  70. # filter out the events from a delegate of the current agent
  71. if not include_delegates and not any(
  72. # except for the delegate action and observation themselves, currently
  73. # AgentDelegateAction has id = delegate_start
  74. # AgentDelegateObservation has id = delegate_end
  75. delegate_start < event.id < delegate_end
  76. for delegate_start, delegate_end in self.delegates.keys()
  77. ):
  78. yield event
  79. elif include_delegates:
  80. yield event
  81. def get_last_action(self, end_id: int = -1) -> Action | None:
  82. """Return the last action from the event stream, filtered to exclude unwanted events."""
  83. # from end_id in reverse, find the first action
  84. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  85. last_action = next(
  86. (
  87. event
  88. for event in self._event_stream.get_events(
  89. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  90. )
  91. if isinstance(event, Action)
  92. ),
  93. None,
  94. )
  95. return last_action
  96. def get_last_observation(self, end_id: int = -1) -> Observation | None:
  97. """Return the last observation from the event stream, filtered to exclude unwanted events."""
  98. # from end_id in reverse, find the first observation
  99. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  100. last_observation = next(
  101. (
  102. event
  103. for event in self._event_stream.get_events(
  104. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  105. )
  106. if isinstance(event, Observation)
  107. ),
  108. None,
  109. )
  110. return last_observation
  111. def get_last_user_message(self) -> str:
  112. """Return the content of the last user message from the event stream."""
  113. last_user_message = next(
  114. (
  115. event.content
  116. for event in self._event_stream.get_events(reverse=True)
  117. if isinstance(event, MessageAction) and event.source == EventSource.USER
  118. ),
  119. None,
  120. )
  121. return last_user_message if last_user_message is not None else ''
  122. def get_last_agent_message(self) -> str:
  123. """Return the content of the last agent message from the event stream."""
  124. last_agent_message = next(
  125. (
  126. event.content
  127. for event in self._event_stream.get_events(reverse=True)
  128. if isinstance(event, MessageAction)
  129. and event.source == EventSource.AGENT
  130. ),
  131. None,
  132. )
  133. return last_agent_message if last_agent_message is not None else ''
  134. def get_last_events(self, n: int) -> list[Event]:
  135. """Return the last n events from the event stream."""
  136. # dummy agent is using this
  137. # it should work, but it's not great to store temporary lists now just for a test
  138. end_id = self._event_stream.get_latest_event_id()
  139. start_id = max(0, end_id - n + 1)
  140. return list(
  141. event
  142. for event in self._event_stream.get_events(
  143. start_id=start_id,
  144. end_id=end_id,
  145. filter_out_type=self.filter_out,
  146. )
  147. )
  148. def has_delegation(self) -> bool:
  149. for event in self._event_stream.get_events():
  150. if isinstance(event, AgentDelegateObservation):
  151. return True
  152. return False
  153. def on_event(self, event: Event):
  154. if not isinstance(event, AgentDelegateObservation):
  155. return
  156. logger.debug('AgentDelegateObservation received')
  157. # figure out what this delegate's actions were
  158. # from the last AgentDelegateAction to this AgentDelegateObservation
  159. # and save their ids as start and end ids
  160. # in order to use later to exclude them from parent stream
  161. # or summarize them
  162. delegate_end = event.id
  163. delegate_start = -1
  164. delegate_agent: str = ''
  165. delegate_task: str = ''
  166. for prev_event in self._event_stream.get_events(
  167. end_id=event.id - 1, reverse=True
  168. ):
  169. if isinstance(prev_event, AgentDelegateAction):
  170. delegate_start = prev_event.id
  171. delegate_agent = prev_event.agent
  172. delegate_task = prev_event.inputs.get('task', '')
  173. break
  174. if delegate_start == -1:
  175. logger.error(
  176. f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
  177. )
  178. return
  179. self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
  180. logger.debug(
  181. f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
  182. )
  183. # TODO remove me when unnecessary
  184. # history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
  185. # we rebuild the pairs here
  186. # for compatibility with the existing output format in evaluations
  187. def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
  188. history_pairs = []
  189. for action, observation in get_pairs_from_events(
  190. self.get_events_as_list(include_delegates=True)
  191. ):
  192. history_pairs.append((event_to_dict(action), event_to_dict(observation)))
  193. return history_pairs