jupyter_kernel.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import re
  2. import os
  3. import tornado
  4. import asyncio
  5. from tornado.escape import json_encode, json_decode, url_escape
  6. from tornado.websocket import websocket_connect, WebSocketClientConnection
  7. from tornado.ioloop import PeriodicCallback
  8. from tornado.httpclient import AsyncHTTPClient, HTTPRequest
  9. from opendevin.logger import opendevin_logger as logger
  10. from uuid import uuid4
  11. def strip_ansi(o: str) -> str:
  12. """
  13. Removes ANSI escape sequences from `o`, as defined by ECMA-048 in
  14. http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf
  15. # https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py
  16. >>> strip_ansi("\\033[33mLorem ipsum\\033[0m")
  17. 'Lorem ipsum'
  18. >>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.")
  19. 'Lorem Ipsum sit\\namet.'
  20. >>> strip_ansi("")
  21. ''
  22. >>> strip_ansi("\\x1b[0m")
  23. ''
  24. >>> strip_ansi("Lorem")
  25. 'Lorem'
  26. >>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m')
  27. 'Lorem ipsum'
  28. >>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m')
  29. 'Lorem dolor sit ipsum'
  30. """
  31. # pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/')
  32. pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m')
  33. stripped = pattern.sub('', o)
  34. return stripped
  35. class JupyterKernel:
  36. def __init__(self, url_suffix, convid, lang='python'):
  37. self.base_url = f'http://{url_suffix}'
  38. self.base_ws_url = f'ws://{url_suffix}'
  39. self.lang = lang
  40. self.kernel_id = None
  41. self.convid = convid
  42. logger.info(
  43. f'Jupyter kernel created for conversation {convid} at {url_suffix}'
  44. )
  45. self.heartbeat_interval = 10000 # 10 seconds
  46. self.heartbeat_callback = None
  47. async def initialize(self):
  48. await self.execute(r'%colors nocolor')
  49. # pre-defined tools
  50. # self.tools_to_run = [
  51. # # TODO: You can add code for your pre-defined tools here
  52. # ]
  53. # for tool in self.tools_to_run:
  54. # # logger.info(f"Tool initialized:\n{tool}")
  55. # await self.execute(tool)
  56. async def _send_heartbeat(self):
  57. if not hasattr(self, 'ws') or not self.ws:
  58. return
  59. try:
  60. self.ws.ping()
  61. logger.debug('Heartbeat sent...')
  62. except tornado.iostream.StreamClosedError:
  63. logger.info('Heartbeat failed, reconnecting...')
  64. try:
  65. await self._connect()
  66. except ConnectionRefusedError:
  67. logger.info(
  68. 'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?'
  69. )
  70. async def _connect(self):
  71. if hasattr(self, 'ws') and self.ws:
  72. self.ws.close()
  73. self.ws: WebSocketClientConnection = None
  74. client = AsyncHTTPClient()
  75. if not self.kernel_id:
  76. n_tries = 5
  77. while n_tries > 0:
  78. try:
  79. response = await client.fetch(
  80. '{}/api/kernels'.format(self.base_url),
  81. method='POST',
  82. body=json_encode({'name': self.lang}),
  83. )
  84. kernel = json_decode(response.body)
  85. self.kernel_id = kernel['id']
  86. break
  87. except Exception as e:
  88. logger.error('Failed to connect to kernel')
  89. logger.exception(e)
  90. logger.info('Retrying in 5 seconds...')
  91. # kernels are not ready yet
  92. n_tries -= 1
  93. await asyncio.sleep(5)
  94. if n_tries == 0:
  95. raise ConnectionRefusedError('Failed to connect to kernel')
  96. ws_req = HTTPRequest(
  97. url='{}/api/kernels/{}/channels'.format(
  98. self.base_ws_url, url_escape(self.kernel_id)
  99. )
  100. )
  101. self.ws = await websocket_connect(ws_req)
  102. logger.info('Connected to kernel websocket')
  103. # Setup heartbeat
  104. if self.heartbeat_callback:
  105. self.heartbeat_callback.stop()
  106. self.heartbeat_callback = PeriodicCallback(
  107. self._send_heartbeat, self.heartbeat_interval
  108. )
  109. self.heartbeat_callback.start()
  110. async def execute(self, code, timeout=60):
  111. if not hasattr(self, 'ws') or not self.ws:
  112. await self._connect()
  113. msg_id = uuid4().hex
  114. self.ws.write_message(
  115. json_encode(
  116. {
  117. 'header': {
  118. 'username': '',
  119. 'version': '5.0',
  120. 'session': '',
  121. 'msg_id': msg_id,
  122. 'msg_type': 'execute_request',
  123. },
  124. 'parent_header': {},
  125. 'channel': 'shell',
  126. 'content': {
  127. 'code': code,
  128. 'silent': False,
  129. 'store_history': False,
  130. 'user_expressions': {},
  131. 'allow_stdin': False,
  132. },
  133. 'metadata': {},
  134. 'buffers': {},
  135. }
  136. )
  137. )
  138. logger.info(f'EXECUTE REQUEST SENT:\n{code}')
  139. outputs = []
  140. async def wait_for_messages():
  141. execution_done = False
  142. while not execution_done:
  143. msg = await self.ws.read_message()
  144. msg = json_decode(msg)
  145. msg_type = msg['msg_type']
  146. parent_msg_id = msg['parent_header'].get('msg_id', None)
  147. if parent_msg_id != msg_id:
  148. continue
  149. if os.environ.get('DEBUG', False):
  150. logger.info(
  151. f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}"
  152. )
  153. if msg_type == 'error':
  154. traceback = '\n'.join(msg['content']['traceback'])
  155. outputs.append(traceback)
  156. execution_done = True
  157. elif msg_type == 'stream':
  158. outputs.append(msg['content']['text'])
  159. elif msg_type in ['execute_result', 'display_data']:
  160. outputs.append(msg['content']['data']['text/plain'])
  161. if 'image/png' in msg['content']['data']:
  162. # use markdone to display image (in case of large image)
  163. # outputs.append(f"\n<img src=\"data:image/png;base64,{msg['content']['data']['image/png']}\"/>\n")
  164. outputs.append(
  165. f"![image](data:image/png;base64,{msg['content']['data']['image/png']})"
  166. )
  167. elif msg_type == 'execute_reply':
  168. execution_done = True
  169. return execution_done
  170. async def interrupt_kernel():
  171. client = AsyncHTTPClient()
  172. interrupt_response = await client.fetch(
  173. f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt',
  174. method='POST',
  175. body=json_encode({'kernel_id': self.kernel_id}),
  176. )
  177. logger.info(f'Kernel interrupted: {interrupt_response}')
  178. try:
  179. execution_done = await asyncio.wait_for(wait_for_messages(), timeout)
  180. except asyncio.TimeoutError:
  181. await interrupt_kernel()
  182. return f'[Execution timed out ({timeout} seconds).]'
  183. if not outputs and execution_done:
  184. ret = '[Code executed successfully with no output]'
  185. else:
  186. ret = ''.join(outputs)
  187. # Remove ANSI
  188. ret = strip_ansi(ret)
  189. if os.environ.get('DEBUG', False):
  190. logger.info(f'OUTPUT:\n{ret}')
  191. return ret
  192. async def shutdown_async(self):
  193. if self.kernel_id:
  194. client = AsyncHTTPClient()
  195. await client.fetch(
  196. '{}/api/kernels/{}'.format(self.base_url, self.kernel_id),
  197. method='DELETE',
  198. )
  199. self.kernel_id = None
  200. if self.ws:
  201. self.ws.close()
  202. self.ws = None