Эх сурвалжийг харах

feat(sandbox): Add Jupyter Kernel for Interactive Python Interpreter for Sandbox (#1215)

* add initial version of py interpreter

* fix bug

* fix async issue

* remove debugging print statement

* initialize kernel & update printing

* fix port mapping

* uncomment debug lines

* fix poetry lock

* make jupyter py interpreter into a subclass
Xingyao Wang 1 жил өмнө
parent
commit
492feecb67

+ 6 - 0
containers/sandbox/Dockerfile

@@ -19,3 +19,9 @@ RUN apt-get update && apt-get install -y \
     && rm -rf /var/lib/apt/lists/*
 
 RUN service ssh start
+
+RUN pip install jupyterlab notebook jupyter_kernel_gateway
+# Add common data science utils
+RUN pip install transformers[torch]
+RUN pip install torch --index-url https://download.pytorch.org/whl/cpu
+RUN pip install boilerpy3 pandas datasets sympy scikit-learn matplotlib seaborn

+ 3 - 0
opendevin/sandbox/exec_box.py

@@ -120,6 +120,9 @@ class DockerExecBox(Sandbox):
                 return -1, f'Command: "{cmd}" timed out'
         return exit_code, logs.decode('utf-8')
 
+    def execute_python(self, code: str) -> str:
+        raise NotImplementedError('execute_python is not supported in DockerExecBox')
+
     def execute_in_background(self, cmd: str) -> BackgroundCommand:
         result = self.container.exec_run(
             self.get_exec_cmd(cmd), socket=True, workdir=SANDBOX_WORKSPACE_DIR

+ 238 - 0
opendevin/sandbox/jupyter_kernel.py

@@ -0,0 +1,238 @@
+import re
+import os
+import tornado
+import asyncio
+from tornado.escape import json_encode, json_decode, url_escape
+from tornado.websocket import websocket_connect, WebSocketClientConnection
+from tornado.ioloop import PeriodicCallback
+from tornado.httpclient import AsyncHTTPClient, HTTPRequest
+
+from opendevin.logger import opendevin_logger as logger
+from uuid import uuid4
+
+
+def strip_ansi(o: str) -> str:
+    """
+    Removes ANSI escape sequences from `o`, as defined by ECMA-048 in
+    http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf
+
+    # https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py
+
+    >>> strip_ansi("\\033[33mLorem ipsum\\033[0m")
+    'Lorem ipsum'
+
+    >>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.")
+    'Lorem Ipsum sit\\namet.'
+
+    >>> strip_ansi("")
+    ''
+
+    >>> strip_ansi("\\x1b[0m")
+    ''
+
+    >>> strip_ansi("Lorem")
+    'Lorem'
+
+    >>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m')
+    'Lorem ipsum'
+
+    >>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m')
+    'Lorem dolor sit ipsum'
+    """
+
+    # pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/')
+    pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m')
+    stripped = pattern.sub('', o)
+    return stripped
+
+
+class JupyterKernel:
+    def __init__(self, url_suffix, convid, lang='python'):
+        self.base_url = f'http://{url_suffix}'
+        self.base_ws_url = f'ws://{url_suffix}'
+        self.lang = lang
+        self.kernel_id = None
+        self.convid = convid
+        logger.info(
+            f'Jupyter kernel created for conversation {convid} at {url_suffix}'
+        )
+
+        self.heartbeat_interval = 10000  # 10 seconds
+        self.heartbeat_callback = None
+
+    async def initialize(self):
+        await self.execute(r'%colors nocolor')
+        # pre-defined tools
+        # self.tools_to_run = [
+        #     # TODO: You can add code for your pre-defined tools here
+        # ]
+        # for tool in self.tools_to_run:
+        #     # logger.info(f"Tool initialized:\n{tool}")
+        #     await self.execute(tool)
+
+    async def _send_heartbeat(self):
+        if not hasattr(self, 'ws') or not self.ws:
+            return
+        try:
+            self.ws.ping()
+            logger.debug('Heartbeat sent...')
+        except tornado.iostream.StreamClosedError:
+            logger.info('Heartbeat failed, reconnecting...')
+            try:
+                await self._connect()
+            except ConnectionRefusedError:
+                logger.info(
+                    'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?'
+                )
+
+    async def _connect(self):
+        if hasattr(self, 'ws') and self.ws:
+            self.ws.close()
+            self.ws: WebSocketClientConnection = None
+
+        client = AsyncHTTPClient()
+        if not self.kernel_id:
+            n_tries = 5
+            while n_tries > 0:
+                try:
+                    response = await client.fetch(
+                        '{}/api/kernels'.format(self.base_url),
+                        method='POST',
+                        body=json_encode({'name': self.lang}),
+                    )
+                    kernel = json_decode(response.body)
+                    self.kernel_id = kernel['id']
+                    break
+                except Exception as e:
+                    logger.error('Failed to connect to kernel')
+                    logger.exception(e)
+                    logger.info('Retrying in 5 seconds...')
+                    # kernels are not ready yet
+                    n_tries -= 1
+                    await asyncio.sleep(5)
+
+            if n_tries == 0:
+                raise ConnectionRefusedError('Failed to connect to kernel')
+
+        ws_req = HTTPRequest(
+            url='{}/api/kernels/{}/channels'.format(
+                self.base_ws_url, url_escape(self.kernel_id)
+            )
+        )
+        self.ws = await websocket_connect(ws_req)
+        logger.info('Connected to kernel websocket')
+
+        # Setup heartbeat
+        if self.heartbeat_callback:
+            self.heartbeat_callback.stop()
+        self.heartbeat_callback = PeriodicCallback(
+            self._send_heartbeat, self.heartbeat_interval
+        )
+        self.heartbeat_callback.start()
+
+    async def execute(self, code, timeout=60):
+        if not hasattr(self, 'ws') or not self.ws:
+            await self._connect()
+
+        msg_id = uuid4().hex
+        self.ws.write_message(
+            json_encode(
+                {
+                    'header': {
+                        'username': '',
+                        'version': '5.0',
+                        'session': '',
+                        'msg_id': msg_id,
+                        'msg_type': 'execute_request',
+                    },
+                    'parent_header': {},
+                    'channel': 'shell',
+                    'content': {
+                        'code': code,
+                        'silent': False,
+                        'store_history': False,
+                        'user_expressions': {},
+                        'allow_stdin': False,
+                    },
+                    'metadata': {},
+                    'buffers': {},
+                }
+            )
+        )
+        logger.info(f'EXECUTE REQUEST SENT:\n{code}')
+
+        outputs = []
+
+        async def wait_for_messages():
+            execution_done = False
+            while not execution_done:
+                msg = await self.ws.read_message()
+                msg = json_decode(msg)
+                msg_type = msg['msg_type']
+                parent_msg_id = msg['parent_header'].get('msg_id', None)
+
+                if parent_msg_id != msg_id:
+                    continue
+
+                if os.environ.get('DEBUG', False):
+                    logger.info(
+                        f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}"
+                    )
+
+                if msg_type == 'error':
+                    traceback = '\n'.join(msg['content']['traceback'])
+                    outputs.append(traceback)
+                    execution_done = True
+                elif msg_type == 'stream':
+                    outputs.append(msg['content']['text'])
+                elif msg_type in ['execute_result', 'display_data']:
+                    outputs.append(msg['content']['data']['text/plain'])
+                    if 'image/png' in msg['content']['data']:
+                        # use markdone to display image (in case of large image)
+                        # outputs.append(f"\n<img src=\"data:image/png;base64,{msg['content']['data']['image/png']}\"/>\n")
+                        outputs.append(
+                            f"![image](data:image/png;base64,{msg['content']['data']['image/png']})"
+                        )
+
+                elif msg_type == 'execute_reply':
+                    execution_done = True
+            return execution_done
+
+        async def interrupt_kernel():
+            client = AsyncHTTPClient()
+            interrupt_response = await client.fetch(
+                f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt',
+                method='POST',
+                body=json_encode({'kernel_id': self.kernel_id}),
+            )
+            logger.info(f'Kernel interrupted: {interrupt_response}')
+
+        try:
+            execution_done = await asyncio.wait_for(wait_for_messages(), timeout)
+        except asyncio.TimeoutError:
+            await interrupt_kernel()
+            return f'[Execution timed out ({timeout} seconds).]'
+
+        if not outputs and execution_done:
+            ret = '[Code executed successfully with no output]'
+        else:
+            ret = ''.join(outputs)
+
+        # Remove ANSI
+        ret = strip_ansi(ret)
+
+        if os.environ.get('DEBUG', False):
+            logger.info(f'OUTPUT:\n{ret}')
+        return ret
+
+    async def shutdown_async(self):
+        if self.kernel_id:
+            client = AsyncHTTPClient()
+            await client.fetch(
+                '{}/api/kernels/{}'.format(self.base_url, self.kernel_id),
+                method='DELETE',
+            )
+            self.kernel_id = None
+            if self.ws:
+                self.ws.close()
+                self.ws = None

+ 3 - 0
opendevin/sandbox/local_box.py

@@ -37,6 +37,9 @@ class LocalBox(Sandbox):
         except subprocess.TimeoutExpired:
             return -1, 'Command timed out'
 
+    def execute_python(self, code: str) -> str:
+        raise NotImplementedError('execute_python is not supported in LocalBox')
+
     def execute_in_background(self, cmd: str) -> BackgroundCommand:
         process = subprocess.Popen(
             cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,

+ 4 - 0
opendevin/sandbox/sandbox.py

@@ -131,6 +131,10 @@ class Sandbox(ABC):
     def execute(self, cmd: str) -> Tuple[int, str]:
         pass
 
+    @abstractmethod
+    def execute_python(self, code: str) -> str:
+        pass
+
     @abstractmethod
     def execute_in_background(self, cmd: str):
         pass

+ 63 - 4
opendevin/sandbox/ssh_box.py

@@ -4,6 +4,7 @@ import platform
 import sys
 import time
 import uuid
+import asyncio
 from collections import namedtuple
 from typing import Dict, List, Tuple, Union
 
@@ -16,6 +17,7 @@ from opendevin.sandbox.sandbox import Sandbox, BackgroundCommand
 from opendevin.schema import ConfigType
 from opendevin.utils import find_available_tcp_port
 from opendevin.exceptions import SandboxInvalidBackgroundCommandError
+from opendevin.sandbox.jupyter_kernel import JupyterKernel
 
 InputType = namedtuple('InputType', ['content'])
 OutputType = namedtuple('OutputType', ['content'])
@@ -265,6 +267,11 @@ class DockerSSHBox(Sandbox):
         except docker.errors.NotFound:
             return False
 
+    def _get_port_mapping(self):
+        return {
+            f'{self._ssh_port}/tcp': self._ssh_port
+        }
+
     def restart_docker_container(self):
         try:
             self.stop_docker_container()
@@ -279,9 +286,10 @@ class DockerSSHBox(Sandbox):
                 network_kwargs['network_mode'] = 'host'
             else:
                 # FIXME: This is a temporary workaround for Mac OS
-                network_kwargs['ports'] = {f'{self._ssh_port}/tcp': self._ssh_port}
+
+                network_kwargs['ports'] = self._get_port_mapping()
                 logger.warning(
-                    ('Using port forwarding for Mac OS. '
+                    ('Using port forwarding. '
                      'Server started by OpenDevin will not be accessible from the host machine at the moment. '
                      'See https://github.com/OpenDevin/OpenDevin/issues/897 for more information.'
                      )
@@ -340,11 +348,56 @@ class DockerSSHBox(Sandbox):
             except docker.errors.NotFound:
                 pass
 
+    def execute_python(self, code: str) -> str:
+        raise NotImplementedError('execute_python is not supported in DockerSSHBox. Please use DockerSSHJupyterBox.')
+
+
+class DockerSSHJupyterBox(DockerSSHBox):
+    _jupyter_port: int
+
+    def __init__(
+        self,
+        container_image: str | None = None,
+        timeout: int = 120,
+        sid: str | None = None,
+    ):
+        self._jupyter_port = find_available_tcp_port()
+        super().__init__(container_image, timeout, sid)
+        self.setup_jupyter()
+
+    def _get_port_mapping(self):
+        return {
+            f'{self._ssh_port}/tcp': self._ssh_port,
+            '8888/tcp': self._jupyter_port,
+        }
+
+    def setup_jupyter(self):
+        # Setup Jupyter
+        self.jupyer_background_cmd = self.execute_in_background(
+            'jupyter kernelgateway --KernelGatewayApp.ip=0.0.0.0 --KernelGatewayApp.port=8888'
+        )
+        self.jupyter_kernel = JupyterKernel(
+            url_suffix=f'{SSH_HOSTNAME}:{self._jupyter_port}',
+            convid=self.instance_id,
+        )
+        logger.info(f'Jupyter Kernel Gateway started at {SSH_HOSTNAME}:{self._jupyter_port}: {self.jupyer_background_cmd.read_logs()}')
+
+        # initialize the kernel
+        logger.info('Initializing Jupyter Kernel Gateway...')
+        time.sleep(1)  # wait for the kernel to start
+        loop = asyncio.get_event_loop()
+        loop.run_until_complete(self.jupyter_kernel.initialize())
+        logger.info('Jupyter Kernel Gateway initialized')
+
+    def execute_python(self, code: str) -> str:
+        loop = asyncio.get_event_loop()
+        return loop.run_until_complete(self.jupyter_kernel.execute(code))
+
 
 if __name__ == '__main__':
 
     try:
-        ssh_box = DockerSSHBox()
+        ssh_box = DockerSSHJupyterBox()
     except Exception as e:
         logger.exception('Failed to start Docker container: %s', e)
         sys.exit(1)
@@ -353,7 +406,7 @@ if __name__ == '__main__':
         "Interactive Docker container started. Type 'exit' or use Ctrl+C to exit.")
 
     bg_cmd = ssh_box.execute_in_background(
-        "while true; do echo 'dot ' && sleep 1; done"
+        "while true; do echo 'dot ' && sleep 5; done"
     )
 
     sys.stdout.flush()
@@ -371,6 +424,12 @@ if __name__ == '__main__':
                 ssh_box.kill_background(bg_cmd.id)
                 logger.info('Background process killed')
                 continue
+            if user_input.startswith('py:'):
+                output = ssh_box.execute_python(user_input[3:])
+                logger.info(output)
+                sys.stdout.flush()
+                continue
+            print('JUPYTER LOG:', ssh_box.jupyer_background_cmd.read_logs())
             exit_code, output = ssh_box.execute(user_input)
             logger.info('exit code: %d', exit_code)
             logger.info(output)

+ 23 - 1
poetry.lock

@@ -3391,6 +3391,7 @@ optional = false
 python-versions = ">=3.9"
 files = [
     {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
+    {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
     {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
     {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
     {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
@@ -3411,6 +3412,7 @@ files = [
     {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
     {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
     {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
+    {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
     {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
     {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
     {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
@@ -5176,6 +5178,26 @@ typing-extensions = ">=4.8.0"
 opt-einsum = ["opt-einsum (>=3.3)"]
 optree = ["optree (>=0.9.1)"]
 
+[[package]]
+name = "tornado"
+version = "6.4"
+description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
+optional = false
+python-versions = ">= 3.8"
+files = [
+    {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"},
+    {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"},
+    {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"},
+    {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"},
+    {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"},
+    {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"},
+    {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"},
+    {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"},
+    {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"},
+    {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"},
+    {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"},
+]
+
 [[package]]
 name = "tqdm"
 version = "4.66.2"
@@ -5970,4 +5992,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.11"
-content-hash = "3a5ca3c8b47e0e43994032d1620d85a8d602c52a93790d192b9fdb3a8ac36d97"
+content-hash = "ba00c1217b404cad1139884cdaa2072ab9c416ff95aff1d5f525e562ee7d1350"

+ 1 - 0
pyproject.toml

@@ -24,6 +24,7 @@ numpy = "*"
 json-repair = "*"
 playwright = "*"
 pexpect = "*"
+tornado = "*"
 
 [tool.poetry.group.llama-index.dependencies]
 llama-index = "*"