Browse Source

✅ test(cache): enhance unit tests for cache module

- bind and initialize in-memory test database for isolated testing
- add tests for translation cache operations and engine distinction
- include additional scenarios for parameter-based cache separation
- ensure cleanup and teardown after tests to maintain test consistency
awwaawwa 11 months ago
parent
commit
53f426a100
1 changed files with 72 additions and 96 deletions
  1. 72 96
      test/test_cache.py

+ 72 - 96
test/test_cache.py

@@ -1,107 +1,83 @@
 import unittest
-import os
-import tempfile
-import shutil
-import time
 from pdf2zh import cache
+from peewee import SqliteDatabase
+
+MODELS = [cache._TranslationCache]
+test_db = SqliteDatabase(":memory:")
 
 
 class TestCache(unittest.TestCase):
     def setUp(self):
-        # Create a temporary directory for testing
-        self.test_cache_dir = os.path.join(tempfile.gettempdir(), "test_cache")
-        self.original_cache_dir = cache.cache_dir
-        cache.cache_dir = self.test_cache_dir
-        os.makedirs(self.test_cache_dir, exist_ok=True)
+        # Bind model classes to test db. Since we have a complete list of
+        # all models, we do not need to recursively bind dependencies.
+        test_db.bind(MODELS, bind_refs=False, bind_backrefs=False)
+
+        test_db.connect()
+        test_db.create_tables(MODELS)
 
     def tearDown(self):
-        # Clean up the test directory
-        shutil.rmtree(self.test_cache_dir)
-        cache.cache_dir = self.original_cache_dir
-
-    def test_deterministic_hash(self):
-        # Test hash generation for different inputs
-        test_input = "Hello World"
-        hash1 = cache.deterministic_hash(test_input)
-        hash2 = cache.deterministic_hash(test_input)
-        self.assertEqual(hash1, hash2)
-        self.assertEqual(len(hash1), 20)
-
-        # Test different inputs produce different hashes
-        hash3 = cache.deterministic_hash("Different input")
-        self.assertNotEqual(hash1, hash3)
-
-    def test_get_dirs(self):
-        # Create test directories
-        test_dirs = ["dir1", "dir2", "dir3"]
-        for dir_name in test_dirs:
-            os.makedirs(os.path.join(self.test_cache_dir, dir_name))
-
-        # Create a file (should be ignored)
-        with open(os.path.join(self.test_cache_dir, "test.txt"), "w") as f:
-            f.write("test")
-
-        dirs = cache.get_dirs()
-        self.assertEqual(len(dirs), 3)
-        for dir_path in dirs:
-            self.assertTrue(os.path.isdir(dir_path))
-
-    def test_get_time(self):
-        # Create test directory with time file
-        test_dir = os.path.join(self.test_cache_dir, "test_dir")
-        os.makedirs(test_dir)
-        test_time = 1234567890.0
-
-        with open(os.path.join(test_dir, cache.time_filename), "w") as f:
-            f.write(str(test_time))
-
-        # Test reading time
-        result = cache.get_time(test_dir)
-        self.assertEqual(result, test_time)
-
-        # Test non-existent directory
-        non_existent_dir = os.path.join(self.test_cache_dir, "non_existent")
-        result = cache.get_time(non_existent_dir)
-        self.assertEqual(result, float("inf"))
-
-    def test_write_time(self):
-        test_dir = os.path.join(self.test_cache_dir, "test_dir")
-        os.makedirs(test_dir)
-
-        cache.write_time(test_dir)
-
-        self.assertTrue(os.path.exists(os.path.join(test_dir, cache.time_filename)))
-        with open(os.path.join(test_dir, cache.time_filename)) as f:
-            time_value = float(f.read())
-        self.assertIsInstance(time_value, float)
-
-    def test_remove_extra(self):
-        # Create more than max_cache directories
-        for i in range(cache.max_cache + 2):
-            dir_path = os.path.join(self.test_cache_dir, f"dir{i}")
-            os.makedirs(dir_path)
-            time.sleep(0.1)  # Ensure different timestamps
-            cache.write_time(dir_path)
-
-        cache.remove_extra()
-
-        remaining_dirs = cache.get_dirs()
-        self.assertLessEqual(len(remaining_dirs), cache.max_cache)
-
-    def test_cache_operations(self):
-        test_hash = "test123hash"
-        test_para_hash = "para456hash"
-        test_content = "Test paragraph content"
-
-        # Test cache creation
-        self.assertFalse(cache.is_cached(test_hash))
-        cache.create_cache(test_hash)
-        self.assertTrue(cache.is_cached(test_hash))
-
-        # Test paragraph operations
-        self.assertIsNone(cache.load_paragraph(test_hash, test_para_hash))
-        cache.write_paragraph(test_hash, test_para_hash, test_content)
-        self.assertEqual(cache.load_paragraph(test_hash, test_para_hash), test_content)
+        # Clean up
+        test_db.drop_tables(MODELS)
+        test_db.close()
+
+    def test_basic_set_get(self):
+        """Test basic set and get operations"""
+        cache_instance = cache.TranslationCache("test_engine", "{}")
+
+        # Test get with non-existent entry
+        result = cache_instance.get("hello")
+        self.assertIsNone(result)
+
+        # Test set and get
+        cache_instance.set("hello", "你好")
+        result = cache_instance.get("hello")
+        self.assertIsNotNone(result)
+        self.assertEqual(result.translation, "你好")
+        self.assertEqual(result.original_text, "hello")
+
+    def test_cache_overwrite(self):
+        """Test that cache entries can be overwritten"""
+        cache_instance = cache.TranslationCache("test_engine", "{}")
+
+        # Set initial translation
+        cache_instance.set("hello", "你好")
+
+        # Overwrite with new translation
+        cache_instance.set("hello", "您好")
+
+        # Verify the new translation is returned
+        result = cache_instance.get("hello")
+        self.assertEqual(result.translation, "您好")
+
+    def test_engine_distinction(self):
+        """Test that cache distinguishes between different translation engines"""
+        cache1 = cache.TranslationCache("engine1", "{}")
+        cache2 = cache.TranslationCache("engine2", "{}")
+
+        # Set same text with different engines
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each engine gets its own translation
+        result1 = cache1.get("hello")
+        result2 = cache2.get("hello")
+        self.assertEqual(result1.translation, "你好 1")
+        self.assertEqual(result2.translation, "你好 2")
+
+    def test_params_distinction(self):
+        """Test that cache distinguishes between different engine parameters"""
+        cache1 = cache.TranslationCache("test_engine", '{"param": "value1"}')
+        cache2 = cache.TranslationCache("test_engine", '{"param": "value2"}')
+
+        # Set same text with different parameters
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each parameter set gets its own translation
+        result1 = cache1.get("hello")
+        result2 = cache2.get("hello")
+        self.assertEqual(result1.translation, "你好 1")
+        self.assertEqual(result2.translation, "你好 2")
 
 
 if __name__ == "__main__":