hellofinch hace 1 año
padre
commit
2b3734e4cc
Se han modificado 4 ficheros con 110 adiciones y 55 borrados
  1. 106 53
      pdf2zh/config.py
  2. 1 1
      pdf2zh/gui.py
  3. 2 0
      pdf2zh/high_level.py
  4. 1 1
      pdf2zh/translator.py

+ 106 - 53
pdf2zh/config.py

@@ -1,15 +1,17 @@
 import json
 from pathlib import Path
-from threading import Lock
+from threading import RLock  # 改成 RLock
 import os
+import copy
 
 class ConfigManager:
     _instance = None
-    _lock = Lock()  # 用于线程安全
+    _lock = RLock()  # 用 RLock 替换 Lock,允许在同一个线程中重复获取锁
 
     @classmethod
     def get_instance(cls):
         """获取单例实例"""
+        # 先判断是否存在实例,如果不存在再加锁进行初始化
         if cls._instance is None:
             with cls._lock:
                 if cls._instance is None:
@@ -17,15 +19,21 @@ class ConfigManager:
         return cls._instance
 
     def __init__(self):
+        # 防止重复初始化
         if hasattr(self, "_initialized") and self._initialized:
-            return  # 防止重复初始化
+            return
         self._initialized = True
+
         self._config_path = Path.home() / ".config" / "PDFMathTranslate" / "config.json"
         self._config_data = {}
+
+        # 这里不要再加锁,因为外层可能已经加了锁 (get_instance), RLock也无妨
         self._ensure_config_exists()
 
     def _ensure_config_exists(self, isInit=True):
         """确保配置文件存在,如果不存在则创建默认配置"""
+        # 这里也不需要显式再次加锁,原因同上,方法体中再调用 _load_config(),
+        # 而 _load_config() 内部会加锁。因为 RLock 是可重入的,不会阻塞。
         if not self._config_path.exists():
             if isInit:
                 self._config_path.parent.mkdir(parents=True, exist_ok=True)
@@ -37,14 +45,35 @@ class ConfigManager:
             self._load_config()
 
     def _load_config(self):
-        """从config.json中加载配置"""
-        with self._config_path.open("r", encoding="utf-8") as f:
-            self._config_data = json.load(f)
+        """从 config.json 中加载配置"""
+        with self._lock:  # 加锁确保线程安全
+            with self._config_path.open("r", encoding="utf-8") as f:
+                self._config_data = json.load(f)
 
     def _save_config(self):
-        """保存配置到config.json"""
-        with self._config_path.open("w", encoding="utf-8") as f:
-            json.dump(self._config_data, f, indent=4, ensure_ascii=False)
+        """保存配置到 config.json"""
+        with self._lock:  # 加锁确保线程安全
+            # 移除循环引用并写入
+            cleaned_data = self._remove_circular_references(self._config_data)
+            with self._config_path.open("w", encoding="utf-8") as f:
+                json.dump(cleaned_data, f, indent=4, ensure_ascii=False)
+
+    def _remove_circular_references(self, obj, seen=None):
+        """递归移除循环引用"""
+        if seen is None:
+            seen = set()
+        obj_id = id(obj)
+        if obj_id in seen:
+            return None  # 遇到已处理过的对象,视为循环引用
+        seen.add(obj_id)
+
+        if isinstance(obj, dict):
+            return {
+                k: self._remove_circular_references(v, seen) for k, v in obj.items()
+            }
+        elif isinstance(obj, list):
+            return [self._remove_circular_references(i, seen) for i in obj]
+        return obj
 
     @classmethod
     def custome_config(cls, file_path):
@@ -52,10 +81,11 @@ class ConfigManager:
         custom_path = Path(file_path)
         if not custom_path.exists():
             raise ValueError(f"Config file {custom_path} not found!")
-        # 销毁现有的实例并重新初始化
+        # 加锁
         with cls._lock:
             instance = cls()
             instance._config_path = custom_path
+            # 此处传 isInit=False,若不存在则报错;若存在则正常 _load_config()
             instance._ensure_config_exists(isInit=False)
             cls._instance = instance
 
@@ -63,29 +93,35 @@ class ConfigManager:
     def get(cls, key, default=None):
         """获取配置值"""
         instance = cls.get_instance()
-        ret = instance._config_data.get(key)
-        if not ret:
-            env_get = os.environ.get(key)
-            if not env_get:
-                if not default:
-                    raise ValueError(f"{key} is not found in environment or config file.")
-                else:
-                    instance._config_data[key] = default
-                    instance._save_config()
-                    return default
-            else:
-                instance._config_data[key] = env_get
-                instance._save_config()
-                return env_get
-        else:
-            return ret
+        # 读取时,加锁或不加锁都行。但为了统一,我们在修改配置前后都要加锁。
+        # get 只要最终需要保存,则会加锁 -> _save_config()
+        if key in instance._config_data:
+            return instance._config_data[key]
+
+        # 若环境变量中存在该 key,则使用环境变量并写回 config
+        if key in os.environ:
+            value = os.environ[key]
+            instance._config_data[key] = value
+            instance._save_config()
+            return value
+
+        # 若 default 不为 None,则设置并保存
+        if default is not None:
+            instance._config_data[key] = default
+            instance._save_config()
+            return default
+
+        # 找不到则抛出异常
+        # raise KeyError(f"{key} is not found in config file or environment variables.")
+        return default
 
     @classmethod
     def set(cls, key, value):
         """设置配置值并保存"""
         instance = cls.get_instance()
