Selaa lähdekoodia

format code and add pre commit

Yadomin Jinta 1 vuosi sitten
vanhempi
sitoutus
8514737ca6

+ 15 - 0
.pre-commit-config.yaml

@@ -0,0 +1,15 @@
+# See https://pre-commit.com for more information
+# See https://pre-commit.com/hooks.html for more hooks
+repos:
+-   repo: local
+    hooks:
+    - id: black
+      name: black
+      entry: black --check --diff --color
+      language: python
+      files: ".py"
+    - id: flake8
+      name: flake8
+      entry: flake8
+      language: python
+      files: ".py"

+ 18 - 10
pdf2zh/cache.py

@@ -3,9 +3,10 @@ import os
 import time
 import hashlib
 import shutil
-cache_dir = os.path.join(tempfile.gettempdir(), 'cache')
+
+cache_dir = os.path.join(tempfile.gettempdir(), "cache")
 os.makedirs(cache_dir, exist_ok=True)
-time_filename = 'update_time'
+time_filename = "update_time"
 max_cache = 5
 
 
@@ -16,25 +17,30 @@ def deterministic_hash(obj):
 
 
 def get_dirs():
-    dirs = [os.path.join(cache_dir, dir) for dir in os.listdir(cache_dir) if os.path.isdir(os.path.join(cache_dir, dir))]
+    dirs = [
+        os.path.join(cache_dir, dir)
+        for dir in os.listdir(cache_dir)
+        if os.path.isdir(os.path.join(cache_dir, dir))
+    ]
     return dirs
 
 
 def get_time(dir):
     try:
         timefile = os.path.join(dir, time_filename)
-        t = float(open(timefile, encoding='utf-8').read())
+        t = float(open(timefile, encoding="utf-8").read())
         return t
     except FileNotFoundError:
         # handle the error as needed, for now we'll just return a default value
-        return float('inf')  # This ensures that this directory will be the first to be removed if required
-
+        return float(
+            "inf"
+        )  # This ensures that this directory will be the first to be removed if required
 
 
 def write_time(dir):
     timefile = os.path.join(dir, time_filename)
     t = time.time()
-    print(t, file=open(timefile, "w", encoding='utf-8'), end='')
+    print(t, file=open(timefile, "w", encoding="utf-8"), end="")
 
 
 def argmin(iterable):
@@ -44,7 +50,9 @@ def argmin(iterable):
 def remove_extra():
     dirs = get_dirs()
     for dir in dirs:
-        if not os.path.isdir(dir):  # This line might be redundant now, as get_dirs() ensures only directories are returned
+        if not os.path.isdir(
+            dir
+        ):  # This line might be redundant now, as get_dirs() ensures only directories are returned
             os.remove(dir)
         try:
             get_time(dir)
@@ -73,11 +81,11 @@ def create_cache(hash_key):
 def load_paragraph(hash_key, hash_key_paragraph):
     filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
     if os.path.exists(filename):
-        return open(filename, encoding='utf-8').read()
+        return open(filename, encoding="utf-8").read()
     else:
         return None
 
 
 def write_paragraph(hash_key, hash_key_paragraph, paragraph):
     filename = os.path.join(cache_dir, hash_key, hash_key_paragraph)
-    print(paragraph, file=open(filename, "w", encoding='utf-8'), end='')
+    print(paragraph, file=open(filename, "w", encoding="utf-8"), end="")

+ 324 - 214
pdf2zh/converter.py

@@ -1,3 +1,45 @@
+from pdf2zh.utils import (
+    AnyIO,
+    Matrix,
+    PathSegment,
+    Point,
+    Rect,
+    apply_matrix_pt,
+    bbox2str,
+    enc,
+    make_compat_str,
+    mult_matrix,
+    matrix_scale,
+)
+from pdf2zh.pdftypes import PDFStream
+from pdf2zh.pdfpage import PDFPage
+from pdf2zh.pdfinterp import PDFGraphicState, PDFResourceManager
+from pdf2zh.pdffont import PDFFont, PDFUnicodeNotDefined, PDFCIDFont
+from pdf2zh.pdfexceptions import PDFValueError
+from pdf2zh.pdfdevice import PDFTextDevice
+from pdf2zh.pdfcolor import PDFColorSpace
+from pdf2zh.layout import (
+    LAParams,
+    LTAnno,
+    LTChar,
+    LTComponent,
+    LTCurve,
+    LTFigure,
+    LTImage,
+    LTItem,
+    LTLayoutContainer,
+    LTLine,
+    LTPage,
+    LTRect,
+    LTText,
+    LTTextBox,
+    LTTextBoxVertical,
+    LTTextGroup,
+    LTTextLine,
+    TextGroupElement,
+)
+from pdf2zh.image import ImageWriter
+from pdf2zh import utils
 import io
 import logging
 import re
@@ -28,55 +70,15 @@ from pdf2zh.translator import (
     OpenAITranslator,
     AzureTranslator,
 )
+
+
 def remove_control_characters(s):
-    return "".join(ch for ch in s if unicodedata.category(ch)[0]!="C")
+    return "".join(ch for ch in s if unicodedata.category(ch)[0] != "C")
 
-from pdf2zh import utils
-from pdf2zh.image import ImageWriter
-from pdf2zh.layout import (
-    LAParams,
-    LTAnno,
-    LTChar,
-    LTComponent,
-    LTContainer,
-    LTCurve,
-    LTFigure,
-    LTImage,
-    LTItem,
-    LTLayoutContainer,
-    LTLine,
-    LTPage,
-    LTRect,
-    LTText,
-    LTTextBox,
-    LTTextBoxVertical,
-    LTTextGroup,
-    LTTextLine,
-    TextGroupElement,
-)
-from pdf2zh.pdfcolor import PDFColorSpace
-from pdf2zh.pdfdevice import PDFTextDevice
-from pdf2zh.pdfexceptions import PDFValueError
-from pdf2zh.pdffont import PDFFont, PDFUnicodeNotDefined, PDFCIDFont
-from pdf2zh.pdfinterp import PDFGraphicState, PDFResourceManager
-from pdf2zh.pdfpage import PDFPage
-from pdf2zh.pdftypes import PDFStream
-from pdf2zh.utils import (
-    AnyIO,
-    Matrix,
-    PathSegment,
-    Point,
-    Rect,
-    apply_matrix_pt,
-    bbox2str,
-    enc,
-    make_compat_str,
-    mult_matrix,
-    matrix_scale,
-)
 
 log = logging.getLogger(__name__)
 
+
 class PDFLayoutAnalyzer(PDFTextDevice):
     cur_item: LTLayoutContainer
     ctm: Matrix
