msg_stack.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import os
  2. import json
  3. import atexit
  4. import signal
  5. import uuid
  6. from typing import Dict, List
  7. from opendevin.schema.action import ActionType
  8. from opendevin.logging import opendevin_logger as logger
  9. CACHE_DIR = os.getenv("CACHE_DIR", "cache")
  10. MSG_CACHE_FILE = os.path.join(CACHE_DIR, "messages.json")
  11. class Message:
  12. id: str = str(uuid.uuid4())
  13. role: str # "user"| "assistant"
  14. payload: Dict[str, object]
  15. def __init__(self, role: str, payload: Dict[str, object]):
  16. self.role = role
  17. self.payload = payload
  18. def to_dict(self):
  19. return {"id": self.id, "role": self.role, "payload": self.payload}
  20. @classmethod
  21. def from_dict(cls, data: Dict):
  22. m = cls(data["role"], data["payload"])
  23. m.id = data["id"]
  24. return m
  25. class MessageStack:
  26. _messages: Dict[str, List[Message]] = {}
  27. def __init__(self):
  28. self._load_messages()
  29. atexit.register(self.close)
  30. signal.signal(signal.SIGINT, self.handle_signal)
  31. signal.signal(signal.SIGTERM, self.handle_signal)
  32. def close(self):
  33. self._save_messages()
  34. def handle_signal(self, signum, _):
  35. logger.info("Received signal %s, exiting...", signum)
  36. self.close()
  37. exit(0)
  38. def add_message(self, sid: str, role: str, message: Dict[str, object]):
  39. if sid not in self._messages:
  40. self._messages[sid] = []
  41. self._messages[sid].append(Message(role, message))
  42. def del_messages(self, sid: str):
  43. if sid not in self._messages:
  44. return
  45. del self._messages[sid]
  46. def get_messages(self, sid: str) -> List[Dict[str, object]]:
  47. if sid not in self._messages:
  48. return []
  49. return [msg.to_dict() for msg in self._messages[sid]]
  50. def get_message_total(self, sid: str) -> int:
  51. if sid not in self._messages:
  52. return 0
  53. cnt = 0
  54. for msg in self._messages[sid]:
  55. # Ignore assistant init message for now.
  56. if "action" in msg.payload and msg.payload["action"] == ActionType.INIT:
  57. continue
  58. cnt += 1
  59. return cnt
  60. def _save_messages(self):
  61. if not os.path.exists(CACHE_DIR):
  62. os.makedirs(CACHE_DIR)
  63. data = {}
  64. for sid, msgs in self._messages.items():
  65. data[sid] = [msg.to_dict() for msg in msgs]
  66. with open(MSG_CACHE_FILE, "w+") as file:
  67. json.dump(data, file)
  68. def _load_messages(self):
  69. try:
  70. # TODO: delete useless messages
  71. with open(MSG_CACHE_FILE, "r") as file:
  72. data = json.load(file)
  73. for sid, msgs in data.items():
  74. self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
  75. except FileNotFoundError:
  76. pass
  77. except json.decoder.JSONDecodeError:
  78. pass
  79. message_stack = MessageStack()