|
|
@@ -1,408 +0,0 @@
|
|
|
-from abc import ABC, abstractmethod
|
|
|
-import json
|
|
|
-from typing import Optional, Set, Union
|
|
|
-from llama_index.core import PromptTemplate
|
|
|
-import asyncio
|
|
|
-import aiofiles
|
|
|
-import os
|
|
|
-import sys
|
|
|
-from dotenv import load_dotenv
|
|
|
-from pydantic import BaseModel
|
|
|
-from src.models.product_model import (
|
|
|
- AIAnalyzeCompare, Product, CompetitorCrawlData, AICompetitorAnalyzeMainKeywords,
|
|
|
- TrafficKeywordResult, ProductImageInfo,AICompetitorAnalyzeMainKeywordsResult,
|
|
|
- SearchAmazoneKeyResult, ProductBaseInfo, Variant,MarketingInfo,
|
|
|
-)
|
|
|
-from src.models.config_model import (UserConfig, AIPromptConfig, )
|
|
|
-from llama_index.llms.openai import OpenAI
|
|
|
-from llama_index.llms.litellm import LiteLLM
|
|
|
-from src.manager.core.db_mongo import BaseMongoManager
|
|
|
-from utils.logu import get_logger
|
|
|
-from src.models.field_config import FieldConfig,get_field_descriptions
|
|
|
-load_dotenv()
|
|
|
-logger = get_logger('ai')
|
|
|
-
|
|
|
-class ConfigManager:
|
|
|
- _instance = None
|
|
|
- _config = FieldConfig(
|
|
|
- include_fields={
|
|
|
- "ProductImageInfo": {"main_text"},
|
|
|
- "TrafficKeywordResult": {"traffic_keyword", "monthly_searches"},
|
|
|
- "ProductBaseInfo": {
|
|
|
- "name", "content", "material", "color", "size",
|
|
|
- "packaging_size", "weight", "main_usage", "selling_point"
|
|
|
- },
|
|
|
- "CompetitorCrawlData": {"asin"}
|
|
|
- }
|
|
|
- )
|
|
|
-
|
|
|
- def __new__(cls):
|
|
|
- if cls._instance is None:
|
|
|
- cls._instance = super().__new__(cls)
|
|
|
- return cls._instance
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def get_field_config(cls) -> FieldConfig:
|
|
|
- return cls._config
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def update_config(cls, new_config: FieldConfig):
|
|
|
- cls._config = new_config
|
|
|
-
|
|
|
-def get_competitor_prompt_data(
|
|
|
- product: Product,
|
|
|
- field_config: FieldConfig = ConfigManager.get_field_config()
|
|
|
-) -> list:
|
|
|
- """
|
|
|
- 获取竞品提示数据
|
|
|
-
|
|
|
- Args:
|
|
|
- product: 产品对象
|
|
|
- field_config: 字段配置
|
|
|
-
|
|
|
- Returns:
|
|
|
- 结构化竞品数据列表
|
|
|
- """
|
|
|
- competitor_crawl_data = product.competitor_crawl_data
|
|
|
- list_data = []
|
|
|
-
|
|
|
- for asin, crawl_data in competitor_crawl_data.items():
|
|
|
- if crawl_data.extra_result:
|
|
|
- structured_result = {"asin": asin}
|
|
|
-
|
|
|
- if crawl_data.extra_result.product_info:
|
|
|
- structured_result["product_info"] = field_config.filter_model_dump(
|
|
|
- crawl_data.extra_result.product_info,
|
|
|
- "ProductImageInfo"
|
|
|
- )
|
|
|
-
|
|
|
- if crawl_data.extra_result.result_table:
|
|
|
- structured_result["result_table"] = [
|
|
|
- field_config.filter_model_dump(item, "TrafficKeywordResult")
|
|
|
- for item in crawl_data.extra_result.result_table
|
|
|
- ]
|
|
|
-
|
|
|
- logger.debug(f"Structured result for LLM: {json.dumps(structured_result, indent=4, ensure_ascii=False)}")
|
|
|
- list_data.append(structured_result)
|
|
|
-
|
|
|
- return list_data
|
|
|
-
|
|
|
-class PromptFormatter:
|
|
|
- """LLM提示词模板格式化器"""
|
|
|
- def __init__(self, template: str, **kwargs):
|
|
|
- self.template = template
|
|
|
- self.kwargs = kwargs
|
|
|
- self.partial_kwargs = {}
|
|
|
- self.var_mappings = {}
|
|
|
- self.function_mappings = {}
|
|
|
-
|
|
|
- def partial_format(self, **kwargs) -> "PromptFormatter":
|
|
|
- """部分格式化模板"""
|
|
|
- self.partial_kwargs.update(kwargs)
|
|
|
- return self
|
|
|
-
|
|
|
- def map_variables(self, **mappings) -> "PromptFormatter":
|
|
|
- """映射模板变量名"""
|
|
|
- self.var_mappings.update(mappings)
|
|
|
- return self
|
|
|
-
|
|
|
- def map_functions(self, **functions) -> "PromptFormatter":
|
|
|
- """映射模板处理函数"""
|
|
|
- self.function_mappings.update(functions)
|
|
|
- return self
|
|
|
-
|
|
|
- def format(self, **kwargs) -> str:
|
|
|
- """最终格式化提示词"""
|
|
|
- # 合并所有参数
|
|
|
- all_kwargs = {**self.partial_kwargs, **kwargs}
|
|
|
-
|
|
|
- # 应用变量名映射
|
|
|
- mapped_kwargs = {}
|
|
|
- for key, value in all_kwargs.items():
|
|
|
- mapped_key = self.var_mappings.get(key, key)
|
|
|
- mapped_kwargs[mapped_key] = value
|
|
|
-
|
|
|
- # 应用函数处理
|
|
|
- for key, func in self.function_mappings.items():
|
|
|
- if key in mapped_kwargs:
|
|
|
- mapped_kwargs[key] = func(**mapped_kwargs)
|
|
|
-
|
|
|
- return self.template.format(**mapped_kwargs)
|
|
|
-
|
|
|
-class Formatter(ABC):
|
|
|
- """格式化器抽象基类"""
|
|
|
- def __init__(self, notes: Optional[dict] = None):
|
|
|
- self.notes = notes or {}
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- def format(self, fields_desc: dict) -> str:
|
|
|
- pass
|
|
|
-
|
|
|
-class JSONFormatter(Formatter):
|
|
|
- """JSON格式处理器"""
|
|
|
- def format(self, fields_desc: dict) -> str:
|
|
|
- json_output = json.dumps(fields_desc, indent=2, ensure_ascii=False)
|
|
|
- if self.notes.get('json'):
|
|
|
- return f"```json\n{json_output}\n```\n{self.notes['json']}"
|
|
|
- return f"```json\n{json_output}\n```"
|
|
|
-
|
|
|
-class HumanFormatter(Formatter):
|
|
|
- """人类可读格式处理器"""
|
|
|
- def format(self, fields_desc: dict) -> str:
|
|
|
- def process_dict(d, indent=0):
|
|
|
- lines = []
|
|
|
- for key, value in d.items():
|
|
|
- prefix = " " * indent
|
|
|
- if isinstance(value, dict):
|
|
|
- lines.append(f"{prefix}{key}:")
|
|
|
- lines.append(process_dict(value, indent + 2))
|
|
|
- else:
|
|
|
- lines.append(f"{prefix}{value}: {{{key}}}")
|
|
|
- return "\n".join(lines)
|
|
|
-
|
|
|
- result = process_dict(fields_desc)
|
|
|
- if self.notes.get('human'):
|
|
|
- result += f"\n{self.notes['human']}"
|
|
|
- return result
|
|
|
-
|
|
|
-class FormatterFactory:
|
|
|
- """格式化器工厂类"""
|
|
|
- @staticmethod
|
|
|
- def create_formatter(format_type: str, notes: Optional[dict] = None) -> Formatter:
|
|
|
- if format_type == "json":
|
|
|
- return JSONFormatter(notes)
|
|
|
- elif format_type == "human":
|
|
|
- return HumanFormatter(notes)
|
|
|
- raise ValueError(f"Unsupported format type: {format_type}")
|
|
|
-
|
|
|
-
|
|
|
-class LLMService:
|
|
|
- """LLM服务抽象类"""
|
|
|
- @abstractmethod
|
|
|
- async def analyze(self, prompt: str) -> Union[dict, str]:
|
|
|
- pass
|
|
|
-
|
|
|
-class LiteLLMService(LLMService):
|
|
|
- """LiteLLM实现"""
|
|
|
- def __init__(self, model: str = "openai/deepseek-chat", max_retries: int = 3,
|
|
|
- retry_delay: float = 1.0, format_type: str = "json"):
|
|
|
- self.model = model
|
|
|
- self.max_retries = max_retries
|
|
|
- self.retry_delay = retry_delay
|
|
|
- self.format_type = format_type
|
|
|
-
|
|
|
- async def analyze(self, prompt: str) -> Union[dict, str]:
|
|
|
- llm_kwargs = {}
|
|
|
- if self.format_type == "json":
|
|
|
- # if 'deepseek-r' not in self.model:
|
|
|
- # llm_kwargs["additional_kwargs"] = {"response_format": {"type": "json_object"}}
|
|
|
- prompt += "\n请确保输出的是有效的JSON对象。"
|
|
|
- logger.info(f"{self.model} 调用参数: {llm_kwargs}")
|
|
|
- llm = LiteLLM(model=self.model, **llm_kwargs)
|
|
|
-
|
|
|
- for attempt in range(self.max_retries):
|
|
|
- try:
|
|
|
- logger.info(f"尝试第 {attempt + 1} 次LLM调用...")
|
|
|
- completion = await llm.acomplete(prompt)
|
|
|
- return self._process_response(completion.text)
|
|
|
- except (json.JSONDecodeError, ValueError) as e:
|
|
|
- if self.format_type == "json":
|
|
|
- logger.warning(f"JSON解析失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}")
|
|
|
- if attempt < self.max_retries - 1:
|
|
|
- await asyncio.sleep(self.retry_delay)
|
|
|
- else:
|
|
|
- raise ValueError(f"无法获取有效的JSON响应: {str(e)}")
|
|
|
- except Exception as e:
|
|
|
- logger.exception(f"LLM调用失败: {str(e)}")
|
|
|
- raise
|
|
|
-
|
|
|
- def _process_response(self, response_text: str) -> Union[dict, str]:
|
|
|
- if self.format_type == "json":
|
|
|
- if "```json" in response_text:
|
|
|
- json_str = response_text.split("```json")[1].split("```")[0].strip()
|
|
|
- else:
|
|
|
- json_str = response_text
|
|
|
-
|
|
|
- result = json.loads(json_str)
|
|
|
- if not isinstance(result, dict):
|
|
|
- raise ValueError("响应不是有效的JSON对象")
|
|
|
-
|
|
|
- logger.debug(f"LLM响应验证通过: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
- return result
|
|
|
- return response_text
|
|
|
-
|
|
|
-class AnalysisService:
|
|
|
- """分析领域服务"""
|
|
|
- def __init__(self, llm_service: LLMService, db_manager: BaseMongoManager):
|
|
|
- self.llm_service:LiteLLMService = llm_service
|
|
|
- self.db_manager = db_manager
|
|
|
-
|
|
|
- async def execute_analysis(self, product:Product, format_type: str = "json", dry_run=False, template: Optional[AIPromptConfig] = None) -> tuple[dict, str]:
|
|
|
- # if template:
|
|
|
- # formatter = PromptFormatter(template.template)
|
|
|
- # if template.keywords:
|
|
|
- # formatter.partial_format(**template.keywords)
|
|
|
- # prompt = formatter.format(
|
|
|
- # product=product,
|
|
|
- # format_type=format_type
|
|
|
- # )
|
|
|
- # else:
|
|
|
- # prompt = await self._prepare_prompt(product, format_type)
|
|
|
-
|
|
|
- # logger.info(f"prompt: {prompt}")
|
|
|
- # analysis_result = await self.llm_service.analyze(prompt)
|
|
|
- # return analysis_result, prompt
|
|
|
- pass
|
|
|
-
|
|
|
- async def execute_marketing_analysis(self, product: Product, format_type: str = "json", template: Optional[AIPromptConfig] = None) -> tuple[MarketingInfo, str]:
|
|
|
- """
|
|
|
- 执行营销文案分析
|
|
|
-
|
|
|
- Args:
|
|
|
- product: 产品对象
|
|
|
- format_type: 输出格式
|
|
|
- template: 自定义提示模板
|
|
|
-
|
|
|
- Returns:
|
|
|
- (分析结果, 使用的提示词)
|
|
|
- """
|
|
|
- if template:
|
|
|
- formatter = PromptFormatter(template.template)
|
|
|
- if template.keywords:
|
|
|
- formatter.partial_format(**template.keywords)
|
|
|
- prompt = formatter.format(
|
|
|
- product=product,
|
|
|
- format_type=format_type
|
|
|
- )
|
|
|
- else:
|
|
|
- prompt = f'''我是亚马逊运营,请为产品 {product.basic_info.name} 生成营销文案。
|
|
|
-
|
|
|
-产品信息:
|
|
|
-{product.basic_info.model_dump_json(indent=2)}
|
|
|
-
|
|
|
-要求:
|
|
|
-- 突出产品卖点: {', '.join(product.basic_info.selling_point)}
|
|
|
-- 适合日本市场风格
|
|
|
-- 包含吸引人的标题和详细描述'''
|
|
|
-
|
|
|
- logger.info(f"营销分析提示词: {prompt}")
|
|
|
- analysis_result = await self.llm_service.analyze(prompt)
|
|
|
-
|
|
|
- try:
|
|
|
- marketing_info = MarketingInfo(**analysis_result)
|
|
|
- return marketing_info, prompt
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"营销分析结果解析失败: {str(e)}")
|
|
|
- raise ValueError("营销分析结果格式不正确") from e
|
|
|
-
|
|
|
- async def _prepare_competitor_prompt(self, product: Product, template: AIPromptConfig) -> str:
|
|
|
- """使用llmaindex模板方法格式化提示词
|
|
|
-
|
|
|
- Args:
|
|
|
- product: 产品对象
|
|
|
- template: 提示模板配置
|
|
|
-
|
|
|
- Returns:
|
|
|
- 格式化后的提示词字符串
|
|
|
- """
|
|
|
- output_fields = get_field_descriptions(
|
|
|
- AICompetitorAnalyzeMainKeywordsResult,
|
|
|
- exclude=['results.crawl_result', 'results.created_at']
|
|
|
- )
|
|
|
- formatter = FormatterFactory.create_formatter(self.llm_service.format_type)
|
|
|
- output_format = formatter.format(output_fields)
|
|
|
-
|
|
|
- competitor_data = get_competitor_prompt_data(product)
|
|
|
- basic_template =f'''各个字段说明:
|
|
|
-{get_field_descriptions(CompetitorCrawlData, include=['asin'])}
|
|
|
-{get_field_descriptions(ProductImageInfo, include=['main_text'])}
|
|
|
-{get_field_descriptions(TrafficKeywordResult, include=['traffic_keyword', 'monthly_searches'])}
|
|
|
-
|
|
|
-竞品数据:
|
|
|
-{competitor_data}
|
|
|
-
|
|
|
-我的产品信息如下:
|
|
|
-{product.basic_info.model_dump_json(indent=2)}
|
|
|
-
|
|
|
-返回格式:
|
|
|
-{output_format}
|
|
|
-----
|
|
|
-'''
|
|
|
- template.template = basic_template + template.template
|
|
|
- return template.template
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def convert_monthly_searches(value):
|
|
|
- if value is None:
|
|
|
- return None
|
|
|
- if isinstance(value, str):
|
|
|
- if not value.strip():
|
|
|
- return None
|
|
|
- return int(value.replace(',', ''))
|
|
|
- return value
|
|
|
- async def run_competitor_analysis(self, product: Product,
|
|
|
- ai_analyze_compare_model: AIAnalyzeCompare,
|
|
|
- format_type: str = 'json',):
|
|
|
- prompt = await self._prepare_competitor_prompt(product, ai_analyze_compare_model.competitor_template)
|
|
|
- logger.info(f"_prepare_competitor_prompt {prompt}")
|
|
|
- analyze_result = await self.llm_service.analyze(prompt)
|
|
|
- if 'results' in analyze_result:
|
|
|
- for result in analyze_result['results']:
|
|
|
- if 'monthly_searches' in result:
|
|
|
- result['monthly_searches'] = self.convert_monthly_searches(result['monthly_searches'])
|
|
|
-
|
|
|
- if 'tail_keys' in analyze_result:
|
|
|
- for tail_key in analyze_result['tail_keys']:
|
|
|
- if 'monthly_searches' in tail_key:
|
|
|
- tail_key['monthly_searches'] = self.convert_monthly_searches(tail_key['monthly_searches'])
|
|
|
- return analyze_result
|
|
|
- async def _prepare_prompt(self, product: Product, format_type: str = "json", main_key_num: int = 3, tail_key_num:int = 12) -> str:
|
|
|
- competitor_data = get_competitor_prompt_data(product)
|
|
|
- # 从数据模型获取输出字段描述
|
|
|
- output_fields = get_field_descriptions(
|
|
|
- AICompetitorAnalyzeMainKeywordsResult,
|
|
|
- exclude=['results.crawl_result', 'results.created_at']
|
|
|
- )
|
|
|
- formatter = FormatterFactory.create_formatter(format_type)
|
|
|
- output_format = formatter.format(output_fields)
|
|
|
-
|
|
|
- return f'''各个字段说明:
|
|
|
-{get_field_descriptions(CompetitorCrawlData, include=['asin'])}
|
|
|
-{get_field_descriptions(ProductImageInfo, include=['main_text'])}
|
|
|
-{get_field_descriptions(TrafficKeywordResult, include=['traffic_keyword', 'monthly_searches'])}
|
|
|
-
|
|
|
-竞品数据:
|
|
|
-{competitor_data}
|
|
|
-
|
|
|
-我的产品信息如下:
|
|
|
-{product.basic_info.model_dump_json(indent=2)}
|
|
|
-----
|
|
|
-我是日本站的亚马逊运营,我在给产品名称为 {product.basic_info.name} 选主要关键词和长尾关键词。
|
|
|
-
|
|
|
-请根据以上 {len(competitor_data)} 个竞品数据,按以下规则分析:
|
|
|
-- 选出搜索量在1万以上的相同关键词作为主要关键词{main_key_num}个。
|
|
|
-- 如果竞品的搜索量都不足1万,则从排名前十的关键词中筛选 {main_key_num} 个搜索量最大且相关性最强的词。
|
|
|
-- 结合日本市场特点分析
|
|
|
-- 根据我的产品基本信息,从竞品的主要信息和同类竞品的相似关键词中,筛选出最符合我产品的长尾关键词 tail_keys {tail_key_num} 个以上
|
|
|
-
|
|
|
-筛选长尾词的示例:
|
|
|
-- 假设我的产品是电线保护,那么竞品关键词中,“隐藏排线管” 就不符合长尾词
|
|
|
-- 假设我的产品是“防老化、防动物咬”用途,你就不能在竞品数据中选择不属于我这个使用场景的长尾关键词。
|
|
|
-
|
|
|
-输出格式:
|
|
|
-{output_format}'''
|
|
|
-
|
|
|
-async def main():
|
|
|
- logger.info(f"base url {os.environ.get('OPENAI_API_BASE')}")
|
|
|
- db_manager = BaseMongoManager()
|
|
|
- llm_service = LiteLLMService(format_type='json')
|
|
|
- analysis_service = AnalysisService(llm_service, db_manager)
|
|
|
-
|
|
|
- try:
|
|
|
- result = await analysis_service.execute_analysis("电线保护套")
|
|
|
- logger.info(f"分析结果: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
- except ValueError as e:
|
|
|
- logger.error(f"分析失败: {str(e)}")
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
- asyncio.run(main())
|