瀏覽代碼

完善 MongoDB模板化添加集合名称

mrh 8 月之前
父節點
當前提交
4e078cdf9f
共有 3 個文件被更改,包括 41 次插入12 次删除
  1. 17 8
      src/manager/template_manager.py
  2. 1 0
      src/models/template_model.py
  3. 23 4
      tests/mytest/models/t_mongo_template_service.py

+ 17 - 8
src/manager/template_manager.py

@@ -20,13 +20,15 @@ class TemplateManager(BaseMongoManager):
         
     async def create_template(self, name: str, template_str: str,
                             template_type: TemplateType,
-                            description: str = None) -> Template:
+                            description: str = None,
+                            collection_name: str = None) -> Template:
         """创建新模板"""
         template = Template(
             name=name,
             template_str=template_str,
             template_type=template_type,
-            description=description
+            description=description,
+            collection_name=collection_name
         )
         await template.insert()
         return template
@@ -34,6 +36,7 @@ class TemplateManager(BaseMongoManager):
     async def create_or_update_template(self, name: str, template_str: str,
                                      template_type: TemplateType,
                                      description: str = None,
+                                     collection_name: str = None,
                                      if_exists: str = "update") -> Template:
         """
         创建模板,如果已存在则根据if_exists参数处理
@@ -56,7 +59,8 @@ class TemplateManager(BaseMongoManager):
                 return await self.update_template(
                     name=name,
                     new_template_str=template_str,
-                    new_description=description
+                    new_description=description,
+                    new_collection_name=collection_name
                 )
             else:
                 raise ValueError(f"Invalid if_exists value: {if_exists}. Must be 'update' or 'ignore'")
@@ -64,7 +68,8 @@ class TemplateManager(BaseMongoManager):
             name=name,
             template_str=template_str,
             template_type=template_type,
-            description=description
+            description=description,
+            collection_name=collection_name
         )
 
     async def get_template(self, name: str) -> Optional[Template]:
@@ -73,7 +78,8 @@ class TemplateManager(BaseMongoManager):
         return await Template.find_one(Template.name == name)
 
     async def update_template(self, name: str, new_template_str: str = None,
-                            new_description: str = None) -> Optional[Template]:
+                            new_description: str = None,
+                            new_collection_name: str = None) -> Optional[Template]:
         """更新模板"""
         template = await self.get_template(name)
         if not template:
@@ -83,6 +89,8 @@ class TemplateManager(BaseMongoManager):
             template.template_str = new_template_str
         if new_description:
             template.description = new_description
+        if new_collection_name is not None:
+            template.collection_name = new_collection_name
             
         await template.update_timestamp()
         return template
@@ -122,9 +130,10 @@ class TemplateManager(BaseMongoManager):
             
         pipeline = self.render_template(template.template_str, context)
         
-        if not collection_name:
-            if template.template_type == TemplateType.AGGREGATION:
-                collection_name = "Product"  # 默认集合
+        if not collection_name and template.collection_name:
+            collection_name = template.collection_name
+        elif not collection_name and template.template_type == TemplateType.AGGREGATION:
+            collection_name = "Product"  # 默认集合
             
         if not collection_name:
             raise ValueError("Collection name is required for this template type")

+ 1 - 0
src/models/template_model.py

@@ -14,6 +14,7 @@ class BaseTemplate(BaseModel):
     description: Optional[str] = Field(None, description="模板描述")
     template_type: TemplateType = Field(..., description="模板类型")
     template_str: str = Field(..., description="模板字符串")
+    collection_name: Optional[str] = Field(None, description="默认集合名称")
     created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
     updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
 

+ 23 - 4
tests/mytest/models/t_mongo_template_service.py

@@ -68,26 +68,45 @@ async def create_or_update_template():
     await manager.initialize()
     template = await manager.create_or_update_template(
         name="product_info",
+        collection_name="Product",
         template_str=json.dumps(product_info_pipeline),
         template_type=TemplateType.AGGREGATION,
         description="产品信息查询模板"
     )
     await manager.create_or_update_template(
         name="competitor_for_llm",
+        collection_name="Product",
         template_str=json.dumps(filter_competior_by_name),
         template_type=TemplateType.AGGREGATION,
         description="竞品数据查询模板,筛选出主要信息"
     )
     logger.info(f"Created template: {template}")
 
-async def task():
+async def render():
     manager = TemplateManager()
     await manager.initialize()
     template = await manager.get_template("product_info")
-    logger.info(f"product_info template: {template}")
+    rendered_template = manager.render_template(template_str=template.template_str, context={"product_name": "电线保护套"})
+    logger.info(f"product_info template: {rendered_template}")
     template = await manager.get_template("competitor_for_llm")
-    logger.info(f"competitor_for_llm template: {template}")
-    
+    rendered_template = manager.render_template(template_str=template.template_str, context={"product_name": "电线保护套"})
+    logger.info(f"competitor_for_llm template: {rendered_template}")
+
+async def task():
+    manager = TemplateManager()
+    await manager.initialize()
+    result_list = await manager.execute_template(
+        name="product_info",
+        context={"product_name": "电线保护套"},
+        collection_name="Product"
+    )
+    logger.info(f"product_info result: {result_list}")
+    result_list = await manager.execute_template(
+        name="competitor_for_llm",
+        context={"product_name": "电线保护套"},
+        collection_name="Product"
+    )
+    logger.info(f"competitor_for_llm result: {result_list}")
 def main():
     asyncio.run(task())