Răsfoiți Sursa

refactor code

yuze.zyz 1 an în urmă
părinte
comite
16c2426012
7 a modificat fișierele cu 68 adăugiri și 44 ștergeri
  1. 0 3
      pdf2zh/__init__.py
  2. 4 1
      pdf2zh/converter.py
  3. 2 1
      pdf2zh/entrance.py
  4. 5 1
      pdf2zh/gui.py
  5. 1 1
      pdf2zh/high_level.py
  6. 55 36
      pdf2zh/translator.py
  7. 1 1
      pyproject.toml

+ 0 - 3
pdf2zh/__init__.py

@@ -1,8 +1,5 @@
 import logging
 import logging
-from pdf2zh.high_level import translate, translate_stream
-
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 
 
 __version__ = "1.8.8"
 __version__ = "1.8.8"
 __author__ = "Byaidu"
 __author__ = "Byaidu"
-__all__ = ["translate", "translate_stream"]

+ 4 - 1
pdf2zh/converter.py

@@ -1,3 +1,5 @@
+from typing import List, Dict
+
 from pdfminer.pdfinterp import PDFGraphicState, PDFResourceManager
 from pdfminer.pdfinterp import PDFGraphicState, PDFResourceManager
 from pdfminer.pdffont import PDFCIDFont
 from pdfminer.pdffont import PDFCIDFont
 from pdfminer.converter import PDFConverter
 from pdfminer.converter import PDFConverter
@@ -133,6 +135,7 @@ class TranslateConverter(PDFConverterEx):
         service: str = "",
         service: str = "",
         resfont: str = "",
         resfont: str = "",
         noto: Font = None,
         noto: Font = None,
+        envs: Dict = None,
     ) -> None:
     ) -> None:
         super().__init__(rsrcmgr)
         super().__init__(rsrcmgr)
         self.vfont = vfont
         self.vfont = vfont
@@ -148,7 +151,7 @@ class TranslateConverter(PDFConverterEx):
         for translator in [GoogleTranslator, BingTranslator, DeepLTranslator, DeepLXTranslator, OllamaTranslator, AzureOpenAITranslator,
         for translator in [GoogleTranslator, BingTranslator, DeepLTranslator, DeepLXTranslator, OllamaTranslator, AzureOpenAITranslator,
                            OpenAITranslator, ZhipuTranslator, ModelScopeTranslator, SiliconTranslator, GeminiTranslator, AzureTranslator, TencentTranslator, DifyTranslator, AnythingLLMTranslator]:
                            OpenAITranslator, ZhipuTranslator, ModelScopeTranslator, SiliconTranslator, GeminiTranslator, AzureTranslator, TencentTranslator, DifyTranslator, AnythingLLMTranslator]:
             if service_name == translator.name:
             if service_name == translator.name:
-                self.translator = translator(lang_in, lang_out, service_model)
+                self.translator = translator(lang_in, lang_out, service_model, envs=envs)
         if not self.translator:
         if not self.translator:
             raise ValueError("Unsupported translation service")
             raise ValueError("Unsupported translation service")
 
 

+ 2 - 1
pdf2zh/pdf2zh.py → pdf2zh/entrance.py

@@ -9,11 +9,11 @@ import argparse
 import sys
 import sys
 import logging
 import logging
 from typing import List, Optional
 from typing import List, Optional
-from pdf2zh import __version__, log
 from pdf2zh.high_level import translate
 from pdf2zh.high_level import translate
 
 
 
 
 def create_parser() -> argparse.ArgumentParser:
 def create_parser() -> argparse.ArgumentParser:
