|
@@ -56,11 +56,12 @@ class GoogleTranslator(BaseTranslator):
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAITranslator(BaseTranslator):
|
|
class OpenAITranslator(BaseTranslator):
|
|
|
- def __init__(self, service, lang_out, lang_in, model):
|
|
|
|
|
|
|
+ def __init__(self, service, lang_out, lang_in, model, max_tokens=2000):
|
|
|
lang_out = "zh-CN" if lang_out == "auto" else lang_out
|
|
lang_out = "zh-CN" if lang_out == "auto" else lang_out
|
|
|
lang_in = "en" if lang_in == "auto" else lang_in
|
|
lang_in = "en" if lang_in == "auto" else lang_in
|
|
|
super().__init__(service, lang_out, lang_in, model)
|
|
super().__init__(service, lang_out, lang_in, model)
|
|
|
self.options = {"temperature": 0} # 随机采样可能会打断公式标记
|
|
self.options = {"temperature": 0} # 随机采样可能会打断公式标记
|
|
|
|
|
+ self.max_tokens = max_tokens
|
|
|
# Configure OpenAI client with environment variables
|
|
# Configure OpenAI client with environment variables
|
|
|
self.client = openai.OpenAI(
|
|
self.client = openai.OpenAI(
|
|
|
api_key=os.getenv('OPENAI_API_KEY'),
|
|
api_key=os.getenv('OPENAI_API_KEY'),
|
|
@@ -68,8 +69,13 @@ class OpenAITranslator(BaseTranslator):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def translate(self, text) -> str:
|
|
def translate(self, text) -> str:
|
|
|
|
|
+ if isinstance(text, list):
|
|
|
|
|
+ return self._batch_translate(text)
|
|
|
|
|
+ return self._single_translate(text)
|
|
|
|
|
+
|
|
|
|
|
+ def _single_translate(self, text) -> str:
|
|
|
response = self.client.chat.completions.create(
|
|
response = self.client.chat.completions.create(
|
|
|
- model=os.getenv('LLM_MODEL', self.model), # Use env var or fallback to default
|
|
|
|
|
|
|
+ model=os.getenv('LLM_MODEL', self.model),
|
|
|
**self.options,
|
|
**self.options,
|
|
|
messages=[
|
|
messages=[
|
|
|
{
|
|
{
|
|
@@ -78,12 +84,55 @@ class OpenAITranslator(BaseTranslator):
|
|
|
},
|
|
},
|
|
|
{
|
|
{
|
|
|
"role": "user",
|
|
"role": "user",
|
|
|
- "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:", # noqa: E501
|
|
|
|
|
|
|
+ "content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:",
|
|
|
},
|
|
},
|
|
|
],
|
|
],
|
|
|
)
|
|
)
|
|
|
return response.choices[0].message.content.strip()
|
|
return response.choices[0].message.content.strip()
|
|
|
|
|
|
|
|
|
|
+ def _batch_translate(self, texts) -> list:
|
|
|
|
|
+ # 将文本列表转换为带索引的格式
|
|
|
|
|
+ indexed_texts = [f"{i}: {text}" for i, text in enumerate(texts)]
|
|
|
|
|
+ combined_text = "\n".join(indexed_texts)
|
|
|
|
|
+
|
|
|
|
|
+ # 计算总token数并分块处理
|
|
|
|
|
+ total_length = len(combined_text)
|
|
|
|
|
+ if total_length > self.max_tokens:
|
|
|
|
|
+ # 如果超过最大token数,分成多个批次处理
|
|
|
|
|
+ batch_size = len(texts) // (total_length // self.max_tokens + 1)
|
|
|
|
|
+ results = []
|
|
|
|
|
+ for i in range(0, len(texts), batch_size):
|
|
|
|
|
+ batch = texts[i:i + batch_size]
|
|
|
|
|
+ results.extend(self._batch_translate(batch))
|
|
|
|
|
+ return results
|
|
|
|
|
+
|
|
|
|
|
+ response = self.client.chat.completions.create(
|
|
|
|
|
+ model=os.getenv('LLM_MODEL', self.model),
|
|
|
|
|
+ **self.options,
|
|
|
|
|
+ messages=[
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "system",
|
|
|
|
|
+ "content": "You are a professional,authentic machine translation engine.",
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "user",
|
|
|
|
|
+ "content": f"Translate the following list of texts to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translations in the same order with their original indexes. Each line should be in format 'index: translation'.\nSource Texts:\n{combined_text}\nTranslated Texts:",
|
|
|
|
|
+ },
|
|
|
|
|
+ ],
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 解析返回结果并保持顺序
|
|
|
|
|
+ translated_lines = response.choices[0].message.content.strip().split("\n")
|
|
|
|
|
+ translations = [""] * len(texts)
|
|
|
|
|
+ for line in translated_lines:
|
|
|
|
|
+ try:
|
|
|
|
|
+ index, translation = line.split(":", 1)
|
|
|
|
|
+ translations[int(index)] = translation.strip()
|
|
|
|
|
+ except (ValueError, IndexError):
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ return translations
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
# 测试翻译示例
|
|
# 测试翻译示例
|
|
|
translator = OpenAITranslator("openai", "zh-CN", "en", "gpt-3.5-turbo")
|
|
translator = OpenAITranslator("openai", "zh-CN", "en", "gpt-3.5-turbo")
|
|
@@ -99,3 +148,10 @@ if __name__ == "__main__":
|
|
|
translated_math = translator.translate(math_text)
|
|
translated_math = translator.translate(math_text)
|
|
|
print(f"\nOriginal with math: {math_text}")
|
|
print(f"\nOriginal with math: {math_text}")
|
|
|
print(f"Translated with math: {translated_math}")
|
|
print(f"Translated with math: {translated_math}")
|
|
|
|
|
+
|
|
|
|
|
+ # 测试批量翻译
|
|
|
|
|
+ batch_texts = ["apple", "banana", "orange", "grape"]
|
|
|
|
|
+ translated_batch = translator.translate(batch_texts)
|
|
|
|
|
+ print("\nBatch translation results:")
|
|
|
|
|
+ for original, translated in zip(batch_texts, translated_batch):
|
|
|
|
|
+ print(f"{original} -> {translated}")
|