Byaidu 1 год назад
Родитель
Сommit
0b7a789f5f
3 измененных файлов с 20 добавлено и 14 удалено
  1. 3 1
      pdf2zh/gui.py
  2. 2 1
      pdf2zh/high_level.py
  3. 15 12
      pdf2zh/translator.py

+ 3 - 1
pdf2zh/gui.py

@@ -105,6 +105,7 @@ def download_with_limit(url, save_path, size_limit):
                 file.write(chunk)
     return save_path / filename
 
+
 def stop_translate_file(state):
     session_id = state["session_id"]
     if session_id is None:
@@ -112,6 +113,7 @@ def stop_translate_file(state):
     if session_id in cancellation_event_map:
         cancellation_event_map[session_id].set()
 
+
 def translate_file(
     file_type,
     file_input,
@@ -182,7 +184,7 @@ def translate_file(
     print(param)
     try:
         translate(**param)
-    except CancelledError as e:
+    except CancelledError:
         del cancellation_event_map[session_id]
         raise gr.Error("Translation cancelled")
     print(f"Files after translation: {os.listdir(output)}")

+ 2 - 1
pdf2zh/high_level.py

@@ -1,4 +1,5 @@
 """Functions that can be used for the most common use-cases for pdf2zh.six"""
+
 import asyncio
 from asyncio import CancelledError
 from typing import BinaryIO
@@ -85,7 +86,7 @@ def translate_patch(
     resfont: str = "",
     noto: Font = None,
     callback: object = None,
-    cancellation_event : asyncio.Event = None,
+    cancellation_event: asyncio.Event = None,
     **kwarg: Any,
 ) -> None:
     rsrcmgr = PDFResourceManager()

+ 15 - 12
pdf2zh/translator.py

@@ -274,13 +274,13 @@ class ZhipuTranslator(OpenAITranslator):
                 **self.options,
                 messages=self.prompt(text),
             )
-        except openai.APIError:
+        except openai.BadRequestError as e:
             if (
                 json.loads(response.choices[0].message.content.strip())["error"]["code"]
                 == "1301"
             ):
-                return ""
-            raise ValueError("openai api error.")
+                return "IRREPARABLE TRANSLATION ERROR"
+            raise e
         return response.choices[0].message.content.strip()
 
 
@@ -368,6 +368,8 @@ class TencentTranslator(BaseTranslator):
         self.req.SourceText = text
         resp: TextTranslateResponse = self.client.TextTranslate(self.req)
         return resp.TargetText
+
+
 class AnythingLLMTranslator(BaseTranslator):
     name = "anythingllm"
     envs = {
@@ -393,18 +395,21 @@ class AnythingLLMTranslator(BaseTranslator):
             "sessionId": "translation_expert",
         }
 
-        response = requests.post(self.api_url, headers=self.headers, data=json.dumps(payload))
+        response = requests.post(
+            self.api_url, headers=self.headers, data=json.dumps(payload)
+        )
         response.raise_for_status()
         data = response.json()
 
         if "textResponse" in data:
             return data["textResponse"].strip()
 
+
 class DifyTranslator(BaseTranslator):
     name = "dify"
     envs = {
         "DIFY_API_URL": None,  # 填写实际 Dify API 地址
-        "DIFY_API_KEY": "api_key"  # 替换为实际 API 密钥
+        "DIFY_API_KEY": "api_key",  # 替换为实际 API 密钥
     }
 
     def __init__(self, lang_out, lang_in, model):
@@ -415,27 +420,25 @@ class DifyTranslator(BaseTranslator):
     def translate(self, text):
         headers = {
             "Authorization": f"Bearer {self.api_key}",
-            "Content-Type": "application/json"
+            "Content-Type": "application/json",
         }
 
         payload = {
             "inputs": {
                 "lang_out": self.lang_out,
                 "lang_in": self.lang_in,
-                "text": text
+                "text": text,
             },
             "response_mode": "blocking",
-            "user": "translator-service"
+            "user": "translator-service",
         }
 
         # 向 Dify 服务器发送请求
         response = requests.post(
-            self.api_url,
-            headers=headers,
-            data=json.dumps(payload)
+            self.api_url, headers=headers, data=json.dumps(payload)
         )
         response.raise_for_status()
         response_data = response.json()
 
         # 解析响应
-        return response_data.get('data', {}).get('outputs', {}).get('text', [])
+        return response_data.get("data", {}).get("outputs", {}).get("text", [])