history.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from typing import ClassVar, Iterable
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.events.action.action import Action
  4. from openhands.events.action.agent import (
  5. AgentDelegateAction,
  6. ChangeAgentStateAction,
  7. )
  8. from openhands.events.action.empty import NullAction
  9. from openhands.events.action.message import MessageAction
  10. from openhands.events.event import Event, EventSource
  11. from openhands.events.observation.agent import AgentStateChangedObservation
  12. from openhands.events.observation.commands import CmdOutputObservation
  13. from openhands.events.observation.delegate import AgentDelegateObservation
  14. from openhands.events.observation.empty import NullObservation
  15. from openhands.events.observation.observation import Observation
  16. from openhands.events.serialization.event import event_to_dict
  17. from openhands.events.stream import EventStream
  18. class ShortTermHistory(list[Event]):
  19. """A list of events that represents the short-term memory of the agent.
  20. This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
  21. """
  22. start_id: int
  23. end_id: int
  24. _event_stream: EventStream
  25. delegates: dict[tuple[int, int], tuple[str, str]]
  26. filter_out: ClassVar[tuple[type[Event], ...]] = (
  27. NullAction,
  28. NullObservation,
  29. ChangeAgentStateAction,
  30. AgentStateChangedObservation,
  31. )
  32. def __init__(self):
  33. super().__init__()
  34. self.start_id = -1
  35. self.end_id = -1
  36. self.delegates = {}
  37. def set_event_stream(self, event_stream: EventStream):
  38. self._event_stream = event_stream
  39. def get_events_as_list(self, include_delegates: bool = False) -> list[Event]:
  40. """Return the history as a list of Event objects."""
  41. return list(self.get_events(include_delegates=include_delegates))
  42. def get_events(
  43. self,
  44. reverse: bool = False,
  45. include_delegates: bool = False,
  46. include_hidden=False,
  47. ) -> Iterable[Event]:
  48. """Return the events as a stream of Event objects."""
  49. # TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
  50. # or even if it is, because currently we don't add it to the summary
  51. # iterate from start_id to end_id, or reverse
  52. start_id = self.start_id if self.start_id != -1 else 0
  53. end_id = (
  54. self.end_id
  55. if self.end_id != -1
  56. else self._event_stream.get_latest_event_id()
  57. )
  58. for event in self._event_stream.get_events(
  59. start_id=start_id,
  60. end_id=end_id,
  61. reverse=reverse,
  62. filter_out_type=self.filter_out,
  63. ):
  64. if not include_hidden and hasattr(event, 'hidden') and event.hidden:
  65. continue
  66. # TODO add summaries
  67. # and filter out events that were included in a summary
  68. # filter out the events from a delegate of the current agent
  69. if not include_delegates and not any(
  70. # except for the delegate action and observation themselves, currently
  71. # AgentDelegateAction has id = delegate_start
  72. # AgentDelegateObservation has id = delegate_end
  73. delegate_start < event.id < delegate_end
  74. for delegate_start, delegate_end in self.delegates.keys()
  75. ):
  76. yield event
  77. elif include_delegates:
  78. yield event
  79. def get_last_action(self, end_id: int = -1) -> Action | None:
  80. """Return the last action from the event stream, filtered to exclude unwanted events."""
  81. # from end_id in reverse, find the first action
  82. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  83. last_action = next(
  84. (
  85. event
  86. for event in self._event_stream.get_events(
  87. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  88. )
  89. if isinstance(event, Action)
  90. ),
  91. None,
  92. )
  93. return last_action
  94. def get_last_observation(self, end_id: int = -1) -> Observation | None:
  95. """Return the last observation from the event stream, filtered to exclude unwanted events."""
  96. # from end_id in reverse, find the first observation
  97. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  98. last_observation = next(
  99. (
  100. event
  101. for event in self._event_stream.get_events(
  102. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  103. )
  104. if isinstance(event, Observation)
  105. ),
  106. None,
  107. )
  108. return last_observation
  109. def get_last_user_message(self) -> str:
  110. """Return the content of the last user message from the event stream."""
  111. last_user_message = next(
  112. (
  113. event.content
  114. for event in self._event_stream.get_events(reverse=True)
  115. if isinstance(event, MessageAction) and event.source == EventSource.USER
  116. ),
  117. None,
  118. )
  119. return last_user_message if last_user_message is not None else ''
  120. def get_last_agent_message(self) -> str:
  121. """Return the content of the last agent message from the event stream."""
  122. last_agent_message = next(
  123. (
  124. event.content
  125. for event in self._event_stream.get_events(reverse=True)
  126. if isinstance(event, MessageAction)
  127. and event.source == EventSource.AGENT
  128. ),
  129. None,
  130. )
  131. return last_agent_message if last_agent_message is not None else ''
  132. def get_last_events(self, n: int) -> list[Event]:
  133. """Return the last n events from the event stream."""
  134. # dummy agent is using this
  135. # it should work, but it's not great to store temporary lists now just for a test
  136. end_id = self._event_stream.get_latest_event_id()
  137. start_id = max(0, end_id - n + 1)
  138. return list(
  139. event
  140. for event in self._event_stream.get_events(
  141. start_id=start_id,
  142. end_id=end_id,
  143. filter_out_type=self.filter_out,
  144. )
  145. )
  146. def has_delegation(self) -> bool:
  147. for event in self._event_stream.get_events():
  148. if isinstance(event, AgentDelegateObservation):
  149. return True
  150. return False
  151. def on_event(self, event: Event):
  152. if not isinstance(event, AgentDelegateObservation):
  153. return
  154. logger.debug('AgentDelegateObservation received')
  155. # figure out what this delegate's actions were
  156. # from the last AgentDelegateAction to this AgentDelegateObservation
  157. # and save their ids as start and end ids
  158. # in order to use later to exclude them from parent stream
  159. # or summarize them
  160. delegate_end = event.id
  161. delegate_start = -1
  162. delegate_agent: str = ''
  163. delegate_task: str = ''
  164. for prev_event in self._event_stream.get_events(
  165. end_id=event.id - 1, reverse=True
  166. ):
  167. if isinstance(prev_event, AgentDelegateAction):
  168. delegate_start = prev_event.id
  169. delegate_agent = prev_event.agent
  170. delegate_task = prev_event.inputs.get('task', '')
  171. break
  172. if delegate_start == -1:
  173. logger.error(
  174. f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
  175. )
  176. return
  177. self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
  178. logger.debug(
  179. f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
  180. )
  181. # TODO remove me when unnecessary
  182. # history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
  183. # we rebuild the pairs here
  184. # for compatibility with the existing output format in evaluations
  185. def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
  186. history_pairs = []
  187. for action, observation in self.get_pairs():
  188. history_pairs.append((event_to_dict(action), event_to_dict(observation)))
  189. return history_pairs
  190. def get_pairs(self) -> list[tuple[Action, Observation]]:
  191. """Return the history as a list of tuples (action, observation)."""
  192. tuples: list[tuple[Action, Observation]] = []
  193. action_map: dict[int, Action] = {}
  194. observation_map: dict[int, Observation] = {}
  195. # runnable actions are set as cause of observations
  196. # (MessageAction, NullObservation) for source=USER
  197. # (MessageAction, NullObservation) for source=AGENT
  198. # (other_action?, NullObservation)
  199. # (NullAction, CmdOutputObservation) background CmdOutputObservations
  200. for event in self.get_events_as_list(include_delegates=True):
  201. if event.id is None or event.id == -1:
  202. logger.debug(f'Event {event} has no ID')
  203. if isinstance(event, Action):
  204. action_map[event.id] = event
  205. if isinstance(event, Observation):
  206. if event.cause is None or event.cause == -1:
  207. logger.debug(f'Observation {event} has no cause')
  208. if event.cause is None:
  209. # runnable actions are set as cause of observations
  210. # NullObservations have no cause
  211. continue
  212. observation_map[event.cause] = event
  213. for action_id, action in action_map.items():
  214. observation = observation_map.get(action_id)
  215. if observation:
  216. # observation with a cause
  217. tuples.append((action, observation))
  218. else:
  219. tuples.append((action, NullObservation('')))
  220. for cause_id, observation in observation_map.items():
  221. if cause_id not in action_map:
  222. if isinstance(observation, NullObservation):
  223. continue
  224. if not isinstance(observation, CmdOutputObservation):
  225. logger.debug(f'Observation {observation} has no cause')
  226. tuples.append((NullAction(), observation))
  227. return tuples.copy()