session.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import asyncio
  2. import os
  3. from typing import Optional
  4. from fastapi import WebSocketDisconnect
  5. from opendevin import config
  6. from opendevin.action import (
  7. Action,
  8. NullAction,
  9. )
  10. from opendevin.observation import NullObservation
  11. from opendevin.agent import Agent
  12. from opendevin.controller import AgentController
  13. from opendevin.llm.llm import LLM
  14. from opendevin.observation import Observation, UserMessageObservation
  15. DEFAULT_WORKSPACE_DIR = config.get_or_default("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace"))
  16. LLM_MODEL = config.get_or_default("LLM_MODEL", "gpt-4-0125-preview")
  17. class Session:
  18. def __init__(self, websocket):
  19. self.websocket = websocket
  20. self.controller: Optional[AgentController] = None
  21. self.agent: Optional[Agent] = None
  22. self.agent_task = None
  23. asyncio.create_task(self.create_controller(), name="create controller") # FIXME: starting the docker container synchronously causes a websocket error...
  24. async def send_error(self, message):
  25. await self.send({"error": True, "message": message})
  26. async def send_message(self, message):
  27. await self.send({"message": message})
  28. async def send(self, data):
  29. if self.websocket is None:
  30. return
  31. try:
  32. await self.websocket.send_json(data)
  33. except Exception as e:
  34. print("Error sending data to client", e)
  35. async def start_listening(self):
  36. try:
  37. while True:
  38. try:
  39. data = await self.websocket.receive_json()
  40. except ValueError:
  41. await self.send_error("Invalid JSON")
  42. continue
  43. action = data.get("action", None)
  44. if action is None:
  45. await self.send_error("Invalid event")
  46. continue
  47. if action == "initialize":
  48. await self.create_controller(data)
  49. elif action == "start":
  50. await self.start_task(data)
  51. else:
  52. if self.controller is None:
  53. await self.send_error("No agent started. Please wait a second...")
  54. elif action == "chat":
  55. self.controller.add_history(NullAction(), UserMessageObservation(data["message"]))
  56. else:
  57. # TODO: we only need to implement user message for now
  58. # since even Devin does not support having the user taking other
  59. # actions (e.g., edit files) while the agent is running
  60. raise NotImplementedError
  61. except WebSocketDisconnect as e:
  62. self.websocket = None
  63. if self.agent_task:
  64. self.agent_task.cancel()
  65. print("Client websocket disconnected", e)
  66. async def create_controller(self, start_event=None):
  67. directory = DEFAULT_WORKSPACE_DIR
  68. if start_event and "directory" in start_event["args"]:
  69. directory = start_event["args"]["directory"]
  70. agent_cls = "LangchainsAgent"
  71. if start_event and "agent_cls" in start_event["args"]:
  72. agent_cls = start_event["args"]["agent_cls"]
  73. model = LLM_MODEL
  74. if start_event and "model" in start_event["args"]:
  75. model = start_event["args"]["model"]
  76. if not os.path.exists(directory):
  77. print(f"Workspace directory {directory} does not exist. Creating it...")
  78. os.makedirs(directory)
  79. directory = os.path.relpath(directory, os.getcwd())
  80. llm = LLM(model)
  81. AgentCls = Agent.get_cls(agent_cls)
  82. self.agent = AgentCls(llm)
  83. try:
  84. self.controller = AgentController(self.agent, workdir=directory, callbacks=[self.on_agent_event])
  85. except Exception:
  86. print("Error creating controller.")
  87. await self.send_error("Error creating controller. Please check Docker is running using `docker ps`.")
  88. return
  89. await self.send({"action": "initialize", "message": "Control loop started."})
  90. async def start_task(self, start_event):
  91. if "task" not in start_event["args"]:
  92. await self.send_error("No task specified")
  93. return
  94. await self.send_message("Starting new task...")
  95. task = start_event["args"]["task"]
  96. if self.controller is None:
  97. await self.send_error("No agent started. Please wait a second...")
  98. return
  99. self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")
  100. def on_agent_event(self, event: Observation | Action):
  101. if isinstance(event, NullAction):
  102. return
  103. if isinstance(event, NullObservation):
  104. return
  105. event_dict = event.to_dict()
  106. asyncio.create_task(self.send(event_dict), name="send event in callback")