msg_stack.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import atexit
  2. import json
  3. import os
  4. import uuid
  5. from typing import Dict
  6. from opendevin.core.logger import opendevin_logger as logger
  7. from opendevin.core.schema.action import ActionType
  8. CACHE_DIR = os.getenv('CACHE_DIR', 'cache')
  9. MSG_CACHE_FILE = os.path.join(CACHE_DIR, 'messages.json')
  10. class Message:
  11. id: str = str(uuid.uuid4())
  12. role: str # "user"| "assistant"
  13. payload: Dict[str, object]
  14. def __init__(self, role: str, payload: Dict[str, object]):
  15. self.role = role
  16. self.payload = payload
  17. def to_dict(self):
  18. return {'id': self.id, 'role': self.role, 'payload': self.payload}
  19. @classmethod
  20. def from_dict(cls, data: Dict):
  21. m = cls(data['role'], data['payload'])
  22. m.id = data['id']
  23. return m
  24. class MessageStack:
  25. _messages: Dict[str, list[Message]] = {}
  26. def __init__(self):
  27. self._load_messages()
  28. atexit.register(self.close)
  29. def close(self):
  30. logger.info('Saving messages...')
  31. self._save_messages()
  32. def add_message(self, sid: str, role: str, message: Dict[str, object]):
  33. if sid not in self._messages:
  34. self._messages[sid] = []
  35. self._messages[sid].append(Message(role, message))
  36. def del_messages(self, sid: str):
  37. if sid not in self._messages:
  38. return
  39. del self._messages[sid]
  40. def get_messages(self, sid: str) -> list[Dict[str, object]]:
  41. if sid not in self._messages:
  42. return []
  43. return [msg.to_dict() for msg in self._messages[sid]]
  44. def get_message_total(self, sid: str) -> int:
  45. if sid not in self._messages:
  46. return 0
  47. cnt = 0
  48. for msg in self._messages[sid]:
  49. # Ignore assistant init message for now.
  50. if 'action' in msg.payload and msg.payload['action'] in [
  51. ActionType.INIT,
  52. ActionType.CHANGE_AGENT_STATE,
  53. ]:
  54. continue
  55. cnt += 1
  56. return cnt
  57. def _save_messages(self):
  58. if not os.path.exists(CACHE_DIR):
  59. os.makedirs(CACHE_DIR)
  60. data = {}
  61. for sid, msgs in self._messages.items():
  62. data[sid] = [msg.to_dict() for msg in msgs]
  63. with open(MSG_CACHE_FILE, 'w+') as file:
  64. json.dump(data, file)
  65. def _load_messages(self):
  66. try:
  67. # TODO: delete useless messages
  68. with open(MSG_CACHE_FILE, 'r') as file:
  69. data = json.load(file)
  70. for sid, msgs in data.items():
  71. self._messages[sid] = [Message.from_dict(msg) for msg in msgs]
  72. except FileNotFoundError:
  73. pass
  74. except json.decoder.JSONDecodeError:
  75. pass
  76. message_stack = MessageStack()