|
|
@@ -1,107 +1,213 @@
|
|
|
import unittest
|
|
|
-import os
|
|
|
-import tempfile
|
|
|
-import shutil
|
|
|
-import time
|
|
|
from pdf2zh import cache
|
|
|
+import threading
|
|
|
+import multiprocessing
|
|
|
+import random
|
|
|
+import string
|
|
|
|
|
|
|
|
|
class TestCache(unittest.TestCase):
|
|
|
def setUp(self):
|
|
|
- # Create a temporary directory for testing
|
|
|
- self.test_cache_dir = os.path.join(tempfile.gettempdir(), "test_cache")
|
|
|
- self.original_cache_dir = cache.cache_dir
|
|
|
- cache.cache_dir = self.test_cache_dir
|
|
|
- os.makedirs(self.test_cache_dir, exist_ok=True)
|
|
|
+ self.test_db = cache.init_test_db()
|
|
|
|
|
|
def tearDown(self):
|
|
|
- # Clean up the test directory
|
|
|
- shutil.rmtree(self.test_cache_dir)
|
|
|
- cache.cache_dir = self.original_cache_dir
|
|
|
-
|
|
|
- def test_deterministic_hash(self):
|
|
|
- # Test hash generation for different inputs
|
|
|
- test_input = "Hello World"
|
|
|
- hash1 = cache.deterministic_hash(test_input)
|
|
|
- hash2 = cache.deterministic_hash(test_input)
|
|
|
- self.assertEqual(hash1, hash2)
|
|
|
- self.assertEqual(len(hash1), 20)
|
|
|
-
|
|
|
- # Test different inputs produce different hashes
|
|
|
- hash3 = cache.deterministic_hash("Different input")
|
|
|
- self.assertNotEqual(hash1, hash3)
|
|
|
-
|
|
|
- def test_get_dirs(self):
|
|
|
- # Create test directories
|
|
|
- test_dirs = ["dir1", "dir2", "dir3"]
|
|
|
- for dir_name in test_dirs:
|
|
|
- os.makedirs(os.path.join(self.test_cache_dir, dir_name))
|
|
|
-
|
|
|
- # Create a file (should be ignored)
|
|
|
- with open(os.path.join(self.test_cache_dir, "test.txt"), "w") as f:
|
|
|
- f.write("test")
|
|
|
-
|
|
|
- dirs = cache.get_dirs()
|
|
|
- self.assertEqual(len(dirs), 3)
|
|
|
- for dir_path in dirs:
|
|
|
- self.assertTrue(os.path.isdir(dir_path))
|
|
|
-
|
|
|
- def test_get_time(self):
|
|
|
- # Create test directory with time file
|
|
|
- test_dir = os.path.join(self.test_cache_dir, "test_dir")
|
|
|
- os.makedirs(test_dir)
|
|
|
- test_time = 1234567890.0
|
|
|
-
|
|
|
- with open(os.path.join(test_dir, cache.time_filename), "w") as f:
|
|
|
- f.write(str(test_time))
|
|
|
-
|
|
|
- # Test reading time
|
|
|
- result = cache.get_time(test_dir)
|
|
|
- self.assertEqual(result, test_time)
|
|
|
-
|
|
|
- # Test non-existent directory
|
|
|
- non_existent_dir = os.path.join(self.test_cache_dir, "non_existent")
|
|
|
- result = cache.get_time(non_existent_dir)
|
|
|
- self.assertEqual(result, float("inf"))
|
|
|
-
|
|
|
- def test_write_time(self):
|
|
|
- test_dir = os.path.join(self.test_cache_dir, "test_dir")
|
|
|
- os.makedirs(test_dir)
|
|
|
-
|
|
|
- cache.write_time(test_dir)
|
|
|
-
|
|
|
- self.assertTrue(os.path.exists(os.path.join(test_dir, cache.time_filename)))
|
|
|
- with open(os.path.join(test_dir, cache.time_filename)) as f:
|
|
|
- time_value = float(f.read())
|
|
|
- self.assertIsInstance(time_value, float)
|
|
|
-
|
|
|
- def test_remove_extra(self):
|
|
|
- # Create more than max_cache directories
|
|
|
- for i in range(cache.max_cache + 2):
|
|
|
- dir_path = os.path.join(self.test_cache_dir, f"dir{i}")
|
|
|
- os.makedirs(dir_path)
|
|
|
- time.sleep(0.1) # Ensure different timestamps
|
|
|
- cache.write_time(dir_path)
|
|
|
-
|
|
|
- cache.remove_extra()
|
|
|
-
|
|
|
- remaining_dirs = cache.get_dirs()
|
|
|
- self.assertLessEqual(len(remaining_dirs), cache.max_cache)
|
|
|
-
|
|
|
- def test_cache_operations(self):
|
|
|
- test_hash = "test123hash"
|
|
|
- test_para_hash = "para456hash"
|
|
|
- test_content = "Test paragraph content"
|
|
|
-
|
|
|
- # Test cache creation
|
|
|
- self.assertFalse(cache.is_cached(test_hash))
|
|
|
- cache.create_cache(test_hash)
|
|
|
- self.assertTrue(cache.is_cached(test_hash))
|
|
|
-
|
|
|
- # Test paragraph operations
|
|
|
- self.assertIsNone(cache.load_paragraph(test_hash, test_para_hash))
|
|
|
- cache.write_paragraph(test_hash, test_para_hash, test_content)
|
|
|
- self.assertEqual(cache.load_paragraph(test_hash, test_para_hash), test_content)
|
|
|
+ # Clean up
|
|
|
+ cache.clean_test_db(self.test_db)
|
|
|
+
|
|
|
+ def test_basic_set_get(self):
|
|
|
+ """Test basic set and get operations"""
|
|
|
+ cache_instance = cache.TranslationCache("test_engine")
|
|
|
+
|
|
|
+ # Test get with non-existent entry
|
|
|
+ result = cache_instance.get("hello")
|
|
|
+ self.assertIsNone(result)
|
|
|
+
|
|
|
+ # Test set and get
|
|
|
+ cache_instance.set("hello", "你好")
|
|
|
+ result = cache_instance.get("hello")
|
|
|
+ self.assertEqual(result, "你好")
|
|
|
+
|
|
|
+ def test_cache_overwrite(self):
|
|
|
+ """Test that cache entries can be overwritten"""
|
|
|
+ cache_instance = cache.TranslationCache("test_engine")
|
|
|
+
|
|
|
+ # Set initial translation
|
|
|
+ cache_instance.set("hello", "你好")
|
|
|
+
|
|
|
+ # Overwrite with new translation
|
|
|
+ cache_instance.set("hello", "您好")
|
|
|
+
|
|
|
+ # Verify the new translation is returned
|
|
|
+ result = cache_instance.get("hello")
|
|
|
+ self.assertEqual(result, "您好")
|
|
|
+
|
|
|
+ def test_non_string_params(self):
|
|
|
+ """Test that non-string parameters are automatically converted to JSON"""
|
|
|
+ params = {"model": "gpt-3.5", "temperature": 0.7}
|
|
|
+ cache_instance = cache.TranslationCache("test_engine", params)
|
|
|
+
|
|
|
+ # Test that params are converted to JSON string internally
|
|
|
+ cache_instance.set("hello", "你好")
|
|
|
+ result = cache_instance.get("hello")
|
|
|
+ self.assertEqual(result, "你好")
|
|
|
+
|
|
|
+ # Test with different param types
|
|
|
+ array_params = ["param1", "param2"]
|
|
|
+ cache_instance2 = cache.TranslationCache("test_engine", array_params)
|
|
|
+ cache_instance2.set("hello", "你好2")
|
|
|
+ self.assertEqual(cache_instance2.get("hello"), "你好2")
|
|
|
+
|
|
|
+ # Test with nested structures
|
|
|
+ nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
|
|
|
+ cache_instance3 = cache.TranslationCache("test_engine", nested_params)
|
|
|
+ cache_instance3.set("hello", "你好3")
|
|
|
+ self.assertEqual(cache_instance3.get("hello"), "你好3")
|
|
|
+
|
|
|
+ def test_engine_distinction(self):
|
|
|
+ """Test that cache distinguishes between different translation engines"""
|
|
|
+ cache1 = cache.TranslationCache("engine1")
|
|
|
+ cache2 = cache.TranslationCache("engine2")
|
|
|
+
|
|
|
+ # Set same text with different engines
|
|
|
+ cache1.set("hello", "你好 1")
|
|
|
+ cache2.set("hello", "你好 2")
|
|
|
+
|
|
|
+ # Verify each engine gets its own translation
|
|
|
+ self.assertEqual(cache1.get("hello"), "你好 1")
|
|
|
+ self.assertEqual(cache2.get("hello"), "你好 2")
|
|
|
+
|
|
|
+ def test_params_distinction(self):
|
|
|
+ """Test that cache distinguishes between different engine parameters"""
|
|
|
+ params1 = {"param": "value1"}
|
|
|
+ params2 = {"param": "value2"}
|
|
|
+ cache1 = cache.TranslationCache("test_engine", params1)
|
|
|
+ cache2 = cache.TranslationCache("test_engine", params2)
|
|
|
+
|
|
|
+ # Set same text with different parameters
|
|
|
+ cache1.set("hello", "你好 1")
|
|
|
+ cache2.set("hello", "你好 2")
|
|
|
+
|
|
|
+ # Verify each parameter set gets its own translation
|
|
|
+ self.assertEqual(cache1.get("hello"), "你好 1")
|
|
|
+ self.assertEqual(cache2.get("hello"), "你好 2")
|
|
|
+
|
|
|
+ def test_consistent_param_serialization(self):
|
|
|
+ """Test that dictionary parameters are consistently serialized regardless of key order"""
|
|
|
+ # Test simple dictionary
|
|
|
+ params1 = {"b": 1, "a": 2}
|
|
|
+ params2 = {"a": 2, "b": 1}
|
|
|
+ cache1 = cache.TranslationCache("test_engine", params1)
|
|
|
+ cache2 = cache.TranslationCache("test_engine", params2)
|
|
|
+ self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
|
|
|
+
|
|
|
+ # Test nested dictionary
|
|
|
+ params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3}
|
|
|
+ params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}}
|
|
|
+ cache1 = cache.TranslationCache("test_engine", params1)
|
|
|
+ cache2 = cache.TranslationCache("test_engine", params2)
|
|
|
+ self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
|
|
|
+
|
|
|
+ # Test dictionary with list of dictionaries
|
|
|
+ params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
|
|
|
+ params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
|
|
|
+ cache1 = cache.TranslationCache("test_engine", params1)
|
|
|
+ cache2 = cache.TranslationCache("test_engine", params2)
|
|
|
+ self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
|
|
|
+
|
|
|
+ # Test that different values still produce different results
|
|
|
+ params1 = {"a": 1, "b": 2}
|
|
|
+ params2 = {"a": 2, "b": 1}
|
|
|
+ cache1 = cache.TranslationCache("test_engine", params1)
|
|
|
+ cache2 = cache.TranslationCache("test_engine", params2)
|
|
|
+ self.assertNotEqual(
|
|
|
+ cache1.translate_engine_params, cache2.translate_engine_params
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_cache_with_sorted_params(self):
|
|
|
+ """Test that cache works correctly with sorted parameters"""
|
|
|
+ params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
|
|
|
+ params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
|
|
|
+
|
|
|
+ # Both caches should work with the same key
|
|
|
+ cache1 = cache.TranslationCache("test_engine", params1)
|
|
|
+ cache1.set("hello", "你好")
|
|
|
+
|
|
|
+ cache2 = cache.TranslationCache("test_engine", params2)
|
|
|
+ self.assertEqual(cache2.get("hello"), "你好")
|
|
|
+
|
|
|
+ def test_append_params(self):
|
|
|
+ """Test the append_params method"""
|
|
|
+ cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})
|
|
|
+
|
|
|
+ # Test appending new parameter
|
|
|
+ cache_instance.add_params("new_param", "new_value")
|
|
|
+ self.assertEqual(
|
|
|
+ cache_instance.params, {"initial": "value", "new_param": "new_value"}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Test that cache with appended params works correctly
|
|
|
+ cache_instance.set("hello", "你好")
|
|
|
+ self.assertEqual(cache_instance.get("hello"), "你好")
|
|
|
+
|
|
|
+ # Test overwriting existing parameter
|
|
|
+ cache_instance.add_params("initial", "new_value")
|
|
|
+ self.assertEqual(
|
|
|
+ cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Cache should work with updated params
|
|
|
+ cache_instance.set("hello2", "你好2")
|
|
|
+ self.assertEqual(cache_instance.get("hello2"), "你好2")
|
|
|
+
|
|
|
+ def test_thread_safety(self):
|
|
|
+ """Test thread safety of cache operations"""
|
|
|
+ cache_instance = cache.TranslationCache("test_engine")
|
|
|
+ lock = threading.Lock()
|
|
|
+ results = []
|
|
|
+ num_threads = multiprocessing.cpu_count()
|
|
|
+ items_per_thread = 100
|
|
|
+
|
|
|
+ def generate_random_text(length=10):
|
|
|
+ return "".join(
|
|
|
+ random.choices(string.ascii_letters + string.digits, k=length)
|
|
|
+ )
|
|
|
+
|
|
|
+ def worker():
|
|
|
+ thread_results = [] # 线程本地存储结果
|
|
|
+ for _ in range(items_per_thread):
|
|
|
+ text = generate_random_text()
|
|
|
+ translation = f"翻译_{text}"
|
|
|
+
|
|
|
+ # Write operation
|
|
|
+ cache_instance.set(text, translation)
|
|
|
+
|
|
|
+ # Read operation - verify our own write
|
|
|
+ result = cache_instance.get(text)
|
|
|
+ thread_results.append((text, result))
|
|
|
+
|
|
|
+ # 所有操作完成后,一次性加锁并追加结果
|
|
|
+ with lock:
|
|
|
+ results.extend(thread_results)
|
|
|
+
|
|
|
+ # Create threads equal to CPU core count
|
|
|
+ threads = []
|
|
|
+ for _ in range(num_threads):
|
|
|
+ thread = threading.Thread(target=worker)
|
|
|
+ threads.append(thread)
|
|
|
+ thread.start()
|
|
|
+
|
|
|
+ # Wait for all threads to complete
|
|
|
+ for thread in threads:
|
|
|
+ thread.join()
|
|
|
+
|
|
|
+ # Verify all operations were successful
|
|
|
+ expected_total = num_threads * items_per_thread
|
|
|
+ self.assertEqual(len(results), expected_total)
|
|
|
+
|
|
|
+ # Verify each thread got its correct value
|
|
|
+ for text, result in results:
|
|
|
+ expected = f"翻译_{text}"
|
|
|
+ self.assertEqual(result, expected)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|