|
|
@@ -1,6 +1,8 @@
|
|
|
import os
|
|
|
import json
|
|
|
from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
|
|
|
# we don't init the database here
|
|
|
db = SqliteDatabase(None)
|
|
|
@@ -33,29 +35,44 @@ class TranslationCache:
|
|
|
@staticmethod
|
|
|
def _sort_dict_recursively(obj):
|
|
|
if isinstance(obj, dict):
|
|
|
- return {k: TranslationCache._sort_dict_recursively(v) for k in sorted(obj.keys()) for v in [obj[k]]}
|
|
|
+ return {
|
|
|
+ k: TranslationCache._sort_dict_recursively(v)
|
|
|
+ for k in sorted(obj.keys())
|
|
|
+ for v in [obj[k]]
|
|
|
+ }
|
|
|
elif isinstance(obj, list):
|
|
|
return [TranslationCache._sort_dict_recursively(item) for item in obj]
|
|
|
return obj
|
|
|
|
|
|
- def __init__(self, translate_engine, translate_engine_params):
|
|
|
+ def __init__(self, translate_engine: str, translate_engine_params: dict = None):
|
|
|
self.translate_engine = translate_engine
|
|
|
self.update_params(translate_engine_params)
|
|
|
|
|
|
- def update_params(self, params):
|
|
|
- if not isinstance(params, str):
|
|
|
- params = self._sort_dict_recursively(params)
|
|
|
- params = json.dumps(params)
|
|
|
- self.translate_engine_params = params
|
|
|
+ # The program typically starts multi-threaded translation
|
|
|
+ # only after cache parameters are fully configured,
|
|
|
+ # so thread safety doesn't need to be considered here.
|
|
|
+ def update_params(self, params: dict = None):
|
|
|
+ if params is None:
|
|
|
+ params = {}
|
|
|
+ self.params = params
|
|
|
+ params = self._sort_dict_recursively(params)
|
|
|
+ self.translate_engine_params = json.dumps(params)
|
|
|
+
|
|
|
+ def append_params(self, k: str, v):
|
|
|
+ self.params[k] = v
|
|
|
+ self.update_params(self.params)
|
|
|
|
|
|
- def get(self, original_text):
|
|
|
- return _TranslationCache.get_or_none(
|
|
|
+ # Since peewee and the underlying sqlite are thread-safe,
|
|
|
+ # get and set operations don't need locks.
|
|
|
+ def get(self, original_text: str) -> Optional[str]:
|
|
|
+ result = _TranslationCache.get_or_none(
|
|
|
translate_engine=self.translate_engine,
|
|
|
translate_engine_params=self.translate_engine_params,
|
|
|
original_text=original_text,
|
|
|
)
|
|
|
+ return result.translation if result else None
|
|
|
|
|
|
- def set(self, original_text, translation):
|
|
|
+ def set(self, original_text: str, translation: str):
|
|
|
_TranslationCache.create(
|
|
|
translate_engine=self.translate_engine,
|
|
|
translate_engine_params=self.translate_engine_params,
|
|
|
@@ -81,4 +98,34 @@ def init_db(remove_exists=False):
|
|
|
db.create_tables([_TranslationCache], safe=True)
|
|
|
|
|
|
|
|
|
+def init_test_db():
|
|
|
+ import tempfile
|
|
|
+ cache_db_path = tempfile.mktemp(suffix=".db")
|
|
|
+ test_db = SqliteDatabase(
|
|
|
+ cache_db_path,
|
|
|
+ pragmas={
|
|
|
+ "journal_mode": "wal",
|
|
|
+ "busy_timeout": 1000,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False)
|
|
|
+ test_db.connect()
|
|
|
+ test_db.create_tables([_TranslationCache], safe=True)
|
|
|
+ return test_db
|
|
|
+
|
|
|
+
|
|
|
+def clean_test_db(test_db):
|
|
|
+ test_db.drop_tables([_TranslationCache])
|
|
|
+ test_db.close()
|
|
|
+ db_path = test_db.database
|
|
|
+ if os.path.exists(db_path):
|
|
|
+ os.remove(test_db.database)
|
|
|
+ wal_path = db_path + "-wal"
|
|
|
+ if os.path.exists(wal_path):
|
|
|
+ os.remove(wal_path)
|
|
|
+ shm_path = db_path + "-shm"
|
|
|
+ if os.path.exists(shm_path):
|
|
|
+ os.remove(shm_path)
|
|
|
+
|
|
|
+
|
|
|
init_db()
|