Browse Source

feat (main) : switch os.environ to ConfigManager.

hellofinch 1 year ago
parent
commit
e1d32cc09a
6 changed files with 32 additions and 13 deletions
  1. 3 3
      pdf2zh/backend.py
  2. 17 1
      pdf2zh/config.py
  3. 3 1
      pdf2zh/doclayout.py
  4. 6 5
      pdf2zh/gui.py
  5. 1 1
      pdf2zh/high_level.py
  6. 2 2
      pdf2zh/translator.py

+ 3 - 3
pdf2zh/backend.py

@@ -1,4 +1,3 @@
-import os
 from flask import Flask, request, send_file
 from celery import Celery, Task
 from celery.result import AsyncResult
@@ -7,12 +6,13 @@ import tqdm
 import json
 import io
 from pdf2zh.doclayout import ModelInstance
+from pdf2zh.config import ConfigManager
 
 flask_app = Flask("pdf2zh")
 flask_app.config.from_mapping(
     CELERY=dict(
-        broker_url=os.environ.get("CELERY_BROKER", "redis://127.0.0.1:6379/0"),
-        result_backend=os.environ.get("CELERY_RESULT", "redis://127.0.0.1:6379/0"),
+        broker_url=ConfigManager.get("CELERY_BROKER", "redis://127.0.0.1:6379/0"),
+        result_backend=ConfigManager.get("CELERY_RESULT", "redis://127.0.0.1:6379/0"),
     )
 )
 

+ 17 - 1
pdf2zh/config.py

@@ -1,6 +1,7 @@
 import json
 from pathlib import Path
 from threading import Lock
+import os
 
 class ConfigManager:
     _instance = None
@@ -62,7 +63,22 @@ class ConfigManager:
     def get(cls, key, default=None):
         """获取配置值"""
         instance = cls.get_instance()
-        return instance._config_data.get(key, default)
+        ret = instance._config_data.get(key)
+        if not ret:
+            env_get = os.environ.get(key)
+            if not env_get:
+                if not default:
+                    raise ValueError(f"{key} is not found in environment or config file.")
+                else:
+                    instance._config_data[key] = default
+                    instance._save_config()
+                    return default
+            else:
+                instance._config_data[key] = env_get
+                instance._save_config()
+                return env_get
+        else:
+            return ret
 
     @classmethod
     def set(cls, key, value):

+ 3 - 1
pdf2zh/doclayout.py

@@ -8,6 +8,8 @@ import onnx
 import onnxruntime
 from huggingface_hub import hf_hub_download
 
+from pdf2zh.config import ConfigManager
+
 
 class DocLayoutModel(abc.ABC):
     @staticmethod
@@ -73,7 +75,7 @@ class OnnxModel(DocLayoutModel):
 
     @staticmethod
     def from_pretrained(repo_id: str, filename: str):
-        if os.environ.get("USE_MODELSCOPE", "0") == "1":
+        if ConfigManager.get("USE_MODELSCOPE", "0") == "1":
             repo_mapping = {
                 # Edit here to add more models
                 "wybxc/DocLayout-YOLO-DocStructBench-onnx": "AI-ModelScope/DocLayout-YOLO-DocStructBench-onnx"

+ 6 - 5
pdf2zh/gui.py

@@ -14,6 +14,7 @@ from gradio_pdf import PDF
 from pdf2zh import __version__
 from pdf2zh.high_level import translate
 from pdf2zh.doclayout import ModelInstance
+from pdf2zh.config import ConfigManager
 from pdf2zh.translator import (
     AnythingLLMTranslator,
     AzureOpenAITranslator,
@@ -90,7 +91,7 @@ page_map = {
 flag_demo = False
 
 # Limit resources
-if os.getenv("PDF2ZH_DEMO"):
+if ConfigManager.get("PDF2ZH_DEMO"):
     flag_demo = True
     service_map = {
         "Google": GoogleTranslator,
@@ -99,8 +100,8 @@ if os.getenv("PDF2ZH_DEMO"):
         "First": [0],
         "First 20 pages": list(range(0, 20)),
     }
-    client_key = os.getenv("PDF2ZH_CLIENT_KEY")
-    server_key = os.getenv("PDF2ZH_SERVER_KEY")
+    client_key = ConfigManager.get("PDF2ZH_CLIENT_KEY")
+    server_key = ConfigManager.get("PDF2ZH_SERVER_KEY")
 
 
 # Public demo control
@@ -411,12 +412,12 @@ with gr.Blocks(
                 lang_from = gr.Dropdown(
                     label="Translate from",
                     choices=lang_map.keys(),
-                    value=os.getenv("PDF2ZH_LANG_FROM", "English"),
+                    value=ConfigManager.get("PDF2ZH_LANG_FROM", "English"),
                 )
                 lang_to = gr.Dropdown(
                     label="Translate to",
                     choices=lang_map.keys(),
-                    value=os.getenv("PDF2ZH_LANG_TO", "Simplified Chinese"),
+                    value=ConfigManager.get("PDF2ZH_LANG_TO", "Simplified Chinese"),
                 )
             page_range = gr.Radio(
                 choices=page_map.keys(),

+ 1 - 1
pdf2zh/high_level.py

@@ -383,7 +383,7 @@ def download_remote_fonts(lang: str):
     font_name = LANG_NAME_MAP.get(lang, "GoNotoKurrent-Regular.ttf")
 
     # docker
-    font_path = os.environ.get("NOTO_FONT_PATH", Path("/app", font_name).as_posix())
+    font_path = ConfigManager.get("NOTO_FONT_PATH", Path("/app", font_name).as_posix())
     if not Path(font_path).exists():
         font_path = Path(tempfile.gettempdir(), font_name).as_posix()
     if not Path(font_path).exists():

+ 2 - 2
pdf2zh/translator.py

@@ -230,7 +230,7 @@ class DeepLXTranslator(BaseTranslator):
         super().__init__(lang_in, lang_out, model)
         self.endpoint = self.envs["DEEPLX_ENDPOINT"]
         self.session = requests.Session()
-        auth_key = os.getenv("DEEPLX_ACCESS_TOKEN", self.envs["DEEPLX_ACCESS_TOKEN"])
+        auth_key = self.envs["DEEPLX_ACCESS_TOKEN"]
         if auth_key:
             self.endpoint = f"{self.endpoint}?token={auth_key}"
 
@@ -551,7 +551,7 @@ class AzureTranslator(BaseTranslator):
         self.set_envs(envs)
         super().__init__(lang_in, lang_out, model)
         endpoint = self.envs["AZURE_ENDPOINT"]
-        api_key = os.getenv("AZURE_API_KEY")
+        api_key = self.envs["AZURE_API_KEY"]
         credential = AzureKeyCredential(api_key)
         self.client = TextTranslationClient(
             endpoint=endpoint, credential=credential, region="chinaeast2"