Explorar o código

Merge branch 'main' into main

Byaidu hai 1 ano
pai
achega
c9d9e675bb

+ 0 - 2
.github/workflows/python-build.yml

@@ -2,8 +2,6 @@ name: Test and Build Python Package
 
 on:
   push:
-    branches:
-      - main
   pull_request:
 
 jobs:

+ 19 - 15
docs/README_ja-JP.md

@@ -2,7 +2,7 @@
 
 [English](../README.md) | [简体中文](README_zh-CN.md) | 日本語
 
-<img src="./docs/images/banner.png" width="320px"  alt="PDF2ZH"/>  
+<img src="./images/banner.png" width="320px"  alt="PDF2ZH"/>  
 
 <h2 id="title">PDFMathTranslate</h2>
 
@@ -56,7 +56,7 @@
 <h2 id="preview">プレビュー</h2>
 
 <div align="center">
-<img src="./docs/images/preview.gif" width="80%"/>
+<img src="./images/preview.gif" width="80%"/>
 </div>
 
 <h2 id="demo">公共サービス 🌟</h2>
@@ -122,9 +122,9 @@ Python環境を事前にインストールする必要はありません
     http://localhost:7860/
     ```
 
-    <img src="./docs/images/gui.gif" width="500"/>
+    <img src="./images/gui.gif" width="500"/>
 
-詳細については、[GUIのドキュメント](./docs/README_GUI.md) を参照してください。
+詳細については、[GUIのドキュメント](./README_GUI.md) を参照してください。
 
 <h3 id="docker">方法4. Docker</h3>
 
@@ -158,7 +158,7 @@ Python環境を事前にインストールする必要はありません
 
 コマンドラインで翻訳コマンドを実行し、現在の作業ディレクトリに翻訳されたドキュメント `example-mono.pdf` とバイリンガルドキュメント `example-dual.pdf` を生成します。デフォルトではGoogle翻訳サービスを使用します。
 
-<img src="./docs/images/cmd.explained.png" width="580px"  alt="cmd"/>  
+<img src="./images/cmd.explained.png" width="580px"  alt="cmd"/>  
 
 以下の表に、参考のためにすべての高度なオプションをリストしました:
 
@@ -263,16 +263,18 @@ pdf2zh example.pdf -f "(CM[^R]|(MS|XY|MT|BL|RM|EU|LA|RS)[A-Z]|LINE|LCIRCLE|TeX-|
 pdf2zh example.pdf -t 1
 ```
 
-<h3 id="prompt">custom prompt</h3>
-(need Japenese translation)
-Use `--prompt` to specify which prompt to use in llm:
+<h3 id="prompt">カスタム プロンプト</h3>
+
+`--prompt`を使用して、LLMで使用するプロンプトを指定します:
+
 ```bash
 pdf2zh example.pdf -pr prompt.txt
 ```
 
 
