test_translator.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import unittest
  2. from pdf2zh import cache
  3. import threading
  4. import multiprocessing
  5. import random
  6. import string
  7. class TestCache(unittest.TestCase):
  8. def setUp(self):
  9. self.test_db = cache.init_test_db()
  10. def tearDown(self):
  11. # Clean up
  12. cache.clean_test_db(self.test_db)
  13. def test_basic_set_get(self):
  14. """Test basic set and get operations"""
  15. cache_instance = cache.TranslationCache("test_engine")
  16. # Test get with non-existent entry
  17. result = cache_instance.get("hello")
  18. self.assertIsNone(result)
  19. # Test set and get
  20. cache_instance.set("hello", "你好")
  21. result = cache_instance.get("hello")
  22. self.assertEqual(result, "你好")
  23. def test_cache_overwrite(self):
  24. """Test that cache entries can be overwritten"""
  25. cache_instance = cache.TranslationCache("test_engine")
  26. # Set initial translation
  27. cache_instance.set("hello", "你好")
  28. # Overwrite with new translation
  29. cache_instance.set("hello", "您好")
  30. # Verify the new translation is returned
  31. result = cache_instance.get("hello")
  32. self.assertEqual(result, "您好")
  33. def test_non_string_params(self):
  34. """Test that non-string parameters are automatically converted to JSON"""
  35. params = {"model": "gpt-3.5", "temperature": 0.7}
  36. cache_instance = cache.TranslationCache("test_engine", params)
  37. # Test that params are converted to JSON string internally
  38. cache_instance.set("hello", "你好")
  39. result = cache_instance.get("hello")
  40. self.assertEqual(result, "你好")
  41. # Test with different param types
  42. array_params = ["param1", "param2"]
  43. cache_instance2 = cache.TranslationCache("test_engine", array_params)
  44. cache_instance2.set("hello", "你好2")
  45. self.assertEqual(cache_instance2.get("hello"), "你好2")
  46. # Test with nested structures
  47. nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
  48. cache_instance3 = cache.TranslationCache("test_engine", nested_params)
  49. cache_instance3.set("hello", "你好3")
  50. self.assertEqual(cache_instance3.get("hello"), "你好3")
  51. def test_engine_distinction(self):
  52. """Test that cache distinguishes between different translation engines"""
  53. cache1 = cache.TranslationCache("engine1")
  54. cache2 = cache.TranslationCache("engine2")
  55. # Set same text with different engines
  56. cache1.set("hello", "你好 1")
  57. cache2.set("hello", "你好 2")
  58. # Verify each engine gets its own translation
  59. self.assertEqual(cache1.get("hello"), "你好 1")
  60. self.assertEqual(cache2.get("hello"), "你好 2")
  61. def test_params_distinction(self):
  62. """Test that cache distinguishes between different engine parameters"""
  63. params1 = {"param": "value1"}
  64. params2 = {"param": "value2"}
  65. cache1 = cache.TranslationCache("test_engine", params1)
  66. cache2 = cache.TranslationCache("test_engine", params2)
  67. # Set same text with different parameters
  68. cache1.set("hello", "你好 1")
  69. cache2.set("hello", "你好 2")
  70. # Verify each parameter set gets its own translation
  71. self.assertEqual(cache1.get("hello"), "你好 1")
  72. self.assertEqual(cache2.get("hello"), "你好 2")
  73. def test_consistent_param_serialization(self):
  74. """Test that dictionary parameters are consistently serialized regardless of key order"""
  75. # Test simple dictionary
  76. params1 = {"b": 1, "a": 2}
  77. params2 = {"a": 2, "b": 1}
  78. cache1 = cache.TranslationCache("test_engine", params1)
  79. cache2 = cache.TranslationCache("test_engine", params2)
  80. self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
  81. # Test nested dictionary
  82. params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3}
  83. params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}}
  84. cache1 = cache.TranslationCache("test_engine", params1)
  85. cache2 = cache.TranslationCache("test_engine", params2)
  86. self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
  87. # Test dictionary with list of dictionaries
  88. params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
  89. params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
  90. cache1 = cache.TranslationCache("test_engine", params1)
  91. cache2 = cache.TranslationCache("test_engine", params2)
  92. self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
  93. # Test that different values still produce different results
  94. params1 = {"a": 1, "b": 2}
  95. params2 = {"a": 2, "b": 1}
  96. cache1 = cache.TranslationCache("test_engine", params1)
  97. cache2 = cache.TranslationCache("test_engine", params2)
  98. self.assertNotEqual(
  99. cache1.translate_engine_params, cache2.translate_engine_params
  100. )
  101. def test_cache_with_sorted_params(self):
  102. """Test that cache works correctly with sorted parameters"""
  103. params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
  104. params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
  105. # Both caches should work with the same key
  106. cache1 = cache.TranslationCache("test_engine", params1)
  107. cache1.set("hello", "你好")
  108. cache2 = cache.TranslationCache("test_engine", params2)
  109. self.assertEqual(cache2.get("hello"), "你好")
  110. def test_append_params(self):
  111. """Test the append_params method"""
  112. cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})
  113. # Test appending new parameter
  114. cache_instance.add_params("new_param", "new_value")
  115. self.assertEqual(
  116. cache_instance.params, {"initial": "value", "new_param": "new_value"}
  117. )
  118. # Test that cache with appended params works correctly
  119. cache_instance.set("hello", "你好")
  120. self.assertEqual(cache_instance.get("hello"), "你好")
  121. # Test overwriting existing parameter
  122. cache_instance.add_params("initial", "new_value")
  123. self.assertEqual(
  124. cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
  125. )
  126. # Cache should work with updated params
  127. cache_instance.set("hello2", "你好2")
  128. self.assertEqual(cache_instance.get("hello2"), "你好2")
  129. def test_thread_safety(self):
  130. """Test thread safety of cache operations"""
  131. cache_instance = cache.TranslationCache("test_engine")
  132. lock = threading.Lock()
  133. results = []
  134. num_threads = multiprocessing.cpu_count()
  135. items_per_thread = 100
  136. def generate_random_text(length=10):
  137. return "".join(
  138. random.choices(string.ascii_letters + string.digits, k=length)
  139. )
  140. def worker():
  141. thread_results = [] # 线程本地存储结果
  142. for _ in range(items_per_thread):
  143. text = generate_random_text()
  144. translation = f"翻译_{text}"
  145. # Write operation
  146. cache_instance.set(text, translation)
  147. # Read operation - verify our own write
  148. result = cache_instance.get(text)
  149. thread_results.append((text, result))
  150. # 所有操作完成后,一次性加锁并追加结果
  151. with lock:
  152. results.extend(thread_results)
  153. # Create threads equal to CPU core count
  154. threads = []
  155. for _ in range(num_threads):
  156. thread = threading.Thread(target=worker)
  157. threads.append(thread)
  158. thread.start()
  159. # Wait for all threads to complete
  160. for thread in threads:
  161. thread.join()
  162. # Verify all operations were successful
  163. expected_total = num_threads * items_per_thread
  164. self.assertEqual(len(results), expected_total)
  165. # Verify each thread got its correct value
  166. for text, result in results:
  167. expected = f"翻译_{text}"
  168. self.assertEqual(result, expected)
  169. if __name__ == "__main__":
  170. unittest.main()