server.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import json
  2. import os
  3. from time import sleep
  4. import docker
  5. import websockets
  6. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  7. from starlette.websockets import WebSocketState
  8. app = FastAPI()
  9. CONTAINER_NAME = "devin-agent"
  10. AGENT_LISTEN_PORT = 8080
  11. AGENT_BIND_PORT = os.environ.get("AGENT_PORT", 4522)
  12. MAX_WAIT_TIME_SECONDS = 30
  13. agent_listener = None
  14. client_fast_websocket = None
  15. agent_websocket = None
  16. def get_message_payload(message):
  17. return {"source": "server", "message": message}
  18. def get_error_payload(message):
  19. payload = get_message_payload(message)
  20. payload["error"] = True
  21. return payload
  22. # This endpoint recieves events from the client (i.e. the browser)
  23. @app.websocket("/ws")
  24. async def websocket_endpoint(websocket: WebSocket):
  25. global client_fast_websocket
  26. global agent_websocket
  27. await websocket.accept()
  28. client_fast_websocket = websocket
  29. try:
  30. while True:
  31. data = await websocket.receive_json()
  32. if "action" not in data:
  33. await send_message_to_client(get_error_payload("No action specified"))
  34. continue
  35. action = data["action"]
  36. if action == "start":
  37. await send_message_to_client(get_message_payload("Starting new agent..."))
  38. directory = os.getcwd()
  39. if "directory" in data:
  40. directory = data["directory"]
  41. try:
  42. await restart_docker_container(directory)
  43. except Exception as e:
  44. print("error while restarting docker container:", e)
  45. await send_message_to_client(get_error_payload("Failed to start container: " + str(e)))
  46. continue
  47. if action == "terminal":
  48. msg = {
  49. "action": "terminal",
  50. "data": data["data"]
  51. }
  52. await send_message_to_client(get_message_payload(msg))
  53. else:
  54. if agent_websocket is None:
  55. await send_message_to_client(get_error_payload("Agent not connected"))
  56. continue
  57. except WebSocketDisconnect:
  58. print("Client websocket disconnected")
  59. await close_all_websockets(get_error_payload("Client disconnected"))
  60. async def stop_docker_container():
  61. docker_client = docker.from_env()
  62. try:
  63. container = docker_client.containers.get(CONTAINER_NAME)
  64. container.stop()
  65. container.remove()
  66. elapsed = 0
  67. while container.status != "exited":
  68. print("waiting for container to stop...")
  69. sleep(1)
  70. elapsed += 1
  71. if elapsed > MAX_WAIT_TIME_SECONDS:
  72. break
  73. container = docker_client.containers.get(CONTAINER_NAME)
  74. except docker.errors.NotFound:
  75. pass
  76. async def restart_docker_container(directory):
  77. await stop_docker_container()
  78. docker_client = docker.from_env()
  79. container = docker_client.containers.run(
  80. "jmalloc/echo-server",
  81. name=CONTAINER_NAME,
  82. detach=True,
  83. ports={str(AGENT_LISTEN_PORT) + "/tcp": AGENT_BIND_PORT},
  84. volumes={directory: {"bind": "/workspace", "mode": "rw"}})
  85. # wait for container to be ready
  86. elapsed = 0
  87. while container.status != "running":
  88. if container.status == "exited":
  89. print("container exited")
  90. print("container logs:")
  91. print(container.logs())
  92. break
  93. print("waiting for container to start...")
  94. sleep(1)
  95. elapsed += 1
  96. container = docker_client.containers.get(CONTAINER_NAME)
  97. if elapsed > MAX_WAIT_TIME_SECONDS:
  98. break
  99. if container.status != "running":
  100. raise Exception("Failed to start container")
  101. async def listen_for_agent_messages():
  102. global agent_websocket
  103. global client_fast_websocket
  104. try:
  105. async with websockets.connect("ws://localhost:" + str(AGENT_BIND_PORT)) as ws:
  106. agent_websocket = ws
  107. await send_message_to_client(get_message_payload("Agent connected!"))
  108. await send_message_to_agent({"source": "server", "message": "Hello, agent!"})
  109. try:
  110. async for message in agent_websocket:
  111. if client_fast_websocket is None:
  112. print("Client websocket not connected")
  113. await close_all_websockets(get_error_payload("Client not connected"))
  114. break
  115. try:
  116. data = json.loads(message)
  117. except Exception as e:
  118. print("error parsing message from agent:", message)
  119. print(e)
  120. continue
  121. if "source" not in data or data["source"] != "agent":
  122. # TODO: remove this once we're not using echo server
  123. print("echo server responded", data)
  124. continue
  125. await send_message_to_agent(data)
  126. except websockets.exceptions.ConnectionClosed:
  127. await send_message_to_client(get_error_payload("Agent disconnected"))
  128. except Exception as e:
  129. print("error connecting to agent:", e)
  130. payload = get_error_payload("Failed to connect to agent: " + str(e))
  131. await send_message_to_client(payload)
  132. await close_agent_websocket(payload)
  133. async def send_message_to_client(data):
  134. print("to client:", data)
  135. if client_fast_websocket is None:
  136. return
  137. await client_fast_websocket.send_json(data)
  138. async def send_message_to_agent(data):
  139. print("to agent:", data)
  140. if agent_websocket is None:
  141. return
  142. await agent_websocket.send(json.dumps(data))
  143. async def close_agent_websocket(payload):
  144. global agent_websocket
  145. if agent_websocket is not None:
  146. if not agent_websocket.closed:
  147. await send_message_to_agent(payload)
  148. await agent_websocket.close()
  149. agent_websocket = None
  150. await stop_docker_container()
  151. async def close_client_websocket(payload):
  152. global client_fast_websocket
  153. if client_fast_websocket is not None:
  154. if client_fast_websocket.client_state != WebSocketState.DISCONNECTED:
  155. await send_message_to_client(payload)
  156. await client_fast_websocket.close()
  157. client_fast_websocket = None
  158. async def close_all_websockets(payload):
  159. await close_agent_websocket(payload)
  160. await close_client_websocket(payload)