Sfoglia il codice sorgente

只对智谱检查1301的问题。

hellofinch 1 anno fa
parent
commit
14c2508bf1
1 ha cambiato i file con 23 aggiunte e 9 eliminazioni
  1. 23 9
      pdf2zh/translator.py

+ 23 - 9
pdf2zh/translator.py

@@ -15,6 +15,8 @@ from tencentcloud.tmt.v20180321.tmt_client import TmtClient
 from tencentcloud.tmt.v20180321.models import TextTranslateRequest
 from tencentcloud.tmt.v20180321.models import TextTranslateResponse
 
+import json
+
 
 def remove_control_characters(s):
     return "".join(ch for ch in s if unicodedata.category(ch)[0] != "C")
@@ -212,15 +214,11 @@ class OpenAITranslator(BaseTranslator):
         self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
 
     def translate(self, text) -> str:
-        try:
-            response = self.client.chat.completions.create(
-                model=self.model,
-                **self.options,
-                messages=self.prompt(text),
-            )
-        except openai.BadRequestError:
-            print("400 API BadRequestError")
-            return ""
+        response = self.client.chat.completions.create(
+            model=self.model,
+            **self.options,
+            messages=self.prompt(text),
+        )
         return response.choices[0].message.content.strip()
 
 
@@ -271,6 +269,22 @@ class ZhipuTranslator(OpenAITranslator):
             model = os.getenv("ZHIPU_MODEL", self.envs["ZHIPU_MODEL"])
         super().__init__(lang_in, lang_out, model, base_url=base_url, api_key=api_key)
 
+    def translate(self, text) -> str:
+        try:
+            response = self.client.chat.completions.create(
+                model=self.model,
+                **self.options,
+                messages=self.prompt(text),
+            )
+        except openai.APIError:
+            if (
+                json.loads(response.choices[0].message.content.strip())["error"]["code"]
+                == "1301"
+            ):
+                return ""
+            print("openai api error.")
+        return response.choices[0].message.content.strip()
+
 
 class SiliconTranslator(OpenAITranslator):
     # https://docs.siliconflow.cn/quickstart