浏览代码

remove: torch

Byaidu 1 年之前
父节点
当前提交
d1f3b0e9c3
共有 5 个文件被更改,包括 19 次插入81 次删除
  1. 3 0
      pdf2zh/__init__.py
  2. 4 54
      pdf2zh/doclayout.py
  3. 1 17
      pdf2zh/high_level.py
  4. 11 6
      pdf2zh/pdf2zh.py
  5. 0 4
      pyproject.toml

+ 3 - 0
pdf2zh/__init__.py

@@ -1,2 +1,5 @@
+import logging
+log = logging.getLogger(__name__)
+
 __version__ = "1.8.1"
 __version__ = "1.8.1"
 __author__ = "Byaidu"
 __author__ = "Byaidu"

+ 4 - 54
pdf2zh/doclayout.py

@@ -1,19 +1,13 @@
 import abc
 import abc
 import cv2
 import cv2
 import numpy as np
 import numpy as np
-import contextlib
+import ast
+import onnx
+import onnxruntime
 from huggingface_hub import hf_hub_download
 from huggingface_hub import hf_hub_download
 
 
 
 
 class DocLayoutModel(abc.ABC):
 class DocLayoutModel(abc.ABC):
-    @staticmethod
-    def load_torch():
-        model = TorchModel.from_pretrained(
-            repo_id="juliozhao/DocLayout-YOLO-DocStructBench",
-            filename="doclayout_yolo_docstructbench_imgsz1024.pt",
-        )
-        return model
-
     @staticmethod
     @staticmethod
     def load_onnx():
     def load_onnx():
         model = OnnxModel.from_pretrained(
         model = OnnxModel.from_pretrained(
@@ -24,15 +18,7 @@ class DocLayoutModel(abc.ABC):
 
 
     @staticmethod
     @staticmethod
     def load_available():
     def load_available():
-        with contextlib.suppress(ImportError):
-            return DocLayoutModel.load_torch()
-
-        with contextlib.suppress(ImportError):
-            return DocLayoutModel.load_onnx()
-
-        raise ImportError(
-            "Please install the `torch` or `onnx` feature to use the DocLayout model."
-        )
+        return DocLayoutModel.load_onnx()
 
 
     @property
     @property
     @abc.abstractmethod
     @abc.abstractmethod
@@ -53,31 +39,6 @@ class DocLayoutModel(abc.ABC):
         pass
         pass
 
 
 
 
-class TorchModel(DocLayoutModel):
-    def __init__(self, model_path: str):
-        try:
-            import doclayout_yolo
-        except ImportError:
-            raise ImportError(
-                "Please install the `torch` feature to use the Torch model."
-            )
-
-        self.model_path = model_path
-        self.model = doclayout_yolo.YOLOv10(model_path)
-
-    @staticmethod
-    def from_pretrained(repo_id: str, filename: str):
-        pth = hf_hub_download(repo_id=repo_id, filename=filename)
-        return TorchModel(pth)
-
-    @property
-    def stride(self):
-        return 32
-
-    def predict(self, *args, **kwargs):
-        return self.model.predict(*args, **kwargs)
-
-
 class YoloResult:
 class YoloResult:
     """Helper class to store detection results from ONNX model."""
     """Helper class to store detection results from ONNX model."""
 
 
@@ -98,17 +59,6 @@ class YoloBox:
 
 
 class OnnxModel(DocLayoutModel):
 class OnnxModel(DocLayoutModel):
     def __init__(self, model_path: str):
     def __init__(self, model_path: str):
-        import ast
-
-        try:
-
-            import onnx
-            import onnxruntime
-        except ImportError:
-            raise ImportError(
-                "Please install the `onnx` feature to use the ONNX model."
-            )
-
         self.model_path = model_path
         self.model_path = model_path
 
 
         model = onnx.load(model_path)
         model = onnx.load(model_path)

+ 1 - 17
pdf2zh/high_level.py

@@ -13,19 +13,6 @@ from pdf2zh.converter import TranslateConverter
 from pdf2zh.pdfinterp import PDFPageInterpreterEx
 from pdf2zh.pdfinterp import PDFPageInterpreterEx
 
 
 
 
-def get_device():
-    """Get the device to use for computation."""
-    try:
-        import torch
-
-        if torch.cuda.is_available():
-            return "cuda:0"
-    except ImportError:
-        pass
-
-    return "cpu"
-
-
 def extract_text_to_fp(
 def extract_text_to_fp(
     inf: BinaryIO,
     inf: BinaryIO,
     pages=None,
     pages=None,
@@ -43,9 +30,6 @@ def extract_text_to_fp(
     callback: object = None,
     callback: object = None,
     **kwarg,
     **kwarg,
 ) -> None:
 ) -> None:
-    if debug:
-        logging.getLogger().setLevel(logging.DEBUG)
-
     rsrcmgr = PDFResourceManager()
     rsrcmgr = PDFResourceManager()
     layout = {}
     layout = {}
     device = TranslateConverter(
     device = TranslateConverter(
@@ -77,7 +61,7 @@ def extract_text_to_fp(
                 pix.height, pix.width, 3
                 pix.height, pix.width, 3
             )[:, :, ::-1]
             )[:, :, ::-1]
             page_layout = model.predict(
             page_layout = model.predict(
-                image, imgsz=int(pix.height / 32) * 32, device=get_device()
+                image, imgsz=int(pix.height / 32) * 32
             )[0]
             )[0]
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             box = np.ones((pix.height, pix.width))
             box = np.ones((pix.height, pix.width))

+ 11 - 6
pdf2zh/pdf2zh.py

@@ -8,6 +8,7 @@ from __future__ import annotations
 import argparse
 import argparse
 import os
 import os
 import sys
 import sys
+import logging
 from pathlib import Path
 from pathlib import Path
 from typing import Any, Container, Iterable, List, Optional
 from typing import Any, Container, Iterable, List, Optional
 from pdfminer.pdfexceptions import PDFValueError
 from pdfminer.pdfexceptions import PDFValueError
@@ -15,7 +16,13 @@ from pdfminer.pdfexceptions import PDFValueError
 import pymupdf
 import pymupdf
 import requests
 import requests
 
 
-from pdf2zh import __version__
+from pdf2zh import __version__, log
+from pdf2zh.high_level import extract_text_to_fp
+from pdf2zh.doclayout import DocLayoutModel
+
+logging.basicConfig()
+
+model = DocLayoutModel.load_available()
 
 
 
 
 def check_files(files: List[str]) -> List[str]:
 def check_files(files: List[str]) -> List[str]:
@@ -44,14 +51,12 @@ def extract_text(
     output: str = "",
     output: str = "",
     **kwargs: Any,
     **kwargs: Any,
 ):
 ):
-    import pdf2zh.high_level
-    from pdf2zh.doclayout import DocLayoutModel
+    if debug:
+        log.setLevel(logging.DEBUG)
 
 
     if not files:
     if not files:
         raise PDFValueError("Must provide files to work upon!")
         raise PDFValueError("Must provide files to work upon!")
 
 
-    model = DocLayoutModel.load_available()
-
     for file in files:
     for file in files:
         if file is str and (file.startswith("http://") or file.startswith("https://")):
         if file is str and (file.startswith("http://") or file.startswith("https://")):
             print("Online files detected, downloading...")
             print("Online files detected, downloading...")
@@ -99,7 +104,7 @@ def extract_text(
         doc_en.save(Path(output) / f"{filename}-en.pdf")
         doc_en.save(Path(output) / f"{filename}-en.pdf")
 
 
         with open(Path(output) / f"{filename}-en.pdf", "rb") as fp:
         with open(Path(output) / f"{filename}-en.pdf", "rb") as fp:
-            obj_patch: dict = pdf2zh.high_level.extract_text_to_fp(fp, **locals())
+            obj_patch: dict = extract_text_to_fp(fp, model=model, **locals())
 
 
         for obj_id, ops_new in obj_patch.items():
         for obj_id, ops_new in obj_patch.items():
             # ops_old=doc_en.xref_stream(obj_id)
             # ops_old=doc_en.xref_stream(obj_id)

+ 0 - 4
pyproject.toml

@@ -29,10 +29,6 @@ dependencies = [
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]
-torch = [
-    "doclayout-yolo",
-    "torch",
-]
 dev = [
 dev = [
     "black",
     "black",
     "flake8",
     "flake8",