|
@@ -1,3 +1,4 @@
|
|
|
|
|
+from abc import ABC, abstractmethod
|
|
|
import json
|
|
import json
|
|
|
from typing import Optional, Union
|
|
from typing import Optional, Union
|
|
|
from llama_index.core import PromptTemplate
|
|
from llama_index.core import PromptTemplate
|
|
@@ -20,29 +21,36 @@ from src.models.field_config import FieldConfig
|
|
|
load_dotenv()
|
|
load_dotenv()
|
|
|
logger = get_logger('ai')
|
|
logger = get_logger('ai')
|
|
|
|
|
|
|
|
-# 默认包含的字段配置
|
|
|
|
|
-DEFAULT_FIELD_CONFIG = FieldConfig(
|
|
|
|
|
- include_fields={
|
|
|
|
|
- "ProductImageInfo": {
|
|
|
|
|
- "main_text" # 产品图片主要文字
|
|
|
|
|
- },
|
|
|
|
|
- "TrafficKeywordResult": {
|
|
|
|
|
- "traffic_keyword", # 流量关键词名称
|
|
|
|
|
- "monthly_searches" # 关键词月搜索量
|
|
|
|
|
- },
|
|
|
|
|
- "ProductBaseInfo": {
|
|
|
|
|
- "name", "content", "material", "color", "size",
|
|
|
|
|
- "packaging_size", "weight", "main_usage", "selling_point"
|
|
|
|
|
- },
|
|
|
|
|
- "CompetitorCrawlData": {
|
|
|
|
|
- "asin",
|
|
|
|
|
|
|
+class ConfigManager:
|
|
|
|
|
+ _instance = None
|
|
|
|
|
+ _config = FieldConfig(
|
|
|
|
|
+ include_fields={
|
|
|
|
|
+ "ProductImageInfo": {"main_text"},
|
|
|
|
|
+ "TrafficKeywordResult": {"traffic_keyword", "monthly_searches"},
|
|
|
|
|
+ "ProductBaseInfo": {
|
|
|
|
|
+ "name", "content", "material", "color", "size",
|
|
|
|
|
+ "packaging_size", "weight", "main_usage", "selling_point"
|
|
|
|
|
+ },
|
|
|
|
|
+ "CompetitorCrawlData": {"asin"}
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
-)
|
|
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def __new__(cls):
|
|
|
|
|
+ if cls._instance is None:
|
|
|
|
|
+ cls._instance = super().__new__(cls)
|
|
|
|
|
+ return cls._instance
|
|
|
|
|
+
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def get_field_config(cls) -> FieldConfig:
|
|
|
|
|
+ return cls._config
|
|
|
|
|
+
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def update_config(cls, new_config: FieldConfig):
|
|
|
|
|
+ cls._config = new_config
|
|
|
|
|
|
|
|
def get_competitor_prompt_data(
|
|
def get_competitor_prompt_data(
|
|
|
product: Product,
|
|
product: Product,
|
|
|
- field_config: FieldConfig = DEFAULT_FIELD_CONFIG
|
|
|
|
|
|
|
+ field_config: FieldConfig = ConfigManager.get_field_config()
|
|
|
) -> list:
|
|
) -> list:
|
|
|
"""
|
|
"""
|
|
|
获取竞品提示数据
|
|
获取竞品提示数据
|
|
@@ -80,7 +88,7 @@ def get_competitor_prompt_data(
|
|
|
|
|
|
|
|
def get_field_descriptions(
|
|
def get_field_descriptions(
|
|
|
model_class: BaseModel,
|
|
model_class: BaseModel,
|
|
|
- field_config: FieldConfig = DEFAULT_FIELD_CONFIG,
|
|
|
|
|
|
|
+ field_config: FieldConfig = ConfigManager.get_field_config(),
|
|
|
model_name: Optional[str] = None
|
|
model_name: Optional[str] = None
|
|
|
) -> dict:
|
|
) -> dict:
|
|
|
"""
|
|
"""
|
|
@@ -96,34 +104,51 @@ def get_field_descriptions(
|
|
|
"""
|
|
"""
|
|
|
return field_config.get_model_fields(model_class, model_name)
|
|
return field_config.get_model_fields(model_class, model_name)
|
|
|
|
|
|
|
|
-def format_output(fields_desc: dict, format_type: str = "json", notes: Optional[dict] = None):
|
|
|
|
|
- """根据字段描述生成输出格式(支持嵌套字典结构)
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- fields_desc: 字段描述字典(支持嵌套)
|
|
|
|
|
- format_type: 输出格式类型(json/human)
|
|
|
|
|
- notes: 格式特定的额外说明信息字典,如 {"json": "JSON格式说明", "human": "人类可读说明"}
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 格式化后的输出模板
|
|
|
|
|
- """
|
|
|
|
|
- def process_dict(d, format_type):
|
|
|
|
|
- if format_type == "human":
|
|
|
|
|
|
|
+class Formatter(ABC):
|
|
|
|
|
+ """格式化器抽象基类"""
|
|
|
|
|
+ def __init__(self, notes: Optional[dict] = None):
|
|
|
|
|
+ self.notes = notes or {}
|
|
|
|
|
+
|
|
|
|
|
+ @abstractmethod
|
|
|
|
|
+ def format(self, fields_desc: dict) -> str:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+class JSONFormatter(Formatter):
|
|
|
|
|
+ """JSON格式处理器"""
|
|
|
|
|
+ def format(self, fields_desc: dict) -> str:
|
|
|
|
|
+ json_output = json.dumps(fields_desc, indent=2, ensure_ascii=False)
|
|
|
|
|
+ if self.notes.get('json'):
|
|
|
|
|
+ return f"```json\n{json_output}\n```\n{self.notes['json']}"
|
|
|
|
|
+ return f"```json\n{json_output}\n```"
|
|
|
|
|
+
|
|
|
|
|
+class HumanFormatter(Formatter):
|
|
|
|
|
+ """人类可读格式处理器"""
|
|
|
|
|
+ def format(self, fields_desc: dict) -> str:
|
|
|
|
|
+ def process_dict(d, indent=0):
|
|
|
lines = []
|
|
lines = []
|
|
|
for key, value in d.items():
|
|
for key, value in d.items():
|
|
|
|
|
+ prefix = " " * indent
|
|
|
if isinstance(value, dict):
|
|
if isinstance(value, dict):
|
|
|
- nested = process_dict(value, format_type)
|
|
|
|
|
- lines.append(f"{key}:\n{nested}")
|
|
|
|
|
|
|
+ lines.append(f"{prefix}{key}:")
|
|
|
|
|
+ lines.append(process_dict(value, indent + 2))
|
|
|
else:
|
|
else:
|
|
|
- lines.append(f"{value}: {{{key}}}")
|
|
|
|
|
|
|
+ lines.append(f"{prefix}{value}: {{{key}}}")
|
|
|
return "\n".join(lines)
|
|
return "\n".join(lines)
|
|
|
- else:
|
|
|
|
|
- return f"```json\n{json.dumps(d, indent=2, ensure_ascii=False)}\n```" if format_type == "json" else d
|
|
|
|
|
|
|
|
|
|
- result = process_dict(fields_desc, format_type)
|
|
|
|
|
- if notes and notes.get(format_type):
|
|
|
|
|
- result += f"\n{notes[format_type]}"
|
|
|
|
|
- return result
|
|
|
|
|
|
|
+ result = process_dict(fields_desc)
|
|
|
|
|
+ if self.notes.get('human'):
|
|
|
|
|
+ result += f"\n{self.notes['human']}"
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+class FormatterFactory:
|
|
|
|
|
+ """格式化器工厂类"""
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def create_formatter(format_type: str, notes: Optional[dict] = None) -> Formatter:
|
|
|
|
|
+ if format_type == "json":
|
|
|
|
|
+ return JSONFormatter(notes)
|
|
|
|
|
+ elif format_type == "human":
|
|
|
|
|
+ return HumanFormatter(notes)
|
|
|
|
|
+ raise ValueError(f"Unsupported format type: {format_type}")
|
|
|
|
|
|
|
|
async def test_product_mongo(main_key_num=3, format_type: str = "json"):
|
|
async def test_product_mongo(main_key_num=3, format_type: str = "json"):
|
|
|
db_mongo = BaseMongoManager()
|
|
db_mongo = BaseMongoManager()
|
|
@@ -184,94 +209,117 @@ async def test_product_mongo(main_key_num=3, format_type: str = "json"):
|
|
|
logger.info(formatted_output)
|
|
logger.info(formatted_output)
|
|
|
return formatted_output
|
|
return formatted_output
|
|
|
|
|
|
|
|
-async def analyze_with_llm(
|
|
|
|
|
- prompt: str,
|
|
|
|
|
- model: str = "openai/deepseek-chat",
|
|
|
|
|
- max_retries: int = 3,
|
|
|
|
|
- retry_delay: float = 1.0,
|
|
|
|
|
- format_type: str = "json"
|
|
|
|
|
-) -> Union[dict, str]:
|
|
|
|
|
- """使用LLM分析数据并返回结果
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- prompt: 完整的提示词
|
|
|
|
|
- model: 模型名称
|
|
|
|
|
- max_retries: 最大重试次数
|
|
|
|
|
- retry_delay: 重试延迟(秒)
|
|
|
|
|
- format_type: 输出格式类型(json/human)
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- dict: 当format_type为json时的解析结果
|
|
|
|
|
- str: 当format_type为human时的原始文本
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- ValueError: 当无法获取有效响应时
|
|
|
|
|
- """
|
|
|
|
|
- llm_kwargs = {}
|
|
|
|
|
- if format_type == "json":
|
|
|
|
|
- llm_kwargs["additional_kwargs"] = {"response_format": {"type": "json_object"}}
|
|
|
|
|
-
|
|
|
|
|
- llm = LiteLLM(model=model, **llm_kwargs)
|
|
|
|
|
-
|
|
|
|
|
- for attempt in range(max_retries):
|
|
|
|
|
- try:
|
|
|
|
|
- logger.info(f"尝试第 {attempt + 1} 次LLM调用...")
|
|
|
|
|
- completion = await llm.acomplete(prompt)
|
|
|
|
|
- response_text = completion.text
|
|
|
|
|
-
|
|
|
|
|
- if format_type == "json":
|
|
|
|
|
- # 尝试从Markdown代码块中提取JSON
|
|
|
|
|
- if "```json" in response_text:
|
|
|
|
|
- json_str = response_text.split("```json")[1].split("```")[0].strip()
|
|
|
|
|
- else:
|
|
|
|
|
- json_str = response_text
|
|
|
|
|
-
|
|
|
|
|
- result = json.loads(json_str)
|
|
|
|
|
-
|
|
|
|
|
- if not isinstance(result, dict):
|
|
|
|
|
- raise ValueError("响应不是有效的JSON对象")
|
|
|
|
|
-
|
|
|
|
|
- logger.debug(f"LLM响应验证通过: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
|
|
- return result
|
|
|
|
|
- else:
|
|
|
|
|
- # 直接返回原始文本
|
|
|
|
|
- return response_text
|
|
|
|
|
-
|
|
|
|
|
- except (json.JSONDecodeError, ValueError) as e:
|
|
|
|
|
- if format_type == "json":
|
|
|
|
|
- logger.warning(f"JSON解析失败(尝试 {attempt + 1}/{max_retries}): {str(e)}")
|
|
|
|
|
- if attempt < max_retries - 1:
|
|
|
|
|
- await asyncio.sleep(retry_delay)
|
|
|
|
|
- else:
|
|
|
|
|
- raise ValueError(f"无法获取有效的JSON响应: {str(e)}")
|
|
|
|
|
- else:
|
|
|
|
|
|
|
+class LLMService:
|
|
|
|
|
+ """LLM服务抽象类"""
|
|
|
|
|
+ @abstractmethod
|
|
|
|
|
+ async def analyze(self, prompt: str) -> Union[dict, str]:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+class LiteLLMService(LLMService):
|
|
|
|
|
+ """LiteLLM实现"""
|
|
|
|
|
+ def __init__(self, model: str = "openai/deepseek-chat", max_retries: int = 3,
|
|
|
|
|
+ retry_delay: float = 1.0, format_type: str = "json"):
|
|
|
|
|
+ self.model = model
|
|
|
|
|
+ self.max_retries = max_retries
|
|
|
|
|
+ self.retry_delay = retry_delay
|
|
|
|
|
+ self.format_type = format_type
|
|
|
|
|
+
|
|
|
|
|
+ async def analyze(self, prompt: str) -> Union[dict, str]:
|
|
|
|
|
+ llm_kwargs = {}
|
|
|
|
|
+ if self.format_type == "json":
|
|
|
|
|
+ llm_kwargs["additional_kwargs"] = {"response_format": {"type": "json_object"}}
|
|
|
|
|
+
|
|
|
|
|
+ llm = LiteLLM(model=self.model, **llm_kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ for attempt in range(self.max_retries):
|
|
|
|
|
+ try:
|
|
|
|
|
+ logger.info(f"尝试第 {attempt + 1} 次LLM调用...")
|
|
|
|
|
+ completion = await llm.acomplete(prompt)
|
|
|
|
|
+ return self._process_response(completion.text)
|
|
|
|
|
+ except (json.JSONDecodeError, ValueError) as e:
|
|
|
|
|
+ if self.format_type == "json":
|
|
|
|
|
+ logger.warning(f"JSON解析失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}")
|
|
|
|
|
+ if attempt < self.max_retries - 1:
|
|
|
|
|
+ await asyncio.sleep(self.retry_delay)
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"无法获取有效的JSON响应: {str(e)}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"LLM调用失败: {str(e)}")
|
|
|
raise
|
|
raise
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"LLM调用失败: {str(e)}")
|
|
|
|
|
- raise
|
|
|
|
|
|
|
+
|
|
|
|
|
+ def _process_response(self, response_text: str) -> Union[dict, str]:
|
|
|
|
|
+ if self.format_type == "json":
|
|
|
|
|
+ if "```json" in response_text:
|
|
|
|
|
+ json_str = response_text.split("```json")[1].split("```")[0].strip()
|
|
|
|
|
+ else:
|
|
|
|
|
+ json_str = response_text
|
|
|
|
|
+
|
|
|
|
|
+ result = json.loads(json_str)
|
|
|
|
|
+ if not isinstance(result, dict):
|
|
|
|
|
+ raise ValueError("响应不是有效的JSON对象")
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"LLM响应验证通过: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
|
|
+ return result
|
|
|
|
|
+ return response_text
|
|
|
|
|
+
|
|
|
|
|
+class AnalysisService:
|
|
|
|
|
+ """分析领域服务"""
|
|
|
|
|
+ def __init__(self, llm_service: LLMService, db_manager: BaseMongoManager):
|
|
|
|
|
+ self.llm_service = llm_service
|
|
|
|
|
+ self.db_manager = db_manager
|
|
|
|
|
+
|
|
|
|
|
+ async def execute_analysis(self, product_name: str, format_type: str = "json") -> Union[dict, str]:
|
|
|
|
|
+ await self.db_manager.initialize()
|
|
|
|
|
+ product = await Product.find_one(Product.basic_info.name == product_name)
|
|
|
|
|
+ if not product:
|
|
|
|
|
+ raise ValueError(f"未找到产品: {product_name}")
|
|
|
|
|
+
|
|
|
|
|
+ prompt = await self._prepare_prompt(product, format_type)
|
|
|
|
|
+ return await self.llm_service.analyze(prompt)
|
|
|
|
|
+
|
|
|
|
|
+ async def _prepare_prompt(self, product: Product, format_type: str) -> str:
|
|
|
|
|
+ competitor_data = get_competitor_prompt_data(product)
|
|
|
|
|
+ output_fields = {
|
|
|
|
|
+ "results": {
|
|
|
|
|
+ "asin": "商品(竞品)编号",
|
|
|
|
|
+ "main_key": "主要关键词",
|
|
|
|
|
+ "monthly_searches": "月搜索量",
|
|
|
|
|
+ "reason": "分析理由"
|
|
|
|
|
+ },
|
|
|
|
|
+ "supplement": "补充说明"
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ formatter = FormatterFactory.create_formatter(format_type)
|
|
|
|
|
+ output_format = formatter.format(output_fields)
|
|
|
|
|
+
|
|
|
|
|
+ return f'''各个字段说明:
|
|
|
|
|
+{get_field_descriptions(CompetitorCrawlData)}
|
|
|
|
|
+{get_field_descriptions(ProductImageInfo)}
|
|
|
|
|
+{get_field_descriptions(TrafficKeywordResult)}
|
|
|
|
|
+
|
|
|
|
|
+竞品数据:
|
|
|
|
|
+{competitor_data}
|
|
|
|
|
+----
|
|
|
|
|
+我是日本站的亚马逊运营,正在为产品 {product.basic_info.name} 选主要关键词。
|
|
|
|
|
+请根据以上 {len(competitor_data)} 个竞品数据,按以下规则分析:
|
|
|
|
|
+1. 优先选择搜索量1万以上的相同关键词
|
|
|
|
|
+2. 不足时选择搜索量前十且相关性强的关键词
|
|
|
|
|
+3. 结合日本市场特点分析
|
|
|
|
|
+
|
|
|
|
|
+输出格式:
|
|
|
|
|
+{output_format}'''
|
|
|
|
|
|
|
|
async def main():
|
|
async def main():
|
|
|
logger.info(f"base url {os.environ.get('OPENAI_API_BASE')}")
|
|
logger.info(f"base url {os.environ.get('OPENAI_API_BASE')}")
|
|
|
- format_type = 'human' # 可以从配置或参数获取
|
|
|
|
|
- format_type = 'json'
|
|
|
|
|
- analyze_competitor_main_keyword_prompt = await test_product_mongo(format_type=format_type)
|
|
|
|
|
|
|
+ db_manager = BaseMongoManager()
|
|
|
|
|
+ llm_service = LiteLLMService(format_type='json')
|
|
|
|
|
+ analysis_service = AnalysisService(llm_service, db_manager)
|
|
|
|
|
+
|
|
|
try:
|
|
try:
|
|
|
- result = await analyze_with_llm(
|
|
|
|
|
- analyze_competitor_main_keyword_prompt,
|
|
|
|
|
- format_type=format_type
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if format_type == "json":
|
|
|
|
|
- logger.info(f"分析结果: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
|
|
- else:
|
|
|
|
|
- logger.info(f"分析结果:\n{result}")
|
|
|
|
|
-
|
|
|
|
|
- # 这里可以添加结果处理逻辑
|
|
|
|
|
- # 例如保存到数据库或进一步处理
|
|
|
|
|
-
|
|
|
|
|
|
|
+ result = await analysis_service.execute_analysis("电线保护套")
|
|
|
|
|
+ logger.info(f"分析结果: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
except ValueError as e:
|
|
except ValueError as e:
|
|
|
logger.error(f"分析失败: {str(e)}")
|
|
logger.error(f"分析失败: {str(e)}")
|
|
|
- # 可以添加失败处理逻辑
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
asyncio.run(main())
|
|
asyncio.run(main())
|