-example prompt.txt
-```
+`prompt.txt`の例:
+
+```txt
 [
     {
         "role": "system",
@@ -286,12 +288,14 @@ example prompt.txt
 ```
 
 
-In custom prompt file, there are three variables can be used.
-|**variables**|**comment**|
+カスタムプロンプトファイルでは、以下の3つの変数が使用できます。
+
+|**変数**|**内容**|
 |-|-|
-|`lang_in`|input language|
-|`lang_out`|output language|
-|`text`|text need to be translated|
+|`lang_in`|ソース言語|
+|`lang_out`|ターゲット言語|
+|`text`|翻訳するテキスト|
+
 <h2 id="todo">API</h2>
 
 ### Python

+ 5 - 5
docs/README_zh-CN.md

@@ -2,7 +2,7 @@
 
 [English](../README.md) | 简体中文 | [日本語](README_ja-JP.md)
 
-<img src="./docs/images/banner.png" width="320px"  alt="PDF2ZH"/>  
+<img src="./images/banner.png" width="320px"  alt="PDF2ZH"/>  
 
 <h2 id="title">PDFMathTranslate</h2>
 
@@ -56,7 +56,7 @@
 <h2 id="preview">效果预览</h2>
 
 <div align="center">
-<img src="./docs/images/preview.gif" width="80%"/>
+<img src="./images/preview.gif" width="80%"/>
 </div>
 
 <h2 id="demo">在线演示 🌟</h2>
@@ -121,9 +121,9 @@ set HF_ENDPOINT=https://hf-mirror.com
     http://localhost:7860/
     ```
 
-    <img src="./docs/images/gui.gif" width="500"/>
+    <img src="./images/gui.gif" width="500"/>
 
-查看 [documentation for GUI](./docs/README_GUI.md) 获取细节说明
+查看 [documentation for GUI](/README_GUI.md) 获取细节说明
 
 <h3 id="docker">方法四、容器化部署</h3>
 
@@ -157,7 +157,7 @@ set HF_ENDPOINT=https://hf-mirror.com
 
 在命令行中执行翻译命令,在当前工作目录下生成译文文档 `example-mono.pdf` 和双语对照文档 `example-dual.pdf`,默认使用 Google 翻译服务
 
-<img src="./docs/images/cmd.explained.png" width="580px"  alt="cmd"/>  
+<img src="./images/cmd.explained.png" width="580px"  alt="cmd"/>  
 
 我们在下表中列出了所有高级选项,以供参考:
 

+ 140 - 90
pdf2zh/cache.py

@@ -1,91 +1,141 @@
-import tempfile
 import os
-import time
-import hashlib
-import shutil
-
-cache_dir = os.path.join(tempfile.gettempdir(), "cache")
-os.makedirs(cache_dir, exist_ok=True)
-time_filename = "update_time"
-max_cache = 5
-
-
-def deterministic_hash(obj):
-    hash_object = hashlib.sha256()
-    hash_object.update(str(obj).encode())
-    return hash_object.hexdigest()[0:20]
-
-
-def get_dirs():
-    dirs = [
-        os.path.join(cache_dir, dir)
-        for dir in os.listdir(cache_dir)
-        if os.path.isdir(os.path.join(cache_dir, dir))
-    ]
-    return dirs
-
-
-def get_time(dir):
-    try:
-        timefile = os.path.join(dir, time_filename)
-        t = float(open(timefile, encoding="utf-8").read())
-        return t
-    except FileNotFoundError:
-        # handle the error as needed, for now we'll just return a default value
-        return float(
-            "inf"
-        )  # This ensures that this directory will be the first to be removed if required
-
-
-def write_time(dir):
-    timefile = os.path.join(dir, time_filename)
-    t = time.time()
-    print(t, file=open(timefile, "w", encoding="utf-8"), end="")
-
-
-def argmin(iterable):
-    return min(enumerate(iterable), key=lambda x: x[1])[0]
-
-
-def remove_extra():
-    dirs = get_dirs()
-    for dir in dirs:
-        if not os.path.isdir(
-            dir
-        ):  # This line might be redundant now, as get_dirs() ensures only directories are returned
-            os.remove(dir)
-        try:
-            get_time(dir)
-        except BaseException:
-            shutil.rmtree(dir)
-    while True:
-        dirs = get_dirs()
-        if len(dirs) <= max_cache:
-            break
-        times = [get_time(dir) for dir in dirs]
-        arg = argmin(times)
-        shutil.rmtree(dirs[arg])
-
-
-def is_cached(hash_key):
-    dir = os.path.join(cache_dir, hash_key)
-    return os.path.exists(dir)
-
-
-def create_cache(hash_key):
-    dir = os.path.join(cache_dir, hash_key)
-    os.makedirs(dir, exist_ok=True)
-    write_time(dir)
-
-
-def load_paragraph(hash_key, hash_key_paragraph):
-    filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
-    if os.path.exists(filename):
-        return open(filename, encoding="utf-8").read()
-    else:
-        return None
-
-
-def write_paragraph(hash_key, hash_key_paragraph, paragraph):
-    filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
-    print(paragraph, file=open(filename, "w", encoding="utf-8"), end="")
+import json
+from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL
+from typing import Optional
+
+
+# we don't init the database here
+db = SqliteDatabase(None)
+
+
+class _TranslationCache(Model):
+    id = AutoField()
+    translate_engine = CharField(max_length=20)
+    translate_engine_params = TextField()
+    original_text = TextField()
+    translation = TextField()
+
+    class Meta:
+        database = db
+        constraints = [
+            SQL(
+                """
+            UNIQUE (
+                translate_engine,
+                translate_engine_params,
+                original_text
+                )
+            ON CONFLICT REPLACE
+            """
+            )
+        ]
+
+
+class TranslationCache:
+    @staticmethod
+    def _sort_dict_recursively(obj):
+        if isinstance(obj, dict):
+            return {
+                k: TranslationCache._sort_dict_recursively(v)
+                for k in sorted(obj.keys())
+                for v in [obj[k]]
+            }
+        elif isinstance(obj, list):
+            return [TranslationCache._sort_dict_recursively(item) for item in obj]
+        return obj
+
+    def __init__(self, translate_engine: str, translate_engine_params: dict = None):
+        assert (
+            len(translate_engine) < 20
+        ), "current cache require translate engine name less than 20 characters"
+        self.translate_engine = translate_engine
+        self.replace_params(translate_engine_params)
+
+    # The program typically starts multi-threaded translation
+    # only after cache parameters are fully configured,
+    # so thread safety doesn't need to be considered here.
+    def replace_params(self, params: dict = None):
+        if params is None:
+            params = {}
+        self.params = params
+        params = self._sort_dict_recursively(params)
+        self.translate_engine_params = json.dumps(params)
+
+    def update_params(self, params: dict = None):
+        if params is None:
+            params = {}
+        self.params.update(params)
+        self.replace_params(self.params)
+
+    def add_params(self, k: str, v):
+        self.params[k] = v
+        self.replace_params(self.params)
+
+    # Since peewee and the underlying sqlite are thread-safe,
+    # get and set operations don't need locks.
+    def get(self, original_text: str) -> Optional[str]:
+        result = _TranslationCache.get_or_none(
+            translate_engine=self.translate_engine,
+            translate_engine_params=self.translate_engine_params,
+            original_text=original_text,
+        )
+        return result.translation if result else None
+
+    def set(self, original_text: str, translation: str):
+        _TranslationCache.create(
+            translate_engine=self.translate_engine,
+            translate_engine_params=self.translate_engine_params,
+            original_text=original_text,
+            translation=translation,
+        )
+
+
+def init_db(remove_exists=False):
+    cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh")
+    os.makedirs(cache_folder, exist_ok=True)
+    # The current version does not support database migration, so add the version number to the file name.
+    cache_db_path = os.path.join(cache_folder, "cache.v1.db")
+    if remove_exists and os.path.exists(cache_db_path):
+        os.remove(cache_db_path)
+    db.init(
+        cache_db_path,
+        pragmas={
+            "journal_mode": "wal",
+            "busy_timeout": 1000,
+        },
+    )
+    db.create_tables([_TranslationCache], safe=True)
+
+
+def init_test_db():
+    import tempfile
+
+    cache_db_path = tempfile.mktemp(suffix=".db")
+    test_db = SqliteDatabase(
+        cache_db_path,
+        pragmas={
+            "journal_mode": "wal",
+            "busy_timeout": 1000,
+        },
+    )
+    test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False)
+    test_db.connect()
+    test_db.create_tables([_TranslationCache], safe=True)
+    return test_db
+
+
+def clean_test_db(test_db):
+    test_db.drop_tables([_TranslationCache])
+    test_db.close()
+    db_path = test_db.database
+    if os.path.exists(db_path):
+        os.remove(test_db.database)
+    wal_path = db_path + "-wal"
+    if os.path.exists(wal_path):
+        os.remove(wal_path)
+    shm_path = db_path + "-shm"
+    if os.path.exists(shm_path):
+        os.remove(shm_path)
+
+
+init_db()

+ 1 - 10
pdf2zh/converter.py

@@ -17,7 +17,6 @@ import concurrent.futures
 import numpy as np
 import unicodedata
 from tenacity import retry, wait_fixed
-from pdf2zh import cache
 from pdf2zh.translator import (
     AzureOpenAITranslator,
     BaseTranslator,
@@ -329,21 +328,13 @@ class TranslateConverter(PDFConverterEx):
         ############################################################
         # B. 段落翻译
         log.debug("\n==========[SSTACK]==========\n")
-        hash_key = cache.deterministic_hash("PDFMathTranslate")
-        cache.create_cache(hash_key)
 
         @retry(wait=wait_fixed(1))
         def worker(s: str):  # 多线程翻译
             if not s.strip() or re.match(r"^\{v\d+\}$", s):  # 空白和公式不翻译
                 return s
             try:
-                hash_key_paragraph = cache.deterministic_hash(
-                    (s, str(self.translator))
-                )
-                new = cache.load_paragraph(hash_key, hash_key_paragraph)  # 查询缓存
-                if new is None:
-                    new = self.translator.translate(s)
-                    cache.write_paragraph(hash_key, hash_key_paragraph, new)
+                new = self.translator.translate(s)
                 return new
             except BaseException as e:
                 if log.isEnabledFor(logging.DEBUG):

+ 74 - 16
pdf2zh/translator.py

@@ -8,6 +8,7 @@ import deepl
 import ollama
 import openai
 import requests
+from pdf2zh.cache import TranslationCache
 from azure.ai.translation.text import TextTranslationClient
 from azure.core.credentials import AzureKeyCredential
 from tencentcloud.common import credential
@@ -29,6 +30,7 @@ class BaseTranslator:
     envs = {}
     lang_map = {}
     CustomPrompt = False
+    ignore_cache = False
 
     def __init__(self, lang_in, lang_out, model):
         lang_in = self.lang_map.get(lang_in.lower(), lang_in)
@@ -37,6 +39,15 @@ class BaseTranslator:
         self.lang_out = lang_out
         self.model = model
 
+        self.cache = TranslationCache(
+            self.name,
+            {
+                "lang_in": lang_in,
+                "lang_out": lang_out,
+                "model": model,
+            },
+        )
+
     def set_envs(self, envs):
         # Detach from self.__class__.envs
         # Cannot use self.envs = copy(self.__class__.envs)
@@ -49,8 +60,36 @@ class BaseTranslator:
             for key in envs:
                 self.envs[key] = envs[key]
 
-    def translate(self, text):
-        pass
+    def add_cache_impact_parameters(self, k: str, v):
+        """
+        Add parameters that affect the translation quality to distinguish the translation effects under different parameters.
+        :param k: key
+        :param v: value
+        """
+        self.cache.add_params(k, v)
+
+    def translate(self, text, ignore_cache=False):
+        """
+        Translate the text, and the other part should call this method.
+        :param text: text to translate
+        :return: translated text
+        """
+        if not (self.ignore_cache or ignore_cache):
+            cache = self.cache.get(text)
+            if cache is not None:
+                return cache
+
+        translation = self.do_translate(text)
+        self.cache.set(text, translation)
+        return translation
+
+    def do_translate(self, text):
+        """
+        Actual translate text, override this method
+        :param text: text to translate
+        :return: translated text
+        """
+        raise NotImplementedError
 
     def prompt(self, text, prompt):
         if prompt:
@@ -88,7 +127,7 @@ class GoogleTranslator(BaseTranslator):
             "User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)"  # noqa: E501
         }
 
-    def translate(self, text):
+    def do_translate(self, text):
         text = text[:5000]  # google translate max length
         response = self.session.get(
             self.endpoint,
@@ -119,7 +158,7 @@ class BingTranslator(BaseTranslator):
             "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36 Edg/131.0.0.0",  # noqa: E501
         }
 
-    def findSID(self):
+    def find_sid(self):
         response = self.session.get(self.endpoint)
         response.raise_for_status()
         url = response.url[:-10]
@@ -130,9 +169,9 @@ class BingTranslator(BaseTranslator):
         )[0]
         return url, ig, iid, key, token
 
