Browse Source

remove: torch

Byaidu 1 year ago
parent
commit
d1f3b0e9c3
5 changed files with 19 additions and 81 deletions
  1. 3 0
      pdf2zh/__init__.py
  2. 4 54
      pdf2zh/doclayout.py
  3. 1 17
      pdf2zh/high_level.py
  4. 11 6
      pdf2zh/pdf2zh.py
  5. 0 4
      pyproject.toml

+ 3 - 0
pdf2zh/__init__.py

@@ -1,2 +1,5 @@
+import logging
+log = logging.getLogger(__name__)
+
 __version__ = "1.8.1"
 __author__ = "Byaidu"

+ 4 - 54
pdf2zh/doclayout.py

@@ -1,19 +1,13 @@
 import abc
 import cv2
 import numpy as np
-import contextlib
+import ast
+import onnx
+import onnxruntime
 from huggingface_hub import hf_hub_download
 
 
 class DocLayoutModel(abc.ABC):
-    @staticmethod
-    def load_torch():
-        model = TorchModel.from_pretrained(
-            repo_id="juliozhao/DocLayout-YOLO-DocStructBench",
-            filename="doclayout_yolo_docstructbench_imgsz1024.pt",
-        )
-        return model
-
     @staticmethod
     def load_onnx():
         model = OnnxModel.from_pretrained(
@@ -24,15 +18,7 @@ class DocLayoutModel(abc.ABC):
 
     @staticmethod
     def load_available():
-        with contextlib.suppress(ImportError):
-            return DocLayoutModel.load_torch()
-
-        with contextlib.suppress(ImportError):
-            return DocLayoutModel.load_onnx()
-
-        raise ImportError(
-            "Please install the `torch` or `onnx` feature to use the DocLayout model."
-        )
+        return DocLayoutModel.load_onnx()
 
     @property
     @abc.abstractmethod
@@ -53,31 +39,6 @@ class DocLayoutModel(abc.ABC):
         pass
 
 
-class TorchModel(DocLayoutModel):
-    def __init__(self, model_path: str):
-        try:
-            import doclayout_yolo
-        except ImportError:
-            raise ImportError(
-                "Please install the `torch` feature to use the Torch model."
-            )
-
-        self.model_path = model_path
-        self.model = doclayout_yolo.YOLOv10(model_path)
-
-    @staticmethod
-    def from_pretrained(repo_id: str, filename: str):
-        pth = hf_hub_download(repo_id=repo_id, filename=filename)
-        return TorchModel(pth)
-
-    @property
-    def stride(self):
-        return 32
-
-    def predict(self, *args, **kwargs):
-        return self.model.predict(*args, **kwargs)
-
-
 class YoloResult:
     """Helper class to store detection results from ONNX model."""
 
@@ -98,17 +59,6 @@ class YoloBox:
 
 class OnnxModel(DocLayoutModel):
     def __init__(self, model_path: str):
-        import ast
-
-        try:
-
-            import onnx
-            import onnxruntime
-        except ImportError:
-            raise ImportError(
-                "Please install the `onnx` feature to use the ONNX model."
-            )
-
         self.model_path = model_path
 
         model = onnx.load(model_path)

+ 1 - 17
pdf2zh/high_level.py

@@ -13,19 +13,6 @@ from pdf2zh.converter import TranslateConverter
 from pdf2zh.pdfinterp import PDFPageInterpreterEx
 
 
-def get_device():
-    """Get the device to use for computation."""
-    try:
-        import torch
-
-        if torch.cuda.is_available():
-            return "cuda:0"
-    except ImportError:
-        pass
-
-    return "cpu"
-
-
 def extract_text_to_fp(
     inf: BinaryIO,
     pages=None,
@@ -43,9 +30,6 @@ def extract_text_to_fp(
     callback: object = None,
     **kwarg,
 ) -> None:
-    if debug:
-        logging.getLogger().setLevel(logging.DEBUG)
-
     rsrcmgr = PDFResourceManager()
     layout = {}
     device = TranslateConverter(
@@ -77,7 +61,7 @@ def extract_text_to_fp(
                 pix.height, pix.width, 3
             )[:, :, ::-1]
             page_layout = model.predict(
-                image, imgsz=int(pix.height / 32) * 32, device=get_device()
+                image, imgsz=int(pix.height / 32) * 32
             )[0]
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             box = np.ones((pix.height, pix.width))

+ 11 - 6
pdf2zh/pdf2zh.py

@@ -8,6 +8,7 @@ from __future__ import annotations
 import argparse
 import os
 import sys
+import logging
 from pathlib import Path
 from typing import Any, Container, Iterable, List, Optional
 from pdfminer.pdfexceptions import PDFValueError
@@ -15,7 +16,13 @@ from pdfminer.pdfexceptions import PDFValueError
 import pymupdf
 import requests
 
-from pdf2zh import __version__
+from pdf2zh import __version__, log
+from pdf2zh.high_level import extract_text_to_fp
+from pdf2zh.doclayout import DocLayoutModel
+
+logging.basicConfig()
+
+model = DocLayoutModel.load_available()
 
 
 def check_files(files: List[str]) -> List[str]:
@@ -44,14 +51,12 @@ def extract_text(
     output: str = "",
     **kwargs: Any,
 ):
-    import pdf2zh.high_level
-    from pdf2zh.doclayout import DocLayoutModel
+    if debug:
+        log.setLevel(logging.DEBUG)
 
     if not files:
         raise PDFValueError("Must provide files to work upon!")
 
-    model = DocLayoutModel.load_available()
-
     for file in files:
         if file is str and (file.startswith("http://") or file.startswith("https://")):
             print("Online files detected, downloading...")
@@ -99,7 +104,7 @@ def extract_text(
         doc_en.save(Path(output) / f"{filename}-en.pdf")
 
         with open(Path(output) / f"{filename}-en.pdf", "rb") as fp:
-            obj_patch: dict = pdf2zh.high_level.extract_text_to_fp(fp, **locals())
+            obj_patch: dict = extract_text_to_fp(fp, model=model, **locals())
 
         for obj_id, ops_new in obj_patch.items():
             # ops_old=doc_en.xref_stream(obj_id)

+ 0 - 4
pyproject.toml

@@ -29,10 +29,6 @@ dependencies = [
 ]
 
 [project.optional-dependencies]
-torch = [
-    "doclayout-yolo",
-    "torch",
-]
 dev = [
     "black",
     "flake8",