Prechádzať zdrojové kódy

创建模板时,自动添加被渲染的变量

mrh 11 mesiacov pred
rodič
commit
82b05601f3

+ 20 - 3
src/manager/template_manager.py

@@ -1,7 +1,7 @@
 from typing import Optional, List, Dict, Any
 from src.models.template_model import Template, TemplateType
 from src.manager.core.db_mongo import BaseMongoManager
-from jinja2 import Environment, BaseLoader
+from jinja2 import Environment, BaseLoader, meta
 import json
 from bson import json_util
 from utils.logu import get_logger
@@ -23,12 +23,21 @@ class TemplateManager(BaseMongoManager):
                             description: str = None,
                             collection_name: str = None) -> Template:
         """创建新模板"""
+        # 解析模板中的变量
+        try:
+            parsed_content = self.env.parse(template_str)
+            variables = list(meta.find_undeclared_variables(parsed_content))
+        except Exception as e:
+            logger.error(f"Failed to parse template variables: {e}")
+            variables = []
+
         template = Template(
             name=name,
             template_str=template_str,
             template_type=template_type,
             description=description,
-            collection_name=collection_name
+            collection_name=collection_name,
+            variables=variables
         )
         await template.insert()
         return template
@@ -84,9 +93,17 @@ class TemplateManager(BaseMongoManager):
         template = await self.get_template(name)
         if not template:
             return None
-            
+             
         if new_template_str:
+            # 更新模板字符串时同时更新变量列表
+            try:
+                parsed_content = self.env.parse(new_template_str)
+                variables = list(meta.find_undeclared_variables(parsed_content))
+                template.variables = variables
+            except Exception as e:
+                logger.error(f"Failed to parse template variables: {e}")
             template.template_str = new_template_str
+            
         if new_description:
             template.description = new_description
         if new_collection_name is not None:

+ 2 - 1
src/models/template_model.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, List
 from beanie import Document
 from pydantic import BaseModel, Field
 from enum import Enum
@@ -15,6 +15,7 @@ class BaseTemplate(BaseModel):
     template_type: TemplateType = Field(..., description="模板类型")
     template_str: str = Field(..., description="模板字符串")
     collection_name: Optional[str] = Field(None, description="默认集合名称")
+    variables: List[str] = Field(default_factory=list, description="模板中使用的变量列表")
     created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
     updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
 

+ 2 - 0
tests/mytest/llamaindex_t/t_llm_to_pydantic.py

@@ -0,0 +1,2 @@
+# https://docs.llamaindex.ai/en/stable/examples/metadata_extraction/PydanticExtractor/
+# https://docs.llamaindex.ai/en/stable/examples/output_parsing/guidance_pydantic_program/

+ 58 - 0
tests/mytest/llamaindex_t/t_simple_mongo.py

@@ -0,0 +1,58 @@
+# pip install llama-index-readers-mongodb
+from llama_index.readers.mongodb import SimpleMongoReader
+from config.settings import MONGO_URL, MONGO_DB_NAME
+from src.manager.template_manager import TemplateManager, TemplateService
+
+
+# Initialize SimpleMongoReader
+reader = SimpleMongoReader(
+    uri=MONGO_URL,  # Provide the URI if not using host and port
+)
+def simple_load_mongodata():
+# Lazy load data from MongoDB
+    documents = reader.load_data(
+        db_name="test",  # Name of the database
+        collection_name="Product",  # Name of the collection
+        field_names=[
+            "competitor_analyze"
+        ],  # Names of the fields to concatenate (default: ["text"])
+        separator="",  # Separator between fields (default: "")
+        query_dict=None,  # Query to filter documents (default: None)
+        max_docs=0,  # Maximum number of documents to load (default: 0)
+        metadata_names=None,  # Names of the fields to add to metadata attribute (default: None)
+    )
+    for doc in documents:
+        print(doc.get_content())
+
+def get_jinja2_env():
+    from jinja2 import Environment, meta
+
+    env = Environment()
+    template_source = '[{"$match": {"basic_info.name": "{{product_name}}"}}, {"$project": {"basic_info": 1, "_id": {{show_id}}}}]'
+    parsed_content = env.parse(template_source)
+    variables = meta.find_undeclared_variables(parsed_content)
+    print(variables)  # 输出: {'product_name'}
+
+async def query_dict_load_mongodata():
+    manager = TemplateManager()
+    await manager.initialize()
+    tempalte_mmodel = await manager.get_template("product_info")
+    manager.render_template(tempalte_mmodel, {"product_name": "测试"})
+    await reader.aload_data(
+        db_name="test",  # Name of the database
+        collection_name="Product",  # Name of the collection 
+
+    )
+
+import asyncio
+import aiofiles
+import os
+import sys
+
+async def task():
+    get_jinja2_env()
+def main():
+    asyncio.run(task())
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
tests/mytest/models/t_mongo_template_service.py

@@ -108,7 +108,7 @@ async def task():
     )
     logger.info(f"competitor_for_llm result: {result_list}")
 def main():
-    asyncio.run(task())
+    asyncio.run(create_or_update_template())
 
 if __name__ == "__main__":
     main()