@@ -188,7 +190,7 @@ class PDFLayoutAnalyzer(PDFTextDevice):
                 # Note: 'ml', in conditional above, is a frequent anomaly
                 # that we want to support.
                 line = LTLine(
-                    gstate.linewidth*matrix_scale(self.ctm),
+                    gstate.linewidth * matrix_scale(self.ctm),
                     pts[0],
                     pts[1],
                     stroke,
@@ -210,7 +212,7 @@ class PDFLayoutAnalyzer(PDFTextDevice):
                 ) or (y0 == y1 and x1 == x2 and y2 == y3 and x3 == x0)
                 if is_closed_loop and has_square_coordinates:
                     rect = LTRect(
-                        gstate.linewidth*matrix_scale(self.ctm),
+                        gstate.linewidth * matrix_scale(self.ctm),
                         (*pts[0], *pts[2]),
                         stroke,
                         fill,
@@ -223,7 +225,7 @@ class PDFLayoutAnalyzer(PDFTextDevice):
                     self.cur_item.add(rect)
                 else:
                     curve = LTCurve(
-                        gstate.linewidth*matrix_scale(self.ctm),
+                        gstate.linewidth * matrix_scale(self.ctm),
                         pts,
                         stroke,
                         fill,
@@ -236,7 +238,7 @@ class PDFLayoutAnalyzer(PDFTextDevice):
                     self.cur_item.add(curve)
             else:
                 curve = LTCurve(
-                    gstate.linewidth*matrix_scale(self.ctm),
+                    gstate.linewidth * matrix_scale(self.ctm),
                     pts,
                     stroke,
                     fill,
@@ -279,7 +281,7 @@ class PDFLayoutAnalyzer(PDFTextDevice):
             graphicstate,
         )
         self.cur_item.add(item)
-        item.cid=cid # hack
+        item.cid = cid  # hack
         return item.adv
 
     def handle_undefined_char(self, font: PDFFont, cid: int) -> str:
@@ -355,7 +357,7 @@ class TextConverter(PDFConverter[AnyIO]):
         vfont: str = None,
         vchar: str = None,
         thread: int = 0,
-        layout = {},
+        layout={},
         lang_in: str = "",
         lang_out: str = "",
         service: str = "",
@@ -367,7 +369,7 @@ class TextConverter(PDFConverter[AnyIO]):
         self.vchar = vchar
         self.thread = thread
         self.layout = layout
-        param=service.split(':',1)
+        param = service.split(":", 1)
         if param[0] == "google":
             self.translator: BaseTranslator = GoogleTranslator(
                 service, lang_out, lang_in, None
@@ -384,11 +386,11 @@ class TextConverter(PDFConverter[AnyIO]):
             self.translator: BaseTranslator = OllamaTranslator(
                 service, lang_out, lang_in, param[1]
             )
-        elif param[0] == 'openai':
+        elif param[0] == "openai":
             self.translator: BaseTranslator = OpenAITranslator(
                 service, lang_out, lang_in, param[1]
             )
-        elif param[0] == 'azure':
+        elif param[0] == "azure":
             self.translator: BaseTranslator = AzureTranslator(
                 service, lang_out, lang_in, None
             )
@@ -404,173 +406,255 @@ class TextConverter(PDFConverter[AnyIO]):
 
     def receive_layout(self, ltpage: LTPage):
         def render(item: LTItem) -> None:
-            xt=None # 上一个字符
-            sstk=[] # 段落文字栈
-            vstk=[] # 公式符号组
-            vlstk=[] # 公式线条组
-            vfix=0 # 公式纵向偏移
-            vbkt=0 # 段落公式括号计数
-            pstk=[] # 段落属性栈
-            lstk=[] # 全局线条栈
-            var=[] # 公式符号组栈
-            varl=[] # 公式线条组栈
-            varf=[] # 公式纵向偏移栈
-            vlen=[] # 公式宽度栈
-            xt_cls=-1 # 上一个字符所属段落
-            vmax=ltpage.width/4 # 行内公式最大宽度
-            ops="" # 渲染结果
-            def vflag(font,char): # 匹配公式(和角标)字体
-                if re.match(r'\(cid:',char):
+            xt = None  # 上一个字符
+            sstk = []  # 段落文字栈
+            vstk = []  # 公式符号组
+            vlstk = []  # 公式线条组
+            vfix = 0  # 公式纵向偏移
+            vbkt = 0  # 段落公式括号计数
+            pstk = []  # 段落属性栈
+            lstk = []  # 全局线条栈
+            var = []  # 公式符号组栈
+            varl = []  # 公式线条组栈
+            varf = []  # 公式纵向偏移栈
+            vlen = []  # 公式宽度栈
+            xt_cls = -1  # 上一个字符所属段落
+            vmax = ltpage.width / 4  # 行内公式最大宽度
+            ops = ""  # 渲染结果
+
+            def vflag(font, char):  # 匹配公式(和角标)字体
+                if re.match(r"\(cid:", char):
                     return True
                 if self.vfont:
-                    if re.match(self.vfont,font):
+                    if re.match(self.vfont, font):
                         return True
                 else:
-                    if re.match(r'(CM[^R]|MS|XY|MT|BL|RM|EU|LA|RS|LINE|TeX-|rsfs|txsy|wasy|.*Mono|.*Code|.*Ital|.*Sym)',font):
+                    if re.match(
+                        r"(CM[^R]|MS|XY|MT|BL|RM|EU|LA|RS|LINE|TeX-|rsfs|txsy|wasy|.*Mono|.*Code|.*Ital|.*Sym)",
+                        font,
+                    ):
                         return True
                 if self.vchar:
-                    if re.match(self.vchar,char):
+                    if re.match(self.vchar, char):
                         return True
                 else:
-                    if char and char!=' ' and (unicodedata.category(char[0]) in ['Lm','Mn','Sk','Sm','Zl','Zp','Zs'] or ord(char[0]) in range(0x370,0x400)): # 文字修饰符、数学符号、分隔符号、希腊字母
+                    if (
+                        char
+                        and char != " "
+                        and (
+                            unicodedata.category(char[0])
+                            in ["Lm", "Mn", "Sk", "Sm", "Zl", "Zp", "Zs"]
+                            or ord(char[0]) in range(0x370, 0x400)
+                        )
+                    ):  # 文字修饰符、数学符号、分隔符号、希腊字母
                         return True
                 return False
-            ptr=0
-            item=list(item)
-            while ptr<len(item): # 识别文字和公式
-                child=item[ptr]
+
+            ptr = 0
+            item = list(item)
+            while ptr < len(item):  # 识别文字和公式
+                child = item[ptr]
                 if isinstance(child, LTChar):
-                    cur_v=False # 公式
-                    fontname=child.fontname.split('+')[-1]
-                    layout=self.layout[ltpage.pageid]
-                    h,w=layout.shape # ltpage.height 可能是 fig 里面的高度,这里统一用 layout.shape
-                    cx,cy=np.clip(int(child.x0),0,w-1),np.clip(int(child.y0),0,h-1)
-                    cls=layout[cy,cx]
+                    cur_v = False  # 公式
+                    fontname = child.fontname.split("+")[-1]
+                    layout = self.layout[ltpage.pageid]
+                    h, w = (
+                        layout.shape
+                    )  # ltpage.height 可能是 fig 里面的高度,这里统一用 layout.shape
+                    cx, cy = np.clip(int(child.x0), 0, w - 1), np.clip(
+                        int(child.y0), 0, h - 1
+                    )
+                    cls = layout[cy, cx]
                     # if log.isEnabledFor(logging.DEBUG):
-                    #     ops+=f'ET [] 0 d 0 J 0.1 w {child.x0:f} {child.y0:f} {child.x1-child.x0:f} {child.y1-child.y0:f} re S Q BT '
-                    if cls==0 or (cls==xt_cls and child.size<pstk[-1][4]*0.79) or vflag(fontname,child.get_text()) or (child.matrix[0]==0 and child.matrix[3]==0): # 有 0.76 的角标和 0.799 的大写,这里用 0.79 取中
-                        cur_v=True
-                    if not cur_v: # 判定括号组是否属于公式
-                        if vstk and child.get_text()=='(':
-                            cur_v=True
-                            vbkt+=1
-                        if vbkt and child.get_text()==')':
-                            cur_v=True
-                            vbkt-=1
-                    if not cur_v or cls!=xt_cls or (abs(child.x0-xt.x0)>vmax and cls!=0): # 公式结束、段落边界、公式换行
-                        if vstk: # 公式出栈
-                            sstk[-1]+=f'$v{len(var)}$'
-                            if not cur_v and cls==xt_cls and child.x0>max([vch.x0 for vch in vstk]): # and child.y1>vstk[0].y0: # 段落内公式转文字,行内公式修正
-                                vfix=vstk[0].y0-child.y0
+                    # ops+=f'ET [] 0 d 0 J 0.1 w {child.x0:f}
+                    # {child.y0:f} {child.x1-child.x0:f} {child.y1-child.y0:f} re S Q BT '
+                    if (
+                        cls == 0
+                        or (cls == xt_cls and child.size < pstk[-1][4] * 0.79)
+                        or vflag(fontname, child.get_text())
+                        or (child.matrix[0] == 0 and child.matrix[3] == 0)
+                    ):  # 有 0.76 的角标和 0.799 的大写,这里用 0.79 取中
+                        cur_v = True
+                    if not cur_v:  # 判定括号组是否属于公式
+                        if vstk and child.get_text() == "(":
+                            cur_v = True
+                            vbkt += 1
+                        if vbkt and child.get_text() == ")":
+                            cur_v = True
+                            vbkt -= 1
+                    if (
+                        not cur_v
+                        or cls != xt_cls
+                        or (abs(child.x0 - xt.x0) > vmax and cls != 0)
+                    ):  # 公式结束、段落边界、公式换行
+                        if vstk:  # 公式出栈
+                            sstk[-1] += f"$v{len(var)}$"
+                            if (
+                                not cur_v
+                                and cls == xt_cls
+                                and child.x0 > max([vch.x0 for vch in vstk])
+                            ):  # and child.y1>vstk[0].y0: # 段落内公式转文字,行内公式修正
+                                vfix = vstk[0].y0 - child.y0
                             var.append(vstk)
                             varl.append(vlstk)
                             varf.append(vfix)
-                            vstk=[]
-                            vlstk=[]
-                            vfix=0
-                    if not vstk: # 非公式或是公式开头
-                        if cls==xt_cls: # 同一段落
-                            if child.x0 > xt.x1 + 1: # 行内空格
-                                sstk[-1]+=' '
-                            elif child.x1 < xt.x0: # 换行空格
-                                sstk[-1]+=' '
-                                pstk[-1][6]=True # 标记原文段落存在换行
+                            vstk = []
+                            vlstk = []
+                            vfix = 0
+                    if not vstk:  # 非公式或是公式开头
+                        if cls == xt_cls:  # 同一段落
+                            if child.x0 > xt.x1 + 1:  # 行内空格
+                                sstk[-1] += " "
+                            elif child.x1 < xt.x0:  # 换行空格
+                                sstk[-1] += " "
+                                pstk[-1][6] = True  # 标记原文段落存在换行
                         else:
                             sstk.append("")
-                            pstk.append([child.y0,child.x0,child.x0,child.x0,child.size,child.font,False])
-                    if not cur_v: # 文字入栈
-                        if child.size>pstk[-1][4]/0.79 or vflag(pstk[-1][5].fontname.split('+')[-1],'') or re.match(r'(.*Medi|.*Bold)',pstk[-1][5].fontname.split('+')[-1],re.IGNORECASE): # 小字体、公式或粗体开头,后续接文字,需要校正字体
-                            pstk[-1][0]-=child.size-pstk[-1][4]
-                            pstk[-1][4]=child.size
-                            pstk[-1][5]=child.font
-                        sstk[-1]+=child.get_text()
-                    else: # 公式入栈
-                        if not vstk and cls==xt_cls and child.x0>xt.x0: # and child.y1>xt.y0: # 段落内文字转公式,行内公式修正
-                            vfix=child.y0-xt.y0
+                            pstk.append(
+                                [
+                                    child.y0,
+                                    child.x0,
+                                    child.x0,
+                                    child.x0,
+                                    child.size,
+                                    child.font,
+                                    False,
+                                ]
+                            )
+                    if not cur_v:  # 文字入栈
+                        if (
+                            child.size > pstk[-1][4] / 0.79
+                            or vflag(pstk[-1][5].fontname.split("+")[-1], "")
+                            or re.match(
+                                r"(.*Medi|.*Bold)",
+                                pstk[-1][5].fontname.split("+")[-1],
+                                re.IGNORECASE,
+                            )
+                        ):  # 小字体、公式或粗体开头,后续接文字,需要校正字体
+                            pstk[-1][0] -= child.size - pstk[-1][4]
+                            pstk[-1][4] = child.size
+                            pstk[-1][5] = child.font
+                        sstk[-1] += child.get_text()
+                    else:  # 公式入栈
+                        if (
+                            not vstk and cls == xt_cls and child.x0 > xt.x0
+                        ):  # and child.y1>xt.y0: # 段落内文字转公式,行内公式修正
+                            vfix = child.y0 - xt.y0
                         vstk.append(child)
                     # 更新段落边界,段落内换行之后可能是公式开头
-                    pstk[-1][2]=min(pstk[-1][2],child.x0)
-                    pstk[-1][3]=max(pstk[-1][3],child.x1)
-                    xt=child
-                    xt_cls=cls
-                elif isinstance(child, LTFigure): # 图表
+                    pstk[-1][2] = min(pstk[-1][2], child.x0)
+                    pstk[-1][3] = max(pstk[-1][3], child.x1)
+                    xt = child
+                    xt_cls = cls
+                elif isinstance(child, LTFigure):  # 图表
                     pass
-                elif isinstance(child, LTLine): # 线条
-                    layout=self.layout[ltpage.pageid]
-                    h,w=layout.shape # ltpage.height 可能是 fig 里面的高度,这里统一用 layout.shape
-                    cx,cy=np.clip(int(child.x0),0,w-1),np.clip(int(child.y0),0,h-1)
-                    cls=layout[cy,cx]
-                    if vstk and cls==xt_cls: # 公式线条
+                elif isinstance(child, LTLine):  # 线条
+                    layout = self.layout[ltpage.pageid]
+                    h, w = (
+                        layout.shape
+                    )  # ltpage.height 可能是 fig 里面的高度,这里统一用 layout.shape
+                    cx, cy = np.clip(int(child.x0), 0, w - 1), np.clip(
+                        int(child.y0), 0, h - 1
+                    )
+                    cls = layout[cy, cx]
+                    if vstk and cls == xt_cls:  # 公式线条
                         vlstk.append(child)
-                    else: # 全局线条
+                    else:  # 全局线条
                         lstk.append(child)
                 else:
                     # print(child)
                     pass
-                ptr+=1
+                ptr += 1
             # 处理结尾
-            if vstk: # 公式出栈
-                sstk[-1]+=f'$v{len(var)}$'
+            if vstk:  # 公式出栈
+                sstk[-1] += f"$v{len(var)}$"
                 var.append(vstk)
                 varl.append(vlstk)
                 varf.append(vfix)
-            log.debug('\n==========[VSTACK]==========\n')
-            for id,v in enumerate(var): # 计算公式宽度
-                l=max([vch.x1 for vch in v])-v[0].x0
-                log.debug(f'< {l:.1f} {v[0].x0:.1f} {v[0].y0:.1f} {v[0].cid} {v[0].fontname} {len(varl[id])} > $v{id}$ = {"".join([ch.get_text() for ch in v])}')
+            log.debug("\n==========[VSTACK]==========\n")
+            for id, v in enumerate(var):  # 计算公式宽度
+                l = max([vch.x1 for vch in v]) - v[0].x0  # noqa: E741
+                log.debug(
+                    f'< {l:.1f} {v[0].x0:.1f} {v[0].y0:.1f} {v[0].cid} {v[0].fontname} {len(varl[id])} > $v{id}$ = {"".join([ch.get_text() for ch in v])}'  # noqa: E501
+                )
                 vlen.append(l)
-            log.debug('\n==========[SSTACK]==========\n')
-            hash_key=cache.deterministic_hash("PDFMathTranslate")
+            log.debug("\n==========[SSTACK]==========\n")
+            hash_key = cache.deterministic_hash("PDFMathTranslate")
             cache.create_cache(hash_key)
+
             @retry(wait=wait_fixed(1))
-            def worker(s): # 多线程翻译
+            def worker(s):  # 多线程翻译
                 try:
-                    hash_key_paragraph = cache.deterministic_hash((s,str(self.translator)))
-                    new = cache.load_paragraph(hash_key, hash_key_paragraph) # 查询缓存
+                    hash_key_paragraph = cache.deterministic_hash(
+                        (s, str(self.translator))
+                    )
+                    new = cache.load_paragraph(hash_key, hash_key_paragraph)  # 查询缓存
                     if new is None:
-                        new=self.translator.translate(s)
-                        new=remove_control_characters(new)
+                        new = self.translator.translate(s)
+                        new = remove_control_characters(new)
                         cache.write_paragraph(hash_key, hash_key_paragraph, new)
                     return new
                 except BaseException as e:
                     if log.isEnabledFor(logging.DEBUG):
                         log.exception(e)
                     else:
-                        log.exception(e,exc_info=False)
+                        log.exception(e, exc_info=False)
                     raise e
-            with concurrent.futures.ThreadPoolExecutor(max_workers=self.thread) as executor:
+
+            with concurrent.futures.ThreadPoolExecutor(
+                max_workers=self.thread
+            ) as executor:
                 news = list(executor.map(worker, sstk))
-            def raw_string(fcur,cstk): # 编码字符串
-                if isinstance(self.fontmap[fcur],PDFCIDFont): # 判断编码长度
+
+            def raw_string(fcur, cstk):  # 编码字符串
+                if isinstance(self.fontmap[fcur], PDFCIDFont):  # 判断编码长度
                     return "".join(["%04x" % ord(c) for c in cstk])
                 else:
                     return "".join(["%02x" % ord(c) for c in cstk])
-            _x,_y=0,0
-            for id,new in enumerate(news): # 排版文字和公式
-                tx=x=pstk[id][1];y=pstk[id][0];lt=pstk[id][2];rt=pstk[id][3];ptr=0;size=pstk[id][4];font=pstk[id][5];lb=pstk[id][6] # 段落属性
-                cstk='' # 单行文字栈
-                fcur=fcur_=None # 单行字体
-                log.debug(f"< {y} {x} {lt} {rt} {size} {font.fontname} {lb} > {sstk[id]} | {new}")
+
+            _x, _y = 0, 0
+            for id, new in enumerate(news):  # 排版文字和公式
+                tx = x = pstk[id][1]
+                y = pstk[id][0]
+                lt = pstk[id][2]
+                rt = pstk[id][3]
+                ptr = 0
+                size = pstk[id][4]
+                font = pstk[id][5]
+                lb = pstk[id][6]  # 段落属性
+                cstk = ""  # 单行文字栈
+                fcur = fcur_ = None  # 单行字体
+                log.debug(
+                    f"< {y} {x} {lt} {rt} {size} {font.fontname} {lb} > {sstk[id]} | {new}"
+                )
                 while True:
-                    if ptr==len(new): # 到达段落结尾
+                    if ptr == len(new):  # 到达段落结尾
                         if cstk:
-                            ops+=f'/{fcur} {size:f} Tf 1 0 0 1 {tx:f} {y:f} Tm [<{raw_string(fcur,cstk)}>] TJ '
+                            ops += f"/{fcur} {size:f} Tf 1 0 0 1 {tx:f} {y:f} Tm [<{raw_string(fcur, cstk)}>] TJ "
                         break
-                    vy_regex=re.match(r'\$?\s*v([\d\s]+)\$',new[ptr:],re.IGNORECASE) # 匹配 $vn$ 公式标记,前面的 $ 有的时候会被丢掉
-                    mod=False # 当前公式是否为文字修饰符
-                    if vy_regex: # 加载公式
-                        ptr+=len(vy_regex.group(0))
+                    vy_regex = re.match(
+                        r"\$?\s*v([\d\s]+)\$", new[ptr:], re.IGNORECASE
+                    )  # 匹配 $vn$ 公式标记,前面的 $ 有的时候会被丢掉
+                    mod = False  # 当前公式是否为文字修饰符
+                    if vy_regex:  # 加载公式
+                        ptr += len(vy_regex.group(0))
                         try:
-                            vid=int(vy_regex.group(1).replace(' ',''))
-                            adv=vlen[vid]
-                        except:
-                            continue # 翻译器可能会自动补个越界的公式标记
-                        if len(var[vid])==1 and unicodedata.category(var[vid][0].get_text()[0]) in ['Lm','Mn','Sk']: # 文字修饰符
-                            mod=True
-                    else: # 加载文字
-                        ch=new[ptr]
+                            vid = int(vy_regex.group(1).replace(" ", ""))
+                            adv = vlen[vid]
+                        except Exception:
+                            continue  # 翻译器可能会自动补个越界的公式标记
+                        if len(var[vid]) == 1 and unicodedata.category(
+                            var[vid][0].get_text()[0]
+                        ) in [
+                            "Lm",
+                            "Mn",
+                            "Sk",
+                        ]:  # 文字修饰符
+                            mod = True
+                    else:  # 加载文字
+                        ch = new[ptr]
                         # if font.char_width(ord(ch)):
-                        fcur_=None
+                        fcur_ = None
                         # 原字体编码容易出问题,这里直接放弃掉
                         # try:
                         #     if font.widths.get(ord(ch)) and font.to_unichr(ord(ch))==ch:
@@ -578,58 +662,84 @@ class TextConverter(PDFConverter[AnyIO]):
                         # except:
                         #     pass
                         try:
-                            if fcur_==None and self.fontmap['tiro'].to_unichr(ord(ch))==ch:
-                                fcur_='tiro' # 默认英文字体
-                        except:
+                            if (
+                                fcur_ is None
+                                and self.fontmap["tiro"].to_unichr(ord(ch)) == ch
+                            ):
+                                fcur_ = "tiro"  # 默认英文字体
+                        except Exception:
                             pass
-                        if fcur_==None:
-                            fcur_='china-ss' # 默认中文字体
+                        if fcur_ is None:
+                            fcur_ = "china-ss"  # 默认中文字体
                         # print(self.fontid[font],fcur_,ch,font.char_width(ord(ch)))
-                        adv=self.fontmap[fcur_].char_width(ord(ch))*size
-                        ptr+=1
-                    if fcur_!=fcur or vy_regex or x+adv>rt+0.1*size: # 输出文字缓冲区:1.字体更新 2.插入公式 3.到达右边界(可能一整行都被符号化,这里需要考虑浮点误差)
+                        adv = self.fontmap[fcur_].char_width(ord(ch)) * size
+                        ptr += 1
+                    if (
+                        fcur_ != fcur or vy_regex or x + adv > rt + 0.1 * size
+                    ):  # 输出文字缓冲区:1.字体更新 2.插入公式 3.到达右边界(可能一整行都被符号化,这里需要考虑浮点误差)
                         if cstk:
-                            ops+=f'/{fcur} {size:f} Tf 1 0 0 1 {tx:f} {y:f} Tm [<{raw_string(fcur,cstk)}>] TJ '
-                            cstk=''
-                    if lb and x+adv>rt+0.1*size: # 到达右边界且原文段落存在换行
-                        x=lt
-                        lang_space={'zh-CN':1.4,'zh-TW':1.4,'ja':1.1,'ko':1.2,'en':1.2} # CJK
-                        y-=size*lang_space.get(self.translator.lang_out,1.1) # 小语种大多适配 1.1
-                    if vy_regex: # 插入公式
-                        fix=0
-                        if fcur!=None: # 段落内公式修正纵向偏移
-                            fix=varf[vid]
-                        for vch in var[vid]: # 排版公式字符
-                            vc=chr(vch.cid)
-                            ops+=f"/{self.fontid[vch.font]} {vch.size:f} Tf 1 0 0 1 {x+vch.x0-var[vid][0].x0:f} {fix+y+vch.y0-var[vid][0].y0:f} Tm [<{raw_string(self.fontid[vch.font],vc)}>] TJ "
+                            ops += f"/{fcur} {size:f} Tf 1 0 0 1 {tx:f} {y:f} Tm [<{raw_string(fcur, cstk)}>] TJ "
+                            cstk = ""
+                    if lb and x + adv > rt + 0.1 * size:  # 到达右边界且原文段落存在换行
+                        x = lt
+                        lang_space = {
+                            "zh-CN": 1.4,
+                            "zh-TW": 1.4,
+                            "ja": 1.1,
+                            "ko": 1.2,
+                            "en": 1.2,
+                        }  # CJK
+                        y -= size * lang_space.get(
+                            self.translator.lang_out, 1.1
+                        )  # 小语种大多适配 1.1
+                    if vy_regex:  # 插入公式
+                        fix = 0
+                        if fcur is not None:  # 段落内公式修正纵向偏移
+                            fix = varf[vid]
+                        for vch in var[vid]:  # 排版公式字符
+                            vc = chr(vch.cid)
+                            ops += f"/{self.fontid[vch.font]} {vch.size:f} Tf 1 0 0 1 {x + vch.x0 - var[vid][0].x0:f} {fix + y + vch.y0 - var[vid][0].y0:f} Tm [<{raw_string(self.fontid[vch.font], vc)}>] TJ "  # noqa: E501
                             if log.isEnabledFor(logging.DEBUG):
-                                lstk.append(LTLine(0.1,(_x,_y),(x+vch.x0-var[vid][0].x0,fix+y+vch.y0-var[vid][0].y0)))
-                                _x,_y=x+vch.x0-var[vid][0].x0,fix+y+vch.y0-var[vid][0].y0
-                        for l in varl[vid]: # 排版公式线条
-                            if l.linewidth<5: # hack
-                                ops+=f"ET q 1 0 0 1 {l.pts[0][0]+x-var[vid][0].x0:f} {l.pts[0][1]+fix+y-var[vid][0].y0:f} cm [] 0 d 0 J {l.linewidth:f} w 0 0 m {l.pts[1][0]-l.pts[0][0]:f} {l.pts[1][1]-l.pts[0][1]:f} l S Q BT "
-                    else: # 插入文字缓冲区
-                        if not cstk: # 单行开头
-                            tx=x
-                            if x==lt and ch==' ': # 消除段落换行空格
-                                adv=0
+                                lstk.append(
+                                    LTLine(
+                                        0.1,
+                                        (_x, _y),
+                                        (
+                                            x + vch.x0 - var[vid][0].x0,
+                                            fix + y + vch.y0 - var[vid][0].y0,
+                                        ),
+                                    )
+                                )
+                                _x, _y = (
+                                    x + vch.x0 - var[vid][0].x0,
+                                    fix + y + vch.y0 - var[vid][0].y0,
+                                )
+                        for l in varl[vid]:  # 排版公式线条 # noqa: E741
+                            if l.linewidth < 5:  # hack
+                                ops += f"ET q 1 0 0 1 {l.pts[0][0] + x - var[vid][0].x0:f} {l.pts[0][1] + fix + y - var[vid][0].y0:f} cm [] 0 d 0 J {l.linewidth:f} w 0 0 m {l.pts[1][0] - l.pts[0][0]:f} {l.pts[1][1] - l.pts[0][1]:f} l S Q BT "  # noqa: E501
+                    else:  # 插入文字缓冲区
+                        if not cstk:  # 单行开头
+                            tx = x
+                            if x == lt and ch == " ":  # 消除段落换行空格
+                                adv = 0
                             else:
-                                cstk+=ch
+                                cstk += ch
                         else:
-                            cstk+=ch
-                    if mod: # 文字修饰符
-                        adv=0
-                    fcur=fcur_
-                    x+=adv
+                            cstk += ch
+                    if mod:  # 文字修饰符
+                        adv = 0
+                    fcur = fcur_
+                    x += adv
                     if log.isEnabledFor(logging.DEBUG):
-                        lstk.append(LTLine(0.1,(_x,_y),(x,y)))
-                        _x,_y=x,y
-            for l in lstk: # 排版全局线条
-                if l.linewidth<5: # hack
-                    ops+=f"ET q 1 0 0 1 {l.pts[0][0]:f} {l.pts[0][1]:f} cm [] 0 d 0 J {l.linewidth:f} w 0 0 m {l.pts[1][0]-l.pts[0][0]:f} {l.pts[1][1]-l.pts[0][1]:f} l S Q BT "
-            ops=f'BT {ops}ET '
+                        lstk.append(LTLine(0.1, (_x, _y), (x, y)))
+                        _x, _y = x, y
+            for l in lstk:  # 排版全局线条 # noqa: E741
+                if l.linewidth < 5:  # hack
+                    ops += f"ET q 1 0 0 1 {l.pts[0][0]:f} {l.pts[0][1]:f} cm [] 0 d 0 J {l.linewidth:f} w 0 0 m {l.pts[1][0] - l.pts[0][0]:f} {l.pts[1][1] - l.pts[0][1]:f} l S Q BT "  # noqa: E501
+            ops = f"BT {ops}ET "
             return ops
-        ops=render(ltpage)
+
+        ops = render(ltpage)
         return ops
 
     # Some dummy functions to save memory/CPU when all that is wanted

+ 1 - 1
pdf2zh/encodingdb.py

@@ -120,7 +120,7 @@ class EncodingDB:
                 elif isinstance(x, PSLiteral):
                     try:
                         cid2unicode[cid] = name2unicode(cast(str, x.name))
-                    except (KeyError, ValueError) as e:
+                    except (KeyError, ValueError):
                         # log.debug(str(e))
                         pass
                     cid += 1

+ 2 - 2
pdf2zh/fontmetrics.py

@@ -9,7 +9,7 @@ The following data were extracted from the AFM files:
 
 """
 
-###  BEGIN Verbatim copy of the license part
+# BEGIN Verbatim copy of the license part
 
 #
 # Adobe Core 35 AFM Files with 314 Glyph Entries - ReadMe
@@ -24,7 +24,7 @@ The following data were extracted from the AFM files:
 # obligation to support the use of the AFM files.
 #
 
-###  END Verbatim copy of the license part
+# END Verbatim copy of the license part
 
 # flake8: noqa
 from typing import Dict

+ 80 - 45
pdf2zh/gui.py

@@ -33,32 +33,32 @@ lang_map = {
 page_map = {
     "All": None,
     "First": [0],
-    "First 5 pages": list(range(0,5)),
+    "First 5 pages": list(range(0, 5)),
 }
 
-flag_demo=False
-if os.environ.get('PDF2ZH_DEMO'):
-    flag_demo=True
+flag_demo = False
+if os.environ.get("PDF2ZH_DEMO"):
+    flag_demo = True
     service_map = {
         "Google": "google",
     }
     page_map = {
         "First": [0],
-        "First 20 pages": list(range(0,20)),
+        "First 20 pages": list(range(0, 20)),
     }
-    client_key=os.environ.get('PDF2ZH_CLIENT_KEY')
-    server_key=os.environ.get('PDF2ZH_SERVER_KEY')
+    client_key = os.environ.get("PDF2ZH_CLIENT_KEY")
+    server_key = os.environ.get("PDF2ZH_SERVER_KEY")
 
 
 def verify_recaptcha(response):
     recaptcha_url = "https://www.google.com/recaptcha/api/siteverify"
 
-    print('reCAPTCHA',server_key,response)
+    print("reCAPTCHA", server_key, response)
 
     data = {"secret": server_key, "response": response}
     result = requests.post(recaptcha_url, data=data).json()
 
-    print('reCAPTCHA',result.get("success"))
+    print("reCAPTCHA", result.get("success"))
 
     return result.get("success")
 
@@ -87,14 +87,20 @@ def upload_file(file, service, progress=gr.Progress()):
 
 
 def translate(
-    file_path, service, model_id, lang, page_range, recaptcha_response, progress=gr.Progress()
+    file_path,
+    service,
+    model_id,
+    lang,
+    page_range,
+    recaptcha_response,
+    progress=gr.Progress(),
 ):
     """Translate PDF content using selected service."""
     if not file_path:
-        raise gr.Error('No input')
+        raise gr.Error("No input")
 
     if flag_demo and not verify_recaptcha(recaptcha_response):
-        raise gr.Error('reCAPTCHA fail')
+        raise gr.Error("reCAPTCHA fail")
 
     progress(0, desc="Starting translation...")
 
@@ -113,30 +119,31 @@ def translate(
         lang_to = "zh-CN" if lang_to == "zh" else lang_to
 
     print(f"Files before translation: {os.listdir(output)}")
-    def progress_bar(t:tqdm.tqdm):
-        progress(t.n/t.total, desc="Translating...")
-
-    param={
-            'files':[file_en],
-            'pages':selected_page,
-            'lang_in':'auto',
-            'lang_out':lang_to,
-            'service':f"{selected_service}:{model_id}",
-            'output':output,
-            'thread':4,
-            'callback':progress_bar,
-           }
+
+    def progress_bar(t: tqdm.tqdm):
+        progress(t.n / t.total, desc="Translating...")
+
+    param = {
+        "files": [file_en],
+        "pages": selected_page,
+        "lang_in": "auto",
+        "lang_out": lang_to,
+        "service": f"{selected_service}:{model_id}",
+        "output": output,
+        "thread": 4,
+        "callback": progress_bar,
+    }
     print(param)
     extract_text(**param)
     print(f"Files after translation: {os.listdir(output)}")
 
     if not file_zh.exists() or not file_dual.exists():
-        raise gr.Error('No output')
+        raise gr.Error("No output")
 
     try:
         translated_preview = pdf_preview(str(file_zh))
-    except Exception as e:
-        raise gr.Error('No preview')
+    except Exception:
+        raise gr.Error("No preview")
 
     progress(1.0, desc="Translation complete!")
 
@@ -175,7 +182,7 @@ with gr.Blocks(
     footer {visibility: hidden}
     .env-warning {color: #dd5500 !important;}
     .env-success {color: #559900 !important;}
-    
+
     @keyframes pulse-background {
         0% { background-color: #FFFFFF; }
         25% { background-color: #FFFFFF; }
@@ -183,7 +190,7 @@ with gr.Blocks(
         75% { background-color: #FFFFFF; }
         100% { background-color: #FFFFFF; }
     }
-    
+
     /* Add dashed border to input-file class */
     .input-file {
         border: 1.2px dashed #165DFF !important;
@@ -221,7 +228,8 @@ with gr.Blocks(
     #     color: #165DFF !important;
     # }
     """,
-    head='''
+    head=(
+        """
     <script src="https://www.google.com/recaptcha/api.js" async defer></script>
     <script type="text/javascript">
         var onVerify = function(token) {
@@ -230,9 +238,14 @@ with gr.Blocks(
             el.dispatchEvent(new Event('input'));
         };
     </script>
-    ''' if flag_demo else None
+    """
+        if flag_demo
+        else None
+    ),
 ) as demo:
-    gr.Markdown("# [PDFMathTranslate @ Github](https://github.com/Byaidu/PDFMathTranslate)")
+    gr.Markdown(
+        "# [PDFMathTranslate @ Github](https://github.com/Byaidu/PDFMathTranslate)"
+    )
 
     with gr.Row():
         with gr.Column(scale=1):
@@ -265,14 +278,15 @@ with gr.Blocks(
             )
             model_id = gr.Textbox(
                 label="Model ID",
-                info="Please enter the identifier of the model you wish to use (e.g., gemma2). This identifier will be used to specify the particular model for translation.",
+                info="Please enter the identifier of the model you wish to use (e.g., gemma2). "
+                "This identifier will be used to specify the particular model for translation.",
                 # value="gemma2",
                 visible=False,  # hide by default
             )
             envs_status = "<span class='env-success'>- Properly configured.</span><br>"
 
             def details_wrapper(text_markdown):
-                text = f""" 
+                text = f"""
                 <details>
                     <summary>Technical details</summary>
                     {text_markdown}
@@ -287,7 +301,11 @@ with gr.Blocks(
                     not os.environ.get(env_var_name)
                     or os.environ.get(env_var_name) == ""
                 ):
-                    envs_status = f"<span class='env-warning'>- Warning: environmental not found or error ({env_var_name}).</span><br>- Please make sure that the environment variables are properly configured (<a href='https://github.com/Byaidu/PDFMathTranslate'>guide</a>).<br>"
+                    envs_status = (
+                        f"<span class='env-warning'>- Warning: environmental not found or error ({env_var_name})."
+                        + "</span><br>- Please make sure that the environment variables are properly configured "
+                        + "(<a href='https://github.com/Byaidu/PDFMathTranslate'>guide</a>).<br>"
+                    )
                 else:
                     value = str(os.environ.get(env_var_name))
                     envs_status = (
@@ -327,7 +345,11 @@ with gr.Blocks(
                     )  # show model id when service is selected
                     envs_status = env_var_checker("OLLAMA_HOST")
                 else:
-                    envs_status = "<span class='env-warning'>- Warning: model not in the list.</span><br>- Please report via (<a href='https://github.com/Byaidu/PDFMathTranslate'>guide</a>).<br>"
+                    envs_status = (
+                        "<span class='env-warning'>- Warning: model not in the list."
+                        "</span><br>- Please report via "
+                        "(<a href='https://github.com/Byaidu/PDFMathTranslate'>guide</a>).<br>"
+                    )
                 return envs_status, model_visibility
 
             output_title = gr.Markdown("## Translated", visible=False)
@@ -335,11 +357,16 @@ with gr.Blocks(
             output_file_dual = gr.File(
                 label="Download Translation (Dual)", visible=False
             )
-            recaptcha_response = gr.Textbox(label="reCAPTCHA Response", elem_id='verify', visible=False)
+            recaptcha_response = gr.Textbox(
+                label="reCAPTCHA Response", elem_id="verify", visible=False
+            )
             if flag_demo:
-                recaptcha_box=gr.HTML(f'<div class="g-recaptcha" data-sitekey="{client_key}" data-callback="onVerify"></div>', visible=False)
+                recaptcha_box = gr.HTML(
+                    f'<div class="g-recaptcha" data-sitekey="{client_key}" data-callback="onVerify"></div>',
+                    visible=False,
+                )
             else:
-                recaptcha_box=gr.HTML()
+                recaptcha_box = gr.HTML()
             translate_btn = gr.Button("Translate", variant="primary", visible=False)
             tech_details_tog = gr.Markdown(
                 details_wrapper(envs_status),
@@ -373,20 +400,28 @@ with gr.Blocks(
 
 
 def setup_gui(share=False):
-    import doclayout_yolo # cache
+    import doclayout_yolo  # cache # noqa: F401
+
     if flag_demo:
-        demo.launch(server_name="0.0.0.0", max_file_size='5mb', inbrowser=True)
+        demo.launch(server_name="0.0.0.0", max_file_size="5mb", inbrowser=True)
     else:
         try:
             demo.launch(server_name="0.0.0.0", debug=True, inbrowser=True, share=share)
         except Exception:
-            print("Error launching GUI using 0.0.0.0.\nThis may be caused by global mode of proxy software.")
+            print(
+                "Error launching GUI using 0.0.0.0.\nThis may be caused by global mode of proxy software."
+            )
             try:
-                demo.launch(server_name="127.0.0.1", debug=True, inbrowser=True, share=share)
+                demo.launch(
+                    server_name="127.0.0.1", debug=True, inbrowser=True, share=share
+                )
             except Exception:
-                print("Error launching GUI using 127.0.0.1.\nThis may be caused by global mode of proxy software.")
+                print(
+                    "Error launching GUI using 127.0.0.1.\nThis may be caused by global mode of proxy software."
+                )
                 demo.launch(debug=True, inbrowser=True, share=True)
 
+
 # For auto-reloading while developing
 if __name__ == "__main__":
     setup_gui()

+ 48 - 30
pdf2zh/high_level.py

@@ -46,7 +46,7 @@ def extract_text_to_fp(
     vchar: str = "",
     thread: int = 0,
     doc_en: Document = None,
-    model = None,
+    model=None,
     lang_in: str = "",
     lang_out: str = "",
     service: str = "",
@@ -91,7 +91,7 @@ def extract_text_to_fp(
 
     rsrcmgr = PDFResourceManager(caching=not disable_caching)
     device: Optional[PDFDevice] = None
-    layout={}
+    layout = {}
 
     if output_type != "text" and outfp == sys.stdout:
         outfp = sys.stdout.buffer
@@ -151,50 +151,68 @@ def extract_text_to_fp(
         raise PDFValueError(msg)
 
     assert device is not None
-    obj_patch={}
+    obj_patch = {}
     interpreter = PDFPageInterpreter(rsrcmgr, device, obj_patch)
     if pages:
-        total_pages=len(pages)
+        total_pages = len(pages)
     else:
-        total_pages=page_count
-    with tqdm.tqdm(PDFPage.get_pages(
-        inf,
-        pages,
-        maxpages=maxpages,
-        password=password,
-        caching=not disable_caching,
-    ), total=total_pages, position=0) as progress:
+        total_pages = page_count
+    with tqdm.tqdm(
+        PDFPage.get_pages(
+            inf,
+            pages,
+            maxpages=maxpages,
+            password=password,
+            caching=not disable_caching,
+        ),
+        total=total_pages,
+        position=0,
+    ) as progress:
         for page in progress:
             if callback:
                 callback(progress)
             pix = doc_en[page.pageno].get_pixmap()
-            image = np.fromstring(pix.samples, np.uint8).reshape(pix.height, pix.width, 3)[:, :, ::-1]
-            page_layout=model.predict(
+            image = np.fromstring(pix.samples, np.uint8).reshape(
+                pix.height, pix.width, 3
+            )[:, :, ::-1]
+            page_layout = model.predict(
                 image,
-                imgsz=int(pix.height/32)*32,
-                device="cuda:0" if torch.cuda.is_available() else "cpu", # Auto-select GPU if available
+                imgsz=int(pix.height / 32) * 32,
+                device=(
+                    "cuda:0" if torch.cuda.is_available() else "cpu"
+                ),  # Auto-select GPU if available
             )[0]
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
-            box=np.ones((pix.height, pix.width))
-            h,w=box.shape
-            vcls=['abandon','figure','table','isolate_formula','formula_caption']
-            for i,d in enumerate(page_layout.boxes):
+            box = np.ones((pix.height, pix.width))
+            h, w = box.shape
+            vcls = ["abandon", "figure", "table", "isolate_formula", "formula_caption"]
+            for i, d in enumerate(page_layout.boxes):
                 if not page_layout.names[int(d.cls)] in vcls:
-                    x0,y0,x1,y1=d.xyxy.squeeze()
-                    x0,y0,x1,y1=np.clip(int(x0-1),0,w-1),np.clip(int(h-y1-1),0,h-1),np.clip(int(x1+1),0,w-1),np.clip(int(h-y0+1),0,h-1)
-                    box[y0:y1,x0:x1]=i+2
-            for i,d in enumerate(page_layout.boxes):
+                    x0, y0, x1, y1 = d.xyxy.squeeze()
+                    x0, y0, x1, y1 = (
+                        np.clip(int(x0 - 1), 0, w - 1),
+                        np.clip(int(h - y1 - 1), 0, h - 1),
+                        np.clip(int(x1 + 1), 0, w - 1),
+                        np.clip(int(h - y0 + 1), 0, h - 1),
+                    )
+                    box[y0:y1, x0:x1] = i + 2
+            for i, d in enumerate(page_layout.boxes):
                 if page_layout.names[int(d.cls)] in vcls:
-                    x0,y0,x1,y1=d.xyxy.squeeze()
-                    x0,y0,x1,y1=np.clip(int(x0-1),0,w-1),np.clip(int(h-y1-1),0,h-1),np.clip(int(x1+1),0,w-1),np.clip(int(h-y0+1),0,h-1)
-                    box[y0:y1,x0:x1]=0
-            layout[page.pageno]=box
+                    x0, y0, x1, y1 = d.xyxy.squeeze()
+                    x0, y0, x1, y1 = (
+                        np.clip(int(x0 - 1), 0, w - 1),
+                        np.clip(int(h - y1 - 1), 0, h - 1),
+                        np.clip(int(x1 + 1), 0, w - 1),
+                        np.clip(int(h - y0 + 1), 0, h - 1),
+                    )
+                    box[y0:y1, x0:x1] = 0
+            layout[page.pageno] = box
             # print(page.number,page_layout)
             page.rotate = (page.rotate + rotation) % 360
             # 新建一个 xref 存放新指令流
-            page.page_xref = doc_en.get_new_xref() # hack
+            page.page_xref = doc_en.get_new_xref()  # hack
             doc_en.update_object(page.page_xref, "<<>>")
-            doc_en.update_stream(page.page_xref,b'')
+            doc_en.update_stream(page.page_xref, b"")
             doc_en[page.pageno].set_contents(page.page_xref)
             interpreter.process_page(page)
 

+ 10 - 3
pdf2zh/layout.py

@@ -368,7 +368,7 @@ class LTChar(LTComponent, LTText):
         LTText.__init__(self)
         self._text = text
         self.matrix = matrix
-        self.font=font
+        self.font = font
         self.fontname = font.fontname
         self.ncs = ncs
         self.graphicstate = graphicstate
@@ -387,7 +387,7 @@ class LTChar(LTComponent, LTText):
             bbox_upper_right = (-vx + fontsize, vy + rise)
         else:
             # horizontal
-            descent = 0 # descent = font.get_descent() * fontsize
+            descent = 0  # descent = font.get_descent() * fontsize
             bbox_lower_left = (0, descent + rise)
             bbox_upper_right = (self.adv, descent + rise + fontsize)
         (a, b, c, d, e, f) = self.matrix
@@ -405,7 +405,14 @@ class LTChar(LTComponent, LTText):
             self.size = self.height
 
     def __repr__(self) -> str:
-        return f"<{self.__class__.__name__} {bbox2str(self.bbox)} matrix={matrix2str(self.matrix)} font={self.fontname!r} adv={self.adv} text={self.get_text()!r}>"
+        return "<{} {} matrix={} font={} adv={} text={}>".format(
+            self.__class__.__name__,
+            bbox2str(self.bbox),
+            matrix2str(self.matrix),
+            repr(self.fontname),
+            self.adv,
+            repr(self.get_text()),
+        )
 
     def get_text(self) -> str:
         return self._text

+ 2 - 1
pdf2zh/pdf2zh.py

@@ -119,7 +119,7 @@ def extract_text(
                                 doc_en.xref_set_key(
                                     xref, f"{label}Font/{font}", f"{font_id[font]} 0 R"
                                 )
-                except:
+                except Exception:
                     pass
         doc_en.save(Path(output) / f"{filename}-en.pdf")
 
@@ -277,6 +277,7 @@ def main(args: Optional[List[str]] = None) -> int:
         return -1
     if parsed_args.interactive:
         from pdf2zh.gui import setup_gui
+
         setup_gui(parsed_args.share)
         return 0
 

+ 1 - 1
pdf2zh/pdfdocument.py

@@ -706,7 +706,7 @@ class PDFDocument:
         try:
             # print('FIND XREF')
             pos = self.find_xref(parser)
-            self.pos=pos
+            self.pos = pos
             self.read_xref_from(parser, pos, self.xrefs)
         except PDFNoValidXRef:
             if fallback:

+ 1 - 1
pdf2zh/pdffont.py

@@ -140,7 +140,7 @@ class Type1FontHeaderParser(PSStackParser[int]):
                 break
             try:
                 self._cid2unicode[cid] = name2unicode(cast(str, name))
-            except KeyError as e:
+            except KeyError:
                 # log.debug(str(e))
                 pass
         return self._cid2unicode

+ 72 - 36
pdf2zh/pdfinterp.py

@@ -368,7 +368,9 @@ class PDFPageInterpreter:
     Reference: PDF Reference, Appendix A, Operator Summary
     """
 
-    def __init__(self, rsrcmgr: PDFResourceManager, device: PDFDevice, obj_patch) -> None:
+    def __init__(
+        self, rsrcmgr: PDFResourceManager, device: PDFDevice, obj_patch
+    ) -> None:
         self.rsrcmgr = rsrcmgr
         self.device = device
         self.obj_patch = obj_patch
@@ -407,7 +409,7 @@ class PDFPageInterpreter:
                         objid = spec.objid
                     spec = dict_value(spec)
                     self.fontmap[fontid] = self.rsrcmgr.get_font(objid, spec)
-                    self.fontid[self.fontmap[fontid]]=fontid
+                    self.fontid[self.fontmap[fontid]] = fontid
             elif k == "ColorSpace":
                 for csid, spec in dict_value(v).items():
                     colorspace = get_colorspace(resolve1(spec))
@@ -570,16 +572,25 @@ class PDFPageInterpreter:
 
     def do_S(self) -> None:
         """Stroke path"""
+
         def is_black(color: Color) -> bool:
             if isinstance(color, Tuple):
-                return sum(color)==0
+                return sum(color) == 0
             else:
-                return color==0
-        if len(self.curpath)==2 and self.curpath[0][0]=='m' and self.curpath[1][0]=='l' and apply_matrix_pt(self.ctm,self.curpath[0][-2:])[1]==apply_matrix_pt(self.ctm,self.curpath[1][-2:])[1] and is_black(self.graphicstate.scolor): # 独立直线,水平,黑色
+                return color == 0
+
+        if (
+            len(self.curpath) == 2
+            and self.curpath[0][0] == "m"
+            and self.curpath[1][0] == "l"
+            and apply_matrix_pt(self.ctm, self.curpath[0][-2:])[1]
+            == apply_matrix_pt(self.ctm, self.curpath[1][-2:])[1]
+            and is_black(self.graphicstate.scolor)
+        ):  # 独立直线,水平,黑色
             # print(apply_matrix_pt(self.ctm,self.curpath[0][-2:]),apply_matrix_pt(self.ctm,self.curpath[1][-2:]),self.graphicstate.scolor)
             self.device.paint_path(self.graphicstate, True, False, False, self.curpath)
             self.curpath = []
-            return 'n'
+            return "n"
         else:
             self.curpath = []
 
@@ -698,7 +709,7 @@ class PDFPageInterpreter:
             if settings.STRICT:
                 raise PDFInterpreterError("No colorspace specified!")
             n = 1
-        args=self.pop(n)
+        args = self.pop(n)
         self.graphicstate.scolor = cast(Color, args)
         return args
 
@@ -710,7 +721,7 @@ class PDFPageInterpreter:
             if settings.STRICT:
                 raise PDFInterpreterError("No colorspace specified!")
             n = 1
-        args=self.pop(n)
+        args = self.pop(n)
         self.graphicstate.ncolor = cast(Color, args)
         return args
 
@@ -963,22 +974,24 @@ class PDFPageInterpreter:
             else:
                 resources = self.resources.copy()
             self.device.begin_figure(xobjid, bbox, matrix)
-            ctm=mult_matrix(matrix, self.ctm)
-            ops_base=interpreter.render_contents(
+            ctm = mult_matrix(matrix, self.ctm)
+            ops_base = interpreter.render_contents(
                 resources,
                 [xobj],
                 ctm=ctm,
             )
-            try: # 有的时候 form 字体加不上这里会烂掉
-                self.device.fontid=interpreter.fontid
-                self.device.fontmap=interpreter.fontmap
-                ops_new=self.device.end_figure(xobjid)
-                ctm_inv=np.linalg.inv(np.array(ctm[:4]).reshape(2,2))
-                pos_inv=-np.mat(ctm[4:])*ctm_inv
-                a,b,c,d=ctm_inv.reshape(4).tolist()
-                e,f=pos_inv.tolist()[0]
-                self.obj_patch[self.xobjmap[xobjid].objid]=f'q {ops_base}Q {a} {b} {c} {d} {e} {f} cm {ops_new}'
-            except:
+            try:  # 有的时候 form 字体加不上这里会烂掉
+                self.device.fontid = interpreter.fontid
+                self.device.fontmap = interpreter.fontmap
+                ops_new = self.device.end_figure(xobjid)
+                ctm_inv = np.linalg.inv(np.array(ctm[:4]).reshape(2, 2))
+                pos_inv = -np.mat(ctm[4:]) * ctm_inv
+                a, b, c, d = ctm_inv.reshape(4).tolist()
+                e, f = pos_inv.tolist()[0]
+                self.obj_patch[self.xobjmap[xobjid].objid] = (
+                    f"q {ops_base}Q {a} {b} {c} {d} {e} {f} cm {ops_new}"
+                )
+            except Exception:
                 pass
         elif subtype is LITERAL_IMAGE and "Width" in xobj and "Height" in xobj:
             self.device.begin_figure(xobjid, (0, 0, 1, 1), MATRIX_IDENTITY)
@@ -1002,14 +1015,16 @@ class PDFPageInterpreter:
         else:
             ctm = (1, 0, 0, 1, -x0, -y0)
         self.device.begin_page(page, ctm)
-        ops_base=self.render_contents(page.resources, page.contents, ctm=ctm)
-        self.device.fontid=self.fontid
-        self.device.fontmap=self.fontmap
-        ops_new=self.device.end_page(page)
+        ops_base = self.render_contents(page.resources, page.contents, ctm=ctm)
+        self.device.fontid = self.fontid
+        self.device.fontmap = self.fontmap
+        ops_new = self.device.end_page(page)
         # 上面渲染的时候会根据 cropbox 减掉页面偏移得到真实坐标,这里输出的时候需要用 cm 把页面偏移加回来
-        self.obj_patch[page.page_xref]=f'q {ops_base}Q 1 0 0 1 {x0} {y0} cm {ops_new}' # ops_base 里可能有图,需要让 ops_new 里的文字覆盖在上面,使用 q/Q 重置位置矩阵
+        self.obj_patch[page.page_xref] = (
+            f"q {ops_base}Q 1 0 0 1 {x0} {y0} cm {ops_new}"  # ops_base 里可能有图,需要让 ops_new 里的文字覆盖在上面,使用 q/Q 重置位置矩阵
+        )
         for obj in page.contents:
-            self.obj_patch[obj.objid]=''
+            self.obj_patch[obj.objid] = ""
 
     def render_contents(
         self,
@@ -1032,7 +1047,7 @@ class PDFPageInterpreter:
         return self.execute(list_value(streams))
 
     def execute(self, streams: Sequence[object]) -> None:
-        ops=''
+        ops = ""
         try:
             parser = PDFContentParser(streams)
         except PSEOF:
@@ -1057,17 +1072,38 @@ class PDFPageInterpreter:
                         # log.debug("exec: %s %r", name, args)
                         if len(args) == nargs:
                             func(*args)
-                            if not (name[0]=='T' or name in ['"',"'",'EI','MP','DP','BMC','BDC']): # 过滤 T 系列文字指令,因为 EI 的参数是 obj 所以也需要过滤(只在少数文档中画横线时使用),过滤 marked 系列指令
-                                p=" ".join([f'{x:f}' if isinstance(x,float) else str(x).replace("'","") for x in args])
-                                ops+=f'{p} {name} '
+                            if not (
+                                name[0] == "T"
+                                or name in ['"', "'", "EI", "MP", "DP", "BMC", "BDC"]
+                            ):  # 过滤 T 系列文字指令,因为 EI 的参数是 obj 所以也需要过滤(只在少数文档中画横线时使用),过滤 marked 系列指令
+                                p = " ".join(
+                                    [
+                                        (
+                                            f"{x:f}"
+                                            if isinstance(x, float)
+                                            else str(x).replace("'", "")
+                                        )
+                                        for x in args
+                                    ]
+                                )
+                                ops += f"{p} {name} "
                     else:
                         # log.debug("exec: %s", name)
-                        targs=func()
-                        if targs==None:
-                            targs=[]
-                        if not (name[0]=='T' or name in ['BI','ID','EMC']):
-                            p=" ".join([f'{x:f}' if isinstance(x,float) else str(x).replace("'","") for x in targs])
-                            ops+=f'{p} {name} '
+                        targs = func()
+                        if targs is None:
+                            targs = []
+                        if not (name[0] == "T" or name in ["BI", "ID", "EMC"]):
+                            p = " ".join(
+                                [
+                                    (
+                                        f"{x:f}"
+                                        if isinstance(x, float)
+                                        else str(x).replace("'", "")
+                                    )
+                                    for x in targs
+                                ]
+                            )
+                            ops += f"{p} {name} "
                 elif settings.STRICT:
                     error_msg = "Unknown operator: %r" % name
                     raise PDFInterpreterError(error_msg)

+ 1 - 1
pdf2zh/pdfpage.py

@@ -188,7 +188,7 @@ class PDFPage:
                 log.warning(warning_msg)
         # Process each page contained in the document.
         for pageno, page in enumerate(cls.create_pages(doc)):
-            page.pageno=pageno
+            page.pageno = pageno
             if pagenos and (pageno not in pagenos):
                 continue
             yield page

+ 4 - 4
pdf2zh/psparser.py

@@ -580,7 +580,7 @@ class PSStackParser(PSBaseParser, Generic[ExtraT]):
 
         :return: keywords, literals, strings, numbers, arrays and dictionaries.
         """
-        end=None
+        end = None
         while not self.results:
             (pos, token) = self.nexttoken()
             if isinstance(token, (int, float, bool, str, bytes, PSLiteral)):
@@ -632,8 +632,8 @@ class PSStackParser(PSBaseParser, Generic[ExtraT]):
                 #     token,
                 #     self.curstack,
                 # )
-                if token.name==b'endobj':
-                    end=pos+7
+                if token.name == b"endobj":
+                    end = pos + 7
                 self.do_keyword(pos, token)
             else:
                 log.error(
@@ -653,4 +653,4 @@ class PSStackParser(PSBaseParser, Generic[ExtraT]):
         #     log.debug("nextobject: %r", obj)
         # except Exception:
         #     log.debug("nextobject: (unprintable object)")
-        return end,obj
+        return end, obj

+ 21 - 24
pdf2zh/translator.py

@@ -19,11 +19,7 @@ class BaseTranslator:
         self.lang_in = lang_in
         self.model = model
 
-    def translate(self, text) -> str:
-        ...
-
-    def __str__(self):
-        pass
+    def translate(self, text) -> str: ...  # noqa: E704
 
     def __str__(self):
         return f"{self.service} {self.lang_out} {self.lang_in}"
@@ -37,7 +33,7 @@ class GoogleTranslator(BaseTranslator):
         self.session = requests.Session()
         self.base_link = "http://translate.google.com/m"
         self.headers = {
-            "User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)"
+            "User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)"  # noqa: E501
         }
 
     def translate(self, text):
@@ -51,7 +47,7 @@ class GoogleTranslator(BaseTranslator):
             r'(?s)class="(?:t0|result-container)">(.*?)<', response.text
         )
         if response.status_code == 400:
-            result = 'IRREPARABLE TRANSLATION ERROR'
+            result = "IRREPARABLE TRANSLATION ERROR"
         elif len(re_result) == 0:
             raise ValueError("Empty translation result")
         else:
@@ -80,7 +76,7 @@ class DeepLXTranslator(BaseTranslator):
         self.session = requests.Session()
         self.base_link = f"{server_url}/{auth_key}/translate"
         self.headers = {
-            "User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)"
+            "User-Agent": "Mozilla/4.0 (compatible;MSIE 6.0;Windows NT 5.1;SV1;.NET CLR 1.1.4322;.NET CLR 2.0.50727;.NET CLR 3.0.04506.30)"  # noqa: E501
         }
 
     def translate(self, text):
@@ -115,27 +111,25 @@ class DeepLXTranslator(BaseTranslator):
 
 class DeepLTranslator(BaseTranslator):
     def __init__(self, service, lang_out, lang_in, model):
-        lang_out='ZH' if lang_out=='auto' else lang_out
-        lang_in='EN' if lang_in=='auto' else lang_in
+        lang_out = "ZH" if lang_out == "auto" else lang_out
+        lang_in = "EN" if lang_in == "auto" else lang_in
         super().__init__(service, lang_out, lang_in, model)
         self.session = requests.Session()
-        auth_key = os.getenv('DEEPL_AUTH_KEY')
-        server_url = os.getenv('DEEPL_SERVER_URL')
+        auth_key = os.getenv("DEEPL_AUTH_KEY")
+        server_url = os.getenv("DEEPL_SERVER_URL")
         self.client = deepl.Translator(auth_key, server_url=server_url)
 
     def translate(self, text):
         response = self.client.translate_text(
-            text,
-            target_lang=self.lang_out,
-            source_lang=self.lang_in
+            text, target_lang=self.lang_out, source_lang=self.lang_in
         )
         return response.text
 
 
 class OllamaTranslator(BaseTranslator):
     def __init__(self, service, lang_out, lang_in, model):
-        lang_out='zh-CN' if lang_out=='auto' else lang_out
-        lang_in='en' if lang_in=='auto' else lang_in
+        lang_out = "zh-CN" if lang_out == "auto" else lang_out
+        lang_in = "en" if lang_in == "auto" else lang_in
         super().__init__(service, lang_out, lang_in, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         # OLLAMA_HOST
@@ -152,16 +146,17 @@ class OllamaTranslator(BaseTranslator):
                 },
                 {
                     "role": "user",
-                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",
+                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",  # noqa: E501
                 },
             ],
         )
         return response["message"]["content"].strip()
 
+
 class OpenAITranslator(BaseTranslator):
     def __init__(self, service, lang_out, lang_in, model):
-        lang_out='zh-CN' if lang_out=='auto' else lang_out
-        lang_in='en' if lang_in=='auto' else lang_in
+        lang_out = "zh-CN" if lang_out == "auto" else lang_out
+        lang_in = "en" if lang_in == "auto" else lang_in
         super().__init__(service, lang_out, lang_in, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         # OPENAI_BASE_URL
@@ -179,7 +174,7 @@ class OpenAITranslator(BaseTranslator):
                 },
                 {
                     "role": "user",
-                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",
+                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",  # noqa: E501
                 },
             ],
         )
@@ -188,8 +183,8 @@ class OpenAITranslator(BaseTranslator):
 
 class AzureTranslator(BaseTranslator):
     def __init__(self, service, lang_out, lang_in, model):
-        lang_out='zh-Hans' if lang_out=='auto' else lang_out
-        lang_in='en' if lang_in=='auto' else lang_in
+        lang_out = "zh-Hans" if lang_out == "auto" else lang_out
+        lang_in = "en" if lang_in == "auto" else lang_in
         super().__init__(service, lang_out, lang_in, model)
 
         try:
@@ -198,7 +193,9 @@ class AzureTranslator(BaseTranslator):
             region = os.environ["AZURE_REGION"]
         except KeyError as e:
             missing_var = e.args[0]
-            raise ValueError(f"The environment variable '{missing_var}' is required but not set.") from e
+            raise ValueError(
+                f"The environment variable '{missing_var}' is required but not set."
+            ) from e
 
         credential = AzureKeyCredential(api_key)
         self.client = TextTranslationClient(

+ 3 - 1
pdf2zh/utils.py

@@ -284,9 +284,11 @@ def apply_matrix_norm(m: Matrix, v: Point) -> Point:
     (p, q) = v
     return a * p + c * q, b * p + d * q
 
+
 def matrix_scale(m: Matrix) -> float:
     (a, b, c, d, e, f) = m
-    return (a**2+c**2)**0.5
+    return (a**2 + c**2) ** 0.5
+
 
 #  Utility functions
 

+ 4 - 0
setup.cfg

@@ -0,0 +1,4 @@
+[flake8]
+max-line-length = 120
+ignore = E203,W503,E261
+exclude = .git,build,dist,docs