Explorar o código

Merge pull request #240 from charles7668/feat/cancel-support

Add cancel support for gui
Byaidu hai 1 ano
pai
achega
9c203c0050
Modificáronse 2 ficheiros con 37 adicións e 2 borrados
  1. 30 1
      pdf2zh/gui.py
  2. 7 1
      pdf2zh/high_level.py

+ 30 - 1
pdf2zh/gui.py

@@ -1,5 +1,8 @@
 import os
 import shutil
+import uuid
+import asyncio
+from asyncio import CancelledError
 from pathlib import Path
 from pdf2zh import __version__
 from pdf2zh.high_level import translate
@@ -102,6 +105,12 @@ def download_with_limit(url, save_path, size_limit):
                 file.write(chunk)
     return save_path / filename
 
+def stop_translate_file(state):
+    session_id = state["session_id"]
+    if session_id is None:
+        return
+    if session_id in cancellation_event_map:
+        cancellation_event_map[session_id].set()
 
 def translate_file(
     file_type,
@@ -112,9 +121,13 @@ def translate_file(
     lang_to,
     page_range,
     recaptcha_response,
+    state,
     progress=gr.Progress(),
     *envs,
 ):
+    session_id = uuid.uuid4()
+    state["session_id"] = session_id
+    cancellation_event_map[session_id] = asyncio.Event()
     """Translate PDF content using selected service."""
     if flag_demo and not verify_recaptcha(recaptcha_response):
         raise gr.Error("reCAPTCHA fail")
@@ -164,9 +177,14 @@ def translate_file(
         "output": output,
         "thread": 4,
         "callback": progress_bar,
+        "cancellation_event": cancellation_event_map[session_id],
     }
     print(param)
-    translate(**param)
+    try:
+        translate(**param)
+    except CancelledError as e:
+        del cancellation_event_map[session_id]
+        raise gr.Error("Translation cancelled")
     print(f"Files after translation: {os.listdir(output)}")
 
     if not file_mono.exists() or not file_dual.exists():
@@ -199,6 +217,8 @@ custom_blue = gr.themes.Color(
     c950="#020B33",
 )
 
+cancellation_event_map = {}
+
 with gr.Blocks(
     title="PDFMathTranslate - PDF Translation with preserved formats",
     theme=gr.themes.Default(
@@ -322,6 +342,7 @@ with gr.Blocks(
             )
             recaptcha_box = gr.HTML('<div id="recaptcha-box"></div>')
             translate_btn = gr.Button("Translate", variant="primary")
+            cancellation_btn = gr.Button("Cancel", variant="secondary")
             tech_details_tog = gr.Markdown(
                 f"""
                     <summary>Technical details</summary>
@@ -383,6 +404,8 @@ with gr.Blocks(
         ),
     )
 
+    state = gr.State({"session_id": None})
+
     translate_btn.click(
         translate_file,
         inputs=[
@@ -394,6 +417,7 @@ with gr.Blocks(
             lang_to,
             page_range,
             recaptcha_response,
+            state,
             *envs,
         ],
         outputs=[
@@ -406,6 +430,11 @@ with gr.Blocks(
         ],
     ).then(lambda: None, js="()=>{grecaptcha.reset()}" if flag_demo else "")
 
+    cancellation_btn.click(
+        stop_translate_file,
+        inputs=[state],
+    )
+
 
 def setup_gui(share=False):
     if flag_demo:

+ 7 - 1
pdf2zh/high_level.py

@@ -1,5 +1,6 @@
 """Functions that can be used for the most common use-cases for pdf2zh.six"""
-
+import asyncio
+from asyncio import CancelledError
 from typing import BinaryIO
 import numpy as np
 import tqdm
@@ -84,6 +85,7 @@ def translate_patch(
     resfont: str = "",
     noto: Font = None,
     callback: object = None,
+    cancellation_event : asyncio.Event = None,
     **kwarg: Any,
 ) -> None:
     rsrcmgr = PDFResourceManager()
@@ -104,6 +106,8 @@ def translate_patch(
     doc = PDFDocument(parser)
     with tqdm.tqdm(total=total_pages) as progress:
         for pageno, page in enumerate(PDFPage.create_pages(doc)):
+            if cancellation_event and cancellation_event.is_set():
+                raise CancelledError("task cancelled")
             if pages and (pageno not in pages):
                 continue
             progress.update()
@@ -161,6 +165,7 @@ def translate_stream(
     vfont: str = "",
     vchar: str = "",
     callback: object = None,
+    cancellation_event: asyncio.Event = None,
     **kwarg: Any,
 ):
     font_list = [("tiro", None)]
@@ -237,6 +242,7 @@ def translate(
     vfont: str = "",
     vchar: str = "",
     callback: object = None,
+    cancellation_event: asyncio.Event = None,
     **kwarg: Any,
 ):
     if not files: