|
|
@@ -1,19 +1,13 @@
|
|
|
import abc
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
-import contextlib
|
|
|
+import ast
|
|
|
+import onnx
|
|
|
+import onnxruntime
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
class DocLayoutModel(abc.ABC):
|
|
|
- @staticmethod
|
|
|
- def load_torch():
|
|
|
- model = TorchModel.from_pretrained(
|
|
|
- repo_id="juliozhao/DocLayout-YOLO-DocStructBench",
|
|
|
- filename="doclayout_yolo_docstructbench_imgsz1024.pt",
|
|
|
- )
|
|
|
- return model
|
|
|
-
|
|
|
@staticmethod
|
|
|
def load_onnx():
|
|
|
model = OnnxModel.from_pretrained(
|
|
|
@@ -24,15 +18,7 @@ class DocLayoutModel(abc.ABC):
|
|
|
|
|
|
@staticmethod
|
|
|
def load_available():
|
|
|
- with contextlib.suppress(ImportError):
|
|
|
- return DocLayoutModel.load_torch()
|
|
|
-
|
|
|
- with contextlib.suppress(ImportError):
|
|
|
- return DocLayoutModel.load_onnx()
|
|
|
-
|
|
|
- raise ImportError(
|
|
|
- "Please install the `torch` or `onnx` feature to use the DocLayout model."
|
|
|
- )
|
|
|
+ return DocLayoutModel.load_onnx()
|
|
|
|
|
|
@property
|
|
|
@abc.abstractmethod
|
|
|
@@ -53,31 +39,6 @@ class DocLayoutModel(abc.ABC):
|
|
|
pass
|
|
|
|
|
|
|
|
|
-class TorchModel(DocLayoutModel):
|
|
|
- def __init__(self, model_path: str):
|
|
|
- try:
|
|
|
- import doclayout_yolo
|
|
|
- except ImportError:
|
|
|
- raise ImportError(
|
|
|
- "Please install the `torch` feature to use the Torch model."
|
|
|
- )
|
|
|
-
|
|
|
- self.model_path = model_path
|
|
|
- self.model = doclayout_yolo.YOLOv10(model_path)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def from_pretrained(repo_id: str, filename: str):
|
|
|
- pth = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
|
- return TorchModel(pth)
|
|
|
-
|
|
|
- @property
|
|
|
- def stride(self):
|
|
|
- return 32
|
|
|
-
|
|
|
- def predict(self, *args, **kwargs):
|
|
|
- return self.model.predict(*args, **kwargs)
|
|
|
-
|
|
|
-
|
|
|
class YoloResult:
|
|
|
"""Helper class to store detection results from ONNX model."""
|
|
|
|
|
|
@@ -98,17 +59,6 @@ class YoloBox:
|
|
|
|
|
|
class OnnxModel(DocLayoutModel):
|
|
|
def __init__(self, model_path: str):
|
|
|
- import ast
|
|
|
-
|
|
|
- try:
|
|
|
-
|
|
|
- import onnx
|
|
|
- import onnxruntime
|
|
|
- except ImportError:
|
|
|
- raise ImportError(
|
|
|
- "Please install the `onnx` feature to use the ONNX model."
|
|
|
- )
|
|
|
-
|
|
|
self.model_path = model_path
|
|
|
|
|
|
model = onnx.load(model_path)
|