|
@@ -7,6 +7,7 @@ from copy import copy
|
|
|
import deepl
|
|
import deepl
|
|
|
import ollama
|
|
import ollama
|
|
|
import openai
|
|
import openai
|
|
|
|
|
+import xinference_client
|
|
|
import requests
|
|
import requests
|
|
|
from pdf2zh.cache import TranslationCache
|
|
from pdf2zh.cache import TranslationCache
|
|
|
from azure.ai.translation.text import TextTranslationClient
|
|
from azure.ai.translation.text import TextTranslationClient
|
|
@@ -278,6 +279,57 @@ class OllamaTranslator(BaseTranslator):
|
|
|
raise Exception("All models failed")
|
|
raise Exception("All models failed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class XinferenceTranslator(BaseTranslator):
|
|
|
|
|
+ # https://github.com/xorbitsai/inference
|
|
|
|
|
+ name = "xinference"
|
|
|
|
|
+ envs = {
|
|
|
|
|
+ "XINFERENCE_HOST": "http://127.0.0.1:9997",
|
|
|
|
|
+ "XINFERENCE_MODEL": "gemma-2-it",
|
|
|
|
|
+ }
|
|
|
|
|
+ CustomPrompt = True
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
|
|
|
|
|
+ self.set_envs(envs)
|
|
|
|
|
+ if not model:
|
|
|
|
|
+ model = self.envs["XINFERENCE_MODEL"]
|
|
|
|
|
+ super().__init__(lang_in, lang_out, model)
|
|
|
|
|
+ self.options = {"temperature": 0} # 随机采样可能会打断公式标记
|
|
|
|
|
+ self.client = xinference_client.RESTfulClient(self.envs["XINFERENCE_HOST"])
|
|
|
|
|
+ self.prompttext = prompt
|
|
|
|
|
+ self.add_cache_impact_parameters("temperature", self.options["temperature"])
|
|
|
|
|
+ if prompt:
|
|
|
|
|
+ self.add_cache_impact_parameters("prompt", prompt)
|
|
|
|
|
+
|
|
|
|
|
+ def do_translate(self, text):
|
|
|
|
|
+ maxlen = max(2000, len(text) * 5)
|
|
|
|
|
+ for model in self.model.split(";"):
|
|
|
|
|
+ try:
|
|
|
|
|
+ xf_model = self.client.get_model(model)
|
|
|
|
|
+ xf_prompt = self.prompt(text, self.prompttext)
|
|
|
|
|
+ xf_prompt = [
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "user",
|
|
|
|
|
+ "content": xf_prompt[0]["content"]
|
|
|
|
|
+ + "\n"
|
|
|
|
|
+ + xf_prompt[1]["content"],
|
|
|
|
|
+ }
|
|
|
|
|
+ ]
|
|
|
|
|
+ response = xf_model.chat(
|
|
|
|
|
+ generate_config=self.options,
|
|
|
|
|
+ messages=xf_prompt,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ response = response["choices"][0]["message"]["content"].replace(
|
|
|
|
|
+ "<end_of_turn>", ""
|
|
|
|
|
+ )
|
|
|
|
|
+ if len(response) > maxlen:
|
|
|
|
|
+ raise Exception("Response too long")
|
|
|
|
|
+ return response.strip()
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(e)
|
|
|
|
|
+ raise Exception("All models failed")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class OpenAITranslator(BaseTranslator):
|
|
class OpenAITranslator(BaseTranslator):
|
|
|
# https://github.com/openai/openai-python
|
|
# https://github.com/openai/openai-python
|
|
|
name = "openai"
|
|
name = "openai"
|
|
@@ -303,7 +355,10 @@ class OpenAITranslator(BaseTranslator):
|
|
|
model = self.envs["OPENAI_MODEL"]
|
|
model = self.envs["OPENAI_MODEL"]
|
|
|
super().__init__(lang_in, lang_out, model)
|
|
super().__init__(lang_in, lang_out, model)
|
|
|
self.options = {"temperature": 0} # 随机采样可能会打断公式标记
|
|
self.options = {"temperature": 0} # 随机采样可能会打断公式标记
|
|
|
- self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
|
|
|
|
|
|
|
+ self.client = openai.OpenAI(
|
|
|
|
|
+ base_url=base_url or self.envs["OPENAI_BASE_URL"],
|
|
|
|
|
+ api_key=api_key or self.envs["OPENAI_API_KEY"],
|
|
|
|
|
+ )
|
|
|
self.prompttext = prompt
|
|
self.prompttext = prompt
|
|
|
self.add_cache_impact_parameters("temperature", self.options["temperature"])
|
|
self.add_cache_impact_parameters("temperature", self.options["temperature"])
|
|
|
if prompt:
|
|
if prompt:
|