Prechádzať zdrojové kódy

Merge branch 'main' of https://github.com/Byaidu/PDFMathTranslate

Byaidu 11 mesiacov pred
rodič
commit
d3b0b4ea46

+ 0 - 2
.github/workflows/python-build.yml

@@ -2,8 +2,6 @@ name: Test and Build Python Package
 
 on:
   push:
-    branches:
-      - main
   pull_request:
 
 jobs:

+ 3 - 1
README.md

@@ -188,8 +188,10 @@ In the following table, we list all advanced options for reference:
 | `-f`, `-c`     | [Exceptions](https://github.com/Byaidu/PDFMathTranslate/blob/main/docs/ADVANCED.md#exceptions)                | `pdf2zh example.pdf -f "(MS.*)"`               |
 | `-cp`          | Compatibility Mode                                                                                            | `pdf2zh example.pdf --compatible`              |
 | `--share`      | Public link                                                                                                   | `pdf2zh -i --share`                            |
-| `--authorized` | Authorization                                                                                                 | `pdf2zh -i --authorized users.txt [auth.html]` |
+| `--authorized` | [Authorization](https://github.com/Byaidu/PDFMathTranslate/blob/main/docs/ADVANCED.md#auth)                                                                                                 | `pdf2zh -i --authorized users.txt [auth.html]` |
 | `--prompt`     | [Custom Prompt](https://github.com/Byaidu/PDFMathTranslate/blob/main/docs/ADVANCED.md#prompt)                 | `pdf2zh --prompt [prompt.txt]`                 |
+| `--onnx` | [Use Custom DocLayout-YOLO ONNX model] | `pdf2zh --onnx [onnx/model/path]` |
+| `--serverport` | [Use Custom WebUI port] | `pdf2zh --serverport 7860` |
 
 For detailed explanations, please refer to our document about [Advanced Usage](./docs/ADVANCED.md) for a full list of each option.
 

+ 42 - 1
docs/ADVANCED.md

@@ -64,6 +64,9 @@ We've provided a detailed table on the required [environment variables](https://
 | **Tencent**          | `tencent`      | `TENCENTCLOUD_SECRET_ID`, `TENCENTCLOUD_SECRET_KEY`                   | `[Your ID]`, `[Your Key]`                                | See [Tencent](https://www.tencentcloud.com/products/tmt?from_qcintl=122110104)                                                                                                                            |
 | **Dify**             | `dify`         | `DIFY_API_URL`, `DIFY_API_KEY`                                        | `[Your DIFY URL]`, `[Your Key]`                          | See [Dify](https://github.com/langgenius/dify),Three variables, lang_out, lang_in, and text, need to be defined in Dify's workflow input.                                                                 |
 | **AnythingLLM**      | `anythingllm`  | `AnythingLLM_URL`, `AnythingLLM_APIKEY`                               | `[Your AnythingLLM URL]`, `[Your Key]`                   | See [anything-llm](https://github.com/Mintplex-Labs/anything-llm)                                                                                                                                         |
+|**Argos Translate**|`argos`| | |See [argos-translate](https://github.com/argosopentech/argos-translate)|
+
+For large language models that are compatible with the OpenAI API but not listed in the table above, you can set environment variables using the same method outlined for OpenAI in the table.
 
 Use `-s service` or `-s service:model` to specify service:
 
@@ -117,7 +120,7 @@ pdf2zh example.pdf -t 1
 Use `--prompt` to specify which prompt to use in llm:
 
 ```bash
-pdf2zh example.pdf -pr prompt.txt
+pdf2zh example.pdf --prompt prompt.txt
 ```
 
 example prompt.txt
@@ -145,3 +148,41 @@ In custom prompt file, there are three variables can be used.
 [⬆️ Back to top](#toc)
 
 ---
+
+<h3 id="auth">Authorization</h3>
+
+Use `--authorized` to specify which user to use Web UI and custom the login page:
+
+```bash
+pdf2zh example.pdf --authorized users.txt auth.html
+```
+
+example users.txt
+Each line contains two elements, username, and password, separated by a comma.
+
+```
+admin,123456
+user1,password1
+user2,abc123
+guest,guest123
+test,test123
+```
+
+example auth.html
+
+```html
+<!DOCTYPE html>
+<html>
+<head>
+    <title>Simple HTML</title>
+</head>
+<body>
+    <h1>Hello, World!</h1>
+    <p>Welcome to my simple HTML page.</p>
+</body>
+</html>
+```
+
+[⬆️ Back to top](#toc)
+
+---

+ 1 - 1
docs/APIS.md

@@ -1,7 +1,7 @@
 [**Documentation**](https://github.com/Byaidu/PDFMathTranslate) > **API Details** _(current)_
 
 <h2 id="toc">Table of Content</h2>
-The present project supports two types of APIs;
+The present project supports two types of APIs, All methods need the Redis;
 
 - [Functional calls in Python](#api-python)
 - [HTTP protocols](#api-http)

+ 26 - 16
docs/README_ja-JP.md

@@ -2,7 +2,7 @@
 
 [English](../README.md) | [简体中文](README_zh-CN.md) | 日本語
 
-<img src="./docs/images/banner.png" width="320px"  alt="PDF2ZH"/>  
+<img src="./images/banner.png" width="320px"  alt="PDF2ZH"/>  
 
 <h2 id="title">PDFMathTranslate</h2>
 
@@ -56,7 +56,7 @@
 <h2 id="preview">プレビュー</h2>
 
 <div align="center">
-<img src="./docs/images/preview.gif" width="80%"/>
+<img src="./images/preview.gif" width="80%"/>
 </div>
 
 <h2 id="demo">公共サービス 🌟</h2>
@@ -122,9 +122,9 @@ Python環境を事前にインストールする必要はありません
     http://localhost:7860/
     ```
 
-    <img src="./docs/images/gui.gif" width="500"/>
+    <img src="./images/gui.gif" width="500"/>
 
-詳細については、[GUIのドキュメント](./docs/README_GUI.md) を参照してください。
+詳細については、[GUIのドキュメント](./README_GUI.md) を参照してください。
 
 <h3 id="docker">方法4. Docker</h3>
 
@@ -158,7 +158,7 @@ Python環境を事前にインストールする必要はありません
 
 コマンドラインで翻訳コマンドを実行し、現在の作業ディレクトリに翻訳されたドキュメント `example-mono.pdf` とバイリンガルドキュメント `example-dual.pdf` を生成します。デフォルトではGoogle翻訳サービスを使用します。
 
-<img src="./docs/images/cmd.explained.png" width="580px"  alt="cmd"/>  
+<img src="./images/cmd.explained.png" width="580px"  alt="cmd"/>  
 
 以下の表に、参考のためにすべての高度なオプションをリストしました:
 
@@ -175,8 +175,10 @@ Python環境を事前にインストールする必要はありません
 | `-o`  | 出力ディレクトリ | `pdf2zh example.pdf -o output` |
 | `-f`, `-c` | [例外](#exceptions) | `pdf2zh example.pdf -f "(MS.*)"` |
 | `--share` | [gradio公開リンクを取得] | `pdf2zh -i --share` |
-| `--authorized` | [ウェブ認証とカスタム認証ページの追加] | `pdf2zh -i --authorized users.txt [auth.html]` |
+| `--authorized` | [[ウェブ認証とカスタム認証ページの追加](https://github.com/Byaidu/PDFMathTranslate/blob/main/docs/ADVANCED.)] | `pdf2zh -i --authorized users.txt [auth.html]` |
 | `--prompt` | [カスタムビッグモデルのプロンプトを使用する] | `pdf2zh --prompt [prompt.txt]` |
+| `--onnx` | [カスタムDocLayout-YOLO ONNXモデルの使用] | `pdf2zh --onnx [onnx/model/path]` |
+| `--serverport` | [カスタムWebUIポートを使用する] | `pdf2zh --serverport 7860` |
 
 <h3 id="partial">全文または部分的なドキュメント翻訳</h3>
 
@@ -221,6 +223,10 @@ pdf2zh example.pdf -li en -lo ja
 |**Tencent**|`tencent`|`TENCENTCLOUD_SECRET_ID`, `TENCENTCLOUD_SECRET_KEY`|`[Your ID]`, `[Your Key]`|See [Tencent](https://www.tencentcloud.com/products/tmt?from_qcintl=122110104)|
 |**Dify**|`dify`|`DIFY_API_URL`, `DIFY_API_KEY`|`[Your DIFY URL]`, `[Your Key]`|See [Dify](https://github.com/langgenius/dify),Three variables, lang_out, lang_in, and text, need to be defined in Dify's workflow input.|
 |**AnythingLLM**|`anythingllm`|`AnythingLLM_URL`, `AnythingLLM_APIKEY`|`[Your AnythingLLM URL]`, `[Your Key]`|See [anything-llm](https://github.com/Mintplex-Labs/anything-llm)|
+|**Argos Translate**|`argos`| | |See [argos-translate](https://github.com/argosopentech/argos-translate)|
+
+(need Japenese translation)
+For large language models that are compatible with the OpenAI API but not listed in the table above, you can set environment variables using the same method outlined for OpenAI in the table.
 
 `-s service` または `-s service:model` を使用してサービスを指定します:
 
@@ -257,16 +263,18 @@ pdf2zh example.pdf -f "(CM[^R]|(MS|XY|MT|BL|RM|EU|LA|RS)[A-Z]|LINE|LCIRCLE|TeX-|
 pdf2zh example.pdf -t 1
 ```
 
-<h3 id="prompt">custom prompt</h3>
-(need Japenese translation)
-Use `--prompt` to specify which prompt to use in llm:
+<h3 id="prompt">カスタム プロンプト</h3>
+
+`--prompt`を使用して、LLMで使用するプロンプトを指定します:
+
 ```bash
 pdf2zh example.pdf -pr prompt.txt
 ```
 
 
-example prompt.txt
-```
+`prompt.txt`の例:
+
+```txt
 [
     {
         "role": "system",
@@ -280,12 +288,14 @@ example prompt.txt
 ```
 
 
-In custom prompt file, there are three variables can be used.
-|**variables**|**comment**|
+カスタムプロンプトファイルでは、以下の3つの変数が使用できます。
+
+|**変数**|**内容**|
 |-|-|
-|`lang_in`|input language|
-|`lang_out`|output language|
-|`text`|text need to be translated|
+|`lang_in`|ソース言語|
+|`lang_out`|ターゲット言語|
+|`text`|翻訳するテキスト|
+
 <h2 id="todo">API</h2>
 
 ### Python

+ 11 - 6
docs/README_zh-CN.md

@@ -2,7 +2,7 @@
 
 [English](../README.md) | 简体中文 | [日本語](README_ja-JP.md)
 
-<img src="./docs/images/banner.png" width="320px"  alt="PDF2ZH"/>  
+<img src="./images/banner.png" width="320px"  alt="PDF2ZH"/>  
 
 <h2 id="title">PDFMathTranslate</h2>
 
@@ -58,7 +58,7 @@
 <h2 id="preview">效果预览</h2>
 
 <div align="center">
-<img src="./docs/images/preview.gif" width="80%"/>
+<img src="./images/preview.gif" width="80%"/>
 </div>
 
 <h2 id="demo">在线演示 🌟</h2>
@@ -123,9 +123,9 @@ set HF_ENDPOINT=https://hf-mirror.com
     http://localhost:7860/
     ```
 
-    <img src="./docs/images/gui.gif" width="500"/>
+    <img src="./images/gui.gif" width="500"/>
 
-查看 [documentation for GUI](./docs/README_GUI.md) 获取细节说明
+查看 [documentation for GUI](/README_GUI.md) 获取细节说明
 
 <h3 id="docker">方法四、容器化部署</h3>
 
@@ -159,7 +159,7 @@ set HF_ENDPOINT=https://hf-mirror.com
 
 在命令行中执行翻译命令,在当前工作目录下生成译文文档 `example-mono.pdf` 和双语对照文档 `example-dual.pdf`,默认使用 Google 翻译服务
 
-<img src="./docs/images/cmd.explained.png" width="580px"  alt="cmd"/>  
+<img src="./images/cmd.explained.png" width="580px"  alt="cmd"/>  
 
 我们在下表中列出了所有高级选项,以供参考:
 
@@ -176,8 +176,10 @@ set HF_ENDPOINT=https://hf-mirror.com
 | `-o`  | 输出目录 | `pdf2zh example.pdf -o output` |
 | `-f`, `-c` | [例外规则](#exceptions) | `pdf2zh example.pdf -f "(MS.*)"` |
 | `--share` | [获取 gradio 公开链接] | `pdf2zh -i --share` |
-| `--authorized` | [添加网页认证和自定义认证页] | `pdf2zh -i --authorized users.txt [auth.html]` |
+| `--authorized` | [[添加网页认证和自定义认证页](https://github.com/Byaidu/PDFMathTranslate/blob/main/docs/ADVANCED.)] | `pdf2zh -i --authorized users.txt [auth.html]` |
 | `--prompt` | [使用自定义的大模型prompt] | `pdf2zh --prompt [prompt.txt]` |
+| `--onnx` | [使用自定义的 DocLayout-YOLO ONNX 模型] | `pdf2zh --onnx [onnx/model/path]` |
+| `--serverport` | [使用自定义的 WebUI 端口] | `pdf2zh --serverport 7860` |
 
 <h3 id="partial">全文或部分文档翻译</h3>
 
@@ -222,6 +224,9 @@ pdf2zh example.pdf -li en -lo ja
 |**Tencent**|`tencent`|`TENCENTCLOUD_SECRET_ID`, `TENCENTCLOUD_SECRET_KEY`|`[Your ID]`, `[Your Key]`|See [Tencent](https://www.tencentcloud.com/products/tmt?from_qcintl=122110104)|
 |**Dify**|`dify`|`DIFY_API_URL`, `DIFY_API_KEY`|`[Your DIFY URL]`, `[Your Key]`|See [Dify](https://github.com/langgenius/dify),Three variables, lang_out, lang_in, and text, need to be defined in Dify's workflow input.|
 |**AnythingLLM**|`anythingllm`|`AnythingLLM_URL`, `AnythingLLM_APIKEY`|`[Your AnythingLLM URL]`, `[Your Key]`|See [anything-llm](https://github.com/Mintplex-Labs/anything-llm)|
+|**Argos Translate**|`argos`| | |See [argos-translate](https://github.com/argosopentech/argos-translate)|
+
+对于未在上述表格中的,并且兼容 OpenAI api 的大语言模型,可使用表格中的 OpenAI 的方式进行环境变量的设置。
 
 使用 `-s service` 或 `-s service:model` 指定翻译服务:
 

+ 2 - 0
pdf2zh/backend.py

@@ -6,6 +6,7 @@ from pdf2zh import translate_stream
 import tqdm
 import json
 import io
+from pdf2zh.pdf2zh import model
 
 flask_app = Flask("pdf2zh")
 flask_app.config.from_mapping(
@@ -47,6 +48,7 @@ def translate_task(
     doc_mono, doc_dual = translate_stream(
         stream,
         callback=progress_bar,
+        model=model,
         **args,
     )
     return doc_mono, doc_dual

+ 140 - 90
pdf2zh/cache.py

@@ -1,91 +1,141 @@
-import tempfile
 import os
-import time
-import hashlib
-import shutil
-
-cache_dir = os.path.join(tempfile.gettempdir(), "cache")
-os.makedirs(cache_dir, exist_ok=True)
-time_filename = "update_time"
-max_cache = 5
-
-
-def deterministic_hash(obj):
-    hash_object = hashlib.sha256()
-    hash_object.update(str(obj).encode())
-    return hash_object.hexdigest()[0:20]
-
-
-def get_dirs():
-    dirs = [
-        os.path.join(cache_dir, dir)
-        for dir in os.listdir(cache_dir)
-        if os.path.isdir(os.path.join(cache_dir, dir))
-    ]
-    return dirs
-
-
-def get_time(dir):
-    try:
-        timefile = os.path.join(dir, time_filename)
-        t = float(open(timefile, encoding="utf-8").read())
-        return t
-    except FileNotFoundError:
-        # handle the error as needed, for now we'll just return a default value
-        return float(
-            "inf"
-        )  # This ensures that this directory will be the first to be removed if required
-
-
-def write_time(dir):
-    timefile = os.path.join(dir, time_filename)
-    t = time.time()
-    print(t, file=open(timefile, "w", encoding="utf-8"), end="")
-
-
-def argmin(iterable):
-    return min(enumerate(iterable), key=lambda x: x[1])[0]
-
-
-def remove_extra():
-    dirs = get_dirs()
-    for dir in dirs:
-        if not os.path.isdir(
-            dir
-        ):  # This line might be redundant now, as get_dirs() ensures only directories are returned
-            os.remove(dir)
-        try:
-            get_time(dir)
-        except BaseException:
-            shutil.rmtree(dir)
-    while True:
-        dirs = get_dirs()
-        if len(dirs) <= max_cache:
-            break
-        times = [get_time(dir) for dir in dirs]
-        arg = argmin(times)
-        shutil.rmtree(dirs[arg])
-
-
-def is_cached(hash_key):
-    dir = os.path.join(cache_dir, hash_key)
-    return os.path.exists(dir)
-
-
-def create_cache(hash_key):
-    dir = os.path.join(cache_dir, hash_key)
-    os.makedirs(dir, exist_ok=True)
-    write_time(dir)
-
-
-def load_paragraph(hash_key, hash_key_paragraph):
-    filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
-    if os.path.exists(filename):
-        return open(filename, encoding="utf-8").read()
-    else:
-        return None
-
-
-def write_paragraph(hash_key, hash_key_paragraph, paragraph):
-    filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
-    print(paragraph, file=open(filename, "w", encoding="utf-8"), end="")
+import json
+from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL
+from typing import Optional
+
+
+# we don't init the database here
+db = SqliteDatabase(None)
+
+
+class _TranslationCache(Model):
+    id = AutoField()
+    translate_engine = CharField(max_length=20)
+    translate_engine_params = TextField()
+    original_text = TextField()
+    translation = TextField()
+
+    class Meta:
+        database = db
+        constraints = [
+            SQL(
+                """
+            UNIQUE (
+                translate_engine,
+                translate_engine_params,
+                original_text
+                )
+            ON CONFLICT REPLACE
+            """
+            )
+        ]
+
+
+class TranslationCache:
+    @staticmethod
+    def _sort_dict_recursively(obj):
+        if isinstance(obj, dict):
+            return {
+                k: TranslationCache._sort_dict_recursively(v)
+                for k in sorted(obj.keys())
+                for v in [obj[k]]
+            }
+        elif isinstance(obj, list):
+            return [TranslationCache._sort_dict_recursively(item) for item in obj]
+        return obj
+
+    def __init__(self, translate_engine: str, translate_engine_params: dict = None):
+        assert (
+            len(translate_engine) < 20
+        ), "current cache require translate engine name less than 20 characters"
+        self.translate_engine = translate_engine
+        self.replace_params(translate_engine_params)
+
+    # The program typically starts multi-threaded translation
+    # only after cache parameters are fully configured,
+    # so thread safety doesn't need to be considered here.
+    def replace_params(self, params: dict = None):
+        if params is None:
+            params = {}
+        self.params = params
+        params = self._sort_dict_recursively(params)
+        self.translate_engine_params = json.dumps(params)
+
+    def update_params(self, params: dict = None):
+        if params is None:
+            params = {}
+        self.params.update(params)
+        self.replace_params(self.params)
+
+    def add_params(self, k: str, v):
+        self.params[k] = v
+        self.replace_params(self.params)
+
+    # Since peewee and the underlying sqlite are thread-safe,
+    # get and set operations don't need locks.
+    def get(self, original_text: str) -> Optional[str]:
+        result = _TranslationCache.get_or_none(
+            translate_engine=self.translate_engine,
+            translate_engine_params=self.translate_engine_params,
+            original_text=original_text,
+        )
+        return result.translation if result else None
+
+    def set(self, original_text: str, translation: str):
+        _TranslationCache.create(
+            translate_engine=self.translate_engine,
+            translate_engine_params=self.translate_engine_params,
+            original_text=original_text,
+            translation=translation,
+        )
+
+
+def init_db(remove_exists=False):
+    cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh")
+    os.makedirs(cache_folder, exist_ok=True)
+    # The current version does not support database migration, so add the version number to the file name.
+    cache_db_path = os.path.join(cache_folder, "cache.v1.db")
+    if remove_exists and os.path.exists(cache_db_path):
+        os.remove(cache_db_path)
+    db.init(
+        cache_db_path,
+        pragmas={
+            "journal_mode": "wal",
+            "busy_timeout": 1000,
+        },
+    )
+    db.create_tables([_TranslationCache], safe=True)
+
+
+def init_test_db():
+    import tempfile
+
+    cache_db_path = tempfile.mktemp(suffix=".db")
+    test_db = SqliteDatabase(
+        cache_db_path,
+        pragmas={
+            "journal_mode": "wal",
+            "busy_timeout": 1000,
+        },
+    )
+    test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False)
+    test_db.connect()
+    test_db.create_tables([_TranslationCache], safe=True)
+    return test_db
+
+
+def clean_test_db(test_db):
+    test_db.drop_tables([_TranslationCache])
+    test_db.close()
+    db_path = test_db.database
+    if os.path.exists(db_path):
+        os.remove(test_db.database)
+    wal_path = db_path + "-wal"
+    if os.path.exists(wal_path):
+        os.remove(wal_path)
+    shm_path = db_path + "-shm"
+    if os.path.exists(shm_path):
+        os.remove(shm_path)
+
+
+init_db()

+ 3 - 11
pdf2zh/converter.py

@@ -17,7 +17,6 @@ import concurrent.futures
 import numpy as np
 import unicodedata
 from tenacity import retry, wait_fixed
-from pdf2zh import cache
 from pdf2zh.translator import (
     AzureOpenAITranslator,
     BaseTranslator,
@@ -35,6 +34,7 @@ from pdf2zh.translator import (
     TencentTranslator,
     DifyTranslator,
     AnythingLLMTranslator,
+    ArgosTranslator,
 )
 from pymupdf import Font
 
@@ -150,7 +150,7 @@ class TranslateConverter(PDFConverterEx):
         service_name = param[0]
         service_model = param[1] if len(param) > 1 else None
         for translator in [GoogleTranslator, BingTranslator, DeepLTranslator, DeepLXTranslator, OllamaTranslator, AzureOpenAITranslator,
-                           OpenAITranslator, ZhipuTranslator, ModelScopeTranslator, SiliconTranslator, GeminiTranslator, AzureTranslator, TencentTranslator, DifyTranslator, AnythingLLMTranslator]:
+                           OpenAITranslator, ZhipuTranslator, ModelScopeTranslator, SiliconTranslator, GeminiTranslator, AzureTranslator, TencentTranslator, DifyTranslator, AnythingLLMTranslator, ArgosTranslator]:
             if service_name == translator.name:
                 self.translator = translator(lang_in, lang_out, service_model, envs=envs, prompt=prompt)
         if not self.translator:
@@ -328,21 +328,13 @@ class TranslateConverter(PDFConverterEx):
         ############################################################
         # B. 段落翻译
         log.debug("\n==========[SSTACK]==========\n")
-        hash_key = cache.deterministic_hash("PDFMathTranslate")
-        cache.create_cache(hash_key)
 
         @retry(wait=wait_fixed(1))
         def worker(s: str):  # 多线程翻译
             if not s.strip() or re.match(r"^\{v\d+\}$", s):  # 空白和公式不翻译
                 return s
             try:
-                hash_key_paragraph = cache.deterministic_hash(
-                    (s, str(self.translator))
-                )
-                new = cache.load_paragraph(hash_key, hash_key_paragraph)  # 查询缓存
-                if new is None:
-                    new = self.translator.translate(s)
-                    cache.write_paragraph(hash_key, hash_key_paragraph, new)
+                new = self.translator.translate(s)
                 return new
             except BaseException as e:
                 if log.isEnabledFor(logging.DEBUG):

+ 23 - 10
pdf2zh/gui.py

@@ -13,6 +13,7 @@ from gradio_pdf import PDF
 
 from pdf2zh import __version__
 from pdf2zh.high_level import translate
+from pdf2zh.pdf2zh import model
 from pdf2zh.translator import (
     AnythingLLMTranslator,
     AzureOpenAITranslator,
@@ -22,6 +23,7 @@ from pdf2zh.translator import (
     DeepLTranslator,
     DeepLXTranslator,
     DifyTranslator,
+    ArgosTranslator,
     GeminiTranslator,
     GoogleTranslator,
     ModelScopeTranslator,
@@ -49,6 +51,7 @@ service_map: dict[str, BaseTranslator] = {
     "Tencent": TencentTranslator,
     "Dify": DifyTranslator,
     "AnythingLLM": AnythingLLMTranslator,
+    "Argos Translate": ArgosTranslator,
 }
 
 # The following variables associate strings with specific languages
@@ -89,12 +92,6 @@ if os.getenv("PDF2ZH_DEMO"):
     client_key = os.getenv("PDF2ZH_CLIENT_KEY")
     server_key = os.getenv("PDF2ZH_SERVER_KEY")
 
-# Check if everything unconfigured
-if os.getenv("PDF2ZH_INIT") is not False:
-    service_map = {
-        "Google": GoogleTranslator,
-    }
-
 
 # Public demo control
 def verify_recaptcha(response):
@@ -269,6 +266,7 @@ def translate_file(
         "cancellation_event": cancellation_event_map[session_id],
         "envs": _envs,
         "prompt": prompt,
+        "model": model,
     }
     try:
         translate(**param)
@@ -591,7 +589,9 @@ def parse_user_passwd(file_path: str) -> tuple:
     return tuple_list, content
 
 
-def setup_gui(share: bool = False, auth_file: list = ["", ""]) -> None:
+def setup_gui(
+    share: bool = False, auth_file: list = ["", ""], server_port=7860
+) -> None:
     """
     Setup the GUI with the given parameters.
 
@@ -609,7 +609,11 @@ def setup_gui(share: bool = False, auth_file: list = ["", ""]) -> None:
         if len(user_list) == 0:
             try:
                 demo.launch(
-                    server_name="0.0.0.0", debug=True, inbrowser=True, share=share
+                    server_name="0.0.0.0",
+                    debug=True,
+                    inbrowser=True,
+                    share=share,
+                    server_port=server_port,
                 )
             except Exception:
                 print(
@@ -617,13 +621,19 @@ def setup_gui(share: bool = False, auth_file: list = ["", ""]) -> None:
                 )
                 try:
                     demo.launch(
-                        server_name="127.0.0.1", debug=True, inbrowser=True, share=share
+                        server_name="127.0.0.1",
+                        debug=True,
+                        inbrowser=True,
+                        share=share,
+                        server_port=server_port,
                     )
                 except Exception:
                     print(
                         "Error launching GUI using 127.0.0.1.\nThis may be caused by global mode of proxy software."
                     )
-                    demo.launch(debug=True, inbrowser=True, share=True)
+                    demo.launch(
+                        debug=True, inbrowser=True, share=True, server_port=server_port
+                    )
         else:
             try:
                 demo.launch(
@@ -633,6 +643,7 @@ def setup_gui(share: bool = False, auth_file: list = ["", ""]) -> None:
                     share=share,
                     auth=user_list,
                     auth_message=html,
+                    server_port=server_port,
                 )
             except Exception:
                 print(
@@ -646,6 +657,7 @@ def setup_gui(share: bool = False, auth_file: list = ["", ""]) -> None:
                         share=share,
                         auth=user_list,
                         auth_message=html,
+                        server_port=server_port,
                     )
                 except Exception:
                     print(
@@ -657,6 +669,7 @@ def setup_gui(share: bool = False, auth_file: list = ["", ""]) -> None:
                         share=True,
                         auth=user_list,
                         auth_message=html,
+                        server_port=server_port,
                     )
 
 

+ 7 - 5
pdf2zh/high_level.py

@@ -21,11 +21,9 @@ from pdfminer.pdfparser import PDFParser
 from pymupdf import Document, Font
 
 from pdf2zh.converter import TranslateConverter
-from pdf2zh.doclayout import DocLayoutModel
+from pdf2zh.doclayout import OnnxModel
 from pdf2zh.pdfinterp import PDFPageInterpreterEx
 
-model = DocLayoutModel.load_available()
-
 resfont_map = {
     "zh-cn": "china-ss",
     "zh-tw": "china-ts",
@@ -88,6 +86,7 @@ def translate_patch(
     noto: Font = None,
     callback: object = None,
     cancellation_event: asyncio.Event = None,
+    model: OnnxModel = None,
     **kwarg: Any,
 ) -> None:
     rsrcmgr = PDFResourceManager()
@@ -179,6 +178,7 @@ def translate_stream(
     vchar: str = "",
     callback: object = None,
     cancellation_event: asyncio.Event = None,
+    model: OnnxModel = None,
     **kwarg: Any,
 ):
     font_list = [("tiro", None)]
@@ -206,6 +206,8 @@ def translate_stream(
         font_list.append(("china-ss", None))
 
     doc_en = Document(stream=stream)
+    stream = io.BytesIO()
+    doc_en.save(stream)
     doc_zh = Document(stream=stream)
     page_count = doc_zh.page_count
     # font_list = [("china-ss", None), ("tiro", None)]
@@ -232,7 +234,7 @@ def translate_stream(
 
     fp = io.BytesIO()
     doc_zh.save(fp)
-    obj_patch: dict = translate_patch(fp, prompt=kwarg["prompt"], **locals())
+    obj_patch: dict = translate_patch(fp, **locals())
 
     for obj_id, ops_new in obj_patch.items():
         # ops_old=doc_en.xref_stream(obj_id)
@@ -310,6 +312,7 @@ def translate(
     callback: object = None,
     compatible: bool = False,
     cancellation_event: asyncio.Event = None,
+    model: OnnxModel = None,
     **kwarg: Any,
 ):
     if not files:
@@ -362,7 +365,6 @@ def translate(
 
         if file.startswith(tempfile.gettempdir()):
             os.unlink(file)
-
         s_mono, s_dual = translate_stream(
             s_raw,
             envs=kwarg.get("envs", {}),

+ 66 - 2
pdf2zh/pdf2zh.py

@@ -13,6 +13,8 @@ from typing import List, Optional
 
 from pdf2zh import __version__, log
 from pdf2zh.high_level import translate
+from pdf2zh.doclayout import OnnxModel
+import os
 
 
 def create_parser() -> argparse.ArgumentParser:
@@ -136,6 +138,24 @@ def create_parser() -> argparse.ArgumentParser:
         help="Convert the PDF file into PDF/A format to improve compatibility.",
     )
 
+    parse_params.add_argument(
+        "--onnx",
+        type=str,
+        help="custom onnx model path.",
+    )
+
+    parse_params.add_argument(
+        "--serverport",
+        type=int,
+        help="custom WebUI port.",
+    )
+
+    parse_params.add_argument(
+        "--dir",
+        action="store_true",
+        help="translate directory.",
+    )
+
     return parser
 
 
@@ -155,6 +175,33 @@ def parse_args(args: Optional[List[str]]) -> argparse.Namespace:
     return parsed_args
 
 
+def find_all_files_in_directory(directory_path):
+    """
+    Recursively search all PDF files in the given directory and return their paths as a list.
+
+    :param directory_path: str, the path to the directory to search
+    :return: list of PDF file paths
+    """
+    # Check if the provided path is a directory
+    if not os.path.isdir(directory_path):
+        raise ValueError(f"The provided path '{directory_path}' is not a directory.")
+
+    file_paths = []
+
+    # Walk through the directory recursively
+    for root, _, files in os.walk(directory_path):
+        for file in files:
+            # Check if the file is a PDF
+            if file.lower().endswith(".pdf"):
+                # Append the full file path to the list
+                file_paths.append(os.path.join(root, file))
+
+    return file_paths
+
+
+model = None
+
+
 def main(args: Optional[List[str]] = None) -> int:
     logging.basicConfig()
 
@@ -162,11 +209,21 @@ def main(args: Optional[List[str]] = None) -> int:
 
     if parsed_args.debug:
         log.setLevel(logging.DEBUG)
+    global model
+    if parsed_args.onnx:
+        model = OnnxModel(parsed_args.onnx)
+    else:
+        model = OnnxModel.load_available()
 
     if parsed_args.interactive:
         from pdf2zh.gui import setup_gui
 
-        setup_gui(parsed_args.share, parsed_args.authorized)
+        if parsed_args.serverport:
+            setup_gui(
+                parsed_args.share, parsed_args.authorized, int(parsed_args.serverport)
+            )
+        else:
+            setup_gui(parsed_args.share, parsed_args.authorized)
         return 0
 
     if parsed_args.flask:
@@ -189,7 +246,14 @@ def main(args: Optional[List[str]] = None) -> int:
         except Exception:
             raise ValueError("prompt error.")
 
-    translate(**vars(parsed_args))
+    if parsed_args.dir:
+        untranlate_file = find_all_files_in_directory(parsed_args.files[0])
+        parsed_args.files = untranlate_file
+        print(parsed_args)
+        translate(model=model, **vars(parsed_args))
+        return 0
+    # print(parsed_args)
+    translate(model=model, **vars(parsed_args))
     return 0
 
 

+ 116 - 16
pdf2zh/translator.py

@@ -8,12 +8,15 @@ import deepl
 import ollama
 import openai
 import requests
+from pdf2zh.cache import TranslationCache
 from azure.ai.translation.text import TextTranslationClient
 from azure.core.credentials import AzureKeyCredential
 from tencentcloud.common import credential
 from tencentcloud.tmt.v20180321.tmt_client import TmtClient
 from tencentcloud.tmt.v20180321.models import TextTranslateRequest
 from tencentcloud.tmt.v20180321.models import TextTranslateResponse
+import argostranslate.package
+import argostranslate.translate
 
 import json
 
@@ -27,6 +30,7 @@ class BaseTranslator:
     envs = {}
     lang_map = {}
     CustomPrompt = False
+    ignore_cache = False
 
     def __init__(self, lang_in, lang_out, model):
         lang_in = self.lang_map.get(lang_in.lower(), lang_in)
@@ -35,6 +39,15 @@ class BaseTranslator:
         self.lang_out = lang_out
         self.model = model
 
+        self.cache = TranslationCache(
+            self.name,
+            {
+                "lang_in": lang_in,
+                "lang_out": lang_out,
+                "model": model,
+            },
+        )
+
     def set_envs(self, envs):
         # Detach from self.__class__.envs
         # Cannot use self.envs = copy(self.__class__.envs)
@@ -47,8 +60,36 @@ class BaseTranslator:
             for key in envs:
                 self.envs[key] = envs[key]
 
-    def translate(self, text):
-        pass
+    def add_cache_impact_parameters(self, k: str, v):
+        """
+        Add parameters that affect the translation quality to distinguish the translation effects under different parameters.
+        :param k: key
+        :param v: value
+        """
+        self.cache.add_params(k, v)
+
+    def translate(self, text, ignore_cache=False):
+        """
+        Translate the text, and the other part should call this method.
+        :param text: text to translate
+        :return: translated text
+        """
+        if not (self.ignore_cache or ignore_cache):
+            cache = self.cache.get(text)
+            if cache is not None:
+                return cache
+
+        translation = self.do_translate(text)
+        self.cache.set(text, translation)
+        return translation
+
+    def do_translate(self, text):
+        """
+        Actual translate text, override this method
+        :param text: text to translate
+        :return: translated text
+        """
+        raise NotImplementedError
 
     def prompt(self, text, prompt):
         if prompt:
@@ -86,7 +127,7 @@ class GoogleTranslator(BaseTranslator):
             "User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)"  # noqa: E501
         }
 
-    def translate(self, text):
+    def do_translate(self, text):
         text = text[:5000]  # google translate max length
         response = self.session.get(
             self.endpoint,
@@ -117,7 +158,7 @@ class BingTranslator(BaseTranslator):
             "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36 Edg/131.0.0.0",  # noqa: E501
         }
 
-    def findSID(self):
+    def find_sid(self):
         response = self.session.get(self.endpoint)
         response.raise_for_status()
         url = response.url[:-10]
@@ -128,9 +169,9 @@ class BingTranslator(BaseTranslator):
         )[0]
         return url, ig, iid, key, token
 
-    def translate(self, text):
+    def do_translate(self, text):
         text = text[:1000]  # bing translate max length
-        url, ig, iid, key, token = self.findSID()
+        url, ig, iid, key, token = self.find_sid()
         response = self.session.post(
             f"{url}ttranslatev3?IG={ig}&IID={iid}",
             data={
@@ -160,7 +201,7 @@ class DeepLTranslator(BaseTranslator):
         auth_key = self.envs["DEEPL_AUTH_KEY"]
         self.client = deepl.Translator(auth_key)
 
-    def translate(self, text):
+    def do_translate(self, text):
         response = self.client.translate_text(
             text, target_lang=self.lang_out, source_lang=self.lang_in
         )
@@ -181,7 +222,7 @@ class DeepLXTranslator(BaseTranslator):
         self.endpoint = self.envs["DEEPLX_ENDPOINT"]
         self.session = requests.Session()
 
-    def translate(self, text):
+    def do_translate(self, text):
         response = self.session.post(
             self.endpoint,
             json={
@@ -211,8 +252,11 @@ class OllamaTranslator(BaseTranslator):
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = ollama.Client()
         self.prompttext = prompt
+        self.add_cache_impact_parameters("temperature", self.options["temperature"])
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text):
+    def do_translate(self, text):
         maxlen = max(2000, len(text) * 5)
         for model in self.model.split(";"):
             try:
@@ -261,8 +305,11 @@ class OpenAITranslator(BaseTranslator):
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        self.add_cache_impact_parameters("temperature", self.options["temperature"])
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         response = self.client.chat.completions.create(
             model=self.model,
             **self.options,
@@ -303,8 +350,11 @@ class AzureOpenAITranslator(BaseTranslator):
             api_key=api_key,
         )
         self.prompttext = prompt
+        self.add_cache_impact_parameters("temperature", self.options["temperature"])
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         response = self.client.chat.completions.create(
             model=self.model,
             **self.options,
@@ -339,6 +389,8 @@ class ModelScopeTranslator(OpenAITranslator):
             model = self.envs["MODELSCOPE_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
 
 class ZhipuTranslator(OpenAITranslator):
@@ -358,8 +410,10 @@ class ZhipuTranslator(OpenAITranslator):
             model = self.envs["ZHIPU_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         try:
             response = self.client.chat.completions.create(
                 model=self.model,
@@ -393,6 +447,8 @@ class SiliconTranslator(OpenAITranslator):
             model = self.envs["SILICON_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
 
 class GeminiTranslator(OpenAITranslator):
@@ -412,6 +468,8 @@ class GeminiTranslator(OpenAITranslator):
             model = self.envs["GEMINI_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
 
 class AzureTranslator(BaseTranslator):
@@ -436,7 +494,7 @@ class AzureTranslator(BaseTranslator):
         logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
         logger.setLevel(logging.WARNING)
 
-    def translate(self, text) -> str:
+    def do_translate(self, text) -> str:
         response = self.client.translate(
             body=[text],
             from_language=self.lang_in,
@@ -464,7 +522,7 @@ class TencentTranslator(BaseTranslator):
         self.req.Target = self.lang_out
         self.req.ProjectId = 0
 
-    def translate(self, text):
+    def do_translate(self, text):
         self.req.SourceText = text
         resp: TextTranslateResponse = self.client.TextTranslate(self.req)
         return resp.TargetText
@@ -489,8 +547,10 @@ class AnythingLLMTranslator(BaseTranslator):
             "Content-Type": "application/json",
         }
         self.prompttext = prompt
+        if prompt:
+            self.add_cache_impact_parameters("prompt", prompt)
 
-    def translate(self, text):
+    def do_translate(self, text):
         messages = self.prompt(text, self.prompttext)
         payload = {
             "message": messages,
@@ -521,7 +581,7 @@ class DifyTranslator(BaseTranslator):
         self.api_url = self.envs["DIFY_API_URL"]
         self.api_key = self.envs["DIFY_API_KEY"]
 
-    def translate(self, text):
+    def do_translate(self, text):
         headers = {
             "Authorization": f"Bearer {self.api_key}",
             "Content-Type": "application/json",
@@ -546,3 +606,43 @@ class DifyTranslator(BaseTranslator):
 
         # 解析响应
         return response_data.get("data", {}).get("outputs", {}).get("text", [])
+
+
+class ArgosTranslator(BaseTranslator):
+    name = "argos"
+
+    def __init__(self, lang_in, lang_out, model, **kwargs):
+        super().__init__(lang_in, lang_out, model)
+        lang_in = self.lang_map.get(lang_in.lower(), lang_in)
+        lang_out = self.lang_map.get(lang_out.lower(), lang_out)
+        self.lang_in = lang_in
+        self.lang_out = lang_out
+        argostranslate.package.update_package_index()
+        available_packages = argostranslate.package.get_available_packages()
+        try:
+            available_package = list(
+                filter(
+                    lambda x: x.from_code == self.lang_in
+                    and x.to_code == self.lang_out,
+                    available_packages,
+                )
+            )[0]
+        except Exception:
+            raise ValueError(
+                "lang_in and lang_out pair not supported by Argos Translate."
+            )
+        download_path = available_package.download()
+        argostranslate.package.install_from_path(download_path)
+
+    def translate(self, text):
+        # Translate
+        installed_languages = argostranslate.translate.get_installed_languages()
+        from_lang = list(filter(lambda x: x.code == self.lang_in, installed_languages))[
+            0
+        ]
+        to_lang = list(filter(lambda x: x.code == self.lang_out, installed_languages))[
+            0
+        ]
+        translation = from_lang.get_translation(to_lang)
+        translatedText = translation.translate(text)
+        return translatedText

+ 6 - 0
pyproject.toml

@@ -29,6 +29,8 @@ dependencies = [
     "pdfminer.six>=20240706",
     "gradio_pdf>=0.0.21",
     "pikepdf",
+    "peewee>=3.17.8",
+    "argostranslate",
 ]
 
 [project.optional-dependencies]
@@ -54,3 +56,7 @@ build-backend = "hatchling.build"
 
 [project.scripts]
 pdf2zh = "pdf2zh.pdf2zh:main"
+
+[tool.flake8]
+ignore = ["E203", "E261", "E501", "W503", "E741"]
+max-line-length = 88

+ 202 - 96
test/test_cache.py

@@ -1,107 +1,213 @@
 import unittest
-import os
-import tempfile
-import shutil
-import time
 from pdf2zh import cache
+import threading
+import multiprocessing
+import random
+import string
 
 
 class TestCache(unittest.TestCase):
     def setUp(self):
-        # Create a temporary directory for testing
-        self.test_cache_dir = os.path.join(tempfile.gettempdir(), "test_cache")
-        self.original_cache_dir = cache.cache_dir
-        cache.cache_dir = self.test_cache_dir
-        os.makedirs(self.test_cache_dir, exist_ok=True)
+        self.test_db = cache.init_test_db()
 
     def tearDown(self):
-        # Clean up the test directory
-        shutil.rmtree(self.test_cache_dir)
-        cache.cache_dir = self.original_cache_dir
-
-    def test_deterministic_hash(self):
-        # Test hash generation for different inputs
-        test_input = "Hello World"
-        hash1 = cache.deterministic_hash(test_input)
-        hash2 = cache.deterministic_hash(test_input)
-        self.assertEqual(hash1, hash2)
-        self.assertEqual(len(hash1), 20)
-
-        # Test different inputs produce different hashes
-        hash3 = cache.deterministic_hash("Different input")
-        self.assertNotEqual(hash1, hash3)
-
-    def test_get_dirs(self):
-        # Create test directories
-        test_dirs = ["dir1", "dir2", "dir3"]
-        for dir_name in test_dirs:
-            os.makedirs(os.path.join(self.test_cache_dir, dir_name))
-
-        # Create a file (should be ignored)
-        with open(os.path.join(self.test_cache_dir, "test.txt"), "w") as f:
-            f.write("test")
-
-        dirs = cache.get_dirs()
-        self.assertEqual(len(dirs), 3)
-        for dir_path in dirs:
-            self.assertTrue(os.path.isdir(dir_path))
-
-    def test_get_time(self):
-        # Create test directory with time file
-        test_dir = os.path.join(self.test_cache_dir, "test_dir")
-        os.makedirs(test_dir)
-        test_time = 1234567890.0
-
-        with open(os.path.join(test_dir, cache.time_filename), "w") as f:
-            f.write(str(test_time))
-
-        # Test reading time
-        result = cache.get_time(test_dir)
-        self.assertEqual(result, test_time)
-
-        # Test non-existent directory
-        non_existent_dir = os.path.join(self.test_cache_dir, "non_existent")
-        result = cache.get_time(non_existent_dir)
-        self.assertEqual(result, float("inf"))
-
-    def test_write_time(self):
-        test_dir = os.path.join(self.test_cache_dir, "test_dir")
-        os.makedirs(test_dir)
-
-        cache.write_time(test_dir)
-
-        self.assertTrue(os.path.exists(os.path.join(test_dir, cache.time_filename)))
-        with open(os.path.join(test_dir, cache.time_filename)) as f:
-            time_value = float(f.read())
-        self.assertIsInstance(time_value, float)
-
-    def test_remove_extra(self):
-        # Create more than max_cache directories
-        for i in range(cache.max_cache + 2):
-            dir_path = os.path.join(self.test_cache_dir, f"dir{i}")
-            os.makedirs(dir_path)
-            time.sleep(0.1)  # Ensure different timestamps
-            cache.write_time(dir_path)
-
-        cache.remove_extra()
-
-        remaining_dirs = cache.get_dirs()
-        self.assertLessEqual(len(remaining_dirs), cache.max_cache)
-
-    def test_cache_operations(self):
-        test_hash = "test123hash"
-        test_para_hash = "para456hash"
-        test_content = "Test paragraph content"
-
-        # Test cache creation
-        self.assertFalse(cache.is_cached(test_hash))
-        cache.create_cache(test_hash)
-        self.assertTrue(cache.is_cached(test_hash))
-
-        # Test paragraph operations
-        self.assertIsNone(cache.load_paragraph(test_hash, test_para_hash))
-        cache.write_paragraph(test_hash, test_para_hash, test_content)
-        self.assertEqual(cache.load_paragraph(test_hash, test_para_hash), test_content)
+        # Clean up
+        cache.clean_test_db(self.test_db)
+
+    def test_basic_set_get(self):
+        """Test basic set and get operations"""
+        cache_instance = cache.TranslationCache("test_engine")
+
+        # Test get with non-existent entry
+        result = cache_instance.get("hello")
+        self.assertIsNone(result)
+
+        # Test set and get
+        cache_instance.set("hello", "你好")
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "你好")
+
+    def test_cache_overwrite(self):
+        """Test that cache entries can be overwritten"""
+        cache_instance = cache.TranslationCache("test_engine")
+
+        # Set initial translation
+        cache_instance.set("hello", "你好")
+
+        # Overwrite with new translation
+        cache_instance.set("hello", "您好")
+
+        # Verify the new translation is returned
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "您好")
+
+    def test_non_string_params(self):
+        """Test that non-string parameters are automatically converted to JSON"""
+        params = {"model": "gpt-3.5", "temperature": 0.7}
+        cache_instance = cache.TranslationCache("test_engine", params)
+
+        # Test that params are converted to JSON string internally
+        cache_instance.set("hello", "你好")
+        result = cache_instance.get("hello")
+        self.assertEqual(result, "你好")
+
+        # Test with different param types
+        array_params = ["param1", "param2"]
+        cache_instance2 = cache.TranslationCache("test_engine", array_params)
+        cache_instance2.set("hello", "你好2")
+        self.assertEqual(cache_instance2.get("hello"), "你好2")
+
+        # Test with nested structures
+        nested_params = {"options": {"temp": 0.8, "models": ["a", "b"]}}
+        cache_instance3 = cache.TranslationCache("test_engine", nested_params)
+        cache_instance3.set("hello", "你好3")
+        self.assertEqual(cache_instance3.get("hello"), "你好3")
+
+    def test_engine_distinction(self):
+        """Test that cache distinguishes between different translation engines"""
+        cache1 = cache.TranslationCache("engine1")
+        cache2 = cache.TranslationCache("engine2")
+
+        # Set same text with different engines
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each engine gets its own translation
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
+
+    def test_params_distinction(self):
+        """Test that cache distinguishes between different engine parameters"""
+        params1 = {"param": "value1"}
+        params2 = {"param": "value2"}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+
+        # Set same text with different parameters
+        cache1.set("hello", "你好 1")
+        cache2.set("hello", "你好 2")
+
+        # Verify each parameter set gets its own translation
+        self.assertEqual(cache1.get("hello"), "你好 1")
+        self.assertEqual(cache2.get("hello"), "你好 2")
+
+    def test_consistent_param_serialization(self):
+        """Test that dictionary parameters are consistently serialized regardless of key order"""
+        # Test simple dictionary
+        params1 = {"b": 1, "a": 2}
+        params2 = {"a": 2, "b": 1}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test nested dictionary
+        params1 = {"outer2": {"inner2": 2, "inner1": 1}, "outer1": 3}
+        params2 = {"outer1": 3, "outer2": {"inner1": 1, "inner2": 2}}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test dictionary with list of dictionaries
+        params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
+        params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache1.translate_engine_params, cache2.translate_engine_params)
+
+        # Test that different values still produce different results
+        params1 = {"a": 1, "b": 2}
+        params2 = {"a": 2, "b": 1}
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertNotEqual(
+            cache1.translate_engine_params, cache2.translate_engine_params
+        )
+
+    def test_cache_with_sorted_params(self):
+        """Test that cache works correctly with sorted parameters"""
+        params1 = {"b": [{"y": 1, "x": 2}], "a": 3}
+        params2 = {"a": 3, "b": [{"x": 2, "y": 1}]}
+
+        # Both caches should work with the same key
+        cache1 = cache.TranslationCache("test_engine", params1)
+        cache1.set("hello", "你好")
+
+        cache2 = cache.TranslationCache("test_engine", params2)
+        self.assertEqual(cache2.get("hello"), "你好")
+
+    def test_append_params(self):
+        """Test the append_params method"""
+        cache_instance = cache.TranslationCache("test_engine", {"initial": "value"})
+
+        # Test appending new parameter
+        cache_instance.add_params("new_param", "new_value")
+        self.assertEqual(
+            cache_instance.params, {"initial": "value", "new_param": "new_value"}
+        )
+
+        # Test that cache with appended params works correctly
+        cache_instance.set("hello", "你好")
+        self.assertEqual(cache_instance.get("hello"), "你好")
+
+        # Test overwriting existing parameter
+        cache_instance.add_params("initial", "new_value")
+        self.assertEqual(
+            cache_instance.params, {"initial": "new_value", "new_param": "new_value"}
+        )
+
+        # Cache should work with updated params
+        cache_instance.set("hello2", "你好2")
+        self.assertEqual(cache_instance.get("hello2"), "你好2")
+
+    def test_thread_safety(self):
+        """Test thread safety of cache operations"""
+        cache_instance = cache.TranslationCache("test_engine")
+        lock = threading.Lock()
+        results = []
+        num_threads = multiprocessing.cpu_count()
+        items_per_thread = 100
+
+        def generate_random_text(length=10):
+            return "".join(
+                random.choices(string.ascii_letters + string.digits, k=length)
+            )
+
+        def worker():
+            thread_results = []  # 线程本地存储结果
+            for _ in range(items_per_thread):
+                text = generate_random_text()
+                translation = f"翻译_{text}"
+
+                # Write operation
+                cache_instance.set(text, translation)
+
+                # Read operation - verify our own write
+                result = cache_instance.get(text)
+                thread_results.append((text, result))
+
+            # 所有操作完成后,一次性加锁并追加结果
+            with lock:
+                results.extend(thread_results)
+
+        # Create threads equal to CPU core count
+        threads = []
+        for _ in range(num_threads):
+            thread = threading.Thread(target=worker)
+            threads.append(thread)
+            thread.start()
+
+        # Wait for all threads to complete
+        for thread in threads:
+            thread.join()
+
+        # Verify all operations were successful
+        expected_total = num_threads * items_per_thread
+        self.assertEqual(len(results), expected_total)
+
+        # Verify each thread got its correct value
+        for text, result in results:
+            expected = f"翻译_{text}"
+            self.assertEqual(result, expected)
 
 
 if __name__ == "__main__":

+ 0 - 22
test/test_converter.py

@@ -80,28 +80,6 @@ class TestTranslateConverter(unittest.TestCase):
         self.converter.receive_layout(mock_page)
         mock_receive_layout.assert_called_once_with(mock_page)
 
-    @patch("concurrent.futures.ThreadPoolExecutor")
-    @patch("pdf2zh.cache")
-    def test_translation(self, mock_cache, mock_executor):
-        mock_executor.return_value.__enter__.return_value.map.return_value = [
-            "你好",
-            "{v1}",
-        ]
-        mock_cache.deterministic_hash.return_value = "test_hash"
-        mock_cache.load_paragraph.return_value = None
-        mock_cache.write_paragraph.return_value = None
-
-        sstk = ["Hello", "{v1}"]
-        self.converter.thread = 2
-        results = []
-        with patch.object(self.converter, "translator") as mock_translator:
-            mock_translator.translate.side_effect = lambda x: (
-                "你好" if x == "Hello" else x
-            )
-            for s in sstk:
-                results.append(self.converter.translator.translate(s))
-        self.assertEqual(results, ["你好", "{v1}"])
-
     def test_receive_layout_with_complex_formula(self):
         ltpage = LTPage(1, (0, 0, 500, 500))
         ltchar = Mock()

+ 77 - 0
test/test_translator.py

@@ -0,0 +1,77 @@
+import unittest
+from pdf2zh.translator import BaseTranslator
+from pdf2zh import cache
+
+
+class AutoIncreaseTranslator(BaseTranslator):
+    name = "auto_increase"
+    n = 0
+
+    def do_translate(self, text):
+        self.n += 1
+        return str(self.n)
+
+
+class TestTranslator(unittest.TestCase):
+    def setUp(self):
+        self.test_db = cache.init_test_db()
+
+    def tearDown(self):
+        cache.clean_test_db(self.test_db)
+
+    def test_cache(self):
+        translator = AutoIncreaseTranslator("en", "zh", "test")
+        # First translation should be cached
+        text = "Hello World"
+        first_result = translator.translate(text)
+
+        # Second translation should return the same result from cache
+        second_result = translator.translate(text)
+        self.assertEqual(first_result, second_result)
+
+        # Different input should give different result
+        different_text = "Different Text"
+        different_result = translator.translate(different_text)
+        self.assertNotEqual(first_result, different_result)
+
+        # Test cache with ignore_cache=True
+        translator.ignore_cache = True
+        no_cache_result = translator.translate(text)
+        self.assertNotEqual(first_result, no_cache_result)
+
+    def test_add_cache_impact_parameters(self):
+        translator = AutoIncreaseTranslator("en", "zh", "test")
+
+        # Test cache with added parameters
+        text = "Hello World"
+        first_result = translator.translate(text)
+        translator.add_cache_impact_parameters("test", "value")
+        second_result = translator.translate(text)
+        self.assertNotEqual(first_result, second_result)
+
+        # Test cache with ignore_cache=True
+        no_cache_result1 = translator.translate(text, ignore_cache=True)
+        self.assertNotEqual(first_result, no_cache_result1)
+
+        translator.ignore_cache = True
+        no_cache_result2 = translator.translate(text)
+        self.assertNotEqual(no_cache_result1, no_cache_result2)
+
+        # Test cache with ignore_cache=False
+        translator.ignore_cache = False
+        cache_result = translator.translate(text)
+        self.assertEqual(no_cache_result2, cache_result)
+
+        # Test cache with another parameter
+        translator.add_cache_impact_parameters("test2", "value2")
+        another_result = translator.translate(text)
+        self.assertNotEqual(second_result, another_result)
+
+    def test_base_translator_throw(self):
+        translator = BaseTranslator("en", "zh", "test")
+        with self.assertRaises(NotImplementedError):
+            translator.translate("Hello World")
+
+
+if __name__ == "__main__":
+    unittest.main()