msg_stack.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import atexit
  2. import json
  3. import os
  4. import uuid
  5. from typing import Dict, List
  6. from opendevin.logger import opendevin_logger as logger
  7. from opendevin.schema.action import ActionType
  8. CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
  9. MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json')
  10. class Message:
  11. id: str = str(uuid.uuid4())
  12. role: str # "user"| "assistant"
  13. payload: Dict[str, object]
  14. def __init__(self, role: str, payload: Dict[str, object]):
  15. self.role = role
  16. self.payload = payload
  17. def to_dict(self):
  18. return {'id': self.id, 'role': self.role, 'payload': self.payload}
  19. @classmethod
  20. def from_dict(cls, data: Dict):
  21. m = cls(data['role'], data['payload'])
  22. m.id = data['id']
  23. return m
  24. class MessageStack:
  25. _messages: Dict[str, List[Message]] = {}
  26. def __init__(self):
  27. self._load_messages()
  28. atexit.register(self.close)
  29. def close(self):
  30. logger.info('Saving messages...')
  31. self._save_messages()
  32. def add_message(self, sid: str, role: str, message: Dict[str, object]):
  33. if sid not in self._messages:
  34. self._messages[sid] = []
  35. self._messages[sid].append(Message(role, message))
  36. def del_messages(self, sid: str):
  37. if sid not in self._messages:
  38. return
  39. del self._messages[sid]
  40. def get_messages(self, sid: str) -> List[Dict[str, object]]:
  41. if sid not in self._messages:
  42. return []
  43. return [msg.to_dict() for msg in self._messages[sid]]
  44. def get_message_total(self, sid: str) -> int:
  45. if sid not in self._messages:
  46. return 0
  47. cnt = 0
  48. for msg in self._messages[sid]:
  49. # Ignore assistant init message for now.
  50. if 'action' in msg.payload and msg.payload['action'] in [ActionType.INIT, ActionType.CHANGE_TASK_STATE]:
  51. continue
  52. cnt += 1
  53. return cnt
  54. def _save_messages(self):
  55. if not os.path.exists(CACHE_DIR):
  56. os.makedirs(CACHE_DIR)
  57. data = {}
  58. for sid, msgs in self._messages.items():
  59. data[sid] = [msg.to_dict() for msg in msgs]
  60. with open(MSG_CACHE_FILE, 'w+') as file:
  61. json.dump(data, file)
  62. def _load_messages(self):
  63. try:
  64. # TODO: delete useless messages
  65. with open(MSG_CACHE_FILE, 'r') as file:
  66. data = json.load(file)
  67. for sid, msgs in data.items():
  68. self._messages[sid] = [
  69. Message.from_dict(msg) for msg in msgs]
  70. except FileNotFoundError:
  71. pass
  72. except json.decoder.JSONDecodeError:
  73. pass
  74. message_stack = MessageStack()