浏览代码

feat!: make torch an optional dependency

忘忧北萱草 1 年之前
父节点
当前提交
679a4b25dd
共有 4 个文件被更改,包括 26 次插入11 次删除
  1. 3 1
      pdf2zh/doclayout.py
  2. 2 7
      pdf2zh/high_level.py
  3. 13 0
      pdf2zh/utils.py
  4. 8 3
      pyproject.toml

+ 3 - 1
pdf2zh/doclayout.py

@@ -206,5 +206,7 @@ class OnnxModel(DocLayoutModel):
 
 
         # Postprocess predictions
         # Postprocess predictions
         preds = preds[preds[..., 4] > 0.25]
         preds = preds[preds[..., 4] > 0.25]
-        preds[..., :4] = self.scale_boxes((new_h, new_w), preds[..., :4], (orig_h, orig_w))
+        preds[..., :4] = self.scale_boxes(
+            (new_h, new_w), preds[..., :4], (orig_h, orig_w)
+        )
         return [YoloResult(boxes=preds, names=self._names)]
         return [YoloResult(boxes=preds, names=self._names)]

+ 2 - 7
pdf2zh/high_level.py

@@ -4,7 +4,6 @@ import logging
 import sys
 import sys
 from io import StringIO
 from io import StringIO
 from typing import Any, BinaryIO, Container, Iterator, Optional, cast
 from typing import Any, BinaryIO, Container, Iterator, Optional, cast
-import torch
 import numpy as np
 import numpy as np
 import tqdm
 import tqdm
 from pymupdf import Document
 from pymupdf import Document
@@ -22,7 +21,7 @@ from pdf2zh.pdfdevice import PDFDevice, TagExtractor
 from pdf2zh.pdfexceptions import PDFValueError
 from pdf2zh.pdfexceptions import PDFValueError
 from pdf2zh.pdfinterp import PDFPageInterpreter, PDFResourceManager
 from pdf2zh.pdfinterp import PDFPageInterpreter, PDFResourceManager
 from pdf2zh.pdfpage import PDFPage
 from pdf2zh.pdfpage import PDFPage
-from pdf2zh.utils import AnyIO, FileOrName, open_filename
+from pdf2zh.utils import AnyIO, FileOrName, open_filename, get_device
 
 
 
 
 def extract_text_to_fp(
 def extract_text_to_fp(
@@ -176,11 +175,7 @@ def extract_text_to_fp(
                 pix.height, pix.width, 3
                 pix.height, pix.width, 3
             )[:, :, ::-1]
             )[:, :, ::-1]
             page_layout = model.predict(
             page_layout = model.predict(
-                image,
-                imgsz=int(pix.height / 32) * 32,
-                device=(
-                    "cuda:0" if torch.cuda.is_available() else "cpu"
-                ),  # Auto-select GPU if available
+                image, imgsz=int(pix.height / 32) * 32, device=get_device()
             )[0]
             )[0]
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             box = np.ones((pix.height, pix.width))
             box = np.ones((pix.height, pix.width))

+ 13 - 0
pdf2zh/utils.py

@@ -819,3 +819,16 @@ def format_int_alpha(value: int) -> str:
 
 
     result.reverse()
     result.reverse()
     return "".join(result)
     return "".join(result)
+
+
+def get_device():
+    """Get the device to use for computation."""
+    try:
+        import torch
+
+        if torch.cuda.is_available():
+            return "cuda:0"
+    except ImportError:
+        pass
+
+    return "cpu"

+ 8 - 3
pyproject.toml

@@ -5,7 +5,7 @@ description = "Latex PDF Translator"
 authors = [{ name = "Byaidu", email = "byaidux@gmail.com" }]
 authors = [{ name = "Byaidu", email = "byaidux@gmail.com" }]
 license = "AGPL-3.0"
 license = "AGPL-3.0"
 readme = "README.md"
 readme = "README.md"
-requires-python = ">=3.8,<3.13"
+requires-python = ">=3.9,<3.13"
 classifiers = [
 classifiers = [
     "Programming Language :: Python :: 3",
     "Programming Language :: Python :: 3",
     "Operating System :: OS Independent",
     "Operating System :: OS Independent",
@@ -17,7 +17,6 @@ dependencies = [
     "pymupdf",
     "pymupdf",
     "tqdm",
     "tqdm",
     "tenacity",
     "tenacity",
-    "doclayout-yolo",
     "numpy",
     "numpy",
     "ollama",
     "ollama",
     "deepl<1.19.1",
     "deepl<1.19.1",
@@ -25,10 +24,16 @@ dependencies = [
     "azure-ai-translation-text<=1.0.1",
     "azure-ai-translation-text<=1.0.1",
     "gradio",
     "gradio",
     "huggingface_hub",
     "huggingface_hub",
-    "torch",
+    "onnx",
+    "onnxruntime",
+    "opencv-python-headless",
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]
+torch = [
+    "doclayout-yolo",
+    "torch",
+]
 dev = [
 dev = [
     "black",
     "black",
     "flake8",
     "flake8",