Browse Source

feat(pdf2zh): add support for yadt experimental backend

- import and use `yadt_translate` for yadt backend
- add `--yadt` option to enable yadt backend
- implement `yadt_main` function to handle yadt translation process
- download remote fonts for yadt translation
- configure and use appropriate translator based on service name

refactor(translator): add placeholder methods for rich text and formulars

- add `get_rich_text_left_placeholder` and `get_rich_text_right_placeholder` methods
- add `get_formular_placeholder` method
- update `OpenAITranslator` to override placeholder methods
awwaawwa 1 năm trước cách đây
mục cha
commit
591812af0c
2 tập tin đã thay đổi với 110 bổ sung2 xóa
  1. 91 2
      pdf2zh/pdf2zh.py
  2. 19 0
      pdf2zh/translator.py

+ 91 - 2
pdf2zh/pdf2zh.py

@@ -12,12 +12,13 @@ from string import Template
 from typing import List, Optional
 
 from pdf2zh import __version__, log
-from pdf2zh.high_level import translate
+from pdf2zh.high_level import translate, download_remote_fonts
 from pdf2zh.doclayout import OnnxModel, ModelInstance
 import os
 
 from pdf2zh.config import ConfigManager
-
+from yadt.translation_config import TranslationConfig as YadtConfig
+from yadt.high_level import translate as yadt_translate
 
 def create_parser() -> argparse.ArgumentParser:
     parser = argparse.ArgumentParser(description=__doc__, add_help=True)
@@ -164,6 +165,13 @@ def create_parser() -> argparse.ArgumentParser:
         help="config file.",
     )
 
+    parse_params.add_argument(
+        "--yadt",
+        default=False,
+        action="store_true",
+        help="Use experimental backend yadt.",
+    )
+
     return parser
 
 
@@ -178,6 +186,7 @@ def parse_args(args: Optional[List[str]]) -> argparse.Namespace:
                 pages.extend(range(int(start) - 1, int(end)))
             else:
                 pages.append(int(p) - 1)
+        parsed_args.raw_pages = parsed_args.pages
         parsed_args.pages = pages
 
     return parsed_args
@@ -255,6 +264,8 @@ def main(args: Optional[List[str]] = None) -> int:
             raise ValueError("prompt error.")
 
     print(parsed_args)
+    if parsed_args.yadt:
+        return yadt_main(parsed_args)
     if parsed_args.dir:
         untranlate_file = find_all_files_in_directory(parsed_args.files[0])
         parsed_args.files = untranlate_file
@@ -265,5 +276,83 @@ def main(args: Optional[List[str]] = None) -> int:
     return 0
 
 
+def yadt_main(parsed_args) -> int:
+    if parsed_args.dir:
+        untranlate_file = find_all_files_in_directory(parsed_args.files[0])
+    else:
+        untranlate_file = parsed_args.files
+    lang_in = parsed_args.lang_in
+    lang_out = parsed_args.lang_out
+    outputdir = None
+    if parsed_args.output:
+        outputdir = parsed_args.output
+    font_path = download_remote_fonts(lang_out.lower())
+
+    param = parsed_args.service.split(":", 1)
+    service_name = param[0]
+    service_model = param[1] if len(param) > 1 else None
+    
+    envs = {}
+    prompt = []
+
+    if parsed_args.prompt:
+        try:
+            with open(parsed_args.prompt, "r", encoding="utf-8") as file:
+                content = file.read()
+            prompt = Template(content)
+        except Exception:
+            raise ValueError("prompt error.")
+
+    from pdf2zh.translator import (
+        AzureOpenAITranslator,
+        GoogleTranslator,
+        BingTranslator,
+        DeepLTranslator,
+        DeepLXTranslator,
+        OllamaTranslator,
+        OpenAITranslator,
+        ZhipuTranslator,
+        ModelScopeTranslator,
+        SiliconTranslator,
+        GeminiTranslator,
+        AzureTranslator,
+        TencentTranslator,
+        DifyTranslator,
+        AnythingLLMTranslator,
+        XinferenceTranslator,
+        ArgosTranslator,
+        GorkTranslator,
+        GroqTranslator,
+        DeepseekTranslator,
+        OpenAIlikedTranslator,
+    )
+
+    for translator in [GoogleTranslator, BingTranslator, DeepLTranslator, DeepLXTranslator, OllamaTranslator, XinferenceTranslator, AzureOpenAITranslator,
+                       OpenAITranslator, ZhipuTranslator, ModelScopeTranslator, SiliconTranslator, GeminiTranslator, AzureTranslator, TencentTranslator, DifyTranslator, AnythingLLMTranslator, ArgosTranslator, GorkTranslator, GroqTranslator, DeepseekTranslator, OpenAIlikedTranslator,]:
+        if service_name == translator.name:
+            translator = translator(lang_in, lang_out, service_model, envs=envs, prompt=prompt)
+            break
+    else:
+        raise ValueError("Unsupported translation service")
+
+    for file in untranlate_file:
+        file = file.strip("\"'")
+        yadt_config = YadtConfig(
+            input_file=file,
+            font=font_path,
+            pages=','.join((str(x) for x in parsed_args.raw_pages)),
+            output_dir=outputdir,
+            translator=translator,
+            debug=parsed_args.debug,
+            lang_in=lang_in,
+            lang_out=lang_out,
+            no_dual=False,
+            no_mono=False,
+            qps=parsed_args.thread,
+        )
+        yadt_translate(yadt_config)
+    return 0
+
+
 if __name__ == "__main__":
     sys.exit(main())

+ 19 - 0
pdf2zh/translator.py

@@ -123,6 +123,16 @@ class BaseTranslator:
     def __str__(self):
         return f"{self.name} {self.lang_in} {self.lang_out} {self.model}"
 
+    def get_rich_text_left_placeholder(self, id: int):
+        return f"<b{id}>"
+
+    def get_rich_text_right_placeholder(self, id: int):
+        return f"</b{id}>"
+
+    def get_formular_placeholder(self, id: int):
+        return self.get_rich_text_left_placeholder(
+            id
+        ) + self.get_rich_text_right_placeholder(id)
 
 class GoogleTranslator(BaseTranslator):
     name = "google"
@@ -384,6 +394,15 @@ class OpenAITranslator(BaseTranslator):
         )
         return response.choices[0].message.content.strip()
 
+    def get_formular_placeholder(self, id: int):
+        return "{{v" + str(id) + "}}"
+
+    def get_rich_text_left_placeholder(self, id: int):
+        return self.get_formular_placeholder(id)
+
+    def get_rich_text_right_placeholder(self, id: int):
+        return self.get_formular_placeholder(id + 1)
+
 
 class AzureOpenAITranslator(BaseTranslator):
     name = "azure-openai"