|
|
@@ -2,31 +2,41 @@ from datetime import datetime
|
|
|
from beanie import Document, Link
|
|
|
from pydantic import BaseModel, Field
|
|
|
from typing import Any, List, Dict, Optional, Union
|
|
|
-from src.models.product_model import Product,AICompetitorAnalyzeMainKeywordsResult,MarketingInfo
|
|
|
+from src.models.product_model import Product,AICompetitorAnalyzeMainKeywordsResult,MarketingInfo,AIPromptConfig
|
|
|
|
|
|
class LLMConfig(BaseModel):
|
|
|
- executor_type='llm'
|
|
|
- kwargs: Optional[Dict[str, Any]] = Field(
|
|
|
- default_factory=None,
|
|
|
+ model_name: str = Field(
|
|
|
+ ...,
|
|
|
+ description="LLM的名称"
|
|
|
+ )
|
|
|
+ executor_type:str = Field(default='llm',description="执行器类型")
|
|
|
+ keywargs: Optional[Dict[str, Any]] = Field(
|
|
|
+ default=None,
|
|
|
description="LLM的配置参数"
|
|
|
)
|
|
|
|
|
|
class AgentConfig(BaseModel):
|
|
|
- executor_type='agent'
|
|
|
- agent_name: str
|
|
|
+ executor_type: str = Field(default='agent', description="执行器类型")
|
|
|
+ agent_name: Optional[str] = Field(
|
|
|
+ default=None,
|
|
|
+ description="Agent的名称"
|
|
|
+ )
|
|
|
agent_config: Dict = {}
|
|
|
|
|
|
-class SuperPromptMixin:
|
|
|
- executor_type='super_prompt'
|
|
|
- api_base_url: str
|
|
|
- provider_name: str
|
|
|
- provider_config: Dict = {}
|
|
|
+class SuperPromptMixin(BaseModel):
|
|
|
+ executor_type: str = Field(default='super_prompt', description="执行器类型")
|
|
|
+ api_base_url: Optional[str] = Field(
|
|
|
+ default=None,
|
|
|
+ description="API的base url"
|
|
|
+ )
|
|
|
+ provider_name: Optional[str] = Field(default=None,description="Provider的名称")
|
|
|
+ provider_config: Optional[Dict] = Field(default=None,description="Provider的配置")
|
|
|
|
|
|
|
|
|
class BaseAIExecution(Document):
|
|
|
"""AI执行结果基类"""
|
|
|
product: Optional[Link["Product"]] = None
|
|
|
- prompt_template: Optional[str] = None
|
|
|
+ prompting: Optional[AIPromptConfig] = None
|
|
|
input_data: Optional[Dict[str, Any]] = Field(
|
|
|
default_factory=dict,
|
|
|
description="完整的输入数据"
|
|
|
@@ -35,7 +45,7 @@ class BaseAIExecution(Document):
|
|
|
default_factory=dict,
|
|
|
description="原始的AI输出结果"
|
|
|
)
|
|
|
- executor_config: Optional[Dict[str, Union[Any,LLMConfig,Dict]]] = None
|
|
|
+ executor_config: Optional[Union[LLMConfig, AgentConfig, SuperPromptMixin]] = None
|
|
|
helpful_level: Optional[int] = None
|
|
|
created_at: datetime = Field(default_factory=datetime.now)
|
|
|
|
|
|
@@ -47,7 +57,7 @@ class CompetitorKeywordAnalysis(BaseAIExecution):
|
|
|
"""竞品关键词分析结果"""
|
|
|
task_type: str = "competitor_analysis"
|
|
|
result: Optional[AICompetitorAnalyzeMainKeywordsResult] = Field(
|
|
|
- default_factory=None,
|
|
|
+ default=None,
|
|
|
description="竞品关键词分析结果"
|
|
|
)
|
|
|
|
|
|
@@ -55,7 +65,7 @@ class MarketingContentGeneration(BaseAIExecution):
|
|
|
"""营销内容生成结果"""
|
|
|
task_type: str = "marketing_generation"
|
|
|
result: Optional[MarketingInfo] = Field(
|
|
|
- default_factory=None,
|
|
|
+ default=None,
|
|
|
description="生成的营销内容结果"
|
|
|
)
|
|
|
|