Browse Source

feat: add batch translation with max_tokens support in OpenAITranslator

mrh (aider) 1 year ago
parent
commit
0305b4bd5a
1 changed files with 59 additions and 3 deletions
  1. 59 3
      mylib/pdfzh_translator.py

+ 59 - 3
mylib/pdfzh_translator.py

@@ -56,11 +56,12 @@ class GoogleTranslator(BaseTranslator):
 
 
 
 
 class OpenAITranslator(BaseTranslator):
 class OpenAITranslator(BaseTranslator):
-    def __init__(self, service, lang_out, lang_in, model):
+    def __init__(self, service, lang_out, lang_in, model, max_tokens=2000):
         lang_out = "zh-CN" if lang_out == "auto" else lang_out
         lang_out = "zh-CN" if lang_out == "auto" else lang_out
         lang_in = "en" if lang_in == "auto" else lang_in
         lang_in = "en" if lang_in == "auto" else lang_in
         super().__init__(service, lang_out, lang_in, model)
         super().__init__(service, lang_out, lang_in, model)
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
         self.options = {"temperature": 0}  # 随机采样可能会打断公式标记
+        self.max_tokens = max_tokens
         # Configure OpenAI client with environment variables
         # Configure OpenAI client with environment variables
         self.client = openai.OpenAI(
         self.client = openai.OpenAI(
             api_key=os.getenv('OPENAI_API_KEY'),
             api_key=os.getenv('OPENAI_API_KEY'),
@@ -68,8 +69,13 @@ class OpenAITranslator(BaseTranslator):
         )
         )
 
 
     def translate(self, text) -> str:
     def translate(self, text) -> str:
+        if isinstance(text, list):
+            return self._batch_translate(text)
+        return self._single_translate(text)
+
+    def _single_translate(self, text) -> str:
         response = self.client.chat.completions.create(
         response = self.client.chat.completions.create(
-            model=os.getenv('LLM_MODEL', self.model),  # Use env var or fallback to default
+            model=os.getenv('LLM_MODEL', self.model),
             **self.options,
             **self.options,
             messages=[
             messages=[
                 {
                 {
@@ -78,12 +84,55 @@ class OpenAITranslator(BaseTranslator):
                 },
                 },
                 {
                 {
                     "role": "user",
                     "role": "user",
-                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",  # noqa: E501
+                    "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",
                 },
                 },
             ],
             ],
         )
         )
         return response.choices[0].message.content.strip()
         return response.choices[0].message.content.strip()
 
 
+    def _batch_translate(self, texts) -> list:
+        # 将文本列表转换为带索引的格式
+        indexed_texts = [f"{i}: {text}" for i, text in enumerate(texts)]
+        combined_text = "\n".join(indexed_texts)
+        
+        # 计算总token数并分块处理
+        total_length = len(combined_text)
+        if total_length > self.max_tokens:
+            # 如果超过最大token数,分成多个批次处理
+            batch_size = len(texts) // (total_length // self.max_tokens + 1)
+            results = []
+            for i in range(0, len(texts), batch_size):
+                batch = texts[i:i + batch_size]
+                results.extend(self._batch_translate(batch))
+            return results
+
+        response = self.client.chat.completions.create(
+            model=os.getenv('LLM_MODEL', self.model),
+            **self.options,
+            messages=[
+                {
+                    "role": "system",
+                    "content": "You are a professional,authentic machine translation engine.",
+                },
+                {
+                    "role": "user",
+                    "content": f"Translate the following list of texts to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translations in the same order with their original indexes. Each line should be in format 'index: translation'.\nSource Texts:\n{combined_text}\nTranslated Texts:",
+                },
+            ],
+        )
+        
+        # 解析返回结果并保持顺序
+        translated_lines = response.choices[0].message.content.strip().split("\n")
+        translations = [""] * len(texts)
+        for line in translated_lines:
+            try:
+                index, translation = line.split(":", 1)
+                translations[int(index)] = translation.strip()
+            except (ValueError, IndexError):
+                continue
+                
+        return translations
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     # 测试翻译示例
     # 测试翻译示例
     translator = OpenAITranslator("openai", "zh-CN", "en", "gpt-3.5-turbo")
     translator = OpenAITranslator("openai", "zh-CN", "en", "gpt-3.5-turbo")
@@ -99,3 +148,10 @@ if __name__ == "__main__":
     translated_math = translator.translate(math_text)
     translated_math = translator.translate(math_text)
     print(f"\nOriginal with math: {math_text}")
     print(f"\nOriginal with math: {math_text}")
     print(f"Translated with math: {translated_math}")
     print(f"Translated with math: {translated_math}")
+    
+    # 测试批量翻译
+    batch_texts = ["apple", "banana", "orange", "grape"]
+    translated_batch = translator.translate(batch_texts)
+    print("\nBatch translation results:")
+    for original, translated in zip(batch_texts, translated_batch):
+        print(f"{original} -> {translated}")