|
|
@@ -2,14 +2,16 @@ from llama_index.llms.litellm import LiteLLM
|
|
|
from config.settings import OPENAI_API_KEY, OPENAI_API_BASE
|
|
|
from typing import List, Optional, Any
|
|
|
from pydantic import BaseModel, Field
|
|
|
+from sqlmodel import SQLModel, Field as SQLField
|
|
|
+from datetime import datetime
|
|
|
from llama_index.core.program import LLMTextCompletionProgram
|
|
|
from llama_index.core.output_parsers import PydanticOutputParser
|
|
|
from src.flow_task.crawl_asin_flow import parse_url_to_markdown_task
|
|
|
from llama_index.core.output_parsers.pydantic import extract_json_str
|
|
|
|
|
|
|
|
|
-class ProductForExtraction(BaseModel):
|
|
|
- """用于LLMTextCompletionProgram的产品信息模型"""
|
|
|
+class ProductBase(BaseModel):
|
|
|
+ """产品基础信息模型,用于共享字段定义"""
|
|
|
product_name: str = Field(..., description="产品名称")
|
|
|
material: Optional[str] = Field(default=None, description="材质")
|
|
|
color: Optional[str] = Field(default=None, description="颜色")
|
|
|
@@ -18,17 +20,49 @@ class ProductForExtraction(BaseModel):
|
|
|
competitor_list: List[str] = Field(default_factory=list, description="竞品ASIN列表")
|
|
|
|
|
|
|
|
|
-class Product(BaseModel):
|
|
|
- """产品信息Pydantic模型类"""
|
|
|
- product_name: str = Field(..., description="产品名称")
|
|
|
- competitor_list: List[str] = Field(default_factory=list, description="竞品ASIN列表")
|
|
|
- markdown_content: str = Field(..., description="Markdown格式的完整数据源文本")
|
|
|
- uri: str = Field(..., description="数据源的地址或路径")
|
|
|
- filename: str = Field(..., description="文件名")
|
|
|
- material: Optional[str] = Field(default=None, description="材质")
|
|
|
- color: Optional[str] = Field(default=None, description="颜色")
|
|
|
- main_usage: Optional[str] = Field(default=None, description="主要用途")
|
|
|
- main_selling_points: List[str] = Field(default_factory=list, description="主要卖点列表")
|
|
|
+class ProductForExtraction(ProductBase):
|
|
|
+ """用于LLMTextCompletionProgram的产品信息模型"""
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class Product(SQLModel, table=True):
|
|
|
+ """产品信息SQLModel模型类"""
|
|
|
+ id: Optional[int] = SQLField(default=None, primary_key=True)
|
|
|
+ product_name: str = SQLField(..., description="产品名称")
|
|
|
+ product_data: str = SQLField(..., description="ProductForExtraction的JSON格式数据")
|
|
|
+ markdown_content: str = SQLField(..., description="Markdown格式的完整数据源文本")
|
|
|
+ uri: str = SQLField(..., description="数据源的地址或路径")
|
|
|
+ filename: str = SQLField(..., description="文件名")
|
|
|
+ created_at: Optional[datetime] = SQLField(default_factory=datetime.now)
|
|
|
+ updated_at: Optional[datetime] = SQLField(default_factory=datetime.now)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_product_extraction(cls, extracted_product: ProductForExtraction, markdown_content: str, uri: str, filename: str) -> 'Product':
|
|
|
+ """从ProductForExtraction创建Product实例"""
|
|
|
+ import json
|
|
|
+
|
|
|
+ # 将ProductForExtraction转换为JSON字符串
|
|
|
+ product_data_json = extracted_product.model_dump_json()
|
|
|
+
|
|
|
+ return cls(
|
|
|
+ product_name=extracted_product.product_name,
|
|
|
+ product_data=product_data_json,
|
|
|
+ markdown_content=markdown_content,
|
|
|
+ uri=uri,
|
|
|
+ filename=filename
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_product_extraction(self) -> ProductForExtraction:
|
|
|
+ """获取ProductForExtraction实例"""
|
|
|
+ import json
|
|
|
+
|
|
|
+ return ProductForExtraction.model_validate_json(self.product_data)
|
|
|
+
|
|
|
+ def get_product_base(self) -> ProductBase:
|
|
|
+ """获取ProductBase实例"""
|
|
|
+ import json
|
|
|
+
|
|
|
+ return ProductBase.model_validate_json(self.product_data)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -75,13 +109,9 @@ def extract_product_from_text(text: str, uri: str = "", filename: str = "") -> P
|
|
|
|
|
|
extracted_product = program(text=text)
|
|
|
|
|
|
- return Product(
|
|
|
- product_name=extracted_product.product_name,
|
|
|
- material=extracted_product.material,
|
|
|
- color=extracted_product.color,
|
|
|
- main_usage=extracted_product.main_usage,
|
|
|
- main_selling_points=extracted_product.main_selling_points,
|
|
|
- competitor_list=extracted_product.competitor_list,
|
|
|
+ # 使用新的类方法创建Product实例
|
|
|
+ return Product.from_product_extraction(
|
|
|
+ extracted_product=extracted_product,
|
|
|
markdown_content=text,
|
|
|
uri=uri,
|
|
|
filename=filename
|
|
|
@@ -106,11 +136,17 @@ if __name__ == "__main__":
|
|
|
|
|
|
print("从URL提取成功!产品信息:")
|
|
|
print(f"产品名称: {product_from_url.product_name}")
|
|
|
- print(f"材质: {product_from_url.material}")
|
|
|
- print(f"颜色: {product_from_url.color}")
|
|
|
- print(f"主要用途: {product_from_url.main_usage}")
|
|
|
- print(f"主要卖点: {product_from_url.main_selling_points}")
|
|
|
- print(f"竞品列表: {product_from_url.competitor_list}")
|
|
|
+
|
|
|
+ # 使用新的方法获取完整的产品信息
|
|
|
+ product_extraction = product_from_url.get_product_extraction()
|
|
|
+ print(f"材质: {product_extraction.material}")
|
|
|
+ print(f"颜色: {product_extraction.color}")
|
|
|
+ print(f"主要用途: {product_extraction.main_usage}")
|
|
|
+ print(f"主要卖点: {product_extraction.main_selling_points}")
|
|
|
+ print(f"竞品列表: {product_extraction.competitor_list}")
|
|
|
print(f"数据源: {product_from_url.uri}")
|
|
|
print(f"文件名: {product_from_url.filename}")
|
|
|
+
|
|
|
+ # 显示存储的JSON数据
|
|
|
+ print(f"\n存储的JSON数据: {product_from_url.product_data}")
|
|
|
|