execute_server.py 9.1 KB


  1. #!/usr/bin/env python3
  2. import asyncio
  3. import logging
  4. import os
  5. import re
  6. from uuid import uuid4
  7. import tornado
  8. from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
  9. from tornado.escape import json_decode, json_encode, url_escape
  10. from tornado.httpclient import AsyncHTTPClient, HTTPRequest
  11. from tornado.ioloop import PeriodicCallback
  12. from tornado.websocket import websocket_connect
  13. logging.basicConfig(level=logging.INFO)
  14. def strip_ansi(o: str) -> str:
  15. """Removes ANSI escape sequences from `o`, as defined by ECMA-048 in
  16. http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf
  17. # https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py
  18. >>> strip_ansi("\\033[33mLorem ipsum\\033[0m")
  19. 'Lorem ipsum'
  20. >>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.")
  21. 'Lorem Ipsum sit\\namet.'
  22. >>> strip_ansi("")
  23. ''
  24. >>> strip_ansi("\\x1b[0m")
  25. ''
  26. >>> strip_ansi("Lorem")
  27. 'Lorem'
  28. >>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m')
  29. 'Lorem ipsum'
  30. >>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m')
  31. 'Lorem dolor sit ipsum'
  32. """
  33. # pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/')
  34. pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m')
  35. stripped = pattern.sub('', o)
  36. return stripped
  37. class JupyterKernel:
  38. def __init__(self, url_suffix, convid, lang='python'):
  39. self.base_url = f'http://{url_suffix}'
  40. self.base_ws_url = f'ws://{url_suffix}'
  41. self.lang = lang
  42. self.kernel_id = None
  43. self.ws = None
  44. self.convid = convid
  45. logging.info(
  46. f'Jupyter kernel created for conversation {convid} at {url_suffix}'
  47. )
  48. self.heartbeat_interval = 10000 # 10 seconds
  49. self.heartbeat_callback = None
  50. self.initialized = False
  51. async def initialize(self):
  52. await self.execute(r'%colors nocolor')
  53. # pre-defined tools
  54. self.tools_to_run: list[str] = [
  55. # TODO: You can add code for your pre-defined tools here
  56. ]
  57. for tool in self.tools_to_run:
  58. res = await self.execute(tool)
  59. logging.info(f'Tool [{tool}] initialized:\n{res}')
  60. self.initialized = True
  61. async def _send_heartbeat(self):
  62. if not self.ws:
  63. return
  64. try:
  65. self.ws.ping()
  66. # logging.info('Heartbeat sent...')
  67. except tornado.iostream.StreamClosedError:
  68. # logging.info('Heartbeat failed, reconnecting...')
  69. try:
  70. await self._connect()
  71. except ConnectionRefusedError:
  72. logging.info(
  73. 'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?'
  74. )
  75. async def _connect(self):
  76. if self.ws:
  77. self.ws.close()
  78. self.ws = None
  79. client = AsyncHTTPClient()
  80. if not self.kernel_id:
  81. n_tries = 5
  82. while n_tries > 0:
  83. try:
  84. response = await client.fetch(
  85. '{}/api/kernels'.format(self.base_url),
  86. method='POST',
  87. body=json_encode({'name': self.lang}),
  88. )
  89. kernel = json_decode(response.body)
  90. self.kernel_id = kernel['id']
  91. break
  92. except Exception:
  93. # kernels are not ready yet
  94. n_tries -= 1
  95. await asyncio.sleep(1)
  96. if n_tries == 0:
  97. raise ConnectionRefusedError('Failed to connect to kernel')
  98. ws_req = HTTPRequest(
  99. url='{}/api/kernels/{}/channels'.format(
  100. self.base_ws_url, url_escape(self.kernel_id)
  101. )
  102. )
  103. self.ws = await websocket_connect(ws_req)
  104. logging.info('Connected to kernel websocket')
  105. # Setup heartbeat
  106. if self.heartbeat_callback:
  107. self.heartbeat_callback.stop()
  108. self.heartbeat_callback = PeriodicCallback(
  109. self._send_heartbeat, self.heartbeat_interval
  110. )
  111. self.heartbeat_callback.start()
  112. @retry(
  113. retry=retry_if_exception_type(ConnectionRefusedError),
  114. stop=stop_after_attempt(3),
  115. wait=wait_fixed(2),
  116. )
  117. async def execute(self, code, timeout=120):
  118. if not self.ws:
  119. await self._connect()
  120. msg_id = uuid4().hex
  121. assert self.ws is not None
  122. res = await self.ws.write_message(
  123. json_encode(
  124. {
  125. 'header': {
  126. 'username': '',
  127. 'version': '5.0',
  128. 'session': '',
  129. 'msg_id': msg_id,
  130. 'msg_type': 'execute_request',
  131. },
  132. 'parent_header': {},
  133. 'channel': 'shell',
  134. 'content': {
  135. 'code': code,
  136. 'silent': False,
  137. 'store_history': False,
  138. 'user_expressions': {},
  139. 'allow_stdin': False,
  140. },
  141. 'metadata': {},
  142. 'buffers': {},
  143. }
  144. )
  145. )
  146. logging.info(f'Executed code in jupyter kernel:\n{res}')
  147. outputs = []
  148. async def wait_for_messages():
  149. execution_done = False
  150. while not execution_done:
  151. assert self.ws is not None
  152. msg = await self.ws.read_message()
  153. msg = json_decode(msg)
  154. msg_type = msg['msg_type']
  155. parent_msg_id = msg['parent_header'].get('msg_id', None)
  156. if parent_msg_id != msg_id:
  157. continue
  158. if os.environ.get('DEBUG'):
  159. logging.info(
  160. f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}"
  161. )
  162. if msg_type == 'error':
  163. traceback = '\n'.join(msg['content']['traceback'])
  164. outputs.append(traceback)
  165. execution_done = True
  166. elif msg_type == 'stream':
  167. outputs.append(msg['content']['text'])
  168. elif msg_type in ['execute_result', 'display_data']:
  169. outputs.append(msg['content']['data']['text/plain'])
  170. if 'image/png' in msg['content']['data']:
  171. # use markdone to display image (in case of large image)
  172. outputs.append(
  173. f"\n![image](data:image/png;base64,{msg['content']['data']['image/png']})\n"
  174. )
  175. elif msg_type == 'execute_reply':
  176. execution_done = True
  177. return execution_done
  178. async def interrupt_kernel():
  179. client = AsyncHTTPClient()
  180. interrupt_response = await client.fetch(
  181. f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt',
  182. method='POST',
  183. body=json_encode({'kernel_id': self.kernel_id}),
  184. )
  185. logging.info(f'Kernel interrupted: {interrupt_response}')
  186. try:
  187. execution_done = await asyncio.wait_for(wait_for_messages(), timeout)
  188. except asyncio.TimeoutError:
  189. await interrupt_kernel()
  190. return f'[Execution timed out ({timeout} seconds).]'
  191. if not outputs and execution_done:
  192. ret = '[Code executed successfully with no output]'
  193. else:
  194. ret = ''.join(outputs)
  195. # Remove ANSI
  196. ret = strip_ansi(ret)
  197. if os.environ.get('DEBUG'):
  198. logging.info(f'OUTPUT:\n{ret}')
  199. return ret
  200. async def shutdown_async(self):
  201. if self.kernel_id:
  202. client = AsyncHTTPClient()
  203. await client.fetch(
  204. '{}/api/kernels/{}'.format(self.base_url, self.kernel_id),
  205. method='DELETE',
  206. )
  207. self.kernel_id = None
  208. if self.ws:
  209. self.ws.close()
  210. self.ws = None
  211. class ExecuteHandler(tornado.web.RequestHandler):
  212. def initialize(self, jupyter_kernel):
  213. self.jupyter_kernel = jupyter_kernel
  214. async def post(self):
  215. data = json_decode(self.request.body)
  216. code = data.get('code')
  217. if not code:
  218. self.set_status(400)
  219. self.write('Missing code')
  220. return
  221. output = await self.jupyter_kernel.execute(code)
  222. self.write(output)
  223. def make_app():
  224. jupyter_kernel = JupyterKernel(
  225. f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT')}",
  226. os.environ.get('JUPYTER_GATEWAY_KERNEL_ID'),
  227. )
  228. asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize())
  229. return tornado.web.Application(
  230. [
  231. (r'/execute', ExecuteHandler, {'jupyter_kernel': jupyter_kernel}),
  232. ]
  233. )
  234. if __name__ == '__main__':
  235. app = make_app()
  236. app.listen(os.environ.get('JUPYTER_EXEC_SERVER_PORT'))
  237. tornado.ioloop.IOLoop.current().start()