-        instance._config_data[key] = value
-        instance._save_config()
+        with instance._lock:
+            instance._config_data[key] = value
+            instance._save_config()
 
     @classmethod
     def get_translator_by_name(cls, name):
@@ -94,46 +130,63 @@ class ConfigManager:
         translators = instance._config_data.get("translators", [])
         for translator in translators:
             if translator.get("name") == name:
-                return translator
+                return translator["envs"]
         return None
-    
+
     @classmethod
     def set_translator_by_name(cls, name, new_translator_envs):
         """根据 name 设置或更新 translator 配置"""
         instance = cls.get_instance()
+        with instance._lock:
+            translators = instance._config_data.get("translators", [])
+            for translator in translators:
+                if translator.get("name") == name:
+                    translator["envs"] = copy.deepcopy(new_translator_envs)
+                    instance._save_config()
+                    return
+            translators.append({"name": name, "envs": copy.deepcopy(new_translator_envs)})
+            instance._config_data["translators"] = translators
+            instance._save_config()
+
+    @classmethod
+    def get_env_by_translatername(cls, translater_name, name, default=None):
+        """根据 name 获取对应的 translator 配置"""
+        instance = cls.get_instance()
         translators = instance._config_data.get("translators", [])
-        
         for translator in translators:
-            if translator.get("name") == name:
-                translator.update({"envs": new_translator_envs})
-                instance._save_config()
-                return
-        
-        # 如果未找到匹配的 name,则添加新的 translator
-        translators.append({"name": name, "envs": new_translator_envs})
-        instance._config_data["translators"] = translators
-        instance._save_config()
-
+            if translator.get("name") == translater_name.name:
+                if translator["envs"][name]:
+                    return translator["envs"][name]
+                else:
+                    with instance._lock:
+                        translator["envs"][name] = default
+                        instance._save_config()
+                        return default
+                    
+        with instance._lock:
+            translators = instance._config_data.get("translators", [])
+            for translator in translators:
+                if translator.get("name") == translater_name.name:
+                    translator["envs"][name] = default
+                    instance._save_config()
+                    return default
+            translators.append({"name": translater_name.name, "envs": copy.deepcopy(translater_name.envs)})
+            instance._config_data["translators"] = translators
+            instance._save_config()
+            return default
 
     @classmethod
     def delete(cls, key):
         """删除配置值并保存"""
         instance = cls.get_instance()
-        if key in instance._config_data:
-            del instance._config_data[key]
-            instance._save_config()
+        with instance._lock:
+            if key in instance._config_data:
+                del instance._config_data[key]
+                instance._save_config()
 
     @classmethod
     def all(cls):
         """返回所有配置项"""
         instance = cls.get_instance()
+        # 这里只做读取操作,一般可不加锁。不过为了保险也可以加锁。
         return instance._config_data
-
-# 使用示例
-# 默认路径加载
-# ConfigManager.set("username", "admin")
-# print(ConfigManager.get("username"))
-
-# 自定义路径加载
-# ConfigManager.custome_config("/path/to/custom_config.json")
-# print(ConfigManager.get("custom_key"))

+ 1 - 1
pdf2zh/gui.py

@@ -448,7 +448,7 @@ with gr.Blocks(
                     _envs.append(gr.update(visible=False, value=""))
                 for i, env in enumerate(translator.envs.items()):
                     _envs[i] = gr.update(
-                        visible=True, label=env[0], value=os.getenv(env[0], env[1])
+                        visible=True, label=env[0], value=ConfigManager.get_env_by_translatername(translator,env[0],env[1])
                     )
                 _envs[-1] = gr.update(visible=translator.CustomPrompt)
                 return _envs

+ 2 - 0
pdf2zh/high_level.py

@@ -24,6 +24,8 @@ from pdf2zh.converter import TranslateConverter
 from pdf2zh.doclayout import OnnxModel
 from pdf2zh.pdfinterp import PDFPageInterpreterEx
 
+from pdf2zh.config import ConfigManager
+
 NOTO_NAME = "noto"
 
 noto_list = [

+ 1 - 1
pdf2zh/translator.py

@@ -54,7 +54,7 @@ class BaseTranslator:
         # Cannot use self.envs = copy(self.__class__.envs)
         # because if set_envs called twice, the second call will override the first call
         self.envs = copy(self.envs)
-        if not ConfigManager.get_translator_by_name(self.name):
+        if ConfigManager.get_translator_by_name(self.name):
             self.envs = ConfigManager.get_translator_by_name(self.name)
         needUpdate=False
         for key in self.envs: