session.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import asyncio
  2. import os
  3. from typing import Optional
  4. from fastapi import WebSocketDisconnect
  5. from opendevin import config
  6. from opendevin.action import (
  7. Action,
  8. NullAction,
  9. )
  10. from opendevin.observation import NullObservation
  11. from opendevin.agent import Agent
  12. from opendevin.controller import AgentController
  13. from opendevin.llm.llm import LLM
  14. from opendevin.observation import Observation, UserMessageObservation
  15. DEFAULT_API_KEY = config.get_or_none("LLM_API_KEY")
  16. DEFAULT_BASE_URL = config.get_or_none("LLM_BASE_URL")
  17. DEFAULT_WORKSPACE_DIR = config.get_or_default("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace"))
  18. LLM_MODEL = config.get_or_default("LLM_MODEL", "gpt-4-0125-preview")
  19. class Session:
  20. def __init__(self, websocket):
  21. self.websocket = websocket
  22. self.controller: Optional[AgentController] = None
  23. self.agent: Optional[Agent] = None
  24. self.agent_task = None
  25. asyncio.create_task(self.create_controller(), name="create controller") # FIXME: starting the docker container synchronously causes a websocket error...
  26. async def send_error(self, message):
  27. await self.send({"error": True, "message": message})
  28. async def send_message(self, message):
  29. await self.send({"message": message})
  30. async def send(self, data):
  31. if self.websocket is None:
  32. return
  33. try:
  34. await self.websocket.send_json(data)
  35. except Exception as e:
  36. print("Error sending data to client", e)
  37. async def start_listening(self):
  38. try:
  39. while True:
  40. try:
  41. data = await self.websocket.receive_json()
  42. except ValueError:
  43. await self.send_error("Invalid JSON")
  44. continue
  45. action = data.get("action", None)
  46. if action is None:
  47. await self.send_error("Invalid event")
  48. continue
  49. if action == "initialize":
  50. await self.create_controller(data)
  51. elif action == "start":
  52. await self.start_task(data)
  53. else:
  54. if self.controller is None:
  55. await self.send_error("No agent started. Please wait a second...")
  56. elif action == "chat":
  57. self.controller.add_history(NullAction(), UserMessageObservation(data["message"]))
  58. else:
  59. # TODO: we only need to implement user message for now
  60. # since even Devin does not support having the user taking other
  61. # actions (e.g., edit files) while the agent is running
  62. raise NotImplementedError
  63. except WebSocketDisconnect as e:
  64. self.websocket = None
  65. if self.agent_task:
  66. self.agent_task.cancel()
  67. print("Client websocket disconnected", e)
  68. async def create_controller(self, start_event=None):
  69. directory = DEFAULT_WORKSPACE_DIR
  70. if start_event and "directory" in start_event["args"]:
  71. directory = start_event["args"]["directory"]
  72. agent_cls = "LangchainsAgent"
  73. if start_event and "agent_cls" in start_event["args"]:
  74. agent_cls = start_event["args"]["agent_cls"]
  75. model = LLM_MODEL
  76. if start_event and "model" in start_event["args"]:
  77. model = start_event["args"]["model"]
  78. api_key = DEFAULT_API_KEY
  79. if start_event and "api_key" in start_event["args"]:
  80. api_key = start_event["args"]["api_key"]
  81. api_base = DEFAULT_BASE_URL
  82. if start_event and "api_base" in start_event["args"]:
  83. api_base = start_event["args"]["api_base"]
  84. if not os.path.exists(directory):
  85. print(f"Workspace directory {directory} does not exist. Creating it...")
  86. os.makedirs(directory)
  87. directory = os.path.relpath(directory, os.getcwd())
  88. llm = LLM(model=model, api_key=api_key, base_url=api_base)
  89. AgentCls = Agent.get_cls(agent_cls)
  90. self.agent = AgentCls(llm)
  91. try:
  92. self.controller = AgentController(self.agent, workdir=directory, callbacks=[self.on_agent_event])
  93. except Exception:
  94. print("Error creating controller.")
  95. await self.send_error("Error creating controller. Please check Docker is running using `docker ps`.")
  96. return
  97. await self.send({"action": "initialize", "message": "Control loop started."})
  98. async def start_task(self, start_event):
  99. if "task" not in start_event["args"]:
  100. await self.send_error("No task specified")
  101. return
  102. await self.send_message("Starting new task...")
  103. task = start_event["args"]["task"]
  104. if self.controller is None:
  105. await self.send_error("No agent started. Please wait a second...")
  106. return
  107. self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")
  108. def on_agent_event(self, event: Observation | Action):
  109. if isinstance(event, NullAction):
  110. return
  111. if isinstance(event, NullObservation):
  112. return
  113. event_dict = event.to_dict()
  114. asyncio.create_task(self.send(event_dict), name="send event in callback")