session.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import asyncio
  2. import os
  3. from typing import Optional
  4. from fastapi import WebSocketDisconnect
  5. from opendevin import config
  6. from opendevin.action import (
  7. Action,
  8. NullAction,
  9. )
  10. from opendevin.observation import NullObservation
  11. from opendevin.agent import Agent
  12. from opendevin.controller import AgentController
  13. from opendevin.llm.llm import LLM
  14. from opendevin.observation import Observation, UserMessageObservation
  15. DEFAULT_API_KEY = config.get("LLM_API_KEY")
  16. DEFAULT_BASE_URL = config.get("LLM_BASE_URL")
  17. DEFAULT_WORKSPACE_DIR = config.get("WORKSPACE_DIR")
  18. LLM_MODEL = config.get("LLM_MODEL")
  19. CONTAINER_IMAGE = config.get("SANDBOX_CONTAINER_IMAGE")
  20. MAX_ITERATIONS = config.get("MAX_ITERATIONS")
  21. class Session:
  22. """Represents a session with an agent.
  23. Attributes:
  24. websocket: The WebSocket connection associated with the session.
  25. controller: The AgentController instance for controlling the agent.
  26. agent: The Agent instance representing the agent.
  27. agent_task: The task representing the agent's execution.
  28. """
  29. def __init__(self, websocket):
  30. """Initializes a new instance of the Session class.
  31. Args:
  32. websocket: The WebSocket connection associated with the session.
  33. """
  34. self.websocket = websocket
  35. self.controller: Optional[AgentController] = None
  36. self.agent: Optional[Agent] = None
  37. self.agent_task = None
  38. async def send_error(self, message):
  39. """Sends an error message to the client.
  40. Args:
  41. message: The error message to send.
  42. """
  43. await self.send({"error": True, "message": message})
  44. async def send_message(self, message):
  45. """Sends a message to the client.
  46. Args:
  47. message: The message to send.
  48. """
  49. await self.send({"message": message})
  50. async def send(self, data):
  51. """Sends data to the client.
  52. Args:
  53. data: The data to send.
  54. """
  55. if self.websocket is None:
  56. return
  57. try:
  58. await self.websocket.send_json(data)
  59. except Exception as e:
  60. print("Error sending data to client", e)
  61. async def start_listening(self):
  62. """Starts listening for messages from the client."""
  63. try:
  64. while True:
  65. try:
  66. data = await self.websocket.receive_json()
  67. except ValueError:
  68. await self.send_error("Invalid JSON")
  69. continue
  70. action = data.get("action", None)
  71. if action is None:
  72. await self.send_error("Invalid event")
  73. continue
  74. if action == "initialize":
  75. await self.create_controller(data)
  76. elif action == "start":
  77. await self.start_task(data)
  78. else:
  79. if self.controller is None:
  80. await self.send_error("No agent started. Please wait a second...")
  81. elif action == "chat":
  82. self.controller.add_history(NullAction(), UserMessageObservation(data["message"]))
  83. else:
  84. await self.send_error("I didn't recognize this action:" + action)
  85. except WebSocketDisconnect as e:
  86. print("Client websocket disconnected", e)
  87. self.disconnect()
  88. async def create_controller(self, start_event=None):
  89. """Creates an AgentController instance.
  90. Args:
  91. start_event: The start event data (optional).
  92. """
  93. directory = DEFAULT_WORKSPACE_DIR
  94. if start_event and "directory" in start_event["args"]:
  95. directory = start_event["args"]["directory"]
  96. agent_cls = "MonologueAgent"
  97. if start_event and "agent_cls" in start_event["args"]:
  98. agent_cls = start_event["args"]["agent_cls"]
  99. model = LLM_MODEL
  100. if start_event and "model" in start_event["args"]:
  101. model = start_event["args"]["model"]
  102. api_key = DEFAULT_API_KEY
  103. if start_event and "api_key" in start_event["args"]:
  104. api_key = start_event["args"]["api_key"]
  105. api_base = DEFAULT_BASE_URL
  106. if start_event and "api_base" in start_event["args"]:
  107. api_base = start_event["args"]["api_base"]
  108. container_image = CONTAINER_IMAGE
  109. if start_event and "container_image" in start_event["args"]:
  110. container_image = start_event["args"]["container_image"]
  111. max_iterations = MAX_ITERATIONS
  112. if start_event and "max_iterations" in start_event["args"]:
  113. max_iterations = start_event["args"]["max_iterations"]
  114. if not os.path.exists(directory):
  115. print(f"Workspace directory {directory} does not exist. Creating it...")
  116. os.makedirs(directory)
  117. directory = os.path.relpath(directory, os.getcwd())
  118. llm = LLM(model=model, api_key=api_key, base_url=api_base)
  119. AgentCls = Agent.get_cls(agent_cls)
  120. self.agent = AgentCls(llm)
  121. try:
  122. self.controller = AgentController(self.agent, workdir=directory, max_iterations=max_iterations, container_image=container_image, callbacks=[self.on_agent_event])
  123. except Exception:
  124. print("Error creating controller.")
  125. await self.send_error("Error creating controller. Please check Docker is running using `docker ps`.")
  126. return
  127. await self.send({"action": "initialize", "message": "Control loop started."})
  128. async def start_task(self, start_event):
  129. """Starts a task for the agent.
  130. Args:
  131. start_event: The start event data.
  132. """
  133. if "task" not in start_event["args"]:
  134. await self.send_error("No task specified")
  135. return
  136. await self.send_message("Starting new task...")
  137. task = start_event["args"]["task"]
  138. if self.controller is None:
  139. await self.send_error("No agent started. Please wait a second...")
  140. return
  141. try:
  142. self.agent_task = await asyncio.create_task(self.controller.start_loop(task), name="agent loop")
  143. except Exception:
  144. await self.send_error("Error during task loop.")
  145. def on_agent_event(self, event: Observation | Action):
  146. """Callback function for agent events.
  147. Args:
  148. event: The agent event (Observation or Action).
  149. """
  150. if isinstance(event, NullAction):
  151. return
  152. if isinstance(event, NullObservation):
  153. return
  154. event_dict = event.to_dict()
  155. asyncio.create_task(self.send(event_dict), name="send event in callback")
  156. def disconnect(self):
  157. self.websocket = None
  158. if self.agent_task:
  159. self.agent_task.cancel()
  160. if self.controller is not None:
  161. self.controller.command_manager.shell.close()