Browse Source

✅ test(test_cache): enhance caching tests for thread safety

- add tests for thread safety in cache operations
- ensure that all operations are consistent and correct under load
awwaawwa 11 months ago
parent
commit
053f33268e
1 changed files with 94 additions and 46 deletions
  1. 94 46
      test/test_cache.py

+ 94 - 46
test/test_cache.py

@@ -1,28 +1,22 @@
 import unittest
 from pdf2zh import cache
-from peewee import SqliteDatabase
-
-MODELS = [cache._TranslationCache]
-test_db = SqliteDatabase(":memory:")
+import threading
+import multiprocessing
+import random
+import string
 
 
 class TestCache(unittest.TestCase):
     def setUp(self):
-        # Bind model classes to test db. Since we have a complete list of
-        # all models, we do not need to recursively bind dependencies.
-        test_db.bind(MODELS, bind_refs=False, bind_backrefs=False)
-
-        test_db.connect()
-        test_db.create_tables(MODELS)
+        self.test_db = cache.init_test_db()
 
     def tearDown(self):
         # Clean up
-        test_db.drop_tables(MODELS)
-        test_db.close()
+        cache.clean_test_db(self.test_db)
 
     def test_basic_set_get(self):
         """Test basic set and get operations"""
-        cache_instance = cache.TranslationCache("test_engine", "{}")
+        cache_instance = cache.TranslationCache("test_engine")
 
         # Test get with non-existent entry
         result = cache_instance.get("hello")
@@ -31,13 +25,11 @@ class TestCache(unittest.TestCase):
         # Test set and get
         cache_instance.set("hello", "你好")
         result = cache_instance.get("hello")
-        self.assertIsNotNone(result)
-        self.assertEqual(result.translation, "你好")
-        self.assertEqual(result.original_text, "hello")
+        self.assertEqual(result, "你好")
 
     def test_cache_overwrite(self):
         """Test that cache entries can be overwritten"""
-        cache_instance = cache.TranslationCache("test_engine", "{}")
+        cache_instance = cache.TranslationCache("test_engine")
 
         # Set initial translation
         cache_instance.set("hello", "你好")
@@ -47,66 +39,57 @@ class TestCache(unittest.TestCase):
 
         # Verify the new translation is returned
         result = cache_instance.get("hello")
-        self.assertEqual(result.translation, "您好")
+        self.assertEqual(result, "您好")
 
     def test_non_string_params(self):
         """Test that non-string parameters are automatically converted to JSON"""
         params = {"model": "gpt-3.5", "temperature": 0.7}
         cache_instance = cache.TranslationCache("test_engine", params)
 
-        # Test that params are converted to JSON string
-        self.assertEqual(
-            cache_instance.translate_engine_params,
-            '{"model": "gpt-3.5", "temperature": 0.7}',
-        )
-
-        # Test that cache operations work with converted params
+        # Test that params are converted to JSON string internally
         cache_instance.set("hello", "你好")
         result = cache_instance.get("hello")
-        self.assertEqual(result.translation, "你好")
+        self.assertEqual(result, "你好")
 
         # Test with different param types
         array_params = ["param1", "param2"]
         cache_instance2 = cache.TranslationCache("test_engine", array_params)
-        self.assertEqual(cache_instance2.translate_engine_params, '["param1", "param2"]')
+        cache_instance2.set("hello", "你好2")
+        self.assertEqual(cache_instance2.get("hello"), "你好2")
 
         # Test with nested structures
         nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
         cache_instance3 = cache.TranslationCache("test_engine", nested_params)
-        self.assertEqual(
-            cache_instance3.translate_engine_params,
-            '{"options": {"models": ["a", "b"], "temp": 0.8}}',
-        )
+        cache_instance3.set("hello", "你好3")
+        self.assertEqual(cache_instance3.get("hello"), "你好3")
 
     def test_engine_distinction(self):
         """Test that cache distinguishes between different translation engines"""
