浏览代码

Merge pull request #116 from Wybxc/onnx

feat: onnx support
Byaidu 1 年之前
父节点
当前提交
0a0cb709d7
共有 5 个文件被更改,包括 245 次插入25 次删除
  1. 213 0
      pdf2zh/doclayout.py
  2. 2 7
      pdf2zh/high_level.py
  3. 9 15
      pdf2zh/pdf2zh.py
  4. 13 0
      pdf2zh/utils.py
  5. 8 3
      pyproject.toml

+ 213 - 0
pdf2zh/doclayout.py

@@ -0,0 +1,213 @@
+import abc
+import cv2
+import numpy as np
+import contextlib
+from huggingface_hub import hf_hub_download
+
+
+class DocLayoutModel(abc.ABC):
+    @staticmethod
+    def load_torch():
+        model = TorchModel.from_pretrained(
+            repo_id="juliozhao/DocLayout-YOLO-DocStructBench",
+            filename="doclayout_yolo_docstructbench_imgsz1024.pt",
+        )
+        return model
+
+    @staticmethod
+    def load_onnx():
+        model = OnnxModel.from_pretrained(
+            repo_id="wybxc/DocLayout-YOLO-DocStructBench-onnx",
+            filename="doclayout_yolo_docstructbench_imgsz1024.onnx",
+        )
+        return model
+
+    @staticmethod
+    def load_available():
+        with contextlib.suppress(ImportError):
+            return DocLayoutModel.load_torch()
+
+        with contextlib.suppress(ImportError):
+            return DocLayoutModel.load_onnx()
+
+        raise ImportError(
+            "Please install the `torch` or `onnx` feature to use the DocLayout model."
+        )
+
+    @property
+    @abc.abstractmethod
+    def stride(self) -> int:
+        """Stride of the model input."""
+        pass
+
+    @abc.abstractmethod
+    def predict(self, image, imgsz=1024, **kwargs) -> list:
+        """
+        Predict the layout of a document page.
+
+        Args:
+            image: The image of the document page.
+            imgsz: Resize the image to this size. Must be a multiple of the stride.
+            **kwargs: Additional arguments.
+        """
+        pass
+
+
+class TorchModel(DocLayoutModel):
+    def __init__(self, model_path: str):
+        try:
+            import doclayout_yolo
+        except ImportError:
+            raise ImportError(
+                "Please install the `torch` feature to use the Torch model."
+            )
+
+        self.model_path = model_path
+        self.model = doclayout_yolo.YOLOv10(model_path)
+
+    @staticmethod
+    def from_pretrained(repo_id: str, filename: str):
+        pth = hf_hub_download(repo_id=repo_id, filename=filename)
+        return TorchModel(pth)
+
+    @property
+    def stride(self):
+        return 32
+
+    def predict(self, *args, **kwargs):
+        return self.model.predict(*args, **kwargs)
+
+
+class YoloResult:
+    """Helper class to store detection results from ONNX model."""
+
+    def __init__(self, boxes, names):
+        self.boxes = [YoloBox(data=d) for d in boxes]
+        self.boxes.sort(key=lambda x: x.conf, reverse=True)
+        self.names = names
+
+
+class YoloBox:
+    """Helper class to store detection results from ONNX model."""
+
+    def __init__(self, data):
+        self.xyxy = data[:4]
+        self.conf = data[-2]
+        self.cls = data[-1]
+
+
+class OnnxModel(DocLayoutModel):
+    def __init__(self, model_path: str):
+        import ast
+
+        try:
+
+            import onnx
+            import onnxruntime
+        except ImportError:
+            raise ImportError(
+                "Please install the `onnx` feature to use the ONNX model."
+            )
+
+        self.model_path = model_path
+
+        model = onnx.load(model_path)
+        metadata = {d.key: d.value for d in model.metadata_props}
+        self._stride = ast.literal_eval(metadata["stride"])
+        self._names = ast.literal_eval(metadata["names"])
+
+        self.model = onnxruntime.InferenceSession(model.SerializeToString())
+
+    @staticmethod
+    def from_pretrained(repo_id: str, filename: str):
+        pth = hf_hub_download(repo_id=repo_id, filename=filename)
+        return OnnxModel(pth)
+
+    @property
+    def stride(self):
+        return self._stride
+
+    def resize_and_pad_image(self, image, new_shape):
+        """
+        Resize and pad the image to the specified size, ensuring dimensions are multiples of stride.
+
+        Parameters:
+        - image: Input image
+        - new_shape: Target size (integer or (height, width) tuple)
+        - stride: Padding alignment stride, default 32
+
+        Returns:
+        - Processed image
+        """
+        if isinstance(new_shape, int):
+            new_shape = (new_shape, new_shape)
+
+        h, w = image.shape[:2]
+        new_h, new_w = new_shape
+
+        # Calculate scaling ratio
+        r = min(new_h / h, new_w / w)
+        resized_h, resized_w = int(round(h * r)), int(round(w * r))
+
+        # Resize image
+        image = cv2.resize(
+            image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR
+        )
+
+        # Calculate padding size and align to stride multiple
+        pad_w = (new_w - resized_w) % self.stride
+        pad_h = (new_h - resized_h) % self.stride
+        top, bottom = pad_h // 2, pad_h - pad_h // 2
+        left, right = pad_w // 2, pad_w - pad_w // 2
+
+        # Add padding
+        image = cv2.copyMakeBorder(
+            image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
+        )
+
+        return image
+
+    def scale_boxes(self, img1_shape, boxes, img0_shape):
+        """
+        Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
+        specified in (img1_shape) to the shape of a different image (img0_shape).
+
+        Args:
+            img1_shape (tuple): The shape of the image that the bounding boxes are for,
+                in the format of (height, width).
+            boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
+            img0_shape (tuple): the shape of the target image, in the format of (height, width).
+
+        Returns:
+            boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
+        """
+
+        # Calculate scaling ratio
+        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
+
+        # Calculate padding size
+        pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
+        pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
+
+        # Remove padding and scale boxes
+        boxes[..., :4] = (boxes[..., :4] - [pad_x, pad_y, pad_x, pad_y]) / gain
+        return boxes
+
+    def predict(self, image, imgsz=1024, **kwargs):
+        # Preprocess input image
+        orig_h, orig_w = image.shape[:2]
+        pix = self.resize_and_pad_image(image, new_shape=imgsz)
+        pix = np.transpose(pix, (2, 0, 1))  # CHW
+        pix = np.expand_dims(pix, axis=0)  # BCHW
+        pix = pix.astype(np.float32) / 255.0  # Normalize to [0, 1]
+        new_h, new_w = pix.shape[2:]
+
+        # Run inference
+        preds = self.model.run(None, {"images": pix})[0]
+
+        # Postprocess predictions
+        preds = preds[preds[..., 4] > 0.25]
+        preds[..., :4] = self.scale_boxes(
+            (new_h, new_w), preds[..., :4], (orig_h, orig_w)
+        )
+        return [YoloResult(boxes=preds, names=self._names)]

