__init__.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from .base import Observation, NullObservation
  2. from .run import CmdOutputObservation
  3. from .browse import BrowserOutputObservation
  4. from .files import FileReadObservation, FileWriteObservation
  5. from .message import UserMessageObservation, AgentMessageObservation
  6. from .recall import AgentRecallObservation
  7. from .error import AgentErrorObservation
  8. observations = (
  9. CmdOutputObservation,
  10. BrowserOutputObservation,
  11. FileReadObservation,
  12. FileWriteObservation,
  13. UserMessageObservation,
  14. AgentMessageObservation,
  15. AgentRecallObservation,
  16. AgentErrorObservation,
  17. )
  18. OBSERVATION_TYPE_TO_CLASS = {observation_class.observation:observation_class for observation_class in observations} # type: ignore[attr-defined]
  19. def observation_from_dict(observation: dict) -> Observation:
  20. observation = observation.copy()
  21. if "observation" not in observation:
  22. raise KeyError(f"'observation' key is not found in {observation=}")
  23. observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation["observation"])
  24. if observation_class is None:
  25. raise KeyError(f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}")
  26. observation.pop("observation")
  27. observation.pop("message", None)
  28. content = observation.pop("content", "")
  29. extras = observation.pop("extras", {})
  30. return observation_class(content=content, **extras)
  31. __all__ = [
  32. "Observation",
  33. "NullObservation",
  34. "CmdOutputObservation",
  35. "BrowserOutputObservation",
  36. "FileReadObservation",
  37. "FileWriteObservation",
  38. "UserMessageObservation",
  39. "AgentMessageObservation",
  40. "AgentRecallObservation",
  41. "AgentErrorObservation",
  42. ]