-        cache1 = cache.TranslationCache("engine1", "{}")
-        cache2 = cache.TranslationCache("engine2", "{}")
+        cache1 = cache.TranslationCache("engine1")
+        cache2 = cache.TranslationCache("engine2")
 
         # Set same text with different engines
         cache1.set("hello", "你好 1")
         cache2.set("hello", "你好 2")
 
         # Verify each engine gets its own translation
-        result1 = cache1.get("hello")
-        result2 = cache2.get("hello")
-        self.assertEqual(result1.translation, "你好 1")
-        self.assertEqual(result2.translation, "你好 2")
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
 
     def test_params_distinction(self):
         """Test that cache distinguishes between different engine parameters"""
-        cache1 = cache.TranslationCache("test_engine", '{"param": "value1"}')
-        cache2 = cache.TranslationCache("test_engine", '{"param": "value2"}')
+        params1 = {"param": "value1"}
+        params2 = {"param": "value2"}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
 
         # Set same text with different parameters
         cache1.set("hello", "你好 1")
         cache2.set("hello", "你好 2")
 
         # Verify each parameter set gets its own translation
-        result1 = cache1.get("hello")
-        result2 = cache2.get("hello")
-        self.assertEqual(result1.translation, "你好 1")
-        self.assertEqual(result2.translation, "你好 2")
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
 
     def test_consistent_param_serialization(self):
         """Test that dictionary parameters are consistently serialized regardless of key order"""
@@ -148,10 +131,75 @@ class TestCache(unittest.TestCase):
         cache1.set("hello", "你好")
 
         cache2 = cache.TranslationCache("test_engine", params2)
-        result = cache2.get("hello")
+        self.assertEqual(cache2.get("hello"), "你好")
+
+    def test_append_params(self):
+        """Test the append_params method"""
+        cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})
 
-        self.assertIsNotNone(result)
-        self.assertEqual(result.translation, "你好")
+        # Test appending new parameter
+        cache_instance.append_params("new_param", "new_value")
+        self.assertEqual(cache_instance.params, {"initial": "value", "new_param": "new_value"})
+
+        # Test that cache with appended params works correctly
+        cache_instance.set("hello", "你好")
+        self.assertEqual(cache_instance.get("hello"), "你好")
+
+        # Test overwriting existing parameter
+        cache_instance.append_params("initial", "new_value")
+        self.assertEqual(cache_instance.params, {"initial": "new_value", "new_param": "new_value"})
+
+        # Cache should work with updated params
+        cache_instance.set("hello2", "你好2")
+        self.assertEqual(cache_instance.get("hello2"), "你好2")
+
+    def test_thread_safety(self):
+        """Test thread safety of cache operations"""
+        cache_instance = cache.TranslationCache("test_engine")
+        lock = threading.Lock()
+        results = []
+        num_threads = multiprocessing.cpu_count()
+        items_per_thread = 100
+
+        def generate_random_text(length=10):
+            return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
+
+        def worker():
+            thread_results = []  # 线程本地存储结果
+            for _ in range(items_per_thread):
+                text = generate_random_text()
+                translation = f"翻译_{text}"
+
+                # Write operation
+                cache_instance.set(text, translation)
+
+                # Read operation - verify our own write
+                result = cache_instance.get(text)
+                thread_results.append((text, result))
+
+            # 所有操作完成后,一次性加锁并追加结果
+            with lock:
+                results.extend(thread_results)
+
+        # Create threads equal to CPU core count
+        threads = []
+        for _ in range(num_threads):
+            thread = threading.Thread(target=worker)
+            threads.append(thread)
+            thread.start()
+
+        # Wait for all threads to complete
+        for thread in threads:
+            thread.join()
+
+        # Verify all operations were successful
+        expected_total = num_threads * items_per_thread
+        self.assertEqual(len(results), expected_total)
+
+        # Verify each thread got its correct value
+        for text, result in results:
+            expected = f"翻译_{text}"
+            self.assertEqual(result, expected)
 
 
 if __name__ == "__main__":