awwaawwa 11 сар өмнө
parent
commit
c5fb3b4a2c

+ 1 - 0
pdf2zh/cache.py

@@ -106,6 +106,7 @@ def init_db(remove_exists=False):
 
 def init_test_db():
     import tempfile
+
     cache_db_path = tempfile.mktemp(suffix=".db")
     test_db = SqliteDatabase(
         cache_db_path,

+ 12 - 4
test/test_cache.py

@@ -119,7 +119,9 @@ class TestCache(unittest.TestCase):
         params2 = {"a": 2, "b": 1}
         cache1 = cache.TranslationCache("test_engine", params1)
         cache2 = cache.TranslationCache("test_engine", params2)
-        self.assertNotEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+        self.assertNotEqual(
+            cache1.translate_engine_params, cache2.translate_engine_params
+        )
 
     def test_cache_with_sorted_params(self):
         """Test that cache works correctly with sorted parameters"""
@@ -139,7 +141,9 @@ class TestCache(unittest.TestCase):
 
         # Test appending new parameter
         cache_instance.add_params("new_param", "new_value")
-        self.assertEqual(cache_instance.params, {"initial": "value", "new_param": "new_value"})
+        self.assertEqual(
+            cache_instance.params, {"initial": "value", "new_param": "new_value"}
+        )
 
         # Test that cache with appended params works correctly
         cache_instance.set("hello", "你好")
@@ -147,7 +151,9 @@ class TestCache(unittest.TestCase):
 
         # Test overwriting existing parameter
         cache_instance.add_params("initial", "new_value")
-        self.assertEqual(cache_instance.params, {"initial": "new_value", "new_param": "new_value"})
+        self.assertEqual(
+            cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
+        )
 
         # Cache should work with updated params
         cache_instance.set("hello2", "你好2")
@@ -162,7 +168,9 @@ class TestCache(unittest.TestCase):
         items_per_thread = 100
 
         def generate_random_text(length=10):
-            return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
+            return "".join(
+                random.choices(string.ascii_letters + string.digits, k=length)
+            )
 
         def worker():
             thread_results = []  # 线程本地存储结果

+ 214 - 0
test/test_translator.py

@@ -0,0 +1,214 @@
+import unittest
+from pdf2zh import cache
+import threading
+import multiprocessing
+import random
+import string
+
+
+class TestCache(unittest.TestCase):
+    def setUp(self):
+        self.test_db = cache.init_test_db()
+
+    def tearDown(self):
+        # Clean up
+        cache.clean_test_db(self.test_db)
+
+    def test_basic_set_get(self):
+        """Test basic set and get operations"""
+        cache_instance = cache.TranslationCache("test_engine")
+
+        # Test get with non-existent entry
+        result = cache_instance.get("hello")
+        self.assertIsNone(result)
+
+        # Test set and get
+        cache_instance.set("hello", "你好")
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "你好")
+
+    def test_cache_overwrite(self):
+        """Test that cache entries can be overwritten"""
+        cache_instance = cache.TranslationCache("test_engine")
+
+        # Set initial translation
+        cache_instance.set("hello", "你好")
+
+        # Overwrite with new translation
+        cache_instance.set("hello", "您好")
+
+        # Verify the new translation is returned
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "您好")
+
+    def test_non_string_params(self):
+        """Test that non-string parameters are automatically converted to JSON"""
+        params = {"model": "gpt-3.5", "temperature": 0.7}
+        cache_instance = cache.TranslationCache("test_engine", params)
+
+        # Test that params are converted to JSON string internally
+        cache_instance.set("hello", "你好")
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "你好")
+
+        # Test with different param types
+        array_params = ["param1", "param2"]
+        cache_instance2 = cache.TranslationCache("test_engine", array_params)
+        cache_instance2.set("hello", "你好2")
+        self.assertEqual(cache_instance2.get("hello"), "你好2")
+
+        # Test with nested structures
+        nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
+        cache_instance3 = cache.TranslationCache("test_engine", nested_params)
+        cache_instance3.set("hello", "你好3")
+        self.assertEqual(cache_instance3.get("hello"), "你好3")
+
+    def test_engine_distinction(self):
+        """Test that cache distinguishes between different translation engines"""
+        cache1 = cache.TranslationCache("engine1")
+        cache2 = cache.TranslationCache("engine2")
+
+        # Set same text with different engines
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each engine gets its own translation
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
+
+    def test_params_distinction(self):
+        """Test that cache distinguishes between different engine parameters"""
+        params1 = {"param": "value1"}
+        params2 = {"param": "value2"}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+
+        # Set same text with different parameters
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each parameter set gets its own translation
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
+
+    def test_consistent_param_serialization(self):
+        """Test that dictionary parameters are consistently serialized regardless of key order"""
+        # Test simple dictionary
+        params1 = {"b": 1, "a": 2}
+        params2 = {"a": 2, "b": 1}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test nested dictionary
+        params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3}
+        params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test dictionary with list of dictionaries
+        params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
+        params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test that different values still produce different results
+        params1 = {"a": 1, "b": 2}
+        params2 = {"a": 2, "b": 1}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertNotEqual(
+            cache1.translate_engine_params, cache2.translate_engine_params
+        )
+
+    def test_cache_with_sorted_params(self):
+        """Test that cache works correctly with sorted parameters"""
+        params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
+        params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
+
+        # Both caches should work with the same key
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache1.set("hello", "你好")
+
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache2.get("hello"), "你好")
+
+    def test_append_params(self):
+        """Test the append_params method"""
+        cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})
+
+        # Test appending new parameter
+        cache_instance.add_params("new_param", "new_value")
+        self.assertEqual(
+            cache_instance.params, {"initial": "value", "new_param": "new_value"}
+        )
+
+        # Test that cache with appended params works correctly
+        cache_instance.set("hello", "你好")
+        self.assertEqual(cache_instance.get("hello"), "你好")
+
+        # Test overwriting existing parameter
+        cache_instance.add_params("initial", "new_value")
+        self.assertEqual(
+            cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
+        )
+
+        # Cache should work with updated params
+        cache_instance.set("hello2", "你好2")
+        self.assertEqual(cache_instance.get("hello2"), "你好2")
+
+    def test_thread_safety(self):
+        """Test thread safety of cache operations"""
+        cache_instance = cache.TranslationCache("test_engine")
+        lock = threading.Lock()
+        results = []
+        num_threads = multiprocessing.cpu_count()
+        items_per_thread = 100
+
+        def generate_random_text(length=10):
+            return "".join(
+                random.choices(string.ascii_letters + string.digits, k=length)
+            )
+
+        def worker():
+            thread_results = []  # 线程本地存储结果
+            for _ in range(items_per_thread):
+                text = generate_random_text()
+                translation = f"翻译_{text}"
+
+                # Write operation
+                cache_instance.set(text, translation)
+
+                # Read operation - verify our own write
+                result = cache_instance.get(text)
+                thread_results.append((text, result))
+
+            # 所有操作完成后,一次性加锁并追加结果
+            with lock:
+                results.extend(thread_results)
+
+        # Create threads equal to CPU core count
+        threads = []
+        for _ in range(num_threads):
+            thread = threading.Thread(target=worker)
+            threads.append(thread)
+            thread.start()
+
+        # Wait for all threads to complete
+        for thread in threads:
+            thread.join()
+
+        # Verify all operations were successful
+        expected_total = num_threads * items_per_thread
+        self.assertEqual(len(results), expected_total)
+
+        # Verify each thread got its correct value
+        for text, result in results:
+            expected = f"翻译_{text}"
+            self.assertEqual(result, expected)
+
+
+if __name__ == "__main__":
+    unittest.main()