Explorar el Código

fix : prompt not work with GUI.

hellofinch hace 10 meses
padre
commit
167b3666d5
Se han modificado 4 ficheros con 8 adiciones y 34 borrados
  1. 3 4
      pdf2zh/converter.py
  2. 1 1
      pdf2zh/gui.py
  3. 4 3
      pdf2zh/high_level.py
  4. 0 26
      pdf2zh/translator.py

+ 3 - 4
pdf2zh/converter.py

@@ -1,4 +1,4 @@
-from typing import Dict, List
+from typing import Dict
 from enum import Enum
 
 from pdfminer.pdfinterp import PDFGraphicState, PDFResourceManager
@@ -17,6 +17,7 @@ import re
 import concurrent.futures
 import numpy as np
 import unicodedata
+from string import Template
 from tenacity import retry, wait_fixed
 from pdf2zh.translator import (
     AzureOpenAITranslator,
@@ -144,7 +145,7 @@ class TranslateConverter(PDFConverterEx):
         noto_name: str = "",
         noto: Font = None,
         envs: Dict = None,
-        prompt: List = None,
+        prompt: Template = None,
     ) -> None:
         super().__init__(rsrcmgr)
         self.vfont = vfont
@@ -159,8 +160,6 @@ class TranslateConverter(PDFConverterEx):
         service_model = param[1] if len(param) > 1 else None
         if not envs:
             envs = {}
-        if not prompt:
-            prompt = []
         for translator in [GoogleTranslator, BingTranslator, DeepLTranslator, DeepLXTranslator, OllamaTranslator, XinferenceTranslator, AzureOpenAITranslator,
                            OpenAITranslator, ZhipuTranslator, ModelScopeTranslator, SiliconTranslator, GeminiTranslator, AzureTranslator, TencentTranslator, DifyTranslator, AnythingLLMTranslator, ArgosTranslator, GorkTranslator, GroqTranslator, DeepseekTranslator, OpenAIlikedTranslator,]:
             if service_name == translator.name:

+ 1 - 1
pdf2zh/gui.py

@@ -277,7 +277,7 @@ def translate_file(
         "callback": progress_bar,
         "cancellation_event": cancellation_event_map[session_id],
         "envs": _envs,
-        "prompt": Template(prompt),
+        "prompt": Template(prompt) if prompt else None,
         "model": ModelInstance.value,
     }
     try:

+ 4 - 3
pdf2zh/high_level.py

@@ -8,6 +8,7 @@ import tempfile
 import urllib.request
 from asyncio import CancelledError
 from pathlib import Path
+from string import Template
 from typing import Any, BinaryIO, List, Optional, Dict
 
 import numpy as np
@@ -78,7 +79,7 @@ def translate_patch(
     cancellation_event: asyncio.Event = None,
     model: OnnxModel = None,
     envs: Dict = None,
-    prompt: List = None,
+    prompt: Template = None,
     **kwarg: Any,
 ) -> None:
     rsrcmgr = PDFResourceManager()
@@ -172,7 +173,7 @@ def translate_stream(
     cancellation_event: asyncio.Event = None,
     model: OnnxModel = None,
     envs: Dict = None,
-    prompt: List = None,
+    prompt: Template = None,
     **kwarg: Any,
 ):
     font_list = [("tiro", None)]
@@ -297,7 +298,7 @@ def translate(
     cancellation_event: asyncio.Event = None,
     model: OnnxModel = None,
     envs: Dict = None,
-    prompt: List = None,
+    prompt: Template = None,
     **kwarg: Any,
 ):
     if not files:

+ 0 - 26
pdf2zh/translator.py

@@ -277,8 +277,6 @@ class OllamaTranslator(BaseTranslator):
         self.client = ollama.Client()
         self.prompttext = prompt
         self.add_cache_impact_parameters("temperature", self.options["temperature"])
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
     def do_translate(self, text):
         maxlen = max(2000, len(text) * 5)
@@ -320,8 +318,6 @@ class XinferenceTranslator(BaseTranslator):
         self.client = xinference_client.RESTfulClient(self.envs["XINFERENCE_HOST"])
         self.prompttext = prompt
         self.add_cache_impact_parameters("temperature", self.options["temperature"])
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
     def do_translate(self, text):
         maxlen = max(2000, len(text) * 5)
@@ -384,8 +380,6 @@ class OpenAITranslator(BaseTranslator):
         )
         self.prompttext = prompt
         self.add_cache_impact_parameters("temperature", self.options["temperature"])
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
     def do_translate(self, text) -> str:
         response = self.client.chat.completions.create(
@@ -438,8 +432,6 @@ class AzureOpenAITranslator(BaseTranslator):
         )
         self.prompttext = prompt
         self.add_cache_impact_parameters("temperature", self.options["temperature"])
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
     def do_translate(self, text) -> str:
         response = self.client.chat.completions.create(
@@ -476,8 +468,6 @@ class ModelScopeTranslator(OpenAITranslator):
             model = self.envs["MODELSCOPE_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
 
 class ZhipuTranslator(OpenAITranslator):
@@ -497,8 +487,6 @@ class ZhipuTranslator(OpenAITranslator):
             model = self.envs["ZHIPU_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
     def do_translate(self, text) -> str:
         try:
@@ -534,8 +522,6 @@ class SiliconTranslator(OpenAITranslator):
             model = self.envs["SILICON_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
 
 class GeminiTranslator(OpenAITranslator):
@@ -555,8 +541,6 @@ class GeminiTranslator(OpenAITranslator):
             model = self.envs["GEMINI_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
 
 class AzureTranslator(BaseTranslator):
@@ -634,8 +618,6 @@ class AnythingLLMTranslator(BaseTranslator):
             "Content-Type": "application/json",
         }
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
     def do_translate(self, text):
         messages = self.prompt(text, self.prompttext)
@@ -752,8 +734,6 @@ class GorkTranslator(OpenAITranslator):
             model = self.envs["GORK_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
 
 class GroqTranslator(OpenAITranslator):
@@ -772,8 +752,6 @@ class GroqTranslator(OpenAITranslator):
             model = self.envs["GROQ_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
 
 class DeepseekTranslator(OpenAITranslator):
@@ -792,8 +770,6 @@ class DeepseekTranslator(OpenAITranslator):
             model = self.envs["DEEPSEEK_MODEL"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)
 
 
 class OpenAIlikedTranslator(OpenAITranslator):
@@ -822,5 +798,3 @@ class OpenAIlikedTranslator(OpenAITranslator):
             api_key = self.envs["OPENAILIKED_API_KEY"]
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
         self.prompttext = prompt
-        if prompt:
-            self.add_cache_impact_parameters("prompt", prompt.template)