+    from pdf2zh import __version__
     parser = argparse.ArgumentParser(description=__doc__, add_help=True)
     parser = argparse.ArgumentParser(description=__doc__, add_help=True)
     parser.add_argument(
     parser.add_argument(
         "files",
         "files",
@@ -136,6 +136,7 @@ def parse_args(args: Optional[List[str]]) -> argparse.Namespace:
 
 
 
 
 def main(args: Optional[List[str]] = None) -> int:
 def main(args: Optional[List[str]] = None) -> int:
+    from pdf2zh import log
     logging.basicConfig()
     logging.basicConfig()
 
 
     parsed_args = parse_args(args)
     parsed_args = parse_args(args)

+ 5 - 1
pdf2zh/gui.py

@@ -164,6 +164,10 @@ def translate_file(
     lang_from = lang_map[lang_from]
     lang_from = lang_map[lang_from]
     lang_to = lang_map[lang_to]
     lang_to = lang_map[lang_to]
 
 
+    _envs = {}
+    for i, env in enumerate(translator.envs.items()):
+        _envs[env[0]] = envs[i]
+
     print(f"Files before translation: {os.listdir(output)}")
     print(f"Files before translation: {os.listdir(output)}")
 
 
     def progress_bar(t: tqdm.tqdm):
     def progress_bar(t: tqdm.tqdm):
@@ -179,8 +183,8 @@ def translate_file(
         "thread": 4,
         "thread": 4,
         "callback": progress_bar,
         "callback": progress_bar,
         "cancellation_event": cancellation_event_map[session_id],
         "cancellation_event": cancellation_event_map[session_id],
+        "envs": _envs,
     }
     }
-    print(param)
     try:
     try:
         translate(**param)
         translate(**param)
     except CancelledError:
     except CancelledError:

+ 1 - 1
pdf2zh/high_level.py

@@ -92,7 +92,7 @@ def translate_patch(
     rsrcmgr = PDFResourceManager()
     rsrcmgr = PDFResourceManager()
     layout = {}
     layout = {}
     device = TranslateConverter(
     device = TranslateConverter(
-        rsrcmgr, vfont, vchar, thread, layout, lang_in, lang_out, service, resfont, noto
+        rsrcmgr, vfont, vchar, thread, layout, lang_in, lang_out, service, resfont, noto, kwarg.get('envs', {})
     )
     )
 
 
     assert device is not None
     assert device is not None

+ 55 - 36
pdf2zh/translator.py

@@ -34,6 +34,14 @@ class BaseTranslator:
         self.lang_out = lang_out
         self.lang_out = lang_out
         self.model = model
         self.model = model
 
 
+    def set_envs(self, envs):
+        for key in self.envs:
+            if key in os.environ:
+                self.envs[key] = os.environ[key]
+        if envs is not None:
+            for key in envs:
+                self.envs[key] = envs[key]
+
     def translate(self, text):
     def translate(self, text):
         pass
         pass
 
 
@@ -57,7 +65,7 @@ class GoogleTranslator(BaseTranslator):
     name = "google"
     name = "google"
     lang_map = {"zh": "zh-CN"}
     lang_map = {"zh": "zh-CN"}
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, **kwargs):
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
         self.session = requests.Session()
         self.session = requests.Session()
         self.endpoint = "http://translate.google.com/m"
         self.endpoint = "http://translate.google.com/m"
@@ -88,7 +96,7 @@ class BingTranslator(BaseTranslator):
     name = "bing"
     name = "bing"
     lang_map = {"zh": "zh-Hans"}
     lang_map = {"zh": "zh-Hans"}
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, **kwargs):
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
         self.session = requests.Session()
         self.session = requests.Session()
         self.endpoint = "https://www.bing.com/translator"
         self.endpoint = "https://www.bing.com/translator"
@@ -133,9 +141,10 @@ class DeepLTranslator(BaseTranslator):
     }
     }
     lang_map = {"zh": "zh-Hans"}
     lang_map = {"zh": "zh-Hans"}
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
-        auth_key = os.getenv("DEEPL_AUTH_KEY")
+        auth_key = self.envs["DEEPL_AUTH_KEY"]
         self.client = deepl.Translator(auth_key)
         self.client = deepl.Translator(auth_key)
 
 
     def translate(self, text):
     def translate(self, text):
@@ -153,9 +162,10 @@ class DeepLXTranslator(BaseTranslator):
     }
     }
     lang_map = {"zh": "zh-Hans"}
     lang_map = {"zh": "zh-Hans"}
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
-        self.endpoint = os.getenv("DEEPLX_ENDPOINT", self.envs["DEEPLX_ENDPOINT"])
+        self.endpoint = self.envs["DEEPLX_ENDPOINT"]
         self.session = requests.Session()
         self.session = requests.Session()
 
 
     def translate(self, text):
     def translate(self, text):
@@ -179,9 +189,10 @@ class OllamaTranslator(BaseTranslator):
         "OLLAMA_MODEL": "gemma2",
         "OLLAMA_MODEL": "gemma2",
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         if not model:
         if not model:
-            model = os.getenv("OLLAMA_MODEL", self.envs["OLLAMA_MODEL"])
+            model = self.envs["OLLAMA_MODEL"]
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = ollama.Client()
         self.client = ollama.Client()
@@ -204,9 +215,10 @@ class OpenAITranslator(BaseTranslator):
         "OPENAI_MODEL": "gpt-4o-mini",
         "OPENAI_MODEL": "gpt-4o-mini",
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model, base_url=None, api_key=None):
+    def __init__(self, lang_in, lang_out, model, base_url=None, api_key=None, envs=None):
+        self.set_envs(envs)
         if not model:
         if not model:
-            model = os.getenv("OPENAI_MODEL", self.envs["OPENAI_MODEL"])
+            model = self.envs["OPENAI_MODEL"]
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
         self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
@@ -228,12 +240,11 @@ class AzureOpenAITranslator(BaseTranslator):
         "AZURE_OPENAI_MODEL": "gpt-4o-mini",
         "AZURE_OPENAI_MODEL": "gpt-4o-mini",
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model, base_url=None, api_key=None):
-        base_url = os.getenv(
-            "AZURE_OPENAI_BASE_URL", self.envs["AZURE_OPENAI_BASE_URL"]
-        )
+    def __init__(self, lang_in, lang_out, model, base_url=None, api_key=None, envs=None):
+        self.set_envs(envs)
+        base_url = self.envs["AZURE_OPENAI_BASE_URL"]
         if not model:
         if not model:
