Browse Source

fix: cache model

Byaidu 1 year ago
parent
commit
d6f96334b1
3 changed files with 26 additions and 30 deletions
  1. 1 1
      pdf2zh/converter.py
  2. 1 0
      pdf2zh/pdf2zh.py
  3. 24 29
      pdf2zh/translator.py

+ 1 - 1
pdf2zh/converter.py

@@ -144,7 +144,7 @@ class TranslateConverter(PDFConverterEx):
         for translator in [GoogleTranslator, BingTranslator, DeepLTranslator, DeepLXTranslator, OllamaTranslator, AzureOpenAITranslator,
                            OpenAITranslator, ZhipuTranslator, SiliconTranslator, AzureTranslator, TencentTranslator]:
             if service_name == translator.name:
-                self.translator = translator(service, lang_out, lang_in, service_model)
+                self.translator = translator(lang_out, lang_in, service_model)
         if not self.translator:
             raise ValueError("Unsupported translation service")
 

+ 1 - 0
pdf2zh/pdf2zh.py

@@ -142,6 +142,7 @@ def main(args: Optional[List[str]] = None) -> int:
 
     if parsed_args.interactive:
         from pdf2zh.gui import setup_gui
+
         setup_gui(parsed_args.share)
         return 0
 

+ 24 - 29
pdf2zh/translator.py

@@ -25,10 +25,9 @@ class BaseTranslator:
     envs = {}
     lang_map = {}
 
-    def __init__(self, service, lang_out: str, lang_in: str, model):
+    def __init__(self, lang_out: str, lang_in: str, model):
         lang_out = self.lang_map.get(lang_out.lower(), lang_out)
         lang_in = self.lang_map.get(lang_in.lower(), lang_in)
-        self.service = service
         self.lang_out = lang_out
         self.lang_in = lang_in
         self.model = model
@@ -49,15 +48,15 @@ class BaseTranslator:
         ]
 
     def __str__(self):
-        return f"{self.service} {self.lang_out} {self.lang_in}"
+        return f"{self.name} {self.lang_out} {self.lang_in} {self.model}"
 
 
 class GoogleTranslator(BaseTranslator):
     name = "google"
     lang_map = {"zh": "zh-CN"}
 
-    def __init__(self, service, lang_out, lang_in, model):
-        super().__init__(service, lang_out, lang_in, model)
+    def __init__(self, lang_out, lang_in, model):
+        super().__init__(lang_out, lang_in, model)
         self.session = requests.Session()
         self.endpoint = "http://translate.google.com/m"
         self.headers = {
@@ -87,8 +86,8 @@ class BingTranslator(BaseTranslator):
     name = "bing"
     lang_map = {"zh": "zh-Hans"}
 
-    def __init__(self, service, lang_out, lang_in, model):
-        super().__init__(service, lang_out, lang_in, model)
+    def __init__(self, lang_out, lang_in, model):
+        super().__init__(lang_out, lang_in, model)
         self.session = requests.Session()
         self.endpoint = "https://www.bing.com/ttranslatev3"
         self.headers = {
@@ -130,8 +129,8 @@ class DeepLTranslator(BaseTranslator):
     }
     lang_map = {"zh": "zh-Hans"}
 
-    def __init__(self, service, lang_out, lang_in, model):
-        super().__init__(service, lang_out, lang_in, model)
+    def __init__(self, lang_out, lang_in, model):
+        super().__init__(lang_out, lang_in, model)
         self.session = requests.Session()
         server_url = os.getenv("DEEPL_SERVER_URL", self.envs["DEEPL_SERVER_URL"])
         auth_key = os.getenv("DEEPL_AUTH_KEY")
@@ -152,8 +151,8 @@ class DeepLXTranslator(BaseTranslator):
     }
     lang_map = {"zh": "zh-Hans"}
 
-    def __init__(self, service, lang_out, lang_in, model):
-        super().__init__(service, lang_out, lang_in, model)
+    def __init__(self, lang_out, lang_in, model):
+        super().__init__(lang_out, lang_in, model)
         self.endpoint = os.getenv("DEEPLX_ENDPOINT", self.envs["DEEPLX_ENDPOINT"])
         self.session = requests.Session()
 
@@ -177,10 +176,10 @@ class OllamaTranslator(BaseTranslator):
         "OLLAMA_MODEL": "gemma2",
     }
 
-    def __init__(self, service, lang_out, lang_in, model):
+    def __init__(self, lang_out, lang_in, model):
         if not model:
             model = os.getenv("OLLAMA_MODEL", self.envs["OLLAMA_MODEL"])
-        super().__init__(service, lang_out, lang_in, model)
+        super().__init__(lang_out, lang_in, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = ollama.Client()
 
@@ -202,10 +201,10 @@ class OpenAITranslator(BaseTranslator):
         "OPENAI_MODEL": "gpt-4o-mini",
     }
 
-    def __init__(self, service, lang_out, lang_in, model, base_url=None, api_key=None):
+    def __init__(self, lang_out, lang_in, model, base_url=None, api_key=None):
         if not model:
             model = os.getenv("OPENAI_MODEL", self.envs["OPENAI_MODEL"])
-        super().__init__(service, lang_out, lang_in, model)
+        super().__init__(lang_out, lang_in, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
 
@@ -226,14 +225,14 @@ class AzureOpenAITranslator(BaseTranslator):
         "AZURE_OPENAI_MODEL": "gpt-4o-mini",
     }
 
-    def __init__(self, service, lang_out, lang_in, model, base_url=None, api_key=None):
+    def __init__(self, lang_out, lang_in, model, base_url=None, api_key=None):
         base_url = os.getenv(
             "AZURE_OPENAI_BASE_URL", self.envs["AZURE_OPENAI_BASE_URL"]
         )
         api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-06-01")
         if not model:
             model = os.getenv("AZURE_OPENAI_MODEL", self.envs["AZURE_OPENAI_MODEL"])
-        super().__init__(service, lang_out, lang_in, model)
+        super().__init__(lang_out, lang_in, model)
         self.options = {"temperature": 0}
         self.client = openai.AzureOpenAI(
             azure_endpoint=base_url,
@@ -259,14 +258,12 @@ class ZhipuTranslator(OpenAITranslator):
         "ZHIPU_MODEL": "glm-4-flash",
     }
 
-    def __init__(self, service, lang_out, lang_in, model):
+    def __init__(self, lang_out, lang_in, model):
         base_url = "https://open.bigmodel.cn/api/paas/v4"
         api_key = os.getenv("ZHIPU_API_KEY")
         if not model:
             model = os.getenv("ZHIPU_MODEL", self.envs["ZHIPU_MODEL"])
-        super().__init__(
-            service, lang_out, lang_in, model, base_url=base_url, api_key=api_key
-        )
+        super().__init__(lang_out, lang_in, model, base_url=base_url, api_key=api_key)
 
 
 class SiliconTranslator(OpenAITranslator):
@@ -277,14 +274,12 @@ class SiliconTranslator(OpenAITranslator):
         "SILICON_MODEL": "Qwen/Qwen2.5-7B-Instruct",
     }
 
-    def __init__(self, service, lang_out, lang_in, model):
+    def __init__(self, lang_out, lang_in, model):
         base_url = "https://api.siliconflow.cn/v1"
         api_key = os.getenv("SILICON_API_KEY")
         if not model:
             model = os.getenv("SILICON_MODEL", self.envs["SILICON_MODEL"])
-        super().__init__(
-            service, lang_out, lang_in, model, base_url=base_url, api_key=api_key
-        )
+        super().__init__(lang_out, lang_in, model, base_url=base_url, api_key=api_key)
 
 
 class AzureTranslator(BaseTranslator):
@@ -296,8 +291,8 @@ class AzureTranslator(BaseTranslator):
     }
     lang_map = {"zh": "zh-Hans"}
 
-    def __init__(self, service, lang_out, lang_in, model):
-        super().__init__(service, lang_out, lang_in, model)
+    def __init__(self, lang_out, lang_in, model):
+        super().__init__(lang_out, lang_in, model)
         endpoint = os.getenv("AZURE_ENDPOINT", self.envs["AZURE_ENDPOINT"])
         api_key = os.getenv("AZURE_API_KEY")
         credential = AzureKeyCredential(api_key)
@@ -326,8 +321,8 @@ class TencentTranslator(BaseTranslator):
         "TENCENTCLOUD_SECRET_KEY": None,
     }
 
-    def __init__(self, service, lang_out, lang_in, model):
-        super().__init__(service, lang_out, lang_in, model)
+    def __init__(self, lang_out, lang_in, model):
+        super().__init__(lang_out, lang_in, model)
         cred = credential.DefaultCredentialProvider().get_credential()
         self.client = TmtClient(cred, "ap-beijing")
         self.req = TextTranslateRequest()