Просмотр исходного кода

feat: translator default model

Byaidu 1 год назад
Родитель
Сommit
5612c41379
2 измененных файлов с 31 добавлено и 29 удалено
  1. 11 9
      pdf2zh/converter.py
  2. 20 20
      pdf2zh/translator.py

+ 11 - 9
pdf2zh/converter.py

@@ -136,19 +136,21 @@ class TranslateConverter(PDFConverterEx):
         self.noto = noto
         self.translator: BaseTranslator = None
         param = service.split(":", 1)
-        if param[0] == "google":
+        service_id = param[0]
+        service_model = param[1] if len(param) > 1 else None
+        if service_id == "google":
             self.translator = GoogleTranslator(service, lang_out, lang_in, None)
-        elif param[0] == "deepl":
+        elif service_id == "deepl":
             self.translator = DeepLTranslator(service, lang_out, lang_in, None)
-        elif param[0] == "deeplx":
+        elif service_id == "deeplx":
             self.translator = DeepLXTranslator(service, lang_out, lang_in, None)
-        elif param[0] == "ollama":
-            self.translator = OllamaTranslator(service, lang_out, lang_in, param[1])
-        elif param[0] == "openai":
-            self.translator = OpenAITranslator(service, lang_out, lang_in, param[1])
-        elif param[0] == "azure":
+        elif service_id == "ollama":
+            self.translator = OllamaTranslator(service, lang_out, lang_in, service_model)
+        elif service_id == "openai":
+            self.translator = OpenAITranslator(service, lang_out, lang_in, service_model)
+        elif service_id == "azure":
             self.translator = AzureTranslator(service, lang_out, lang_in, None)
-        elif param[0] == "tencent":
+        elif service_id == "tencent":
             self.translator = TencentTranslator(service, lang_out, lang_in, None)
         else:
             raise ValueError("Unsupported translation service")

+ 20 - 20
pdf2zh/translator.py

@@ -32,6 +32,18 @@ class BaseTranslator:
     def translate(self, text):
         pass
 
+    def prompt(self, text):
+        return [
+            {
+                "role": "system",
+                "content": "You are a professional,authentic machine translation engine.",
+            },
+            {
+                "role": "user",
+                "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",  # noqa: E501
+            },
+        ]
+
     def __str__(self):
         return f"{self.service} {self.lang_out} {self.lang_in}"
 
@@ -140,11 +152,14 @@ class OllamaTranslator(BaseTranslator):
     # https://github.com/ollama/ollama-python
     envs = {
         "OLLAMA_HOST": "http://127.0.0.1:11434",
+        "OLLAMA_MODEL": "gemma2",
     }
 
     def __init__(self, service, lang_out, lang_in, model):
         lang_out = "zh-CN" if lang_out == "auto" else lang_out
         lang_in = "en" if lang_in == "auto" else lang_in
+        if not model:
+            model = os.getenv("OLLAMA_MODEL", self.envs["OLLAMA_MODEL"])
         super().__init__(service, lang_out, lang_in, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = ollama.Client()
@@ -153,16 +168,7 @@ class OllamaTranslator(BaseTranslator):
         response = self.client.chat(
             model=self.model,
             options=self.options,
-            messages=[
-                {
-                    "role": "system",
-                    "content": "You are a professional,authentic machine translation engine.",
-                },
-                {
-                    "role": "user",
-                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",  # noqa: E501
-                },
-            ],
+            messages=self.prompt(text),
         )
         return response["message"]["content"].strip()
 
@@ -172,11 +178,14 @@ class OpenAITranslator(BaseTranslator):
     envs = {
         "OPENAI_BASE_URL": "https://api.openai.com/v1",
         "OPENAI_API_KEY": None,
+        "OPENAI_MODEL": "gpt-4o",
     }
 
     def __init__(self, service, lang_out, lang_in, model):
         lang_out = "zh-CN" if lang_out == "auto" else lang_out
         lang_in = "en" if lang_in == "auto" else lang_in
+        if not model:
+            model = os.getenv("OPENAI_MODEL", self.envs["OPENAI_MODEL"])
         super().__init__(service, lang_out, lang_in, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.client = openai.OpenAI()
@@ -185,16 +194,7 @@ class OpenAITranslator(BaseTranslator):
         response = self.client.chat.completions.create(
             model=self.model,
             **self.options,
-            messages=[
-                {
-                    "role": "system",
-                    "content": "You are a professional,authentic machine translation engine.",
-                },
-                {
-                    "role": "user",
-                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",  # noqa: E501
-                },
-            ],
+            messages=self.prompt(text),
         )
         return response.choices[0].message.content.strip()