msg_stack.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import asyncio
  2. import atexit
  3. import json
  4. import os
  5. import uuid
  6. from opendevin.core.logger import opendevin_logger as logger
  7. from opendevin.core.schema.action import ActionType
  8. CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
  9. MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json')
  10. class Message:
  11. id: str = str(uuid.uuid4())
  12. role: str # "user"| "assistant"
  13. payload: dict[str, object]
  14. def __init__(self, role: str, payload: dict[str, object]):
  15. self.role = role
  16. self.payload = payload
  17. def to_dict(self):
  18. return {'id': self.id, 'role': self.role, 'payload': self.payload}
  19. @classmethod
  20. def from_dict(cls, data: dict):
  21. m = cls(data['role'], data['payload'])
  22. m.id = data['id']
  23. return m
  24. class MessageStack:
  25. _messages: dict[str, list[Message]] = {}
  26. def __init__(self):
  27. self._load_messages()
  28. atexit.register(self.close)
  29. def close(self):
  30. logger.info('Saving messages...')
  31. self._save_messages()
  32. def add_message(self, sid: str, role: str, message: dict[str, object]):
  33. if sid not in self._messages:
  34. self._messages[sid] = []
  35. self._messages[sid].append(Message(role, message))
  36. def del_messages(self, sid: str):
  37. if sid not in self._messages:
  38. return
  39. del self._messages[sid]
  40. asyncio.create_task(self._del_messages(sid))
  41. def get_messages(self, sid: str) -> list[dict[str, object]]:
  42. if sid not in self._messages:
  43. return []
  44. return [msg.to_dict() for msg in self._messages[sid]]
  45. def get_message_total(self, sid: str) -> int:
  46. if sid not in self._messages:
  47. return 0
  48. cnt = 0
  49. for msg in self._messages[sid]:
  50. # Ignore assistant init message for now.
  51. if 'action' in msg.payload and msg.payload['action'] in [
  52. ActionType.INIT,
  53. ActionType.CHANGE_AGENT_STATE,
  54. ]:
  55. continue
  56. cnt += 1
  57. return cnt
  58. def _save_messages(self):
  59. if not os.path.exists(CACHE_DIR):
  60. os.makedirs(CACHE_DIR)
  61. data = {}
  62. for sid, msgs in self._messages.items():
  63. data[sid] = [msg.to_dict() for msg in msgs]
  64. with open(MSG_CACHE_FILE, 'w+') as file:
  65. json.dump(data, file)
  66. def _load_messages(self):
  67. try:
  68. with open(MSG_CACHE_FILE, 'r') as file:
  69. data = json.load(file)
  70. for sid, msgs in data.items():
  71. self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
  72. except FileNotFoundError:
  73. pass
  74. except json.decoder.JSONDecodeError:
  75. pass
  76. async def _del_messages(self, del_sid: str):
  77. logger.info('Deleting messages...')
  78. try:
  79. with open(MSG_CACHE_FILE, 'r+') as file:
  80. data = json.load(file)
  81. new_data = {}
  82. for sid, msgs in data.items():
  83. if sid != del_sid:
  84. new_data[sid] = msgs
  85. # Move the file pointer to the beginning of the file to overwrite the original contents
  86. file.seek(0)
  87. # clean previous content
  88. file.truncate()
  89. json.dump(new_data, file)
  90. except FileNotFoundError:
  91. pass
  92. except json.decoder.JSONDecodeError:
  93. pass
  94. message_stack = MessageStack()