history.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. from typing import ClassVar, Iterable
  2. from opendevin.core.logger import opendevin_logger as logger
  3. from opendevin.events.action.action import Action
  4. from opendevin.events.action.agent import (
  5. AgentDelegateAction,
  6. ChangeAgentStateAction,
  7. )
  8. from opendevin.events.action.empty import NullAction
  9. from opendevin.events.action.message import MessageAction
  10. from opendevin.events.event import Event, EventSource
  11. from opendevin.events.observation.agent import AgentStateChangedObservation
  12. from opendevin.events.observation.commands import CmdOutputObservation
  13. from opendevin.events.observation.delegate import AgentDelegateObservation
  14. from opendevin.events.observation.empty import NullObservation
  15. from opendevin.events.observation.observation import Observation
  16. from opendevin.events.serialization.event import event_to_dict
  17. from opendevin.events.stream import EventStream
  18. class ShortTermHistory(list[Event]):
  19. """A list of events that represents the short-term memory of the agent.
  20. This class provides methods to retrieve and filter the events in the history of the running agent from the event stream.
  21. """
  22. start_id: int
  23. end_id: int
  24. _event_stream: EventStream
  25. delegates: dict[tuple[int, int], tuple[str, str]]
  26. filter_out: ClassVar[tuple[type[Event], ...]] = (
  27. NullAction,
  28. NullObservation,
  29. ChangeAgentStateAction,
  30. AgentStateChangedObservation,
  31. )
  32. def __init__(self):
  33. super().__init__()
  34. self.start_id = -1
  35. self.end_id = -1
  36. self.delegates = {}
  37. def set_event_stream(self, event_stream: EventStream):
  38. self._event_stream = event_stream
  39. def get_events_as_list(self) -> list[Event]:
  40. """Return the history as a list of Event objects."""
  41. return list(self.get_events())
  42. def get_events(self, reverse: bool = False) -> Iterable[Event]:
  43. """Return the events as a stream of Event objects."""
  44. # TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation
  45. # or even if it is, because currently we don't add it to the summary
  46. # iterate from start_id to end_id, or reverse
  47. start_id = self.start_id if self.start_id != -1 else 0
  48. end_id = (
  49. self.end_id
  50. if self.end_id != -1
  51. else self._event_stream.get_latest_event_id()
  52. )
  53. for event in self._event_stream.get_events(
  54. start_id=start_id,
  55. end_id=end_id,
  56. reverse=reverse,
  57. filter_out_type=self.filter_out,
  58. ):
  59. # TODO add summaries
  60. # and filter out events that were included in a summary
  61. # filter out the events from a delegate of the current agent
  62. if not any(
  63. # except for the delegate action and observation themselves, currently
  64. # AgentDelegateAction has id = delegate_start
  65. # AgentDelegateObservation has id = delegate_end
  66. delegate_start < event.id < delegate_end
  67. for delegate_start, delegate_end in self.delegates.keys()
  68. ):
  69. yield event
  70. def get_last_action(self, end_id: int = -1) -> Action | None:
  71. """Return the last action from the event stream, filtered to exclude unwanted events."""
  72. # from end_id in reverse, find the first action
  73. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  74. last_action = next(
  75. (
  76. event
  77. for event in self._event_stream.get_events(
  78. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  79. )
  80. if isinstance(event, Action)
  81. ),
  82. None,
  83. )
  84. return last_action
  85. def get_last_observation(self, end_id: int = -1) -> Observation | None:
  86. """Return the last observation from the event stream, filtered to exclude unwanted events."""
  87. # from end_id in reverse, find the first observation
  88. end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id
  89. last_observation = next(
  90. (
  91. event
  92. for event in self._event_stream.get_events(
  93. end_id=end_id, reverse=True, filter_out_type=self.filter_out
  94. )
  95. if isinstance(event, Observation)
  96. ),
  97. None,
  98. )
  99. return last_observation
  100. def get_last_user_message(self) -> str:
  101. """Return the content of the last user message from the event stream."""
  102. last_user_message = next(
  103. (
  104. event.content
  105. for event in self._event_stream.get_events(reverse=True)
  106. if isinstance(event, MessageAction) and event.source == EventSource.USER
  107. ),
  108. None,
  109. )
  110. return last_user_message if last_user_message is not None else ''
  111. def get_last_agent_message(self) -> str:
  112. """Return the content of the last agent message from the event stream."""
  113. last_agent_message = next(
  114. (
  115. event.content
  116. for event in self._event_stream.get_events(reverse=True)
  117. if isinstance(event, MessageAction)
  118. and event.source == EventSource.AGENT
  119. ),
  120. None,
  121. )
  122. return last_agent_message if last_agent_message is not None else ''
  123. def get_last_events(self, n: int) -> list[Event]:
  124. """Return the last n events from the event stream."""
  125. # dummy agent is using this
  126. # it should work, but it's not great to store temporary lists now just for a test
  127. end_id = self._event_stream.get_latest_event_id()
  128. start_id = max(0, end_id - n + 1)
  129. return list(
  130. event
  131. for event in self._event_stream.get_events(
  132. start_id=start_id,
  133. end_id=end_id,
  134. filter_out_type=self.filter_out,
  135. )
  136. )
  137. def has_delegation(self) -> bool:
  138. for event in self._event_stream.get_events():
  139. if isinstance(event, AgentDelegateObservation):
  140. return True
  141. return False
  142. def on_event(self, event: Event):
  143. if not isinstance(event, AgentDelegateObservation):
  144. return
  145. logger.debug('AgentDelegateObservation received')
  146. # figure out what this delegate's actions were
  147. # from the last AgentDelegateAction to this AgentDelegateObservation
  148. # and save their ids as start and end ids
  149. # in order to use later to exclude them from parent stream
  150. # or summarize them
  151. delegate_end = event.id
  152. delegate_start = -1
  153. delegate_agent: str = ''
  154. delegate_task: str = ''
  155. for prev_event in self._event_stream.get_events(
  156. end_id=event.id - 1, reverse=True
  157. ):
  158. if isinstance(prev_event, AgentDelegateAction):
  159. delegate_start = prev_event.id
  160. delegate_agent = prev_event.agent
  161. delegate_task = prev_event.inputs.get('task', '')
  162. break
  163. if delegate_start == -1:
  164. logger.error(
  165. f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}'
  166. )
  167. return
  168. self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task)
  169. logger.debug(
  170. f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}'
  171. )
  172. # TODO remove me when unnecessary
  173. # history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation)
  174. # we rebuild the pairs here
  175. # for compatibility with the existing output format in evaluations
  176. def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]:
  177. history_pairs = []
  178. for action, observation in self.get_pairs():
  179. history_pairs.append((event_to_dict(action), event_to_dict(observation)))
  180. return history_pairs
  181. def get_pairs(self) -> list[tuple[Action, Observation]]:
  182. """Return the history as a list of tuples (action, observation)."""
  183. tuples: list[tuple[Action, Observation]] = []
  184. action_map: dict[int, Action] = {}
  185. observation_map: dict[int, Observation] = {}
  186. # runnable actions are set as cause of observations
  187. # (MessageAction, NullObservation) for source=USER
  188. # (MessageAction, NullObservation) for source=AGENT
  189. # (other_action?, NullObservation)
  190. # (NullAction, CmdOutputObservation) background CmdOutputObservations
  191. for event in self.get_events_as_list():
  192. if event.id is None or event.id == -1:
  193. logger.debug(f'Event {event} has no ID')
  194. if isinstance(event, Action):
  195. action_map[event.id] = event
  196. if isinstance(event, Observation):
  197. if event.cause is None or event.cause == -1:
  198. logger.debug(f'Observation {event} has no cause')
  199. if event.cause is None:
  200. # runnable actions are set as cause of observations
  201. # NullObservations have no cause
  202. continue
  203. observation_map[event.cause] = event
  204. for action_id, action in action_map.items():
  205. observation = observation_map.get(action_id)
  206. if observation:
  207. # observation with a cause
  208. tuples.append((action, observation))
  209. else:
  210. tuples.append((action, NullObservation('')))
  211. for cause_id, observation in observation_map.items():
  212. if cause_id not in action_map:
  213. if isinstance(observation, NullObservation):
  214. continue
  215. if not isinstance(observation, CmdOutputObservation):
  216. logger.debug(f'Observation {observation} has no cause')
  217. tuples.append((NullAction(), observation))
  218. return tuples.copy()