瀏覽代碼

add ollama

Byaidu 1 年之前
父節點
當前提交
cef7512295
共有 7 個文件被更改,包括 80 次插入27 次删除
  1. 6 0
      README.md
  2. 1 1
      pdf2zh/__init__.py
  3. 9 25
      pdf2zh/converter.py
  4. 3 1
      pdf2zh/high_level.py
  5. 8 0
      pdf2zh/pdf2zh.py
  6. 52 0
      pdf2zh/translator.py
  7. 1 0
      setup.py

+ 6 - 0
README.md

@@ -45,6 +45,12 @@ pdf2zh example.pdf -p 1-3,5
 pdf2zh example.pdf -li en -lo ja
 pdf2zh example.pdf -li en -lo ja
 ```
 ```
 
 
+### Translate with Ollama
+
+```bash
+pdf2zh example.pdf -s gemma2
+```
+
 ### Use regex to specify formula fonts and characters that need to be preserved
 ### Use regex to specify formula fonts and characters that need to be preserved
 
 
 Hint: Starting from `\ufb00` is English style ligature.
 Hint: Starting from `\ufb00` is English style ligature.

+ 1 - 1
pdf2zh/__init__.py

@@ -1,2 +1,2 @@
-__version__ = "1.5.9"
+__version__ = "1.6.0"
 __author__ = "Byaidu"
 __author__ = "Byaidu"

+ 9 - 25
pdf2zh/converter.py

@@ -16,12 +16,10 @@ from typing import (
 )
 )
 import concurrent.futures
 import concurrent.futures
 import numpy as np
 import numpy as np
-import html
-import requests
 import unicodedata
 import unicodedata
-import tqdm.auto
 from tenacity import retry
 from tenacity import retry
 from pdf2zh import cache
 from pdf2zh import cache
+from pdf2zh.translator import BaseTranslator, GoogleTranslator, OllamaTranslator
 def remove_control_characters(s):
 def remove_control_characters(s):
     return "".join(ch for ch in s if unicodedata.category(ch)[0]!="C")
     return "".join(ch for ch in s if unicodedata.category(ch)[0]!="C")
 
 
@@ -71,22 +69,6 @@ from pdf2zh.utils import (
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 
 
-class Translator:
-    def __init__(self):
-        self.session = requests.Session()
-        self.base_link = "http://translate.google.com/m"
-        self.headers = {'User-Agent':'Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)'}
-    
-    def translate(self, to_translate, to_language="auto", from_language="auto"):
-        to_translate=to_translate[:5000] # Max Length
-        response = self.session.get(self.base_link, params={'tl':to_language,'sl':from_language,'q':to_translate}, headers=self.headers)
-        re_result = re.findall(r'(?s)class="(?:t0|result-container)">(.*?)<', response.text)
-        if len(re_result) == 0:
-            raise ValueError('Empty translation result')
-        else:
-            result = html.unescape(re_result[0])
-        return result
-
 class PDFLayoutAnalyzer(PDFTextDevice):
 class PDFLayoutAnalyzer(PDFTextDevice):
     cur_item: LTLayoutContainer
     cur_item: LTLayoutContainer
     ctm: Matrix
     ctm: Matrix
@@ -368,6 +350,7 @@ class TextConverter(PDFConverter[AnyIO]):
         layout = {},
         layout = {},
         lang_in: str = "",
         lang_in: str = "",
         lang_out: str = "",
         lang_out: str = "",
+        service: str = "",
     ) -> None:
     ) -> None:
         super().__init__(rsrcmgr, outfp, codec=codec, pageno=pageno, laparams=laparams)
         super().__init__(rsrcmgr, outfp, codec=codec, pageno=pageno, laparams=laparams)
         self.showpageno = showpageno
         self.showpageno = showpageno
@@ -376,9 +359,10 @@ class TextConverter(PDFConverter[AnyIO]):
         self.vchar = vchar
         self.vchar = vchar
         self.thread = thread
         self.thread = thread
         self.layout = layout
         self.layout = layout
-        self.lang_in = lang_in
-        self.lang_out = lang_out
-        self.translator=Translator()
+        if service=='google':
+            self.translator: BaseTranslator = GoogleTranslator(service,lang_out,lang_in)
+        else:
+            self.translator: BaseTranslator = OllamaTranslator(service,lang_out,lang_in)
 
 
     def write_text(self, text: str) -> None:
     def write_text(self, text: str) -> None:
         text = utils.compatible_encode_method(text, self.codec, "ignore")
         text = utils.compatible_encode_method(text, self.codec, "ignore")
@@ -510,10 +494,10 @@ class TextConverter(PDFConverter[AnyIO]):
             @retry
             @retry
             def worker(s): # 多线程翻译
             def worker(s): # 多线程翻译
                 try:
                 try:
-                    hash_key_paragraph = cache.deterministic_hash((s,self.lang_in,self.lang_out))
+                    hash_key_paragraph = cache.deterministic_hash((s,str(self.translator)))
                     new = cache.load_paragraph(hash_key, hash_key_paragraph) # 查询缓存
                     new = cache.load_paragraph(hash_key, hash_key_paragraph) # 查询缓存
                     if new is None:
                     if new is None:
-                        new=self.translator.translate(s,self.lang_out,self.lang_in)
+                        new=self.translator.translate(s)
                         new=remove_control_characters(new)
                         new=remove_control_characters(new)
                         cache.write_paragraph(hash_key, hash_key_paragraph, new)
                         cache.write_paragraph(hash_key, hash_key_paragraph, new)
                     return new
                     return new
