| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- import unittest
- from pdf2zh import cache
- import threading
- import multiprocessing
- import random
- import string
- class TestCache(unittest.TestCase):
- def setUp(self):
- self.test_db = cache.init_test_db()
- def tearDown(self):
- # Clean up
- cache.clean_test_db(self.test_db)
- def test_basic_set_get(self):
- """Test basic set and get operations"""
- cache_instance = cache.TranslationCache("test_engine")
- # Test get with non-existent entry
- result = cache_instance.get("hello")
- self.assertIsNone(result)
- # Test set and get
- cache_instance.set("hello", "你好")
- result = cache_instance.get("hello")
- self.assertEqual(result, "你好")
- def test_cache_overwrite(self):
- """Test that cache entries can be overwritten"""
- cache_instance = cache.TranslationCache("test_engine")
- # Set initial translation
- cache_instance.set("hello", "你好")
- # Overwrite with new translation
- cache_instance.set("hello", "您好")
- # Verify the new translation is returned
- result = cache_instance.get("hello")
- self.assertEqual(result, "您好")
- def test_non_string_params(self):
- """Test that non-string parameters are automatically converted to JSON"""
- params = {"model": "gpt-3.5", "temperature": 0.7}
- cache_instance = cache.TranslationCache("test_engine", params)
- # Test that params are converted to JSON string internally
- cache_instance.set("hello", "你好")
- result = cache_instance.get("hello")
- self.assertEqual(result, "你好")
- # Test with different param types
- array_params = ["param1", "param2"]
- cache_instance2 = cache.TranslationCache("test_engine", array_params)
- cache_instance2.set("hello", "你好2")
- self.assertEqual(cache_instance2.get("hello"), "你好2")
- # Test with nested structures
- nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
- cache_instance3 = cache.TranslationCache("test_engine", nested_params)
- cache_instance3.set("hello", "你好3")
- self.assertEqual(cache_instance3.get("hello"), "你好3")
- def test_engine_distinction(self):
- """Test that cache distinguishes between different translation engines"""
- cache1 = cache.TranslationCache("engine1")
- cache2 = cache.TranslationCache("engine2")
- # Set same text with different engines
- cache1.set("hello", "你好 1")
- cache2.set("hello", "你好 2")
- # Verify each engine gets its own translation
- self.assertEqual(cache1.get("hello"), "你好 1")
- self.assertEqual(cache2.get("hello"), "你好 2")
- def test_params_distinction(self):
- """Test that cache distinguishes between different engine parameters"""
- params1 = {"param": "value1"}
- params2 = {"param": "value2"}
- cache1 = cache.TranslationCache("test_engine", params1)
- cache2 = cache.TranslationCache("test_engine", params2)
- # Set same text with different parameters
- cache1.set("hello", "你好 1")
- cache2.set("hello", "你好 2")
- # Verify each parameter set gets its own translation
- self.assertEqual(cache1.get("hello"), "你好 1")
- self.assertEqual(cache2.get("hello"), "你好 2")
- def test_consistent_param_serialization(self):
- """Test that dictionary parameters are consistently serialized regardless of key order"""
- # Test simple dictionary
- params1 = {"b": 1, "a": 2}
- params2 = {"a": 2, "b": 1}
- cache1 = cache.TranslationCache("test_engine", params1)
- cache2 = cache.TranslationCache("test_engine", params2)
- self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
- # Test nested dictionary
- params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3}
- params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}}
- cache1 = cache.TranslationCache("test_engine", params1)
- cache2 = cache.TranslationCache("test_engine", params2)
- self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
- # Test dictionary with list of dictionaries
- params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
- params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
- cache1 = cache.TranslationCache("test_engine", params1)
- cache2 = cache.TranslationCache("test_engine", params2)
- self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
- # Test that different values still produce different results
- params1 = {"a": 1, "b": 2}
- params2 = {"a": 2, "b": 1}
- cache1 = cache.TranslationCache("test_engine", params1)
- cache2 = cache.TranslationCache("test_engine", params2)
- self.assertNotEqual(
- cache1.translate_engine_params, cache2.translate_engine_params
- )
- def test_cache_with_sorted_params(self):
- """Test that cache works correctly with sorted parameters"""
- params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
- params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
- # Both caches should work with the same key
- cache1 = cache.TranslationCache("test_engine", params1)
- cache1.set("hello", "你好")
- cache2 = cache.TranslationCache("test_engine", params2)
- self.assertEqual(cache2.get("hello"), "你好")
- def test_append_params(self):
- """Test the append_params method"""
- cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})
- # Test appending new parameter
- cache_instance.add_params("new_param", "new_value")
- self.assertEqual(
- cache_instance.params, {"initial": "value", "new_param": "new_value"}
- )
- # Test that cache with appended params works correctly
- cache_instance.set("hello", "你好")
- self.assertEqual(cache_instance.get("hello"), "你好")
- # Test overwriting existing parameter
- cache_instance.add_params("initial", "new_value")
- self.assertEqual(
- cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
- )
- # Cache should work with updated params
- cache_instance.set("hello2", "你好2")
- self.assertEqual(cache_instance.get("hello2"), "你好2")
- def test_thread_safety(self):
- """Test thread safety of cache operations"""
- cache_instance = cache.TranslationCache("test_engine")
- lock = threading.Lock()
- results = []
- num_threads = multiprocessing.cpu_count()
- items_per_thread = 100
- def generate_random_text(length=10):
- return "".join(
- random.choices(string.ascii_letters + string.digits, k=length)
- )
- def worker():
- thread_results = [] # 线程本地存储结果
- for _ in range(items_per_thread):
- text = generate_random_text()
- translation = f"翻译_{text}"
- # Write operation
- cache_instance.set(text, translation)
- # Read operation - verify our own write
- result = cache_instance.get(text)
- thread_results.append((text, result))
- # 所有操作完成后,一次性加锁并追加结果
- with lock:
- results.extend(thread_results)
- # Create threads equal to CPU core count
- threads = []
- for _ in range(num_threads):
- thread = threading.Thread(target=worker)
- threads.append(thread)
- thread.start()
- # Wait for all threads to complete
- for thread in threads:
- thread.join()
- # Verify all operations were successful
- expected_total = num_threads * items_per_thread
- self.assertEqual(len(results), expected_total)
- # Verify each thread got its correct value
- for text, result in results:
- expected = f"翻译_{text}"
- self.assertEqual(result, expected)
- if __name__ == "__main__":
- unittest.main()
|