Przeglądaj źródła

Feat google cloud storage (#3574)

* Google cloud storage implementation
* Unit test refactor
tofarr 1 rok temu
rodzic
commit
8c4c3b18b5

+ 3 - 0
openhands/storage/__init__.py

@@ -1,4 +1,5 @@
 from openhands.storage.files import FileStore
+from openhands.storage.google_cloud import GoogleCloudFileStore
 from openhands.storage.local import LocalFileStore
 from openhands.storage.memory import InMemoryFileStore
 from openhands.storage.s3 import S3FileStore
@@ -11,4 +12,6 @@ def get_file_store(file_store: str, file_store_path: str | None = None) -> FileS
         return LocalFileStore(file_store_path)
     elif file_store == 's3':
         return S3FileStore()
+    elif file_store == 'google_cloud':
+        return GoogleCloudFileStore()
     return InMemoryFileStore()

+ 59 - 0
openhands/storage/google_cloud.py

@@ -0,0 +1,59 @@
+import os
+from typing import List, Optional
+
+from google.cloud import storage
+
+from openhands.storage.files import FileStore
+
+
+class GoogleCloudFileStore(FileStore):
+    def __init__(self, bucket_name: Optional[str] = None) -> None:
+        """
+        Create a new FileStore. If GOOGLE_APPLICATION_CREDENTIALS is defined in the
+        environment it will be used for authentication. Otherwise access will be
+        anonymous.
+        """
+        if bucket_name is None:
+            bucket_name = os.environ['GOOGLE_CLOUD_BUCKET_NAME']
+        self.storage_client = storage.Client()
+        self.bucket = self.storage_client.bucket(bucket_name)
+
+    def write(self, path: str, contents: str | bytes) -> None:
+        blob = self.bucket.blob(path)
+        with blob.open('w') as f:
+            f.write(contents)
+
+    def read(self, path: str) -> str:
+        blob = self.bucket.blob(path)
+        with blob.open('r') as f:
+            return f.read()
+
+    def list(self, path: str) -> List[str]:
+        if not path or path == '/':
+            path = ''
+        elif not path.endswith('/'):
+            path += '/'
+        # The delimiter logic screens out directories, so we can't use it. :(
+        # For example, given a structure:
+        #   foo/bar/zap.txt
+        #   foo/bar/bang.txt
+        #   ping.txt
+        # prefix=None, delimiter="/"   yields  ["ping.txt"]  # :(
+        # prefix="foo", delimiter="/"  yields  []  # :(
+        blobs = set()
+        prefix_len = len(path)
+        for blob in self.bucket.list_blobs(prefix=path):
+            name = blob.name
+            if name == path:
+                continue
+            try:
+                index = name.index('/', prefix_len + 1)
+                if index != prefix_len:
+                    blobs.add(name[: index + 1])
+            except ValueError:
+                blobs.add(name)
+        return list(blobs)
+
+    def delete(self, path: str) -> None:
+        blob = self.bucket.blob(path)
+        blob.delete()

+ 118 - 33
tests/unit/test_storage.py

@@ -1,65 +1,150 @@
+from __future__ import annotations
+
 import os
 import shutil
+from abc import ABC
+from dataclasses import dataclass, field
+from io import StringIO
+from typing import Dict, List, Optional
+from unittest import TestCase
+from unittest.mock import patch
 
-import pytest
-
+from openhands.storage.files import FileStore
+from openhands.storage.google_cloud import GoogleCloudFileStore
 from openhands.storage.local import LocalFileStore
 from openhands.storage.memory import InMemoryFileStore
 
 
-@pytest.fixture
-def setup_env():
-    os.makedirs('./_test_files_tmp', exist_ok=True)
-
-    yield
-
-    shutil.rmtree('./_test_files_tmp')
+class _StorageTest(ABC):
+    store: FileStore
 
+    def get_store(self) -> FileStore:
+        return self.store
 
-def test_basic_fileops(setup_env):
-    filename = 'test.txt'
-    for store in [LocalFileStore('./_test_files_tmp'), InMemoryFileStore()]:
+    def test_basic_fileops(self):
+        filename = 'test.txt'
+        store = self.get_store()
         store.write(filename, 'Hello, world!')
-        assert store.read(filename) == 'Hello, world!'
-        assert store.list('') == [filename]
+        self.assertEqual(store.read(filename), 'Hello, world!')
+        self.assertEqual(store.list(''), [filename])
         store.delete(filename)
-        with pytest.raises(FileNotFoundError):
+        with self.assertRaises(FileNotFoundError):
             store.read(filename)
 
-
-def test_complex_path_fileops(setup_env):
-    filenames = ['foo.bar.baz', './foo/bar/baz', 'foo/bar/baz', '/foo/bar/baz']
-    for store in [LocalFileStore('./_test_files_tmp'), InMemoryFileStore()]:
+    def test_complex_path_fileops(self):
+        filenames = ['foo.bar.baz', './foo/bar/baz', 'foo/bar/baz', '/foo/bar/baz']
+        store = self.get_store()
         for filename in filenames:
             store.write(filename, 'Hello, world!')
-            assert store.read(filename) == 'Hello, world!'
+            self.assertEqual(store.read(filename), 'Hello, world!')
             store.delete(filename)
-            with pytest.raises(FileNotFoundError):
+            with self.assertRaises(FileNotFoundError):
                 store.read(filename)
 
-
-def test_list(setup_env):
-    for store in [LocalFileStore('./_test_files_tmp'), InMemoryFileStore()]:
+    def test_list(self):
+        store = self.get_store()
         store.write('foo.txt', 'Hello, world!')
         store.write('bar.txt', 'Hello, world!')
         store.write('baz.txt', 'Hello, world!')
-        assert store.list('').sort() == ['foo.txt', 'bar.txt', 'baz.txt'].sort()
+        file_names = store.list('')
+        file_names.sort()
+        self.assertEqual(file_names, ['bar.txt', 'baz.txt', 'foo.txt'])
         store.delete('foo.txt')
         store.delete('bar.txt')
         store.delete('baz.txt')
 
-
-def test_deep_list(setup_env):
-    for store in [LocalFileStore('./_test_files_tmp'), InMemoryFileStore()]:
+    def test_deep_list(self):
+        store = self.get_store()
         store.write('foo/bar/baz.txt', 'Hello, world!')
         store.write('foo/bar/qux.txt', 'Hello, world!')
         store.write('foo/bar/quux.txt', 'Hello, world!')
-        assert store.list('') == ['foo/'], f'for class {store.__class__}'
-        assert store.list('foo') == ['foo/bar/']
-        assert (
-            store.list('foo/bar').sort()
-            == ['foo/bar/baz.txt', 'foo/bar/qux.txt', 'foo/bar/quux.txt'].sort()
+        self.assertEqual(store.list(''), ['foo/'])
+        self.assertEqual(store.list('foo'), ['foo/bar/'])
+        file_names = store.list('foo/bar')
+        file_names.sort()
+        self.assertEqual(
+            file_names, ['foo/bar/baz.txt', 'foo/bar/quux.txt', 'foo/bar/qux.txt']
         )
         store.delete('foo/bar/baz.txt')
         store.delete('foo/bar/qux.txt')
         store.delete('foo/bar/quux.txt')
+
+
+class TestLocalFileStore(TestCase, _StorageTest):
+    def setUp(self):
+        os.makedirs('./_test_files_tmp', exist_ok=True)
+        self.store = LocalFileStore('./_test_files_tmp')
+
+    def tearDown(self):
+        shutil.rmtree('./_test_files_tmp')
+
+
+class TestInMemoryFileStore(TestCase, _StorageTest):
+    def setUp(self):
+        self.store = InMemoryFileStore()
+
+
+class TestGoogleCloudFileStore(TestCase, _StorageTest):
+    def setUp(self):
+        with patch('google.cloud.storage.Client', _MockGoogleCloudClient):
+            self.store = GoogleCloudFileStore('dear-liza')
+
+
+# I would have liked to use cloud-storage-mocker here but the python versions were incompatible :(
+# If we write tests for the S3 storage class I would definitely recommend we use moto.
+class _MockGoogleCloudClient:
+    def bucket(self, name: str):
+        assert name == 'dear-liza'
+        return _MockGoogleCloudBucket()
+
+
+@dataclass
+class _MockGoogleCloudBucket:
+    blobs_by_path: Dict[str, _MockGoogleCloudBlob] = field(default_factory=dict)
+
+    def blob(self, path: Optional[str] = None) -> _MockGoogleCloudBlob:
+        return self.blobs_by_path.get(path) or _MockGoogleCloudBlob(self, path)
+
+    def list_blobs(self, prefix: Optional[str] = None) -> List[_MockGoogleCloudBlob]:
+        blobs = list(self.blobs_by_path.values())
+        if prefix and prefix != '/':
+            blobs = [blob for blob in blobs if blob.name.startswith(prefix)]
+        return blobs
+
+
+@dataclass
+class _MockGoogleCloudBlob:
+    bucket: _MockGoogleCloudBucket
+    name: str
+    content: Optional[str | bytes] = None
+
+    def open(self, op: str):
+        if op == 'r':
+            if self.content is None:
+                raise FileNotFoundError()
+            return StringIO(self.content)
+        if op == 'w':
+            return _MockGoogleCloudBlobWriter(self)
+
+    def delete(self):
+        del self.bucket.blobs_by_path[self.name]
+
+
+@dataclass
+class _MockGoogleCloudBlobWriter:
+    blob: _MockGoogleCloudBlob
+    content: str | bytes = None
+
+    def __enter__(self):
+        return self
+
+    def write(self, __b):
+        assert (
+            self.content is None
+        )  # We don't support buffered writes in this mock for now, as it is not needed
+        self.content = __b
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        blob = self.blob
+        blob.content = self.content
+        blob.bucket.blobs_by_path[blob.name] = blob