execute_server.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #!/usr/bin/env python3
  2. import asyncio
  3. import logging
  4. import os
  5. import re
  6. from uuid import uuid4
  7. import tornado
  8. from tornado.escape import json_decode, json_encode, url_escape
  9. from tornado.httpclient import AsyncHTTPClient, HTTPRequest
  10. from tornado.ioloop import PeriodicCallback
  11. from tornado.websocket import websocket_connect
  12. logging.basicConfig(level=logging.INFO)
  13. def strip_ansi(o: str) -> str:
  14. """Removes ANSI escape sequences from `o`, as defined by ECMA-048 in
  15. http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf
  16. # https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py
  17. >>> strip_ansi("\\033[33mLorem ipsum\\033[0m")
  18. 'Lorem ipsum'
  19. >>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.")
  20. 'Lorem Ipsum sit\\namet.'
  21. >>> strip_ansi("")
  22. ''
  23. >>> strip_ansi("\\x1b[0m")
  24. ''
  25. >>> strip_ansi("Lorem")
  26. 'Lorem'
  27. >>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m')
  28. 'Lorem ipsum'
  29. >>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m')
  30. 'Lorem dolor sit ipsum'
  31. """
  32. # pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/')
  33. pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m')
  34. stripped = pattern.sub('', o)
  35. return stripped
  36. class JupyterKernel:
  37. def __init__(self, url_suffix, convid, lang='python'):
  38. self.base_url = f'http://{url_suffix}'
  39. self.base_ws_url = f'ws://{url_suffix}'
  40. self.lang = lang
  41. self.kernel_id = None
  42. self.ws = None
  43. self.convid = convid
  44. logging.info(
  45. f'Jupyter kernel created for conversation {convid} at {url_suffix}'
  46. )
  47. self.heartbeat_interval = 10000 # 10 seconds
  48. self.heartbeat_callback = None
  49. self.initialized = False
  50. async def initialize(self):
  51. await self.execute(r'%colors nocolor')
  52. # pre-defined tools
  53. self.tools_to_run = [
  54. # TODO: You can add code for your pre-defined tools here
  55. ]
  56. if os.path.exists('/opendevin/plugins/agent_skills/agentskills.py'):
  57. self.tools_to_run.append('from agentskills import *')
  58. for tool in self.tools_to_run:
  59. # logging.info(f'Tool initialized:\n{tool}')
  60. await self.execute(tool)
  61. self.initialized = True
  62. async def _send_heartbeat(self):
  63. if not self.ws:
  64. return
  65. try:
  66. self.ws.ping()
  67. # logging.info('Heartbeat sent...')
  68. except tornado.iostream.StreamClosedError:
  69. # logging.info('Heartbeat failed, reconnecting...')
  70. try:
  71. await self._connect()
  72. except ConnectionRefusedError:
  73. logging.info(
  74. 'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?'
  75. )
  76. async def _connect(self):
  77. if self.ws:
  78. self.ws.close()
  79. self.ws = None
  80. client = AsyncHTTPClient()
  81. if not self.kernel_id:
  82. n_tries = 5
  83. while n_tries > 0:
  84. try:
  85. response = await client.fetch(
  86. '{}/api/kernels'.format(self.base_url),
  87. method='POST',
  88. body=json_encode({'name': self.lang}),
  89. )
  90. kernel = json_decode(response.body)
  91. self.kernel_id = kernel['id']
  92. break
  93. except Exception:
  94. # kernels are not ready yet
  95. n_tries -= 1
  96. await asyncio.sleep(1)
  97. if n_tries == 0:
  98. raise ConnectionRefusedError('Failed to connect to kernel')
  99. ws_req = HTTPRequest(
  100. url='{}/api/kernels/{}/channels'.format(
  101. self.base_ws_url, url_escape(self.kernel_id)
  102. )
  103. )
  104. self.ws = await websocket_connect(ws_req)
  105. logging.info('Connected to kernel websocket')
  106. # Setup heartbeat
  107. if self.heartbeat_callback:
  108. self.heartbeat_callback.stop()
  109. self.heartbeat_callback = PeriodicCallback(
  110. self._send_heartbeat, self.heartbeat_interval
  111. )
  112. self.heartbeat_callback.start()
  113. async def execute(self, code, timeout=120):
  114. if not self.ws:
  115. await self._connect()
  116. msg_id = uuid4().hex
  117. assert self.ws is not None
  118. self.ws.write_message(
  119. json_encode(
  120. {
  121. 'header': {
  122. 'username': '',
  123. 'version': '5.0',
  124. 'session': '',
  125. 'msg_id': msg_id,
  126. 'msg_type': 'execute_request',
  127. },
  128. 'parent_header': {},
  129. 'channel': 'shell',
  130. 'content': {
  131. 'code': code,
  132. 'silent': False,
  133. 'store_history': False,
  134. 'user_expressions': {},
  135. 'allow_stdin': False,
  136. },
  137. 'metadata': {},
  138. 'buffers': {},
  139. }
  140. )
  141. )
  142. outputs = []
  143. async def wait_for_messages():
  144. execution_done = False
  145. while not execution_done:
  146. assert self.ws is not None
  147. msg = await self.ws.read_message()
  148. msg = json_decode(msg)
  149. msg_type = msg['msg_type']
  150. parent_msg_id = msg['parent_header'].get('msg_id', None)
  151. if parent_msg_id != msg_id:
  152. continue
  153. if os.environ.get('DEBUG'):
  154. logging.info(
  155. f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}"
  156. )
  157. if msg_type == 'error':
  158. traceback = '\n'.join(msg['content']['traceback'])
  159. outputs.append(traceback)
  160. execution_done = True
  161. elif msg_type == 'stream':
  162. outputs.append(msg['content']['text'])
  163. elif msg_type in ['execute_result', 'display_data']:
  164. outputs.append(msg['content']['data']['text/plain'])
  165. if 'image/png' in msg['content']['data']:
  166. # use markdone to display image (in case of large image)
  167. outputs.append(
  168. f"\n![image](data:image/png;base64,{msg['content']['data']['image/png']})\n"
  169. )
  170. elif msg_type == 'execute_reply':
  171. execution_done = True
  172. return execution_done
  173. async def interrupt_kernel():
  174. client = AsyncHTTPClient()
  175. interrupt_response = await client.fetch(
  176. f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt',
  177. method='POST',
  178. body=json_encode({'kernel_id': self.kernel_id}),
  179. )
  180. logging.info(f'Kernel interrupted: {interrupt_response}')
  181. try:
  182. execution_done = await asyncio.wait_for(wait_for_messages(), timeout)
  183. except asyncio.TimeoutError:
  184. await interrupt_kernel()
  185. return f'[Execution timed out ({timeout} seconds).]'
  186. if not outputs and execution_done:
  187. ret = '[Code executed successfully with no output]'
  188. else:
  189. ret = ''.join(outputs)
  190. # Remove ANSI
  191. ret = strip_ansi(ret)
  192. if os.environ.get('DEBUG'):
  193. logging.info(f'OUTPUT:\n{ret}')
  194. return ret
  195. async def shutdown_async(self):
  196. if self.kernel_id:
  197. client = AsyncHTTPClient()
  198. await client.fetch(
  199. '{}/api/kernels/{}'.format(self.base_url, self.kernel_id),
  200. method='DELETE',
  201. )
  202. self.kernel_id = None
  203. if self.ws:
  204. self.ws.close()
  205. self.ws = None
  206. class ExecuteHandler(tornado.web.RequestHandler):
  207. def initialize(self, jupyter_kernel):
  208. self.jupyter_kernel = jupyter_kernel
  209. async def post(self):
  210. data = json_decode(self.request.body)
  211. code = data.get('code')
  212. if not code:
  213. self.set_status(400)
  214. self.write('Missing code')
  215. return
  216. output = await self.jupyter_kernel.execute(code)
  217. self.write(output)
  218. def make_app():
  219. jupyter_kernel = JupyterKernel(
  220. f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT')}",
  221. os.environ.get('JUPYTER_GATEWAY_KERNEL_ID'),
  222. )
  223. asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize())
  224. return tornado.web.Application(
  225. [
  226. (r'/execute', ExecuteHandler, {'jupyter_kernel': jupyter_kernel}),
  227. ]
  228. )
  229. if __name__ == '__main__':
  230. app = make_app()
  231. app.listen(os.environ.get('JUPYTER_EXEC_SERVER_PORT'))
  232. tornado.ioloop.IOLoop.current().start()