server.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import asyncio
  2. import json
  3. import os
  4. from time import sleep
  5. import docker
  6. import websockets
  7. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  8. from starlette.websockets import WebSocketState
  9. app = FastAPI()
  10. CONTAINER_NAME = "devin-agent"
  11. AGENT_LISTEN_PORT = 8080
  12. AGENT_BIND_PORT = os.environ.get("AGENT_PORT", 4522)
  13. MAX_WAIT_TIME_SECONDS = 30
  14. agent_listener = None
  15. client_fast_websocket = None
  16. agent_websocket = None
  17. def get_message_payload(message):
  18. return {"source": "server", "message": message}
  19. def get_error_payload(message):
  20. payload = get_message_payload(message)
  21. payload["error"] = True
  22. return payload
  23. # This endpoint recieves events from the client (i.e. the browser)
  24. @app.websocket("/ws")
  25. async def websocket_endpoint(websocket: WebSocket):
  26. global client_fast_websocket
  27. global agent_websocket
  28. await websocket.accept()
  29. client_fast_websocket = websocket
  30. try:
  31. while True:
  32. data = await websocket.receive_json()
  33. if "action" not in data:
  34. await send_message_to_client(get_error_payload("No action specified"))
  35. continue
  36. action = data["action"]
  37. if action == "start":
  38. await send_message_to_client(get_message_payload("Starting new agent..."))
  39. directory = os.getcwd()
  40. if "directory" in data:
  41. directory = data["directory"]
  42. try:
  43. await restart_docker_container(directory)
  44. except Exception as e:
  45. print("error while restarting docker container:", e)
  46. await send_message_to_client(get_error_payload("Failed to start container: " + str(e)))
  47. continue
  48. agent_listener = asyncio.create_task(listen_for_agent_messages())
  49. if action == "terminal":
  50. msg = {
  51. "action": "terminal",
  52. "data": data["data"]
  53. }
  54. await send_message_to_client(get_message_payload(msg))
  55. else:
  56. if agent_websocket is None:
  57. await send_message_to_client(get_error_payload("Agent not connected"))
  58. continue
  59. await send_message_to_agent(data)
  60. except WebSocketDisconnect:
  61. print("Client websocket disconnected")
  62. await close_all_websockets(get_error_payload("Client disconnected"))
  63. async def stop_docker_container():
  64. docker_client = docker.from_env()
  65. try:
  66. container = docker_client.containers.get(CONTAINER_NAME)
  67. container.stop()
  68. container.remove()
  69. elapsed = 0
  70. while container.status != "exited":
  71. print("waiting for container to stop...")
  72. sleep(1)
  73. elapsed += 1
  74. if elapsed > MAX_WAIT_TIME_SECONDS:
  75. break
  76. container = docker_client.containers.get(CONTAINER_NAME)
  77. except docker.errors.NotFound:
  78. pass
  79. async def restart_docker_container(directory):
  80. await stop_docker_container()
  81. docker_client = docker.from_env()
  82. container = docker_client.containers.run(
  83. "jmalloc/echo-server",
  84. name=CONTAINER_NAME,
  85. detach=True,
  86. ports={str(AGENT_LISTEN_PORT) + "/tcp": AGENT_BIND_PORT},
  87. volumes={directory: {"bind": "/workspace", "mode": "rw"}})
  88. # wait for container to be ready
  89. elapsed = 0
  90. while container.status != "running":
  91. if container.status == "exited":
  92. print("container exited")
  93. print("container logs:")
  94. print(container.logs())
  95. break
  96. print("waiting for container to start...")
  97. sleep(1)
  98. elapsed += 1
  99. container = docker_client.containers.get(CONTAINER_NAME)
  100. if elapsed > MAX_WAIT_TIME_SECONDS:
  101. break
  102. if container.status != "running":
  103. raise Exception("Failed to start container")
  104. async def listen_for_agent_messages():
  105. global agent_websocket
  106. global client_fast_websocket
  107. try:
  108. async with websockets.connect("ws://localhost:" + str(AGENT_BIND_PORT)) as ws:
  109. agent_websocket = ws
  110. await send_message_to_client(get_message_payload("Agent connected!"))
  111. await send_message_to_agent({"source": "server", "message": "Hello, agent!"})
  112. try:
  113. async for message in agent_websocket:
  114. if client_fast_websocket is None:
  115. print("Client websocket not connected")
  116. await close_all_websockets(get_error_payload("Client not connected"))
  117. break
  118. try:
  119. data = json.loads(message)
  120. except Exception as e:
  121. print("error parsing message from agent:", message)
  122. print(e)
  123. continue
  124. if "source" not in data or data["source"] != "agent":
  125. # TODO: remove this once we're not using echo server
  126. print("echo server responded", data)
  127. continue
  128. await send_message_to_agent(data)
  129. except websockets.exceptions.ConnectionClosed:
  130. await send_message_to_client(get_error_payload("Agent disconnected"))
  131. except Exception as e:
  132. print("error connecting to agent:", e)
  133. payload = get_error_payload("Failed to connect to agent: " + str(e))
  134. await send_message_to_client(payload)
  135. await close_agent_websocket(payload)
  136. async def send_message_to_client(data):
  137. print("to client:", data)
  138. if client_fast_websocket is None:
  139. return
  140. await client_fast_websocket.send_json(data)
  141. async def send_message_to_agent(data):
  142. print("to agent:", data)
  143. if agent_websocket is None:
  144. return
  145. await agent_websocket.send(json.dumps(data))
  146. async def close_agent_websocket(payload):
  147. global agent_websocket
  148. if agent_websocket is not None:
  149. if not agent_websocket.closed:
  150. await send_message_to_agent(payload)
  151. await agent_websocket.close()
  152. agent_websocket = None
  153. await stop_docker_container()
  154. async def close_client_websocket(payload):
  155. global client_fast_websocket
  156. if client_fast_websocket is not None:
  157. if client_fast_websocket.client_state != WebSocketState.DISCONNECTED:
  158. await send_message_to_client(payload)
  159. await client_fast_websocket.close()
  160. client_fast_websocket = None
  161. async def close_all_websockets(payload):
  162. await close_agent_websocket(payload)
  163. await close_client_websocket(payload)