history.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from typing import ClassVar, Iterable
  2. from openhands.core.logger import openhands_logger as logger
  3. from openhands.events.action.action import Action
  4. from openhands.events.action.agent import (
  5. AgentDelegateAction,
  6. ChangeAgentStateAction,
  7. )
  8. from openhands.events.action.empty import NullAction
  9. from openhands.events.action.message import MessageAction
  10. from openhands.events.event import Event, EventSource
  11. from openhands.events.observation.agent import AgentStateChangedObservation
  12. from openhands.events.observation.commands import CmdOutputObservation
  13. from openhands.events.observation.delegate import AgentDelegateObservation
  14. from openhands.events.observation.empty import NullObservation
  15. from openhands.events.observation.observation import Observation
  16. from openhands.events.serialization.event import event_to_dict
  17. from openhands.events.stream import EventStream
  18. class ShortTermHistory(list[Event]):
  19. """A list of events that represents the short-term memory of the agent.
  20. This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
  21. """
  22. start_id: int
  23. end_id: int
  24. _event_stream: EventStream
  25. delegates: dict[tuple[int, int], tuple[str, str]]
  26. filter_out: ClassVar[tuple[type[Event], ...]] = (
  27. NullAction,
  28. NullObservation,
  29. ChangeAgentStateAction,
  30. AgentStateChangedObservation,
  31. )
  32. def __init__(self):
  33. super().__init__()
  34. self.start_id = -1
  35. self.end_id = -1
  36. self.delegates = {}
  37. def set_event_stream(self, event_stream: EventStream):
  38. self._event_stream = event_stream
  39. def get_events_as_list(self, include_delegates: bool = False) -> list[Event]:
  40. """Return the history as a list of Event objects."""
  41. return list(self.get_events(include_delegates=include_delegates))
  42. def get_events(
  43. self, reverse: bool = False, include_delegates: bool = False
  44. ) -> Iterable[Event]:
  45. """Return the events as a stream of Event objects."""
  46. # TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
  47. # or even if it is, because currently we don't add it to the summary
  48. # iterate from start_id to end_id, or reverse
  49. start_id = self.start_id if self.start_id != -1 else 0
  50. end_id = (
  51. self.end_id
  52. if self.end_id != -1
  53. else self._event_stream.get_latest_event_id()
  54. )
  55. for event in self._event_stream.get_events(
  56. start_id=start_id,
  57. end_id=end_id,
  58. reverse=reverse,
  59. filter_out_type=self.filter_out,
  60. ):
  61. # TODO add summaries
  62. # and filter out events that were included in a summary
  63. # filter out the events from a delegate of the current agent
  64. if not include_delegates and not any(
  65. # except for the delegate action and observation themselves, currently
  66. # AgentDelegateAction has id = delegate_start
  67. # AgentDelegateObservation has id = delegate_end
  68. delegate_start < event.id < delegate_end
  69. for delegate_start, delegate_end in self.delegates.keys()
  70. ):
  71. yield event
  72. elif include_delegates:
  73. yield event
  74. def get_last_action(self, end_id: int = -1) -> Action | None:
  75. """Return the last action from the event stream, filtered to exclude unwanted events."""
  76. # from end_id in reverse, find the first action
  77. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  78. last_action = next(
  79. (
  80. event
  81. for event in self._event_stream.get_events(
  82. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  83. )
  84. if isinstance(event, Action)
  85. ),
  86. None,
  87. )
  88. return last_action
  89. def get_last_observation(self, end_id: int = -1) -> Observation | None:
  90. """Return the last observation from the event stream, filtered to exclude unwanted events."""
  91. # from end_id in reverse, find the first observation
  92. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  93. last_observation = next(
  94. (
  95. event
  96. for event in self._event_stream.get_events(
  97. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  98. )
  99. if isinstance(event, Observation)
  100. ),
  101. None,
  102. )
  103. return last_observation
  104. def get_last_user_message(self) -> str:
  105. """Return the content of the last user message from the event stream."""
  106. last_user_message = next(
  107. (
  108. event.content
  109. for event in self._event_stream.get_events(reverse=True)
  110. if isinstance(event, MessageAction) and event.source == EventSource.USER
  111. ),
  112. None,
  113. )
  114. return last_user_message if last_user_message is not None else ''
  115. def get_last_agent_message(self) -> str:
  116. """Return the content of the last agent message from the event stream."""
  117. last_agent_message = next(
  118. (
  119. event.content
  120. for event in self._event_stream.get_events(reverse=True)
  121. if isinstance(event, MessageAction)
  122. and event.source == EventSource.AGENT
  123. ),
  124. None,
  125. )
  126. return last_agent_message if last_agent_message is not None else ''
  127. def get_last_events(self, n: int) -> list[Event]:
  128. """Return the last n events from the event stream."""
  129. # dummy agent is using this
  130. # it should work, but it's not great to store temporary lists now just for a test
  131. end_id = self._event_stream.get_latest_event_id()
  132. start_id = max(0, end_id - n + 1)
  133. return list(
  134. event
  135. for event in self._event_stream.get_events(
  136. start_id=start_id,
  137. end_id=end_id,
  138. filter_out_type=self.filter_out,
  139. )
  140. )
  141. def has_delegation(self) -> bool:
  142. for event in self._event_stream.get_events():
  143. if isinstance(event, AgentDelegateObservation):
  144. return True
  145. return False
  146. def on_event(self, event: Event):
  147. if not isinstance(event, AgentDelegateObservation):
  148. return
  149. logger.debug('AgentDelegateObservation received')
  150. # figure out what this delegate's actions were
  151. # from the last AgentDelegateAction to this AgentDelegateObservation
  152. # and save their ids as start and end ids
  153. # in order to use later to exclude them from parent stream
  154. # or summarize them
  155. delegate_end = event.id
  156. delegate_start = -1
  157. delegate_agent: str = ''
  158. delegate_task: str = ''
  159. for prev_event in self._event_stream.get_events(
  160. end_id=event.id - 1, reverse=True
  161. ):
  162. if isinstance(prev_event, AgentDelegateAction):
  163. delegate_start = prev_event.id
  164. delegate_agent = prev_event.agent
  165. delegate_task = prev_event.inputs.get('task', '')
  166. break
  167. if delegate_start == -1:
  168. logger.error(
  169. f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
  170. )
  171. return
  172. self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
  173. logger.debug(
  174. f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
  175. )
  176. # TODO remove me when unnecessary
  177. # history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
  178. # we rebuild the pairs here
  179. # for compatibility with the existing output format in evaluations
  180. def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
  181. history_pairs = []
  182. for action, observation in self.get_pairs():
  183. history_pairs.append((event_to_dict(action), event_to_dict(observation)))
  184. return history_pairs
  185. def get_pairs(self) -> list[tuple[Action, Observation]]:
  186. """Return the history as a list of tuples (action, observation)."""
  187. tuples: list[tuple[Action, Observation]] = []
  188. action_map: dict[int, Action] = {}
  189. observation_map: dict[int, Observation] = {}
  190. # runnable actions are set as cause of observations
  191. # (MessageAction, NullObservation) for source=USER
  192. # (MessageAction, NullObservation) for source=AGENT
  193. # (other_action?, NullObservation)
  194. # (NullAction, CmdOutputObservation) background CmdOutputObservations
  195. for event in self.get_events_as_list(include_delegates=True):
  196. if event.id is None or event.id == -1:
  197. logger.debug(f'Event {event} has no ID')
  198. if isinstance(event, Action):
  199. action_map[event.id] = event
  200. if isinstance(event, Observation):
  201. if event.cause is None or event.cause == -1:
  202. logger.debug(f'Observation {event} has no cause')
  203. if event.cause is None:
  204. # runnable actions are set as cause of observations
  205. # NullObservations have no cause
  206. continue
  207. observation_map[event.cause] = event
  208. for action_id, action in action_map.items():
  209. observation = observation_map.get(action_id)
  210. if observation:
  211. # observation with a cause
  212. tuples.append((action, observation))
  213. else:
  214. tuples.append((action, NullObservation('')))
  215. for cause_id, observation in observation_map.items():
  216. if cause_id not in action_map:
  217. if isinstance(observation, NullObservation):
  218. continue
  219. if not isinstance(observation, CmdOutputObservation):
  220. logger.debug(f'Observation {observation} has no cause')
  221. tuples.append((NullAction(), observation))
  222. return tuples.copy()