Просмотр исходного кода

♻️ refactor(cache): enhance parameter handling in translation cache

- update parameter type annotations for clarity
- improve default parameter handling in update_params method
- ensure thread safety considerations are noted in comments
awwaawwa 1 год назад
Родитель
Сommit
27971979df
1 измененных файлов с 57 добавлено и 10 удалено
  1. 57 10
      pdf2zh/cache.py

+ 57 - 10
pdf2zh/cache.py

@@ -1,6 +1,8 @@
 import os
 import json
 from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL
+from typing import Optional
+
 
 # we don't init the database here
 db = SqliteDatabase(None)
@@ -33,29 +35,44 @@ class TranslationCache:
     @staticmethod
     def _sort_dict_recursively(obj):
         if isinstance(obj, dict):
-            return {k: TranslationCache._sort_dict_recursively(v) for k in sorted(obj.keys()) for v in [obj[k]]}
+            return {
+                k: TranslationCache._sort_dict_recursively(v)
+                for k in sorted(obj.keys())
+                for v in [obj[k]]
+            }
         elif isinstance(obj, list):
             return [TranslationCache._sort_dict_recursively(item) for item in obj]
         return obj
 
-    def __init__(self, translate_engine, translate_engine_params):
+    def __init__(self, translate_engine: str, translate_engine_params: dict = None):
         self.translate_engine = translate_engine
         self.update_params(translate_engine_params)
 
-    def update_params(self, params):
-        if not isinstance(params, str):
-            params = self._sort_dict_recursively(params)
-            params = json.dumps(params)
-        self.translate_engine_params = params
+    # The program typically starts multi-threaded translation
+    # only after cache parameters are fully configured,
+    # so thread safety doesn't need to be considered here.
+    def update_params(self, params: dict = None):
+        if params is None:
+            params = {}
+        self.params = params
+        params = self._sort_dict_recursively(params)
+        self.translate_engine_params = json.dumps(params)
+
+    def append_params(self, k: str, v):
+        self.params[k] = v
+        self.update_params(self.params)
 
-    def get(self, original_text):
-        return _TranslationCache.get_or_none(
+    # Since peewee and the underlying sqlite are thread-safe,
+    # get and set operations don't need locks.
+    def get(self, original_text: str) -> Optional[str]:
+        result = _TranslationCache.get_or_none(
             translate_engine=self.translate_engine,
             translate_engine_params=self.translate_engine_params,
             original_text=original_text,
         )
+        return result.translation if result else None
 
-    def set(self, original_text, translation):
+    def set(self, original_text: str, translation: str):
         _TranslationCache.create(
             translate_engine=self.translate_engine,
             translate_engine_params=self.translate_engine_params,
@@ -81,4 +98,34 @@ def init_db(remove_exists=False):
     db.create_tables([_TranslationCache], safe=True)
 
 
+def init_test_db():
+    import tempfile
+    cache_db_path = tempfile.mktemp(suffix=".db")
+    test_db = SqliteDatabase(
+        cache_db_path,
+        pragmas={
+            "journal_mode": "wal",
+            "busy_timeout": 1000,
+        },
+    )
+    test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False)
+    test_db.connect()
+    test_db.create_tables([_TranslationCache], safe=True)
+    return test_db
+
+
+def clean_test_db(test_db):
+    test_db.drop_tables([_TranslationCache])
+    test_db.close()
+    db_path = test_db.database
+    if os.path.exists(db_path):
+        os.remove(test_db.database)
+    wal_path = db_path + "-wal"
+    if os.path.exists(wal_path):
+        os.remove(wal_path)
+    shm_path = db_path + "-shm"
+    if os.path.exists(shm_path):
+        os.remove(shm_path)
+
+
 init_db()