Sfoglia il codice sorgente

Merge pull request #397 from timelic/fix/circle-import-model

Byaidu 1 anno fa
parent
commit
3c2b7e3d32
4 ha cambiato i file con 15 aggiunte e 13 eliminazioni
  1. 2 2
      pdf2zh/backend.py
  2. 5 0
      pdf2zh/doclayout.py
  3. 2 2
      pdf2zh/gui.py
  4. 6 9
      pdf2zh/pdf2zh.py

+ 2 - 2
pdf2zh/backend.py

@@ -6,7 +6,7 @@ from pdf2zh import translate_stream
 import tqdm
 import json
 import io
-from pdf2zh.pdf2zh import model
+from pdf2zh.doclayout import ModelInstance
 
 flask_app = Flask("pdf2zh")
 flask_app.config.from_mapping(
@@ -48,7 +48,7 @@ def translate_task(
     doc_mono, doc_dual = translate_stream(
         stream,
         callback=progress_bar,
-        model=model,
+        model=ModelInstance.value,
         **args,
     )
     return doc_mono, doc_dual

+ 5 - 0
pdf2zh/doclayout.py

@@ -60,6 +60,7 @@ class YoloBox:
 
 
 class OnnxModel(DocLayoutModel):
+
     def __init__(self, model_path: str):
         self.model_path = model_path
 
@@ -173,3 +174,7 @@ class OnnxModel(DocLayoutModel):
             (new_h, new_w), preds[..., :4], (orig_h, orig_w)
         )
         return [YoloResult(boxes=preds, names=self._names)]
+
+
+class ModelInstance:
+    value: OnnxModel = None

+ 2 - 2
pdf2zh/gui.py

@@ -13,7 +13,7 @@ from gradio_pdf import PDF
 
 from pdf2zh import __version__
 from pdf2zh.high_level import translate
-from pdf2zh.pdf2zh import model
+from pdf2zh.doclayout import ModelInstance
 from pdf2zh.translator import (
     AnythingLLMTranslator,
     AzureOpenAITranslator,
@@ -274,7 +274,7 @@ def translate_file(
         "cancellation_event": cancellation_event_map[session_id],
         "envs": _envs,
         "prompt": prompt,
-        "model": model,
+        "model": ModelInstance.value,
     }
     try:
         translate(**param)

+ 6 - 9
pdf2zh/pdf2zh.py

@@ -13,7 +13,7 @@ from typing import List, Optional
 
 from pdf2zh import __version__, log
 from pdf2zh.high_level import translate
-from pdf2zh.doclayout import OnnxModel
+from pdf2zh.doclayout import OnnxModel, ModelInstance
 import os
 
 
@@ -199,9 +199,6 @@ def find_all_files_in_directory(directory_path):
     return file_paths
 
 
-model = None
-
-
 def main(args: Optional[List[str]] = None) -> int:
     logging.basicConfig()
 
@@ -209,11 +206,11 @@ def main(args: Optional[List[str]] = None) -> int:
 
     if parsed_args.debug:
         log.setLevel(logging.DEBUG)
-    global model
+
     if parsed_args.onnx:
-        model = OnnxModel(parsed_args.onnx)
+        ModelInstance.value = OnnxModel(parsed_args.onnx)
     else:
-        model = OnnxModel.load_available()
+        ModelInstance.value = OnnxModel.load_available()
 
     if parsed_args.interactive:
         from pdf2zh.gui import setup_gui
@@ -250,10 +247,10 @@ def main(args: Optional[List[str]] = None) -> int:
         untranlate_file = find_all_files_in_directory(parsed_args.files[0])
         parsed_args.files = untranlate_file
         print(parsed_args)
-        translate(model=model, **vars(parsed_args))
+        translate(model=ModelInstance.value, **vars(parsed_args))
         return 0
     # print(parsed_args)
-    translate(model=model, **vars(parsed_args))
+    translate(model=ModelInstance.value, **vars(parsed_args))
     return 0