Просмотр исходного кода

feat (gui): add custom prompt (#275)

* gui支持自定义页码。

* 修改页数显示bug

* GUI支持自定义prompt。

* format

---------

Co-authored-by: Byaidu <909756245@qq.com>
hellofinch 1 год назад
Родитель
Сommit
10d9cf19e8
2 измененных файлов с 52 добавлено и 4 удалено
  1. 43 4
      pdf2zh/gui.py
  2. 9 0
      pdf2zh/translator.py

+ 43 - 4
pdf2zh/gui.py

@@ -64,6 +64,7 @@ page_map = {
     "All": None,
     "First": [0],
     "First 5 pages": list(range(0, 5)),
+    "Others": None,
 }
 
 flag_demo = False
@@ -125,6 +126,9 @@ def translate_file(
     lang_from,
     lang_to,
     page_range,
+    page_input,
+    prompt,
+    threads,
     recaptcha_response,
     state,
     progress=gr.Progress(),
@@ -161,7 +165,16 @@ def translate_file(
     file_dual = output / f"{filename}-dual.pdf"
 
     translator = service_map[service]
-    selected_page = page_map[page_range]
+    if page_range != "Others":
+        selected_page = page_map[page_range]
+    else:
+        selected_page = []
+        for p in page_input.split(","):
+            if "-" in p:
+                start, end = p.split("-")
+                selected_page.extend(range(int(start) - 1, int(end)))
+            else:
+                selected_page.append(int(p) - 1)
     lang_from = lang_map[lang_from]
     lang_to = lang_map[lang_to]
 
@@ -181,10 +194,11 @@ def translate_file(
         "lang_out": lang_to,
         "service": f"{translator.name}",
         "output": output,
-        "thread": 4,
+        "thread": int(threads),
         "callback": progress_bar,
         "cancellation_event": cancellation_event_map[session_id],
         "envs": _envs,
+        "prompt": prompt,
     }
     try:
         translate(**param)
@@ -319,15 +333,30 @@ with gr.Blocks(
                 value=list(page_map.keys())[0],
             )
 
+            page_input = gr.Textbox(
+                label="Page range",
+                visible=False,
+                interactive=True,
+            )
+
+            with gr.Accordion("Open for More Experimental Options!", open=False):
+                gr.Markdown("#### Experimental")
+                threads = gr.Textbox(label="number of threads", interactive=True)
+                prompt = gr.Textbox(
+                    label="Custom Prompt for llm", interactive=True, visible=False
+                )
+                envs.append(prompt)
+
             def on_select_service(service, evt: gr.EventData):
                 translator = service_map[service]
                 _envs = []
-                for i in range(3):
+                for i in range(4):
                     _envs.append(gr.update(visible=False, value=""))
                 for i, env in enumerate(translator.envs.items()):
                     _envs[i] = gr.update(
                         visible=True, label=env[0], value=os.getenv(env[0], env[1])
                     )
+                _envs[-1] = gr.update(visible=translator.CustomPrompt)
                 return _envs
 
             def on_select_filetype(file_type):
@@ -336,6 +365,12 @@ with gr.Blocks(
                     gr.update(visible=file_type == "Link"),
                 )
 
+            def on_select_page(choice):
+                if choice == "Others":
+                    return gr.update(visible=True)
+                else:
+                    return gr.update(visible=False)
+
             output_title = gr.Markdown("## Translated", visible=False)
             output_file_mono = gr.File(
                 label="Download Translation (Mono)", visible=False
@@ -358,6 +393,7 @@ with gr.Blocks(
                 """,
                 elem_classes=["secondary-text"],
             )
+            page_range.select(on_select_page, page_range, page_input)
             service.select(
                 on_select_service,
                 service,
@@ -422,6 +458,9 @@ with gr.Blocks(
             lang_from,
             lang_to,
             page_range,
+            page_input,
+            prompt,
+            threads,
             recaptcha_response,
             state,
             *envs,
@@ -445,7 +484,7 @@ with gr.Blocks(
 def readuserandpasswd(file_path):
     tuple_list = []
     content = ""
-    if file_path is None:
+    if not file_path:
         return tuple_list, content
     if len(file_path) == 2:
         try:

+ 9 - 0
pdf2zh/translator.py

@@ -26,6 +26,7 @@ class BaseTranslator:
     name = "base"
     envs = {}
     lang_map = {}
+    CustomPrompt = False
 
     def __init__(self, lang_in, lang_out, model):
         lang_in = self.lang_map.get(lang_in.lower(), lang_in)
@@ -200,6 +201,7 @@ class OllamaTranslator(BaseTranslator):
         "OLLAMA_HOST": "http://127.0.0.1:11434",
         "OLLAMA_MODEL": "gemma2",
     }
+    CustomPrompt = True
 
     def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
         self.set_envs(envs)
@@ -230,6 +232,7 @@ class OpenAITranslator(BaseTranslator):
         "OPENAI_API_KEY": None,
         "OPENAI_MODEL": "gpt-4o-mini",
     }
+    CustomPrompt = True
 
     def __init__(
         self,
@@ -265,6 +268,7 @@ class AzureOpenAITranslator(BaseTranslator):
         "AZURE_OPENAI_API_KEY": None,
         "AZURE_OPENAI_MODEL": "gpt-4o-mini",
     }
+    CustomPrompt = True
 
     def __init__(
         self,
@@ -306,6 +310,7 @@ class ModelScopeTranslator(OpenAITranslator):
         "MODELSCOPE_API_KEY": None,
         "MODELSCOPE_MODEL": "Qwen/Qwen2.5-32B-Instruct",
     }
+    CustomPrompt = True
 
     def __init__(
         self,
@@ -333,6 +338,7 @@ class ZhipuTranslator(OpenAITranslator):
         "ZHIPU_API_KEY": None,
         "ZHIPU_MODEL": "glm-4-flash",
     }
+    CustomPrompt = True
 
     def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
         self.set_envs(envs)
@@ -367,6 +373,7 @@ class SiliconTranslator(OpenAITranslator):
         "SILICON_API_KEY": None,
         "SILICON_MODEL": "Qwen/Qwen2.5-7B-Instruct",
     }
+    CustomPrompt = True
 
     def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
         self.set_envs(envs)
@@ -385,6 +392,7 @@ class GeminiTranslator(OpenAITranslator):
         "GEMINI_API_KEY": None,
         "GEMINI_MODEL": "gemini-1.5-flash",
     }
+    CustomPrompt = True
 
     def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
         self.set_envs(envs)
@@ -458,6 +466,7 @@ class AnythingLLMTranslator(BaseTranslator):
         "AnythingLLM_URL": None,
         "AnythingLLM_APIKEY": None,
     }
+    CustomPrompt = True
 
     def __init__(self, lang_out, lang_in, model, envs=None, prompt=None):
         self.set_envs(envs)