Browse Source

创建模板时,自动添加被渲染的变量

mrh 11 months ago
parent
commit
82b05601f3

+ 20 - 3
src/manager/template_manager.py

@@ -1,7 +1,7 @@
 from typing import Optional, List, Dict, Any
 from typing import Optional, List, Dict, Any
 from src.models.template_model import Template, TemplateType
 from src.models.template_model import Template, TemplateType
 from src.manager.core.db_mongo import BaseMongoManager
 from src.manager.core.db_mongo import BaseMongoManager
-from jinja2 import Environment, BaseLoader
+from jinja2 import Environment, BaseLoader, meta
 import json
 import json
 from bson import json_util
 from bson import json_util
 from utils.logu import get_logger
 from utils.logu import get_logger
@@ -23,12 +23,21 @@ class TemplateManager(BaseMongoManager):
                             description: str = None,
                             description: str = None,
                             collection_name: str = None) -> Template:
                             collection_name: str = None) -> Template:
         """创建新模板"""
         """创建新模板"""
+        # 解析模板中的变量
+        try:
+            parsed_content = self.env.parse(template_str)
+            variables = list(meta.find_undeclared_variables(parsed_content))
+        except Exception as e:
+            logger.error(f"Failed to parse template variables: {e}")
+            variables = []
+
         template = Template(
         template = Template(
             name=name,
             name=name,
             template_str=template_str,
             template_str=template_str,
             template_type=template_type,
             template_type=template_type,
             description=description,
             description=description,
-            collection_name=collection_name
+            collection_name=collection_name,
+            variables=variables
         )
         )
         await template.insert()
         await template.insert()
         return template
         return template
@@ -84,9 +93,17 @@ class TemplateManager(BaseMongoManager):
         template = await self.get_template(name)
         template = await self.get_template(name)
         if not template:
         if not template:
             return None
             return None
-            
+             
         if new_template_str:
         if new_template_str:
+            # 更新模板字符串时同时更新变量列表
+            try:
+                parsed_content = self.env.parse(new_template_str)
+                variables = list(meta.find_undeclared_variables(parsed_content))
+                template.variables = variables
+            except Exception as e:
+                logger.error(f"Failed to parse template variables: {e}")
             template.template_str = new_template_str
             template.template_str = new_template_str
+            
         if new_description:
         if new_description:
             template.description = new_description
             template.description = new_description
         if new_collection_name is not None:
         if new_collection_name is not None:

+ 2 - 1
src/models/template_model.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, List
 from beanie import Document
 from beanie import Document
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 from enum import Enum
 from enum import Enum
@@ -15,6 +15,7 @@ class BaseTemplate(BaseModel):
     template_type: TemplateType = Field(..., description="模板类型")
     template_type: TemplateType = Field(..., description="模板类型")
     template_str: str = Field(..., description="模板字符串")
     template_str: str = Field(..., description="模板字符串")
     collection_name: Optional[str] = Field(None, description="默认集合名称")
     collection_name: Optional[str] = Field(None, description="默认集合名称")
+    variables: List[str] = Field(default_factory=list, description="模板中使用的变量列表")
     created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
     created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
     updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
     updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
 
 

+ 2 - 0
tests/mytest/llamaindex_t/t_llm_to_pydantic.py

@@ -0,0 +1,2 @@
+# https://docs.llamaindex.ai/en/stable/examples/metadata_extraction/PydanticExtractor/
+# https://docs.llamaindex.ai/en/stable/examples/output_parsing/guidance_pydantic_program/

+ 58 - 0
tests/mytest/llamaindex_t/t_simple_mongo.py

@@ -0,0 +1,58 @@
+# pip install llama-index-readers-mongodb
+from llama_index.readers.mongodb import SimpleMongoReader
+from config.settings import MONGO_URL, MONGO_DB_NAME
+from src.manager.template_manager import TemplateManager, TemplateService
+
+
+# Initialize SimpleMongoReader
+reader = SimpleMongoReader(
+    uri=MONGO_URL,  # Provide the URI if not using host and port
+)
+def simple_load_mongodata():
+# Lazy load data from MongoDB
+    documents = reader.load_data(
+        db_name="test",  # Name of the database
+        collection_name="Product",  # Name of the collection
+        field_names=[
+            "competitor_analyze"
+        ],  # Names of the fields to concatenate (default: ["text"])
+        separator="",  # Separator between fields (default: "")
+        query_dict=None,  # Query to filter documents (default: None)
+        max_docs=0,  # Maximum number of documents to load (default: 0)
+        metadata_names=None,  # Names of the fields to add to metadata attribute (default: None)
+    )
+    for doc in documents:
+        print(doc.get_content())
+
+def get_jinja2_env():
+    from jinja2 import Environment, meta
+
+    env = Environment()
+    template_source = '[{"$match": {"basic_info.name": "{{product_name}}"}}, {"$project": {"basic_info": 1, "_id": {{show_id}}}}]'
+    parsed_content = env.parse(template_source)
+    variables = meta.find_undeclared_variables(parsed_content)
+    print(variables)  # 输出: {'product_name'}
+
+async def query_dict_load_mongodata():
+    manager = TemplateManager()
+    await manager.initialize()
+    tempalte_mmodel = await manager.get_template("product_info")
+    manager.render_template(tempalte_mmodel, {"product_name": "测试"})
+    await reader.aload_data(
+        db_name="test",  # Name of the database
+        collection_name="Product",  # Name of the collection 
+
+    )
+
+import asyncio
+import aiofiles
+import os
+import sys
+
+async def task():
+    get_jinja2_env()
+def main():
+    asyncio.run(task())
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
tests/mytest/models/t_mongo_template_service.py

@@ -108,7 +108,7 @@ async def task():
     )
     )
     logger.info(f"competitor_for_llm result: {result_list}")
     logger.info(f"competitor_for_llm result: {result_list}")
 def main():
 def main():
-    asyncio.run(task())
+    asyncio.run(create_or_update_template())
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     main()
     main()