|
|
@@ -1,4 +1,6 @@
|
|
|
import asyncio
|
|
|
+from datetime import datetime
|
|
|
+import re
|
|
|
import aiofiles
|
|
|
import os
|
|
|
import sys
|
|
|
@@ -8,13 +10,19 @@ from pydantic import BaseModel
|
|
|
from typing import List
|
|
|
from llama_index.core.program import LLMTextCompletionProgram
|
|
|
from llama_index.core.output_parsers import PydanticOutputParser
|
|
|
+from llama_index.program.lmformatenforcer import (
|
|
|
+ LMFormatEnforcerPydanticProgram,
|
|
|
+)
|
|
|
+from llama_index.core.prompts import PromptTemplate
|
|
|
+from llama_index.llms.llama_cpp import LlamaCPP
|
|
|
from llama_index.llms.openai import OpenAI
|
|
|
from llama_index.llms.litellm import LiteLLM
|
|
|
from llama_index.core.llms.llm import LLM
|
|
|
from src.models.product_model import Product
|
|
|
from src.manager.template_manager import TemplateManager, TemplateService, TemplateType
|
|
|
from src.models.ai_execution_record import MarketingInfo, LLMConfig, SuperPromptMixin, AgentConfig, AgentContent, AICompetitorAnalyzeMainKeywords, AICompetitorAnalyzeMainKeywordsResult, MarketingContentGeneration
|
|
|
-from config.settings import MONGO_URL, MONGO_DB_NAME,LITELLM_API_BASE, LITELLM_API_KEY
|
|
|
+from utils.file import save_to_file, read_file
|
|
|
+from config.settings import MONGO_URL, MONGO_DB_NAME,LITELLM_API_BASE, LITELLM_API_KEY,OPENAI_API_KEY,OPENAI_API_BASE
|
|
|
from utils.logu import get_logger
|
|
|
logger = get_logger("ai")
|
|
|
|
|
|
@@ -25,9 +33,9 @@ class BaseAgent:
|
|
|
|
|
|
def get_mainkeys_tailkeys(self, template_str: str):
|
|
|
pass
|
|
|
-
|
|
|
+
|
|
|
class MarketingAgent(BaseAgent):
|
|
|
- async def get_mainkeys_tailkeys_prompt(self, product_name, prompt: str='', verbose=False):
|
|
|
+ async def get_mainkeys_tailkeys_prompt(self, product_name, prompt: str='',output_type='markdown', verbose=False):
|
|
|
base_prompt = "{{product_info}}\n{{competitor_info}}\n"
|
|
|
prompt_mainkyes = prompt or '''\
|
|
|
你是日本站的亚马逊运营,请你根据产品信息为用户选出主要关键词和长尾关键词。
|
|
|
@@ -47,34 +55,52 @@ class MarketingAgent(BaseAgent):
|
|
|
生成的内容满足以下要求:
|
|
|
- reason 、 suggestions 必须写成中文
|
|
|
- monthly_searches 必须是 int ,按照从大到小排序,别的字段按照源数据填写即可。
|
|
|
+- 内容格式必须是 {output_type}
|
|
|
'''
|
|
|
variables = {'product_name': product_name}
|
|
|
product_info = await self.template_manager.execute_template("product_info", variables)
|
|
|
competitor_info = await self.template_manager.execute_template("competitor_for_llm", variables)
|
|
|
+ prompt_tmpl = PromptTemplate(template=base_prompt + prompt_mainkyes, )
|
|
|
+ return prompt_tmpl.partial_format(
|
|
|
+ product_info=product_info,
|
|
|
+ competitor_info=competitor_info,
|
|
|
+ output_type=output_type,
|
|
|
+ )
|
|
|
+ def llm_mainkeys_tailkeys_to_model(self, prompt_template:PromptTemplate, prompt_kwargs:Dict={}, verbose=False):
|
|
|
program = LLMTextCompletionProgram.from_defaults(
|
|
|
- output_parser=PydanticOutputParser(output_cls=AICompetitorAnalyzeMainKeywordsResult),
|
|
|
+ output_parser=PydanticOutputParser(output_cls=MarketingInfo),
|
|
|
+ prompt_template_str=prompt_template.format(prompt_kwargs),
|
|
|
llm=self.llm,
|
|
|
- prompt_template_str=base_prompt + prompt_mainkyes,
|
|
|
verbose=verbose,
|
|
|
)
|
|
|
- competitor = program(product_info=product_info, competitor_info=competitor_info)
|
|
|
+ competitor = program(**prompt_kwargs)
|
|
|
logger.info(f"{competitor}")
|
|
|
return competitor
|
|
|
- async def gen_mainkeys_tailkeys(self, product_name, prompt: str='', verbose=False, overwrite=False):
|
|
|
- agent_model = await AgentContent.find_one(AgentContent.model_name == self.llm._get_model_name() and AgentContent.product_name == product_name)
|
|
|
- if not overwrite and agent_model:
|
|
|
- logger.info(f"agent_model exist")
|
|
|
+ async def gen_mainkeys_tailkeys(self, product_name, prompt: str='',output_type='markdown', verbose=False, overwrite=False):
|
|
|
+ logger.info(f"start llm model: {self.llm._get_model_name()}, output_type: {output_type}")
|
|
|
+ agent_model = await AgentContent.find_one(AgentContent.model_name == self.llm._get_model_name(),AgentContent.product_name == product_name)
|
|
|
+ if agent_model and not overwrite:
|
|
|
+ logger.info(f"agent_model exist {agent_model}")
|
|
|
return agent_model
|
|
|
- competitor = await self.get_mainkeys_tailkeys_prompt(product_name, prompt, verbose)
|
|
|
- product_model = await Product.find_one(Product.basic_info.name == product_name)
|
|
|
- agent_model = AgentContent(
|
|
|
- model_name=self.llm._get_model_name(),
|
|
|
- product=product_model,
|
|
|
- product_name=product_name,
|
|
|
- competitor=competitor,
|
|
|
- )
|
|
|
- await agent_model.save()
|
|
|
- return agent_model
|
|
|
+ elif not agent_model:
|
|
|
+ agent_model = AgentContent(model_name=self.llm._get_model_name(), product_name=product_name)
|
|
|
+
|
|
|
+ prompt_template = await self.get_mainkeys_tailkeys_prompt(product_name, prompt,output_type, verbose)
|
|
|
+ logger.info(f"{prompt_template.format()}")
|
|
|
+ if output_type == 'json':
|
|
|
+ competitor = self.llm_mainkeys_tailkeys_to_model(prompt_template, verbose)
|
|
|
+ elif output_type == 'markdown':
|
|
|
+ response = await self.llm.acomplete(prompt_template.format())
|
|
|
+ pattern = r'```markdown(.*?)```'
|
|
|
+ matches = re.findall(pattern, response.text, re.DOTALL)
|
|
|
+ if not matches:
|
|
|
+ competitor = response.text
|
|
|
+ else:
|
|
|
+ competitor = matches[0]
|
|
|
+ agent_model.product = await Product.find_one(Product.basic_info.name == product_name)
|
|
|
+ agent_model.competitor[output_type] = competitor
|
|
|
+ agent_model.update_time = datetime.now()
|
|
|
+ return await agent_model.save()
|
|
|
|
|
|
async def get_marketing_prompt(self, product_name, prompt: str='', verbose=False, llm=None):
|
|
|
prompt_marketing = prompt or '''\
|
|
|
@@ -98,31 +124,104 @@ class MarketingAgent(BaseAgent):
|
|
|
all_keywords = self.template_manager.execute_template('agent.mainkeys_tailkeys')
|
|
|
variables = {'product_name': product_name}
|
|
|
product_info = await self.template_manager.execute_template("product_info", variables)
|
|
|
- program = LLMTextCompletionProgram.from_defaults(
|
|
|
- output_parser=PydanticOutputParser(output_cls=MarketingInfo),
|
|
|
+ program = LMFormatEnforcerPydanticProgram(
|
|
|
+ output_cls=MarketingInfo,
|
|
|
llm=llm or self.llm,
|
|
|
prompt_template_str="{{product_info}}\n{{all_keywords}}\n" + prompt_marketing,
|
|
|
verbose=verbose,
|
|
|
)
|
|
|
competitor = program(product_info=product_info, all_keywords=all_keywords)
|
|
|
- logger.info(f"{competitor}")
|
|
|
return competitor
|
|
|
|
|
|
- async def gen_marketing_content(self, product_name, prompt: str='', llm=None, verbose=False):
|
|
|
- pass
|
|
|
-
|
|
|
-async def task():
|
|
|
+ async def gen_marketing_file(self, product_name, output_path: str, llm_models: List[str] = []):
|
|
|
+ models = await AgentContent.find(AgentContent.product_name == product_name).to_list()
|
|
|
+
|
|
|
+ # 创建一个字典来存储模型及其基础名称
|
|
|
+ model_dict = {}
|
|
|
+ unsorted_models = []
|
|
|
+
|
|
|
+ for model in models:
|
|
|
+ # 提取基础名称(去掉前面的xxx/xxx)
|
|
|
+ base_name = model.model_name.split('/')[-1] if '/' in model.model_name else model.model_name
|
|
|
+
|
|
|
+ # 检查是否在优先级列表中
|
|
|
+ found = False
|
|
|
+ for priority_model in llm_models:
|
|
|
+ priority_base = priority_model.split('/')[-1] if '/' in priority_model else priority_model
|
|
|
+ if base_name == priority_base:
|
|
|
+ model_dict[priority_model] = (model, base_name) # 存储模型和显示名称
|
|
|
+ found = True
|
|
|
+ break
|
|
|
+
|
|
|
+ if not found:
|
|
|
+ # 对于不在列表中的模型,也存储基础名称
|
|
|
+ unsorted_models.append((model, base_name))
|
|
|
+
|
|
|
+ # 按照llm_models顺序生成内容
|
|
|
+ sorted_content = ''
|
|
|
+ for priority_model in llm_models:
|
|
|
+ if priority_model in model_dict:
|
|
|
+ model, display_name = model_dict[priority_model]
|
|
|
+ markdown = model.competitor.get('markdown', '')
|
|
|
+ logger.info(f"llm_name: {model.model_name} , {markdown[:100]}")
|
|
|
+ sorted_content += f"# {display_name}\n{markdown}\n\n" # 使用基础名称
|
|
|
+
|
|
|
+ # 添加未排序的模型
|
|
|
+ unsorted_content = ''
|
|
|
+ for model, display_name in unsorted_models:
|
|
|
+ markdown = model.competitor.get('markdown', '')
|
|
|
+ logger.info(f"llm_name: {model.model_name} , {markdown[:100]}")
|
|
|
+ unsorted_content += f"# {display_name}\n{markdown}\n\n" # 使用基础名称
|
|
|
+
|
|
|
+ content = sorted_content + unsorted_content
|
|
|
+ return save_to_file(content, output_path)
|
|
|
+async def llm_task():
|
|
|
m = TemplateManager(MONGO_URL, MONGO_DB_NAME)
|
|
|
await m.initialize()
|
|
|
- model = 'groq/groq/llama-3.1-8b-instant'
|
|
|
+ model = 'openai/groq/llama-3.1-8b-instant'
|
|
|
# model = 'groq/groq/qwen-2.5-coder-32b'
|
|
|
- model = 'groq/groq/qwen-2.5-32b'
|
|
|
- llm = LiteLLM(model=model, api_key=LITELLM_API_KEY, api_base=LITELLM_API_BASE)
|
|
|
+ # model = 'openai/glm-4-flash'
|
|
|
+ # model = 'openai/deepseek-v3'
|
|
|
+ # model = 'openai/groq/qwen-2.5-32b'
|
|
|
+ # model = 'openai/deepseek-chat'
|
|
|
+ model = 'openai/deepseek-reasoner'
|
|
|
+ # model = 'openai/doubao-pro-32k-241215'
|
|
|
+ llm_models = [
|
|
|
+ # 'openai/deepseek-v3',
|
|
|
+ 'openai/QwQ-32B',
|
|
|
+ 'openai/deepseek-reasoner',
|
|
|
+ 'openai/doubao-pro-32k-241215',
|
|
|
+ ]
|
|
|
+ task_list = []
|
|
|
+ for model in llm_models:
|
|
|
+ llm = LiteLLM(model=model, api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)
|
|
|
+ agent = MarketingAgent(llm=llm, template_manager=m)
|
|
|
+ agent_model = agent.gen_mainkeys_tailkeys(product_name='大尺寸厚款卸妆棉240片', verbose=True, overwrite=True)
|
|
|
+ task_list.append(agent_model)
|
|
|
+ # logger.info(f"{agent_model.competitor.items()}")
|
|
|
+ await asyncio.gather(*task_list)
|
|
|
+ # llm = LiteLLM(model=model, api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)
|
|
|
+ # agent = MarketingAgent(llm=llm, template_manager=m)
|
|
|
+ # agent_model = await agent.gen_mainkeys_tailkeys(product_name='大尺寸厚款卸妆棉240片', verbose=True, overwrite=True)
|
|
|
+ # logger.info(f"{agent_model.competitor.items()}")
|
|
|
+async def gen_marketing_file():
|
|
|
+ m = TemplateManager(MONGO_URL, MONGO_DB_NAME)
|
|
|
+ await m.initialize()
|
|
|
+ model = 'openai/deepseek-reasoner'
|
|
|
+ llm = LiteLLM(model=model, api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)
|
|
|
+ product_name = '大尺寸厚款卸妆棉240片'
|
|
|
agent = MarketingAgent(llm=llm, template_manager=m)
|
|
|
- agent_model = await agent.gen_mainkeys_tailkeys(product_name='电线保护套')
|
|
|
- logger.info(f"{agent_model.competitor}")
|
|
|
+ output_path = r'G:\code\amazone\copywriting_production\output\temp' + f"\\{product_name}-营销文案.md"
|
|
|
+ llm_models = [
|
|
|
+'openai/doubao-pro-32k-241215',
|
|
|
+'openai/deepseek-reasoner',
|
|
|
+'openai/deepseek-v3',
|
|
|
+'openai/QwQ-32B',
|
|
|
+ ]
|
|
|
+ await agent.gen_marketing_file(product_name=product_name, output_path=output_path, llm_models=llm_models)
|
|
|
+ logger.info(f"{output_path}")
|
|
|
def main():
|
|
|
- asyncio.run(task())
|
|
|
+ asyncio.run(gen_marketing_file())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|