parser.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from typing import Union
  2. from pydantic import BaseModel, Field
  3. from opendevin.core.logger import opendevin_logger as logger
  4. from opendevin.events.action import (
  5. Action,
  6. ChangeAgentStateAction,
  7. MessageAction,
  8. NullAction,
  9. )
  10. from opendevin.events.event import EventSource
  11. from opendevin.events.observation import (
  12. AgentStateChangedObservation,
  13. NullObservation,
  14. Observation,
  15. )
  16. from opendevin.events.serialization.event import event_to_dict
  17. from opendevin.security.invariant.nodes import Function, Message, ToolCall, ToolOutput
  18. TraceElement = Union[Message, ToolCall, ToolOutput, Function]
  19. def get_next_id(trace: list[TraceElement]) -> str:
  20. used_ids = [el.id for el in trace if type(el) == ToolCall]
  21. for i in range(1, len(used_ids) + 2):
  22. if str(i) not in used_ids:
  23. return str(i)
  24. return '1'
  25. def get_last_id(
  26. trace: list[TraceElement],
  27. ) -> str | None:
  28. for el in reversed(trace):
  29. if type(el) == ToolCall:
  30. return el.id
  31. return None
  32. def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
  33. next_id = get_next_id(trace)
  34. inv_trace = [] # type: list[TraceElement]
  35. if type(action) == MessageAction:
  36. if action.source == EventSource.USER:
  37. inv_trace.append(Message(role='user', content=action.content))
  38. else:
  39. inv_trace.append(Message(role='assistant', content=action.content))
  40. elif type(action) in [NullAction, ChangeAgentStateAction]:
  41. pass
  42. elif hasattr(action, 'action') and action.action is not None:
  43. event_dict = event_to_dict(action)
  44. args = event_dict.get('args', {})
  45. thought = args.pop('thought', None)
  46. function = Function(name=action.action, arguments=args)
  47. if thought is not None:
  48. inv_trace.append(Message(role='assistant', content=thought))
  49. inv_trace.append(ToolCall(id=next_id, type='function', function=function))
  50. else:
  51. logger.error(f'Unknown action type: {type(action)}')
  52. return inv_trace
  53. def parse_observation(
  54. trace: list[TraceElement], obs: Observation
  55. ) -> list[TraceElement]:
  56. last_id = get_last_id(trace)
  57. if type(obs) in [NullObservation, AgentStateChangedObservation]:
  58. return []
  59. elif hasattr(obs, 'content') and obs.content is not None:
  60. return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)]
  61. else:
  62. logger.error(f'Unknown observation type: {type(obs)}')
  63. return []
  64. def parse_element(
  65. trace: list[TraceElement], element: Action | Observation
  66. ) -> list[TraceElement]:
  67. if isinstance(element, Action):
  68. return parse_action(trace, element)
  69. return parse_observation(trace, element)
  70. def parse_trace(trace: list[tuple[Action, Observation]]):
  71. inv_trace = [] # type: list[TraceElement]
  72. for action, obs in trace:
  73. inv_trace.extend(parse_action(inv_trace, action))
  74. inv_trace.extend(parse_observation(inv_trace, obs))
  75. return inv_trace
  76. class InvariantState(BaseModel):
  77. trace: list[TraceElement] = Field(default_factory=list)
  78. def add_action(self, action: Action):
  79. self.trace.extend(parse_action(self.trace, action))
  80. def add_observation(self, obs: Observation):
  81. self.trace.extend(parse_observation(self.trace, obs))
  82. def concatenate(self, other: 'InvariantState'):
  83. self.trace.extend(other.trace)