-    def translate(self, text):
+    def do_translate(self, text):
         text = text[:1000]  # bing translate max length
-        url, ig, iid, key, token = self.findSID()
+        url, ig, iid, key, token = self.find_sid()
         response = self.session.post(
             f"{url}ttranslatev3?IG={ig}&IID={iid}",
             data={
@@ -162,7 +201,7 @@ class DeepLTranslator(BaseTranslator):
         auth_key = self.envs["DEEPL_AUTH_KEY"]
         self.client = deepl.Translator(auth_key)
 
-    def translate(self, text):
+    def do_translate(self, text):
         response = self.client.translate_text(
             text, target_lang=self.lang_out, source_lang=self.lang_in
         )
@@ -183,7 +222,7 @@ class DeepLXTranslator(BaseTranslator):
         self.endpoint = self.envs["DEEPLX_ENDPOINT"]
         self.session = requests.Session()
 
-    def translate(self, text):
+    def do_translate(self, text):
         response = self.session.post(
             self.endpoint,
             json={
@@ -213,8 +252,11 @@ class OllamaTranslator(BaseTranslator):
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = ollama.Client()
         self.prompttext = prompt
+        self.add_cache_impact_parameters("temperature", self.options["temperature"])
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text):
+    def do_translate(self, text):
         maxlen = max(2000, len(text) * 5)
         for model in self.model.split(";"):
             try:
@@ -263,8 +305,11 @@ class OpenAITranslator(BaseTranslator):
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        self.add_cache_impact_parameters("temperature", self.options["temperature"])
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         response = self.client.chat.completions.create(
             model=self.model,
             **self.options,
@@ -305,8 +350,11 @@ class AzureOpenAITranslator(BaseTranslator):
             api_key=api_key,
         )
         self.prompttext = prompt
+        self.add_cache_impact_parameters("temperature", self.options["temperature"])
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         response = self.client.chat.completions.create(
             model=self.model,
             **self.options,
@@ -341,6 +389,8 @@ class ModelScopeTranslator(OpenAITranslator):
             model = self.envs["MODELSCOPE_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
 
 class ZhipuTranslator(OpenAITranslator):
@@ -360,8 +410,10 @@ class ZhipuTranslator(OpenAITranslator):
             model = self.envs["ZHIPU_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         try:
             response = self.client.chat.completions.create(
                 model=self.model,
@@ -395,6 +447,8 @@ class SiliconTranslator(OpenAITranslator):
             model = self.envs["SILICON_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
 
 class GeminiTranslator(OpenAITranslator):
@@ -414,6 +468,8 @@ class GeminiTranslator(OpenAITranslator):
             model = self.envs["GEMINI_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
 
 class AzureTranslator(BaseTranslator):
@@ -438,7 +494,7 @@ class AzureTranslator(BaseTranslator):
         logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
         logger.setLevel(logging.WARNING)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         response = self.client.translate(
             body=[text],
             from_language=self.lang_in,
@@ -466,7 +522,7 @@ class TencentTranslator(BaseTranslator):
         self.req.Target = self.lang_out
         self.req.ProjectId = 0
 
-    def translate(self, text):
+    def do_translate(self, text):
         self.req.SourceText = text
         resp: TextTranslateResponse = self.client.TextTranslate(self.req)
         return resp.TargetText
@@ -491,8 +547,10 @@ class AnythingLLMTranslator(BaseTranslator):
             "Content-Type": "application/json",
         }
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text):
+    def do_translate(self, text):
         messages = self.prompt(text, self.prompttext)
         payload = {
             "message": messages,
@@ -523,7 +581,7 @@ class DifyTranslator(BaseTranslator):
         self.api_url = self.envs["DIFY_API_URL"]
         self.api_key = self.envs["DIFY_API_KEY"]
 
-    def translate(self, text):
+    def do_translate(self, text):
         headers = {
             "Authorization": f"Bearer {self.api_key}",
             "Content-Type": "application/json",

+ 4 - 0
pyproject.toml

@@ -56,3 +56,7 @@ build-backend = "hatchling.build"
 
 [project.scripts]
 pdf2zh = "pdf2zh.pdf2zh:main"
+
+[tool.flake8]
+ignore = ["E203", "E261", "E501", "W503", "E741"]
+max-line-length = 88

+ 202 - 96
test/test_cache.py

@@ -1,107 +1,213 @@
 import unittest
-import os
-import tempfile
-import shutil
-import time
 from pdf2zh import cache
+import threading
+import multiprocessing
+import random
+import string
 
 
 class TestCache(unittest.TestCase):
     def setUp(self):
-        # Create a temporary directory for testing
-        self.test_cache_dir = os.path.join(tempfile.gettempdir(), "test_cache")
-        self.original_cache_dir = cache.cache_dir
-        cache.cache_dir = self.test_cache_dir
-        os.makedirs(self.test_cache_dir, exist_ok=True)
+        self.test_db = cache.init_test_db()
 
     def tearDown(self):
-        # Clean up the test directory
-        shutil.rmtree(self.test_cache_dir)
-        cache.cache_dir = self.original_cache_dir
-
-    def test_deterministic_hash(self):
-        # Test hash generation for different inputs
-        test_input = "Hello World"
-        hash1 = cache.deterministic_hash(test_input)
-        hash2 = cache.deterministic_hash(test_input)
-        self.assertEqual(hash1, hash2)
-        self.assertEqual(len(hash1), 20)
-
-        # Test different inputs produce different hashes
-        hash3 = cache.deterministic_hash("Different input")
-        self.assertNotEqual(hash1, hash3)
-
-    def test_get_dirs(self):
-        # Create test directories
-        test_dirs = ["dir1", "dir2", "dir3"]
-        for dir_name in test_dirs:
-            os.makedirs(os.path.join(self.test_cache_dir, dir_name))
-
-        # Create a file (should be ignored)
-        with open(os.path.join(self.test_cache_dir, "test.txt"), "w") as f:
-            f.write("test")
-
-        dirs = cache.get_dirs()
-        self.assertEqual(len(dirs), 3)
-        for dir_path in dirs:
-            self.assertTrue(os.path.isdir(dir_path))
-
-    def test_get_time(self):
-        # Create test directory with time file
-        test_dir = os.path.join(self.test_cache_dir, "test_dir")
-        os.makedirs(test_dir)
-        test_time = 1234567890.0
-
-        with open(os.path.join(test_dir, cache.time_filename), "w") as f:
-            f.write(str(test_time))
-
-        # Test reading time
-        result = cache.get_time(test_dir)
-        self.assertEqual(result, test_time)
-
-        # Test non-existent directory
-        non_existent_dir = os.path.join(self.test_cache_dir, "non_existent")
-        result = cache.get_time(non_existent_dir)
-        self.assertEqual(result, float("inf"))
-
-    def test_write_time(self):
-        test_dir = os.path.join(self.test_cache_dir, "test_dir")
-        os.makedirs(test_dir)
-
-        cache.write_time(test_dir)
-
-        self.assertTrue(os.path.exists(os.path.join(test_dir, cache.time_filename)))
-        with open(os.path.join(test_dir, cache.time_filename)) as f:
-            time_value = float(f.read())
-        self.assertIsInstance(time_value, float)
-
-    def test_remove_extra(self):
-        # Create more than max_cache directories
-        for i in range(cache.max_cache + 2):
-            dir_path = os.path.join(self.test_cache_dir, f"dir{i}")
-            os.makedirs(dir_path)
-            time.sleep(0.1)  # Ensure different timestamps
-            cache.write_time(dir_path)
-
-        cache.remove_extra()
-
-        remaining_dirs = cache.get_dirs()
-        self.assertLessEqual(len(remaining_dirs), cache.max_cache)
-
-    def test_cache_operations(self):
-        test_hash = "test123hash"
-        test_para_hash = "para456hash"
-        test_content = "Test paragraph content"
-
-        # Test cache creation
-        self.assertFalse(cache.is_cached(test_hash))
-        cache.create_cache(test_hash)
-        self.assertTrue(cache.is_cached(test_hash))
-
-        # Test paragraph operations
-        self.assertIsNone(cache.load_paragraph(test_hash, test_para_hash))
-        cache.write_paragraph(test_hash, test_para_hash, test_content)
-        self.assertEqual(cache.load_paragraph(test_hash, test_para_hash), test_content)
+        # Clean up
+        cache.clean_test_db(self.test_db)
+
+    def test_basic_set_get(self):
+        """Test basic set and get operations"""
+        cache_instance = cache.TranslationCache("test_engine")
+
+        # Test get with non-existent entry
+        result = cache_instance.get("hello")
+        self.assertIsNone(result)
+
+        # Test set and get
+        cache_instance.set("hello", "你好")
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "你好")
+
+    def test_cache_overwrite(self):
+        """Test that cache entries can be overwritten"""
+        cache_instance = cache.TranslationCache("test_engine")
+
+        # Set initial translation
+        cache_instance.set("hello", "你好")
+
+        # Overwrite with new translation
+        cache_instance.set("hello", "您好")
+
+        # Verify the new translation is returned
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "您好")
+
+    def test_non_string_params(self):
+        """Test that non-string parameters are automatically converted to JSON"""
+        params = {"model": "gpt-3.5", "temperature": 0.7}
+        cache_instance = cache.TranslationCache("test_engine", params)
+
+        # Test that params are converted to JSON string internally
+        cache_instance.set("hello", "你好")
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "你好")
+
+        # Test with different param types
+        array_params = ["param1", "param2"]
+        cache_instance2 = cache.TranslationCache("test_engine", array_params)
+        cache_instance2.set("hello", "你好2")
+        self.assertEqual(cache_instance2.get("hello"), "你好2")
+
+        # Test with nested structures
+        nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
+        cache_instance3 = cache.TranslationCache("test_engine", nested_params)
+        cache_instance3.set("hello", "你好3")
+        self.assertEqual(cache_instance3.get("hello"), "你好3")
+
+    def test_engine_distinction(self):
+        """Test that cache distinguishes between different translation engines"""
+        cache1 = cache.TranslationCache("engine1")
+        cache2 = cache.TranslationCache("engine2")
+
+        # Set same text with different engines
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each engine gets its own translation
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
+
+    def test_params_distinction(self):
+        """Test that cache distinguishes between different engine parameters"""
+        params1 = {"param": "value1"}
+        params2 = {"param": "value2"}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+
+        # Set same text with different parameters
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each parameter set gets its own translation
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
+
+    def test_consistent_param_serialization(self):
+        """Test that dictionary parameters are consistently serialized regardless of key order"""
+        # Test simple dictionary
+        params1 = {"b": 1, "a": 2}
+        params2 = {"a": 2, "b": 1}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test nested dictionary
+        params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3}
+        params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test dictionary with list of dictionaries
+        params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
+        params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test that different values still produce different results
+        params1 = {"a": 1, "b": 2}
+        params2 = {"a": 2, "b": 1}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertNotEqual(
+            cache1.translate_engine_params, cache2.translate_engine_params
+        )
+
+    def test_cache_with_sorted_params(self):
+        """Test that cache works correctly with sorted parameters"""
+        params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
+        params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
+
+        # Both caches should work with the same key
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache1.set("hello", "你好")
+
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache2.get("hello"), "你好")
+
+    def test_append_params(self):
+        """Test the append_params method"""
+        cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})
+
+        # Test appending new parameter
+        cache_instance.add_params("new_param", "new_value")
+        self.assertEqual(
+            cache_instance.params, {"initial": "value", "new_param": "new_value"}
+        )
+
+        # Test that cache with appended params works correctly
+        cache_instance.set("hello", "你好")
+        self.assertEqual(cache_instance.get("hello"), "你好")
+
+        # Test overwriting existing parameter
+        cache_instance.add_params("initial", "new_value")
+        self.assertEqual(
+            cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
+        )
+
+        # Cache should work with updated params
+        cache_instance.set("hello2", "你好2")
+        self.assertEqual(cache_instance.get("hello2"), "你好2")
+
+    def test_thread_safety(self):
+        """Test thread safety of cache operations"""
+        cache_instance = cache.TranslationCache("test_engine")
+        lock = threading.Lock()
+        results = []
+        num_threads = multiprocessing.cpu_count()
+        items_per_thread = 100
+
+        def generate_random_text(length=10):
+            return "".join(
+                random.choices(string.ascii_letters + string.digits, k=length)
+            )
+
+        def worker():
+            thread_results = []  # 线程本地存储结果
+            for _ in range(items_per_thread):
+                text = generate_random_text()
+                translation = f"翻译_{text}"
+
+                # Write operation
+                cache_instance.set(text, translation)
+
+                # Read operation - verify our own write
+                result = cache_instance.get(text)
+                thread_results.append((text, result))
+
+            # 所有操作完成后,一次性加锁并追加结果
+            with lock:
+                results.extend(thread_results)
+
+        # Create threads equal to CPU core count
+        threads = []
+        for _ in range(num_threads):
+            thread = threading.Thread(target=worker)
+            threads.append(thread)
+            thread.start()
+
+        # Wait for all threads to complete
+        for thread in threads:
+            thread.join()
+
+        # Verify all operations were successful
+        expected_total = num_threads * items_per_thread
+        self.assertEqual(len(results), expected_total)
+
+        # Verify each thread got its correct value
+        for text, result in results:
+            expected = f"翻译_{text}"
+            self.assertEqual(result, expected)
 
 
 if __name__ == "__main__":

+ 0 - 22
test/test_converter.py

@@ -80,28 +80,6 @@ class TestTranslateConverter(unittest.TestCase):
         self.converter.receive_layout(mock_page)
         mock_receive_layout.assert_called_once_with(mock_page)
 
-    @patch("concurrent.futures.ThreadPoolExecutor")
-    @patch("pdf2zh.cache")
-    def test_translation(self, mock_cache, mock_executor):
-        mock_executor.return_value.__enter__.return_value.map.return_value = [
-            "你好",
-            "{v1}",
-        ]
-        mock_cache.deterministic_hash.return_value = "test_hash"
-        mock_cache.load_paragraph.return_value = None
-        mock_cache.write_paragraph.return_value = None
-
-        sstk = ["Hello", "{v1}"]
-        self.converter.thread = 2
-        results = []
-        with patch.object(self.converter, "translator") as mock_translator:
-            mock_translator.translate.side_effect = lambda x: (
-                "你好" if x == "Hello" else x
-            )
-            for s in sstk:
-                results.append(self.converter.translator.translate(s))
-        self.assertEqual(results, ["你好", "{v1}"])
-
     def test_receive_layout_with_complex_formula(self):
         ltpage = LTPage(1, (0, 0, 500, 500))
         ltchar = Mock()

+ 77 - 0
test/test_translator.py

@@ -0,0 +1,77 @@
+import unittest
+from pdf2zh.translator import BaseTranslator
+from pdf2zh import cache
+
+
+class AutoIncreaseTranslator(BaseTranslator):
+    name = "auto_increase"
+    n = 0
+
+    def do_translate(self, text):
+        self.n += 1
+        return str(self.n)
+
+
+class TestTranslator(unittest.TestCase):
+    def setUp(self):
+        self.test_db = cache.init_test_db()
+
+    def tearDown(self):
+        cache.clean_test_db(self.test_db)
+
+    def test_cache(self):
+        translator = AutoIncreaseTranslator("en", "zh", "test")
+        # First translation should be cached
+        text = "Hello World"
+        first_result = translator.translate(text)
+
+        # Second translation should return the same result from cache
+        second_result = translator.translate(text)
+        self.assertEqual(first_result, second_result)
+
+        # Different input should give different result
+        different_text = "Different Text"
+        different_result = translator.translate(different_text)
+        self.assertNotEqual(first_result, different_result)
+
+        # Test cache with ignore_cache=True
+        translator.ignore_cache = True
+        no_cache_result = translator.translate(text)
+        self.assertNotEqual(first_result, no_cache_result)
+
+    def test_add_cache_impact_parameters(self):
+        translator = AutoIncreaseTranslator("en", "zh", "test")
+
+        # Test cache with added parameters
+        text = "Hello World"
+        first_result = translator.translate(text)
+        translator.add_cache_impact_parameters("test", "value")
+        second_result = translator.translate(text)
+        self.assertNotEqual(first_result, second_result)
+
+        # Test cache with ignore_cache=True
+        no_cache_result1 = translator.translate(text, ignore_cache=True)
+        self.assertNotEqual(first_result, no_cache_result1)
+
+        translator.ignore_cache = True
+        no_cache_result2 = translator.translate(text)
+        self.assertNotEqual(no_cache_result1, no_cache_result2)
+
+        # Test cache with ignore_cache=False
+        translator.ignore_cache = False
+        cache_result = translator.translate(text)
+        self.assertEqual(no_cache_result2, cache_result)
+
+        # Test cache with another parameter
+        translator.add_cache_impact_parameters("test2", "value2")
+        another_result = translator.translate(text)
+        self.assertNotEqual(second_result, another_result)
+
+    def test_base_translator_throw(self):
+        translator = BaseTranslator("en", "zh", "test")
+        with self.assertRaises(NotImplementedError):
+            translator.translate("Hello World")
+
+
+if __name__ == "__main__":
+    unittest.main()