+ 2 - 7
pdf2zh/high_level.py

@@ -4,7 +4,6 @@ import logging
 import sys
 import sys
 from io import StringIO
 from io import StringIO
 from typing import Any, BinaryIO, Container, Iterator, Optional, cast
 from typing import Any, BinaryIO, Container, Iterator, Optional, cast
-import torch
 import numpy as np
 import numpy as np
 import tqdm
 import tqdm
 from pymupdf import Document
 from pymupdf import Document
@@ -22,7 +21,7 @@ from pdf2zh.pdfdevice import PDFDevice, TagExtractor
 from pdf2zh.pdfexceptions import PDFValueError
 from pdf2zh.pdfexceptions import PDFValueError
 from pdf2zh.pdfinterp import PDFPageInterpreter, PDFResourceManager
 from pdf2zh.pdfinterp import PDFPageInterpreter, PDFResourceManager
 from pdf2zh.pdfpage import PDFPage
 from pdf2zh.pdfpage import PDFPage
-from pdf2zh.utils import AnyIO, FileOrName, open_filename
+from pdf2zh.utils import AnyIO, FileOrName, open_filename, get_device
 
 
 
 
 def extract_text_to_fp(
 def extract_text_to_fp(
@@ -176,11 +175,7 @@ def extract_text_to_fp(
                 pix.height, pix.width, 3
                 pix.height, pix.width, 3
             )[:, :, ::-1]
             )[:, :, ::-1]
             page_layout = model.predict(
             page_layout = model.predict(
-                image,
-                imgsz=int(pix.height / 32) * 32,
-                device=(
-                    "cuda:0" if torch.cuda.is_available() else "cpu"
-                ),  # Auto-select GPU if available
+                image, imgsz=int(pix.height / 32) * 32, device=get_device()
             )[0]
             )[0]
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             box = np.ones((pix.height, pix.width))
             box = np.ones((pix.height, pix.width))

+ 9 - 15
pdf2zh/pdf2zh.py

@@ -14,7 +14,6 @@ from pathlib import Path
 from typing import TYPE_CHECKING, Any, Container, Iterable, List, Optional
 from typing import TYPE_CHECKING, Any, Container, Iterable, List, Optional
 
 
 import pymupdf
 import pymupdf
-from huggingface_hub import hf_hub_download
 
 
 from pdf2zh import __version__
 from pdf2zh import __version__
 from pdf2zh.pdfexceptions import PDFValueError
 from pdf2zh.pdfexceptions import PDFValueError
@@ -27,10 +26,14 @@ OUTPUT_TYPES = ((".htm", "html"), (".html", "html"), (".xml", "xml"), (".tag", "
 
 
 
 
 def setup_log() -> None:
 def setup_log() -> None:
-    import doclayout_yolo
-
     logging.basicConfig()
     logging.basicConfig()
-    doclayout_yolo.utils.LOGGER.setLevel(logging.WARNING)
+
+    try:
+        import doclayout_yolo
+
+        doclayout_yolo.utils.LOGGER.setLevel(logging.WARNING)
+    except ImportError:
+        pass
 
 
 
 
 def check_files(files: List[str]) -> List[str]:
 def check_files(files: List[str]) -> List[str]:
@@ -73,8 +76,7 @@ def extract_text(
     output: str = "",
     output: str = "",
     **kwargs: Any,
     **kwargs: Any,
 ) -> AnyIO:
 ) -> AnyIO:
-    import doclayout_yolo
-
+    from pdf2zh.doclayout import DocLayoutModel
     import pdf2zh.high_level
     import pdf2zh.high_level
 
 
     if not files:
     if not files:
@@ -86,15 +88,7 @@ def extract_text(
                 output_type = alttype
                 output_type = alttype
 
 
     outfp: AnyIO = sys.stdout
     outfp: AnyIO = sys.stdout
-    # pth = os.path.join(tempfile.gettempdir(), 'doclayout_yolo_docstructbench_imgsz1024.pt')
-    # if not os.path.exists(pth):
-    #     print('Downloading...')
-    #     urllib.request.urlretrieve("http://huggingface.co/juliozhao/DocLayout-YOLO-DocStructBench/resolve/main/doclayout_yolo_docstructbench_imgsz1024.pt",pth)
-    pth = hf_hub_download(
-        repo_id="juliozhao/DocLayout-YOLO-DocStructBench",
-        filename="doclayout_yolo_docstructbench_imgsz1024.pt",
-    )
-    model = doclayout_yolo.YOLOv10(pth)
+    model = DocLayoutModel.load_available()
 
 
     for file in files:
     for file in files:
         filename = os.path.splitext(os.path.basename(file))[0]
         filename = os.path.splitext(os.path.basename(file))[0]

+ 13 - 0
pdf2zh/utils.py

@@ -819,3 +819,16 @@ def format_int_alpha(value: int) -> str:
 
 
     result.reverse()
     result.reverse()
     return "".join(result)
     return "".join(result)
+
+
+def get_device():
+    """Get the device to use for computation."""
+    try:
+        import torch
+
+        if torch.cuda.is_available():
+            return "cuda:0"
+    except ImportError:
+        pass
+
+    return "cpu"

+ 8 - 3
pyproject.toml

@@ -5,7 +5,7 @@ description = "Latex PDF Translator"
 authors = [{ name = "Byaidu", email = "byaidux@gmail.com" }]
 authors = [{ name = "Byaidu", email = "byaidux@gmail.com" }]
 license = "AGPL-3.0"
 license = "AGPL-3.0"
 readme = "README.md"
 readme = "README.md"
-requires-python = ">=3.8,<3.13"
+requires-python = ">=3.9,<3.13"
 classifiers = [
 classifiers = [
     "Programming Language :: Python :: 3",
     "Programming Language :: Python :: 3",
     "Operating System :: OS Independent",
     "Operating System :: OS Independent",
@@ -17,7 +17,6 @@ dependencies = [
     "pymupdf",
     "pymupdf",
     "tqdm",
     "tqdm",
     "tenacity",
     "tenacity",
-    "doclayout-yolo",
     "numpy",
     "numpy",
     "ollama",
     "ollama",
     "deepl<1.19.1",
     "deepl<1.19.1",
@@ -25,10 +24,16 @@ dependencies = [
     "azure-ai-translation-text<=1.0.1",
     "azure-ai-translation-text<=1.0.1",
     "gradio",
     "gradio",
     "huggingface_hub",
     "huggingface_hub",
-    "torch",
+    "onnx",
+    "onnxruntime",
+    "opencv-python-headless",
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]
+torch = [
+    "doclayout-yolo",
+    "torch",
+]
 dev = [
 dev = [
     "black",
     "black",
     "flake8",
     "flake8",