@@ -575,7 +559,7 @@ class TextConverter(PDFConverter[AnyIO]):
                     if lb and x+adv>rt+0.1*size: # 到达右边界且原文段落存在换行
                     if lb and x+adv>rt+0.1*size: # 到达右边界且原文段落存在换行
                         x=lt
                         x=lt
                         lang_space={'zh-CN':1.4,'zh-TW':1.4,'ja':1.1,'ko':1.2,'en':1.2,'it':1.1}
                         lang_space={'zh-CN':1.4,'zh-TW':1.4,'ja':1.1,'ko':1.2,'en':1.2,'it':1.1}
-                        y-=size*lang_space.get(self.lang_out,1.2)
+                        y-=size*lang_space.get(self.translator.lang_out,1.2)
                     if vy_regex: # 插入公式
                     if vy_regex: # 插入公式
                         fix=0
                         fix=0
                         if fcur!=None: # 段落内公式修正纵向偏移
                         if fcur!=None: # 段落内公式修正纵向偏移

+ 3 - 1
pdf2zh/high_level.py

@@ -47,6 +47,7 @@ def extract_text_to_fp(
     model = None,
     model = None,
     lang_in: str = "",
     lang_in: str = "",
     lang_out: str = "",
     lang_out: str = "",
+    service: str = "",
     **kwargs: Any,
     **kwargs: Any,
 ) -> None:
 ) -> None:
     """Parses text from inf-file and writes to outfp file-like object.
     """Parses text from inf-file and writes to outfp file-like object.
@@ -105,6 +106,7 @@ def extract_text_to_fp(
             layout=layout,
             layout=layout,
             lang_in=lang_in,
             lang_in=lang_in,
             lang_out=lang_out,
             lang_out=lang_out,
+            service=service,
         )
         )
 
 
     elif output_type == "xml":
     elif output_type == "xml":
@@ -152,7 +154,7 @@ def extract_text_to_fp(
         total_pages=len(pages)
         total_pages=len(pages)
     else:
     else:
         total_pages=page_count
         total_pages=page_count
-    for page in tqdm.auto.tqdm(PDFPage.get_pages(
+    for page in tqdm.tqdm(PDFPage.get_pages(
         inf,
         inf,
         pages,
         pages,
         maxpages=maxpages,
         maxpages=maxpages,

+ 8 - 0
pdf2zh/pdf2zh.py

@@ -55,6 +55,7 @@ def extract_text(
     thread: int = 0,
     thread: int = 0,
     lang_in: str = "",
     lang_in: str = "",
     lang_out: str = "",
     lang_out: str = "",
+    service: str = "",
     **kwargs: Any,
     **kwargs: Any,
 ) -> AnyIO:
 ) -> AnyIO:
     if not files:
     if not files:
@@ -185,6 +186,13 @@ def create_parser() -> argparse.ArgumentParser:
         default="zh-CN",
         default="zh-CN",
         help="The code of target language.",
         help="The code of target language.",
     )
     )
+    parse_params.add_argument(
+        "--service",
+        "-s",
+        type=str,
+        default="google",
+        help="The service to use for translating.",
+    )
     parse_params.add_argument(
     parse_params.add_argument(
         "--thread",
         "--thread",
         "-t",
         "-t",

+ 52 - 0
pdf2zh/translator.py

@@ -0,0 +1,52 @@
+import re
+import html
+import requests
+import ollama
+
+class BaseTranslator:
+    def __init__(self,service,lang_out,lang_in):
+        self.service=service
+        self.lang_out=lang_out
+        self.lang_in=lang_in
+
+    def translate(self,text):
+        pass
+
+    def __str__(self):
+        pass
+
+    def __str__(self):
+        return f'{self.service} {self.lang_out} {self.lang_in}'
+
+class GoogleTranslator(BaseTranslator):
+    def __init__(self,service,lang_out,lang_in):
+        super().__init__(service,lang_out,lang_in)
+        self.session=requests.Session()
+        self.base_link="http://translate.google.com/m"
+        self.headers={'User-Agent':'Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)'}
+
+    def translate(self,text):
+        text=text[:5000] # Max Length
+        response=self.session.get(self.base_link,params={'tl':self.lang_out,'sl':self.lang_in,'q':text},headers=self.headers)
+        re_result=re.findall(r'(?s)class="(?:t0|result-container)">(.*?)<',response.text)
+        if len(re_result) == 0:
+            raise ValueError('Empty translation result')
+        else:
+            result=html.unescape(re_result[0])
+        return result
+
+class OllamaTranslator(BaseTranslator):
+    def __init__(self,service,lang_out,lang_in):
+        super().__init__(service,lang_out,lang_in)
+        self.model=service
+        self.options={'temperature':0} # 随机采样可能会打断公式标记
+
+    def translate(self,text):
+        result=ollama.chat(model=self.model,options=self.options,messages=[
+            {
+                'role': 'system',
+                'content': 'You are a professional,authentic machine translation engine.',
+            },
+            { 'role': 'user','content': f'Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:' }
+        ])['message']['content'].strip()
+        return result

+ 1 - 0
setup.py

@@ -26,6 +26,7 @@ setup(
         "tenacity",
         "tenacity",
         "doclayout-yolo",
         "doclayout-yolo",
         "numpy",
         "numpy",
+        "ollama",
     ],
     ],
     classifiers=[
     classifiers=[
         "Programming Language :: Python :: 3",
         "Programming Language :: Python :: 3",