session.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import os
  2. import asyncio
  3. from typing import Optional, Dict, Type
  4. from fastapi import WebSocketDisconnect
  5. from opendevin.action import (
  6. Action,
  7. NullAction,
  8. CmdRunAction,
  9. CmdKillAction,
  10. BrowseURLAction,
  11. FileReadAction,
  12. FileWriteAction,
  13. AgentRecallAction,
  14. AgentThinkAction,
  15. AgentFinishAction,
  16. )
  17. from opendevin.agent import Agent
  18. from opendevin.controller import AgentController
  19. from opendevin.llm.llm import LLM
  20. from opendevin.observation import (
  21. Observation,
  22. UserMessageObservation
  23. )
  24. # NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase
  25. ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
  26. "run": CmdRunAction,
  27. "kill": CmdKillAction,
  28. "browse": BrowseURLAction,
  29. "read": FileReadAction,
  30. "write": FileWriteAction,
  31. "recall": AgentRecallAction,
  32. "think": AgentThinkAction,
  33. "finish": AgentFinishAction,
  34. }
  35. DEFAULT_WORKSPACE_DIR = os.getenv("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace"))
  36. LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4-0125-preview")
  37. def parse_event(data):
  38. if "action" not in data:
  39. return None
  40. action = data["action"]
  41. args = {}
  42. if "args" in data:
  43. args = data["args"]
  44. message = None
  45. if "message" in data:
  46. message = data["message"]
  47. return {
  48. "action": action,
  49. "args": args,
  50. "message": message,
  51. }
  52. class Session:
  53. def __init__(self, websocket):
  54. self.websocket = websocket
  55. self.controller: Optional[AgentController] = None
  56. self.agent: Optional[Agent] = None
  57. self.agent_task = None
  58. asyncio.create_task(self.create_controller(), name="create controller") # FIXME: starting the docker container synchronously causes a websocket error...
  59. async def send_error(self, message):
  60. await self.send({"error": True, "message": message})
  61. async def send_message(self, message):
  62. await self.send({"message": message})
  63. async def send(self, data):
  64. if self.websocket is None:
  65. return
  66. try:
  67. await self.websocket.send_json(data)
  68. except Exception as e:
  69. print("Error sending data to client", e)
  70. async def start_listening(self):
  71. try:
  72. while True:
  73. try:
  74. data = await self.websocket.receive_json()
  75. except ValueError:
  76. await self.send_error("Invalid JSON")
  77. continue
  78. event = parse_event(data)
  79. if event is None:
  80. await self.send_error("Invalid event")
  81. continue
  82. if event["action"] == "initialize":
  83. await self.create_controller(event)
  84. elif event["action"] == "start":
  85. await self.start_task(event)
  86. else:
  87. if self.controller is None:
  88. await self.send_error("No agent started. Please wait a second...")
  89. elif event["action"] == "chat":
  90. self.controller.add_history(NullAction(), UserMessageObservation(event["message"]))
  91. else:
  92. # TODO: we only need to implement user message for now
  93. # since even Devin does not support having the user taking other
  94. # actions (e.g., edit files) while the agent is running
  95. raise NotImplementedError
  96. except WebSocketDisconnect as e:
  97. self.websocket = None
  98. if self.agent_task:
  99. self.agent_task.cancel()
  100. print("Client websocket disconnected", e)
  101. async def create_controller(self, start_event=None):
  102. directory = DEFAULT_WORKSPACE_DIR
  103. if start_event and "directory" in start_event.args:
  104. directory = start_event.args["directory"]
  105. agent_cls = "LangchainsAgent"
  106. if start_event and "agent_cls" in start_event.args:
  107. agent_cls = start_event.args["agent_cls"]
  108. model = LLM_MODEL
  109. if start_event and "model" in start_event.args:
  110. model = start_event.args["model"]
  111. if not os.path.exists(directory):
  112. print(f"Workspace directory {directory} does not exist. Creating it...")
  113. os.makedirs(directory)
  114. directory = os.path.relpath(directory, os.getcwd())
  115. llm = LLM(model)
  116. AgentCls = Agent.get_cls(agent_cls)
  117. self.agent = AgentCls(llm)
  118. self.controller = AgentController(self.agent, workdir=directory, callbacks=[self.on_agent_event])
  119. await self.send({"action": "initialize", "message": "Control loop started."})
  120. async def start_task(self, start_event):
  121. if "task" not in start_event["args"]:
  122. await self.send_error("No task specified")
  123. return
  124. await self.send_message("Starting new task...")
  125. task = start_event["args"]["task"]
  126. if self.controller is None:
  127. await self.send_error("No agent started. Please wait a second...")
  128. return
  129. self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")
  130. def on_agent_event(self, event: Observation | Action):
  131. # FIXME: we need better serialization
  132. event_dict = event.to_dict()
  133. if "action" in event_dict:
  134. if event_dict["action"] == "CmdRunAction":
  135. event_dict["action"] = "run"
  136. elif event_dict["action"] == "CmdKillAction":
  137. event_dict["action"] = "kill"
  138. elif event_dict["action"] == "BrowseURLAction":
  139. event_dict["action"] = "browse"
  140. elif event_dict["action"] == "FileReadAction":
  141. event_dict["action"] = "read"
  142. elif event_dict["action"] == "FileWriteAction":
  143. event_dict["action"] = "write"
  144. elif event_dict["action"] == "AgentFinishAction":
  145. event_dict["action"] = "finish"
  146. elif event_dict["action"] == "AgentRecallAction":
  147. event_dict["action"] = "recall"
  148. elif event_dict["action"] == "AgentThinkAction":
  149. event_dict["action"] = "think"
  150. if "observation" in event_dict:
  151. if event_dict["observation"] == "UserMessageObservation":
  152. event_dict["observation"] = "chat"
  153. elif event_dict["observation"] == "AgentMessageObservation":
  154. event_dict["observation"] = "chat"
  155. elif event_dict["observation"] == "CmdOutputObservation":
  156. event_dict["observation"] = "run"
  157. elif event_dict["observation"] == "FileReadObservation":
  158. event_dict["observation"] = "read"
  159. asyncio.create_task(self.send(event_dict), name="send event in callback")