-            model = os.getenv("AZURE_OPENAI_MODEL", self.envs["AZURE_OPENAI_MODEL"])
+            model = self.envs["AZURE_OPENAI_MODEL"]
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
         self.options = {"temperature": 0}
         self.options = {"temperature": 0}
         self.client = openai.AzureOpenAI(
         self.client = openai.AzureOpenAI(
@@ -260,11 +271,12 @@ class ModelScopeTranslator(OpenAITranslator):
         "MODELSCOPE_MODEL": "Qwen/Qwen2.5-32B-Instruct",
         "MODELSCOPE_MODEL": "Qwen/Qwen2.5-32B-Instruct",
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model, base_url=None, api_key=None):
+    def __init__(self, lang_in, lang_out, model, base_url=None, api_key=None, envs=None):
+        self.set_envs(envs)
         base_url = "https://api-inference.modelscope.cn/v1"
         base_url = "https://api-inference.modelscope.cn/v1"
-        api_key = os.getenv("MODELSCOPE_API_KEY")
+        api_key = self.envs["MODELSCOPE_API_KEY"]
         if not model:
         if not model:
-            model = os.getenv("MODELSCOPE_MODEL", self.envs["MODELSCOPE_MODEL"])
+            model = self.envs["MODELSCOPE_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
 
 
 
 
@@ -276,11 +288,12 @@ class ZhipuTranslator(OpenAITranslator):
         "ZHIPU_MODEL": "glm-4-flash",
         "ZHIPU_MODEL": "glm-4-flash",
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         base_url = "https://open.bigmodel.cn/api/paas/v4"
         base_url = "https://open.bigmodel.cn/api/paas/v4"
-        api_key = os.getenv("ZHIPU_API_KEY")
+        api_key = self.envs["ZHIPU_API_KEY"]
         if not model:
         if not model:
-            model = os.getenv("ZHIPU_MODEL", self.envs["ZHIPU_MODEL"])
+            model = self.envs["ZHIPU_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
 
 
     def translate(self, text) -> str:
     def translate(self, text) -> str:
@@ -308,11 +321,12 @@ class SiliconTranslator(OpenAITranslator):
         "SILICON_MODEL": "Qwen/Qwen2.5-7B-Instruct",
         "SILICON_MODEL": "Qwen/Qwen2.5-7B-Instruct",
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         base_url = "https://api.siliconflow.cn/v1"
         base_url = "https://api.siliconflow.cn/v1"
-        api_key = os.getenv("SILICON_API_KEY")
+        api_key = self.envs["SILICON_API_KEY"]
         if not model:
         if not model:
-            model = os.getenv("SILICON_MODEL", self.envs["SILICON_MODEL"])
+            model = self.envs["SILICON_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
 
 
 
 
@@ -324,11 +338,12 @@ class GeminiTranslator(OpenAITranslator):
         "GEMINI_MODEL": "gemini-1.5-flash",
         "GEMINI_MODEL": "gemini-1.5-flash",
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
         base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
-        api_key = os.getenv("GEMINI_API_KEY")
+        api_key = self.envs["GEMINI_API_KEY"]
         if not model:
         if not model:
-            model = os.getenv("GEMINI_MODEL", self.envs["GEMINI_MODEL"])
+            model = self.envs["GEMINI_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
 
 
 
 
@@ -341,9 +356,10 @@ class AzureTranslator(BaseTranslator):
     }
     }
     lang_map = {"zh": "zh-Hans"}
     lang_map = {"zh": "zh-Hans"}
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
-        endpoint = os.getenv("AZURE_ENDPOINT", self.envs["AZURE_ENDPOINT"])
+        endpoint = self.envs["AZURE_ENDPOINT"]
         api_key = os.getenv("AZURE_API_KEY")
         api_key = os.getenv("AZURE_API_KEY")
         credential = AzureKeyCredential(api_key)
         credential = AzureKeyCredential(api_key)
         self.client = TextTranslationClient(
         self.client = TextTranslationClient(
@@ -371,7 +387,8 @@ class TencentTranslator(BaseTranslator):
         "TENCENTCLOUD_SECRET_KEY": None,
         "TENCENTCLOUD_SECRET_KEY": None,
     }
     }
 
 
-    def __init__(self, lang_in, lang_out, model):
+    def __init__(self, lang_in, lang_out, model, envs=None):
+        self.set_envs(envs)
         super().__init__(lang_in, lang_out, model)
         super().__init__(lang_in, lang_out, model)
         cred = credential.DefaultCredentialProvider().get_credential()
         cred = credential.DefaultCredentialProvider().get_credential()
         self.client = TmtClient(cred, "ap-beijing")
         self.client = TmtClient(cred, "ap-beijing")
@@ -393,10 +410,11 @@ class AnythingLLMTranslator(BaseTranslator):
         "AnythingLLM_APIKEY": "api_key",
         "AnythingLLM_APIKEY": "api_key",
     }
     }
 
 
-    def __init__(self, lang_out, lang_in, model):
+    def __init__(self, lang_out, lang_in, model, envs=None):
+        self.set_envs(envs)
         super().__init__(lang_out, lang_in, model)
         super().__init__(lang_out, lang_in, model)
-        self.api_url = os.getenv("AnythingLLM_URL", self.envs["AnythingLLM_URL"])
-        self.api_key = os.getenv("AnythingLLM_APIKEY", self.envs["AnythingLLM_APIKEY"])
+        self.api_url = self.envs["AnythingLLM_URL"]
+        self.api_key = self.envs["AnythingLLM_APIKEY"]
         self.headers = {
         self.headers = {
             "accept": "application/json",
             "accept": "application/json",
             "Authorization": f"Bearer {self.api_key}",
             "Authorization": f"Bearer {self.api_key}",
@@ -428,10 +446,11 @@ class DifyTranslator(BaseTranslator):
         "DIFY_API_KEY": "api_key",  # 替换为实际 API 密钥
         "DIFY_API_KEY": "api_key",  # 替换为实际 API 密钥
     }
     }
 
 
-    def __init__(self, lang_out, lang_in, model):
+    def __init__(self, lang_out, lang_in, model, envs=None):
+        self.set_envs(envs)
         super().__init__(lang_out, lang_in, model)
         super().__init__(lang_out, lang_in, model)
-        self.api_url = os.getenv("DIFY_API_URL", self.envs["DIFY_API_URL"])
-        self.api_key = os.getenv("DIFY_API_KEY", self.envs["DIFY_API_KEY"])
+        self.api_url = self.envs["DIFY_API_URL"]
+        self.api_key = self.envs["DIFY_API_KEY"]
 
 
     def translate(self, text):
     def translate(self, text):
         headers = {
         headers = {

+ 1 - 1
pyproject.toml

@@ -50,4 +50,4 @@ requires = ["hatchling"]
 build-backend = "hatchling.build"
 build-backend = "hatchling.build"
 
 
 [project.scripts]
 [project.scripts]
-pdf2zh = "pdf2zh.pdf2zh:main"
+pdf2zh = "pdf2zh.entrance:main"