فهرست منبع

完成AI分析关键词模块化可扩展性

mrh 8 ماه پیش
والد
کامیت
b9e74a95f7
2فایلهای تغییر یافته به همراه172 افزوده شده و 124 حذف شده
  1. 1 1
      .clinerules
  2. 171 123
      src/ai/agent_product.py

+ 1 - 1
.clinerules

@@ -1,3 +1,3 @@
 重要:
-- 编码遵循高内聚、低耦合、可扩展,符合最佳程序设计,符合最佳实践。
+- 编码遵循模块化、高内聚、低耦合、可扩展,符合最佳程序设计,符合最佳实践。
 - 不要硬编码,必须要通用可复用、可扩展。

+ 171 - 123
src/ai/agent_product.py

@@ -1,3 +1,4 @@
+from abc import ABC, abstractmethod
 import json
 from typing import Optional, Union
 from llama_index.core import PromptTemplate
@@ -20,29 +21,36 @@ from src.models.field_config import FieldConfig
 load_dotenv()
 logger = get_logger('ai')
 
-# 默认包含的字段配置
-DEFAULT_FIELD_CONFIG = FieldConfig(
-    include_fields={
-        "ProductImageInfo": {
-            "main_text"  # 产品图片主要文字
-        },
-        "TrafficKeywordResult": {
-            "traffic_keyword",  # 流量关键词名称
-            "monthly_searches"  # 关键词月搜索量
-        },
-        "ProductBaseInfo": {
-            "name", "content", "material", "color", "size",
-            "packaging_size", "weight", "main_usage", "selling_point"
-        },
-        "CompetitorCrawlData": {
-            "asin",
+class ConfigManager:
+    _instance = None
+    _config = FieldConfig(
+        include_fields={
+            "ProductImageInfo": {"main_text"},
+            "TrafficKeywordResult": {"traffic_keyword", "monthly_searches"},
+            "ProductBaseInfo": {
+                "name", "content", "material", "color", "size",
+                "packaging_size", "weight", "main_usage", "selling_point"
+            },
+            "CompetitorCrawlData": {"asin"}
         }
-    }
-)
+    )
+
+    def __new__(cls):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    @classmethod
+    def get_field_config(cls) -> FieldConfig:
+        return cls._config
+
+    @classmethod
+    def update_config(cls, new_config: FieldConfig):
+        cls._config = new_config
 
 def get_competitor_prompt_data(
     product: Product,
-    field_config: FieldConfig = DEFAULT_FIELD_CONFIG
+    field_config: FieldConfig = ConfigManager.get_field_config()
 ) -> list:
     """
     获取竞品提示数据
@@ -80,7 +88,7 @@ def get_competitor_prompt_data(
 
 def get_field_descriptions(
     model_class: BaseModel,
-    field_config: FieldConfig = DEFAULT_FIELD_CONFIG,
+    field_config: FieldConfig = ConfigManager.get_field_config(),
     model_name: Optional[str] = None
 ) -> dict:
     """
@@ -96,34 +104,51 @@ def get_field_descriptions(
     """
     return field_config.get_model_fields(model_class, model_name)
 
-def format_output(fields_desc: dict, format_type: str = "json", notes: Optional[dict] = None):
-    """根据字段描述生成输出格式(支持嵌套字典结构)
-    
-    Args:
-        fields_desc: 字段描述字典(支持嵌套)
-        format_type: 输出格式类型(json/human)
-        notes: 格式特定的额外说明信息字典,如 {"json": "JSON格式说明", "human": "人类可读说明"}
-        
-    Returns:
-        格式化后的输出模板
-    """
-    def process_dict(d, format_type):
-        if format_type == "human":
+class Formatter(ABC):
+    """格式化器抽象基类"""
+    def __init__(self, notes: Optional[dict] = None):
+        self.notes = notes or {}
+
+    @abstractmethod
+    def format(self, fields_desc: dict) -> str:
+        pass
+
+class JSONFormatter(Formatter):
+    """JSON格式处理器"""
+    def format(self, fields_desc: dict) -> str:
+        json_output = json.dumps(fields_desc, indent=2, ensure_ascii=False)
+        if self.notes.get('json'):
+            return f"```json\n{json_output}\n```\n{self.notes['json']}"
+        return f"```json\n{json_output}\n```"
+
+class HumanFormatter(Formatter):
+    """人类可读格式处理器"""
+    def format(self, fields_desc: dict) -> str:
+        def process_dict(d, indent=0):
             lines = []
             for key, value in d.items():
+                prefix = " " * indent
                 if isinstance(value, dict):
-                    nested = process_dict(value, format_type)
-                    lines.append(f"{key}:\n{nested}")
+                    lines.append(f"{prefix}{key}:")
+                    lines.append(process_dict(value, indent + 2))
                 else:
-                    lines.append(f"{value}: {{{key}}}")
+                    lines.append(f"{prefix}{value}: {{{key}}}")
             return "\n".join(lines)
-        else:
-            return f"```json\n{json.dumps(d, indent=2, ensure_ascii=False)}\n```" if format_type == "json" else d
 
-    result = process_dict(fields_desc, format_type)
-    if notes and notes.get(format_type):
-        result += f"\n{notes[format_type]}"
-    return result
+        result = process_dict(fields_desc)
+        if self.notes.get('human'):
+            result += f"\n{self.notes['human']}"
+        return result
+
+class FormatterFactory:
+    """格式化器工厂类"""
+    @staticmethod
+    def create_formatter(format_type: str, notes: Optional[dict] = None) -> Formatter:
+        if format_type == "json":
+            return JSONFormatter(notes)
+        elif format_type == "human":
+            return HumanFormatter(notes)
+        raise ValueError(f"Unsupported format type: {format_type}")
 
 async def test_product_mongo(main_key_num=3, format_type: str = "json"):
     db_mongo = BaseMongoManager()
@@ -184,94 +209,117 @@ async def test_product_mongo(main_key_num=3, format_type: str = "json"):
     logger.info(formatted_output)
     return formatted_output
 
-async def analyze_with_llm(
-    prompt: str,
-    model: str = "openai/deepseek-chat",
-    max_retries: int = 3,
-    retry_delay: float = 1.0,
-    format_type: str = "json"
-) -> Union[dict, str]:
-    """使用LLM分析数据并返回结果
-    
-    Args:
-        prompt: 完整的提示词
-        model: 模型名称
-        max_retries: 最大重试次数
-        retry_delay: 重试延迟(秒)
-        format_type: 输出格式类型(json/human)
-        
-    Returns:
-        dict: 当format_type为json时的解析结果
-        str: 当format_type为human时的原始文本
-        
-    Raises:
-        ValueError: 当无法获取有效响应时
-    """
-    llm_kwargs = {}
-    if format_type == "json":
-        llm_kwargs["additional_kwargs"] = {"response_format": {"type": "json_object"}}
-    
-    llm = LiteLLM(model=model, **llm_kwargs)
-    
-    for attempt in range(max_retries):
-        try:
-            logger.info(f"尝试第 {attempt + 1} 次LLM调用...")
-            completion = await llm.acomplete(prompt)
-            response_text = completion.text
-            
-            if format_type == "json":
-                # 尝试从Markdown代码块中提取JSON
-                if "```json" in response_text:
-                    json_str = response_text.split("```json")[1].split("```")[0].strip()
-                else:
-                    json_str = response_text
-                    
-                result = json.loads(json_str)
-                
-                if not isinstance(result, dict):
-                    raise ValueError("响应不是有效的JSON对象")
-                    
-                logger.debug(f"LLM响应验证通过: {json.dumps(result, indent=2, ensure_ascii=False)}")
-                return result
-            else:
-                # 直接返回原始文本
-                return response_text
-                
-        except (json.JSONDecodeError, ValueError) as e:
-            if format_type == "json":
-                logger.warning(f"JSON解析失败(尝试 {attempt + 1}/{max_retries}): {str(e)}")
-                if attempt < max_retries - 1:
-                    await asyncio.sleep(retry_delay)
-                else:
-                    raise ValueError(f"无法获取有效的JSON响应: {str(e)}")
-            else:
+class LLMService:
+    """LLM服务抽象类"""
+    @abstractmethod
+    async def analyze(self, prompt: str) -> Union[dict, str]:
+        pass
+
+class LiteLLMService(LLMService):
+    """LiteLLM实现"""
+    def __init__(self, model: str = "openai/deepseek-chat", max_retries: int = 3,
+                 retry_delay: float = 1.0, format_type: str = "json"):
+        self.model = model
+        self.max_retries = max_retries
+        self.retry_delay = retry_delay
+        self.format_type = format_type
+
+    async def analyze(self, prompt: str) -> Union[dict, str]:
+        llm_kwargs = {}
+        if self.format_type == "json":
+            llm_kwargs["additional_kwargs"] = {"response_format": {"type": "json_object"}}
+
+        llm = LiteLLM(model=self.model, **llm_kwargs)
+
+        for attempt in range(self.max_retries):
+            try:
+                logger.info(f"尝试第 {attempt + 1} 次LLM调用...")
+                completion = await llm.acomplete(prompt)
+                return self._process_response(completion.text)
+            except (json.JSONDecodeError, ValueError) as e:
+                if self.format_type == "json":
+                    logger.warning(f"JSON解析失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}")
+                    if attempt < self.max_retries - 1:
+                        await asyncio.sleep(self.retry_delay)
+                    else:
+                        raise ValueError(f"无法获取有效的JSON响应: {str(e)}")
+            except Exception as e:
+                logger.error(f"LLM调用失败: {str(e)}")
                 raise
-        except Exception as e:
-            logger.error(f"LLM调用失败: {str(e)}")
-            raise
+
+    def _process_response(self, response_text: str) -> Union[dict, str]:
+        if self.format_type == "json":
+            if "```json" in response_text:
+                json_str = response_text.split("```json")[1].split("```")[0].strip()
+            else:
+                json_str = response_text
+
+            result = json.loads(json_str)
+            if not isinstance(result, dict):
+                raise ValueError("响应不是有效的JSON对象")
+
+            logger.debug(f"LLM响应验证通过: {json.dumps(result, indent=2, ensure_ascii=False)}")
+            return result
+        return response_text
+
+class AnalysisService:
+    """分析领域服务"""
+    def __init__(self, llm_service: LLMService, db_manager: BaseMongoManager):
+        self.llm_service = llm_service
+        self.db_manager = db_manager
+
+    async def execute_analysis(self, product_name: str, format_type: str = "json") -> Union[dict, str]:
+        await self.db_manager.initialize()
+        product = await Product.find_one(Product.basic_info.name == product_name)
+        if not product:
+            raise ValueError(f"未找到产品: {product_name}")
+
+        prompt = await self._prepare_prompt(product, format_type)
+        return await self.llm_service.analyze(prompt)
+
+    async def _prepare_prompt(self, product: Product, format_type: str) -> str:
+        competitor_data = get_competitor_prompt_data(product)
+        output_fields = {
+            "results": {
+                "asin": "商品(竞品)编号",
+                "main_key": "主要关键词",
+                "monthly_searches": "月搜索量",
+                "reason": "分析理由"
+            },
+            "supplement": "补充说明"
+        }
+
+        formatter = FormatterFactory.create_formatter(format_type)
+        output_format = formatter.format(output_fields)
+
+        return f'''各个字段说明:
+{get_field_descriptions(CompetitorCrawlData)}
+{get_field_descriptions(ProductImageInfo)}
+{get_field_descriptions(TrafficKeywordResult)}
+
+竞品数据:
+{competitor_data}
+----
+我是日本站的亚马逊运营,正在为产品 {product.basic_info.name} 选主要关键词。
+请根据以上 {len(competitor_data)} 个竞品数据,按以下规则分析:
+1. 优先选择搜索量1万以上的相同关键词
+2. 不足时选择搜索量前十且相关性强的关键词
+3. 结合日本市场特点分析
+
+输出格式:
+{output_format}'''
 
 async def main():
     logger.info(f"base url {os.environ.get('OPENAI_API_BASE')}")
-    format_type = 'human'  # 可以从配置或参数获取
-    format_type = 'json'  
-    analyze_competitor_main_keyword_prompt = await test_product_mongo(format_type=format_type)
+    db_manager = BaseMongoManager()
+    llm_service = LiteLLMService(format_type='json')
+    analysis_service = AnalysisService(llm_service, db_manager)
+
     try:
-        result = await analyze_with_llm(
-            analyze_competitor_main_keyword_prompt,
-            format_type=format_type
-        )
-        
-        if format_type == "json":
-            logger.info(f"分析结果: {json.dumps(result, indent=2, ensure_ascii=False)}")
-        else:
-            logger.info(f"分析结果:\n{result}")
-        
-        # 这里可以添加结果处理逻辑
-        # 例如保存到数据库或进一步处理
-        
+        result = await analysis_service.execute_analysis("电线保护套")
+        logger.info(f"分析结果: {json.dumps(result, indent=2, ensure_ascii=False)}")
     except ValueError as e:
         logger.error(f"分析失败: {str(e)}")
-        # 可以添加失败处理逻辑
 
 if __name__ == "__main__":
     asyncio.run(main())