msg_stack.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import atexit
  2. import json
  3. import os
  4. import uuid
  5. from opendevin.core.logger import opendevin_logger as logger
  6. from opendevin.core.schema.action import ActionType
  7. CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
  8. MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json')
  9. class Message:
  10. id: str = str(uuid.uuid4())
  11. role: str # "user"| "assistant"
  12. payload: dict[str, object]
  13. def __init__(self, role: str, payload: dict[str, object]):
  14. self.role = role
  15. self.payload = payload
  16. def to_dict(self):
  17. return {'id': self.id, 'role': self.role, 'payload': self.payload}
  18. @classmethod
  19. def from_dict(cls, data: dict):
  20. m = cls(data['role'], data['payload'])
  21. m.id = data['id']
  22. return m
  23. class MessageStack:
  24. _messages: dict[str, list[Message]] = {}
  25. def __init__(self):
  26. self._load_messages()
  27. atexit.register(self.close)
  28. def close(self):
  29. logger.info('Saving messages...')
  30. self._save_messages()
  31. def add_message(self, sid: str, role: str, message: dict[str, object]):
  32. if sid not in self._messages:
  33. self._messages[sid] = []
  34. self._messages[sid].append(Message(role, message))
  35. def del_messages(self, sid: str):
  36. if sid not in self._messages:
  37. return
  38. del self._messages[sid]
  39. def get_messages(self, sid: str) -> list[dict[str, object]]:
  40. if sid not in self._messages:
  41. return []
  42. return [msg.to_dict() for msg in self._messages[sid]]
  43. def get_message_total(self, sid: str) -> int:
  44. if sid not in self._messages:
  45. return 0
  46. cnt = 0
  47. for msg in self._messages[sid]:
  48. # Ignore assistant init message for now.
  49. if 'action' in msg.payload and msg.payload['action'] in [
  50. ActionType.INIT,
  51. ActionType.CHANGE_AGENT_STATE,
  52. ]:
  53. continue
  54. cnt += 1
  55. return cnt
  56. def _save_messages(self):
  57. if not os.path.exists(CACHE_DIR):
  58. os.makedirs(CACHE_DIR)
  59. data = {}
  60. for sid, msgs in self._messages.items():
  61. data[sid] = [msg.to_dict() for msg in msgs]
  62. with open(MSG_CACHE_FILE, 'w+') as file:
  63. json.dump(data, file)
  64. def _load_messages(self):
  65. try:
  66. # TODO: delete useless messages
  67. with open(MSG_CACHE_FILE, 'r') as file:
  68. data = json.load(file)
  69. for sid, msgs in data.items():
  70. self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
  71. except FileNotFoundError:
  72. pass
  73. except json.decoder.JSONDecodeError:
  74. pass
  75. message_stack = MessageStack()