Browse Source

feat!: make torch an optional dependency

忘忧北萱草 1 year ago
parent
commit
679a4b25dd
4 changed files with 26 additions and 11 deletions
  1. 3 1
      pdf2zh/doclayout.py
  2. 2 7
      pdf2zh/high_level.py
  3. 13 0
      pdf2zh/utils.py
  4. 8 3
      pyproject.toml

+ 3 - 1
pdf2zh/doclayout.py

@@ -206,5 +206,7 @@ class OnnxModel(DocLayoutModel):
 
         # Postprocess predictions
         preds = preds[preds[..., 4] > 0.25]
-        preds[..., :4] = self.scale_boxes((new_h, new_w), preds[..., :4], (orig_h, orig_w))
+        preds[..., :4] = self.scale_boxes(
+            (new_h, new_w), preds[..., :4], (orig_h, orig_w)
+        )
         return [YoloResult(boxes=preds, names=self._names)]

+ 2 - 7
pdf2zh/high_level.py

@@ -4,7 +4,6 @@ import logging
 import sys
 from io import StringIO
 from typing import Any, BinaryIO, Container, Iterator, Optional, cast
-import torch
 import numpy as np
 import tqdm
 from pymupdf import Document
@@ -22,7 +21,7 @@ from pdf2zh.pdfdevice import PDFDevice, TagExtractor
 from pdf2zh.pdfexceptions import PDFValueError
 from pdf2zh.pdfinterp import PDFPageInterpreter, PDFResourceManager
 from pdf2zh.pdfpage import PDFPage
-from pdf2zh.utils import AnyIO, FileOrName, open_filename
+from pdf2zh.utils import AnyIO, FileOrName, open_filename, get_device
 
 
 def extract_text_to_fp(
@@ -176,11 +175,7 @@ def extract_text_to_fp(
                 pix.height, pix.width, 3
             )[:, :, ::-1]
             page_layout = model.predict(
-                image,
-                imgsz=int(pix.height / 32) * 32,
-                device=(
-                    "cuda:0" if torch.cuda.is_available() else "cpu"
-                ),  # Auto-select GPU if available
+                image, imgsz=int(pix.height / 32) * 32, device=get_device()
             )[0]
             # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间
             box = np.ones((pix.height, pix.width))

+ 13 - 0
pdf2zh/utils.py

@@ -819,3 +819,16 @@ def format_int_alpha(value: int) -> str:
 
     result.reverse()
     return "".join(result)
+
+
+def get_device():
+    """Get the device to use for computation."""
+    try:
+        import torch
+
+        if torch.cuda.is_available():
+            return "cuda:0"
+    except ImportError:
+        pass
+
+    return "cpu"

+ 8 - 3
pyproject.toml

@@ -5,7 +5,7 @@ description = "Latex PDF Translator"
 authors = [{ name = "Byaidu", email = "byaidux@gmail.com" }]
 license = "AGPL-3.0"
 readme = "README.md"
-requires-python = ">=3.8,<3.13"
+requires-python = ">=3.9,<3.13"
 classifiers = [
     "Programming Language :: Python :: 3",
     "Operating System :: OS Independent",
@@ -17,7 +17,6 @@ dependencies = [
     "pymupdf",
     "tqdm",
     "tenacity",
-    "doclayout-yolo",
     "numpy",
     "ollama",
     "deepl<1.19.1",
@@ -25,10 +24,16 @@ dependencies = [
     "azure-ai-translation-text<=1.0.1",
     "gradio",
     "huggingface_hub",
-    "torch",
+    "onnx",
+    "onnxruntime",
+    "opencv-python-headless",
 ]
 
 [project.optional-dependencies]
+torch = [
+    "doclayout-yolo",
+    "torch",
+]
 dev = [
     "black",
     "flake8",