浏览代码

格式化的信息保存为SQLmodel

mrh 4 月之前
父节点
当前提交
bafc9fc7a9
共有 1 个文件被更改,包括 61 次插入25 次删除
  1. 61 25
      tests/mytest/llamaindex_t/t_llm_to_pydantic.py

+ 61 - 25
tests/mytest/llamaindex_t/t_llm_to_pydantic.py

@@ -2,14 +2,16 @@ from llama_index.llms.litellm import LiteLLM
 from config.settings import OPENAI_API_KEY, OPENAI_API_BASE
 from typing import List, Optional, Any
 from pydantic import BaseModel, Field
+from sqlmodel import SQLModel, Field as SQLField
+from datetime import datetime
 from llama_index.core.program import LLMTextCompletionProgram
 from llama_index.core.output_parsers import PydanticOutputParser
 from src.flow_task.crawl_asin_flow import parse_url_to_markdown_task
 from llama_index.core.output_parsers.pydantic import extract_json_str
 
 
-class ProductForExtraction(BaseModel):
-    """用于LLMTextCompletionProgram的产品信息模型"""
+class ProductBase(BaseModel):
+    """产品基础信息模型,用于共享字段定义"""
     product_name: str = Field(..., description="产品名称")
     material: Optional[str] = Field(default=None, description="材质")
     color: Optional[str] = Field(default=None, description="颜色")
@@ -18,17 +20,49 @@ class ProductForExtraction(BaseModel):
     competitor_list: List[str] = Field(default_factory=list, description="竞品ASIN列表")
 
 
-class Product(BaseModel):
-    """产品信息Pydantic模型类"""
-    product_name: str = Field(..., description="产品名称")
-    competitor_list: List[str] = Field(default_factory=list, description="竞品ASIN列表")
-    markdown_content: str = Field(..., description="Markdown格式的完整数据源文本")
-    uri: str = Field(..., description="数据源的地址或路径")
-    filename: str = Field(..., description="文件名")
-    material: Optional[str] = Field(default=None, description="材质")
-    color: Optional[str] = Field(default=None, description="颜色")
-    main_usage: Optional[str] = Field(default=None, description="主要用途")
-    main_selling_points: List[str] = Field(default_factory=list, description="主要卖点列表")
+class ProductForExtraction(ProductBase):
+    """用于LLMTextCompletionProgram的产品信息模型"""
+    pass
+
+
+class Product(SQLModel, table=True):
+    """产品信息SQLModel模型类"""
+    id: Optional[int] = SQLField(default=None, primary_key=True)
+    product_name: str = SQLField(..., description="产品名称")
+    product_data: str = SQLField(..., description="ProductForExtraction的JSON格式数据")
+    markdown_content: str = SQLField(..., description="Markdown格式的完整数据源文本")
+    uri: str = SQLField(..., description="数据源的地址或路径")
+    filename: str = SQLField(..., description="文件名")
+    created_at: Optional[datetime] = SQLField(default_factory=datetime.now)
+    updated_at: Optional[datetime] = SQLField(default_factory=datetime.now)
+    
+    @classmethod
+    def from_product_extraction(cls, extracted_product: ProductForExtraction, markdown_content: str, uri: str, filename: str) -> 'Product':
+        """从ProductForExtraction创建Product实例"""
+        import json
+        
+        # 将ProductForExtraction转换为JSON字符串
+        product_data_json = extracted_product.model_dump_json()
+        
+        return cls(
+            product_name=extracted_product.product_name,
+            product_data=product_data_json,
+            markdown_content=markdown_content,
+            uri=uri,
+            filename=filename
+        )
+    
+    def get_product_extraction(self) -> ProductForExtraction:
+        """获取ProductForExtraction实例"""
+        import json
+        
+        return ProductForExtraction.model_validate_json(self.product_data)
+    
+    def get_product_base(self) -> ProductBase:
+        """获取ProductBase实例"""
+        import json
+        
+        return ProductBase.model_validate_json(self.product_data)
 
 
 
@@ -75,13 +109,9 @@ def extract_product_from_text(text: str, uri: str = "", filename: str = "") -> P
     
     extracted_product = program(text=text)
     
-    return Product(
-        product_name=extracted_product.product_name,
-        material=extracted_product.material,
-        color=extracted_product.color,
-        main_usage=extracted_product.main_usage,
-        main_selling_points=extracted_product.main_selling_points,
-        competitor_list=extracted_product.competitor_list,
+    # 使用新的类方法创建Product实例
+    return Product.from_product_extraction(
+        extracted_product=extracted_product,
         markdown_content=text,
         uri=uri,
         filename=filename
@@ -106,11 +136,17 @@ if __name__ == "__main__":
     
     print("从URL提取成功!产品信息:")
     print(f"产品名称: {product_from_url.product_name}")
-    print(f"材质: {product_from_url.material}")
-    print(f"颜色: {product_from_url.color}")
-    print(f"主要用途: {product_from_url.main_usage}")
-    print(f"主要卖点: {product_from_url.main_selling_points}")
-    print(f"竞品列表: {product_from_url.competitor_list}")
+    
+    # 使用新的方法获取完整的产品信息
+    product_extraction = product_from_url.get_product_extraction()
+    print(f"材质: {product_extraction.material}")
+    print(f"颜色: {product_extraction.color}")
+    print(f"主要用途: {product_extraction.main_usage}")
+    print(f"主要卖点: {product_extraction.main_selling_points}")
+    print(f"竞品列表: {product_extraction.competitor_list}")
     print(f"数据源: {product_from_url.uri}")
     print(f"文件名: {product_from_url.filename}")
+    
+    # 显示存储的JSON数据
+    print(f"\n存储的JSON数据: {product_from_url.product_data}")