|
|
@@ -1,10 +1,11 @@
|
|
|
import json
|
|
|
-from typing import Optional
|
|
|
+from typing import Optional, Union
|
|
|
from llama_index.core import PromptTemplate
|
|
|
import asyncio
|
|
|
import aiofiles
|
|
|
import os
|
|
|
import sys
|
|
|
+from dotenv import load_dotenv
|
|
|
from pydantic import BaseModel
|
|
|
from src.models.product_model import (
|
|
|
Product, CompetitorCrawlData, AICompetitorAnalyzeMainKeywords,
|
|
|
@@ -12,10 +13,11 @@ from src.models.product_model import (
|
|
|
SearchAmazoneKeyResult, ProductBaseInfo, Variant
|
|
|
)
|
|
|
from llama_index.llms.openai import OpenAI
|
|
|
+from llama_index.llms.litellm import LiteLLM
|
|
|
from src.manager.core.db_mongo import BaseMongoManager
|
|
|
from utils.logu import get_logger
|
|
|
from src.models.field_config import FieldConfig
|
|
|
-
|
|
|
+load_dotenv()
|
|
|
logger = get_logger('ai')
|
|
|
|
|
|
# 默认包含的字段配置
|
|
|
@@ -94,27 +96,36 @@ def get_field_descriptions(
|
|
|
"""
|
|
|
return field_config.get_model_fields(model_class, model_name)
|
|
|
|
|
|
-def format_output(fields_desc: dict, format_type: str = "json"):
|
|
|
- """根据字段描述生成输出格式
|
|
|
+def format_output(fields_desc: dict, format_type: str = "json", notes: Optional[dict] = None):
|
|
|
+ """根据字段描述生成输出格式(支持嵌套字典结构)
|
|
|
|
|
|
Args:
|
|
|
- fields_desc: 字段描述字典
|
|
|
+ fields_desc: 字段描述字典(支持嵌套)
|
|
|
format_type: 输出格式类型(json/human)
|
|
|
+ notes: 格式特定的额外说明信息字典,如 {"json": "JSON格式说明", "human": "人类可读说明"}
|
|
|
|
|
|
Returns:
|
|
|
格式化后的输出模板
|
|
|
"""
|
|
|
- if format_type == "human":
|
|
|
- # 动态生成human-readable格式
|
|
|
- lines = []
|
|
|
- for field, desc in fields_desc.items():
|
|
|
- lines.append(f"{desc}: {{{field}}}")
|
|
|
- return "\n".join(lines)
|
|
|
- else:
|
|
|
- # JSON格式返回字段名到空值的映射
|
|
|
- return [{k: "" for k in fields_desc.keys()}]
|
|
|
+ def process_dict(d, format_type):
|
|
|
+ if format_type == "human":
|
|
|
+ lines = []
|
|
|
+ for key, value in d.items():
|
|
|
+ if isinstance(value, dict):
|
|
|
+ nested = process_dict(value, format_type)
|
|
|
+ lines.append(f"{key}:\n{nested}")
|
|
|
+ else:
|
|
|
+ lines.append(f"{value}: {{{key}}}")
|
|
|
+ return "\n".join(lines)
|
|
|
+ else:
|
|
|
+ return f"```json\n{json.dumps(d, indent=2, ensure_ascii=False)}\n```" if format_type == "json" else d
|
|
|
+
|
|
|
+ result = process_dict(fields_desc, format_type)
|
|
|
+ if notes and notes.get(format_type):
|
|
|
+ result += f"\n{notes[format_type]}"
|
|
|
+ return result
|
|
|
|
|
|
-async def test_product_mongo(format_type: str = "json"):
|
|
|
+async def test_product_mongo(main_key_num=3, format_type: str = "json"):
|
|
|
db_mongo = BaseMongoManager()
|
|
|
await db_mongo.initialize()
|
|
|
product = await Product.find_one(Product.basic_info.name == "电线保护套")
|
|
|
@@ -127,14 +138,22 @@ async def test_product_mongo(format_type: str = "json"):
|
|
|
|
|
|
# 定义输出字段描述
|
|
|
output_fields = {
|
|
|
- "asin": "商品(竞品)编号",
|
|
|
- "main_key": "主要关键词",
|
|
|
- "monthly_searches": "月搜索量",
|
|
|
- "reason": "选择理由"
|
|
|
+ "results":{
|
|
|
+ "asin": "商品(竞品)编号",
|
|
|
+ "main_key": "主要关键词",
|
|
|
+ "monthly_searches": "月搜索量",
|
|
|
+ "reason": "分析你选出的关键词理由"
|
|
|
+ },
|
|
|
+ "supplement": "非必填,仅当你觉得有必要时才提供补充说明。"
|
|
|
}
|
|
|
- format_type = 'human'
|
|
|
+ # 格式特定的额外说明信息
|
|
|
+ notes = {
|
|
|
+ # "json": "json 信息时必须",
|
|
|
+ # "human": "人类可读信息时必须"
|
|
|
+ }
|
|
|
+
|
|
|
# 生成输出格式
|
|
|
- output_format = format_output(output_fields, format_type)
|
|
|
+ output_format = format_output(output_fields, format_type, notes)
|
|
|
|
|
|
logger.info(f"competitor_desc {competitor_desc}")
|
|
|
logger.info(f"product_info_desc {product_info_desc}")
|
|
|
@@ -147,10 +166,11 @@ async def test_product_mongo(format_type: str = "json"):
|
|
|
竞品数据:
|
|
|
{competitor_data}
|
|
|
----
|
|
|
-我是亚马逊运营,我在给产品名称为 {product_name} 选主要关键词,以上数据是我从同类竞品的关键词搜索量数据,总共有 {competitor_count} 个竞品数据。
|
|
|
+我是日本站的亚马逊运营,我在给产品名称为 {product_name} 选主要关键词,以上数据是我从同类竞品的关键词搜索量数据,总共有 {competitor_count} 个竞品数据。
|
|
|
请帮我分析这些竞品数据,选出搜索量在1万以上的相同关键词作为主要关键词3个。
|
|
|
-如果竞品的搜索量都不足1万,则从排名前十的关键词中筛选三个搜索量最大且相关性最强的词。
|
|
|
-输出格式:
|
|
|
+如果竞品的搜索量都不足1万,则从排名前十的关键词中筛选 {main_key_num} 个搜索量最大且相关性最强的词。
|
|
|
+建议结合日本的市场和用户习惯来分析关键词。
|
|
|
+请输出以下格式:
|
|
|
{output_format}
|
|
|
'''
|
|
|
text_qa_template = PromptTemplate(analyz_main_keyword_template_str)
|
|
|
@@ -162,8 +182,96 @@ async def test_product_mongo(format_type: str = "json"):
|
|
|
output_format=output_format,
|
|
|
)
|
|
|
logger.info(formatted_output)
|
|
|
-def main():
|
|
|
- asyncio.run(test_product_mongo())
|
|
|
+ return formatted_output
|
|
|
+
|
|
|
+async def analyze_with_llm(
|
|
|
+ prompt: str,
|
|
|
+ model: str = "openai/deepseek-chat",
|
|
|
+ max_retries: int = 3,
|
|
|
+ retry_delay: float = 1.0,
|
|
|
+ format_type: str = "json"
|
|
|
+) -> Union[dict, str]:
|
|
|
+ """使用LLM分析数据并返回结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prompt: 完整的提示词
|
|
|
+ model: 模型名称
|
|
|
+ max_retries: 最大重试次数
|
|
|
+ retry_delay: 重试延迟(秒)
|
|
|
+ format_type: 输出格式类型(json/human)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: 当format_type为json时的解析结果
|
|
|
+ str: 当format_type为human时的原始文本
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: 当无法获取有效响应时
|
|
|
+ """
|
|
|
+ llm_kwargs = {}
|
|
|
+ if format_type == "json":
|
|
|
+ llm_kwargs["additional_kwargs"] = {"response_format": {"type": "json_object"}}
|
|
|
+
|
|
|
+ llm = LiteLLM(model=model, **llm_kwargs)
|
|
|
+
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ logger.info(f"尝试第 {attempt + 1} 次LLM调用...")
|
|
|
+ completion = await llm.acomplete(prompt)
|
|
|
+ response_text = completion.text
|
|
|
+
|
|
|
+ if format_type == "json":
|
|
|
+ # 尝试从Markdown代码块中提取JSON
|
|
|
+ if "```json" in response_text:
|
|
|
+ json_str = response_text.split("```json")[1].split("```")[0].strip()
|
|
|
+ else:
|
|
|
+ json_str = response_text
|
|
|
+
|
|
|
+ result = json.loads(json_str)
|
|
|
+
|
|
|
+ if not isinstance(result, dict):
|
|
|
+ raise ValueError("响应不是有效的JSON对象")
|
|
|
+
|
|
|
+ logger.debug(f"LLM响应验证通过: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
+ return result
|
|
|
+ else:
|
|
|
+ # 直接返回原始文本
|
|
|
+ return response_text
|
|
|
+
|
|
|
+ except (json.JSONDecodeError, ValueError) as e:
|
|
|
+ if format_type == "json":
|
|
|
+ logger.warning(f"JSON解析失败(尝试 {attempt + 1}/{max_retries}): {str(e)}")
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ await asyncio.sleep(retry_delay)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"无法获取有效的JSON响应: {str(e)}")
|
|
|
+ else:
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"LLM调用失败: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+async def main():
|
|
|
+ logger.info(f"base url {os.environ.get('OPENAI_API_BASE')}")
|
|
|
+ format_type = 'human' # 可以从配置或参数获取
|
|
|
+ format_type = 'json'
|
|
|
+ analyze_competitor_main_keyword_prompt = await test_product_mongo(format_type=format_type)
|
|
|
+ try:
|
|
|
+ result = await analyze_with_llm(
|
|
|
+ analyze_competitor_main_keyword_prompt,
|
|
|
+ format_type=format_type
|
|
|
+ )
|
|
|
+
|
|
|
+ if format_type == "json":
|
|
|
+ logger.info(f"分析结果: {json.dumps(result, indent=2, ensure_ascii=False)}")
|
|
|
+ else:
|
|
|
+ logger.info(f"分析结果:\n{result}")
|
|
|
+
|
|
|
+ # 这里可以添加结果处理逻辑
|
|
|
+ # 例如保存到数据库或进一步处理
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ logger.error(f"分析失败: {str(e)}")
|
|
|
+ # 可以添加失败处理逻辑
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
+ asyncio.run(main())
|