Quellcode durchsuchen

完成关键词长尾词分析、并输出成单个文件报告

mrh vor 11 Monaten
Ursprung
Commit
549ad925bf
3 geänderte Dateien mit 137 neuen und 36 gelöschten Zeilen
  1. 132 33
      ai/marketting_agent.py
  2. 2 1
      config/settings.py
  3. 3 2
      src/models/ai_execution_record.py

+ 132 - 33
ai/marketting_agent.py

@@ -1,4 +1,6 @@
 import asyncio
+from datetime import datetime
+import re
 import aiofiles
 import os
 import sys
@@ -8,13 +10,19 @@ from pydantic import BaseModel
 from typing import List
 from llama_index.core.program import LLMTextCompletionProgram
 from llama_index.core.output_parsers import PydanticOutputParser
+from llama_index.program.lmformatenforcer import (
+    LMFormatEnforcerPydanticProgram,
+)
+from llama_index.core.prompts import PromptTemplate
+from llama_index.llms.llama_cpp import LlamaCPP
 from llama_index.llms.openai import OpenAI
 from llama_index.llms.litellm import LiteLLM
 from llama_index.core.llms.llm import LLM
 from src.models.product_model import Product
 from src.manager.template_manager import TemplateManager, TemplateService, TemplateType
 from src.models.ai_execution_record import MarketingInfo, LLMConfig, SuperPromptMixin, AgentConfig, AgentContent, AICompetitorAnalyzeMainKeywords, AICompetitorAnalyzeMainKeywordsResult, MarketingContentGeneration
-from config.settings import MONGO_URL, MONGO_DB_NAME,LITELLM_API_BASE, LITELLM_API_KEY
+from utils.file import save_to_file, read_file
+from config.settings import MONGO_URL, MONGO_DB_NAME,LITELLM_API_BASE, LITELLM_API_KEY,OPENAI_API_KEY,OPENAI_API_BASE
 from utils.logu import get_logger
 logger = get_logger("ai")
 
@@ -25,9 +33,9 @@ class BaseAgent:
     
     def get_mainkeys_tailkeys(self, template_str: str):
         pass
-
+    
 class MarketingAgent(BaseAgent):
-    async def get_mainkeys_tailkeys_prompt(self, product_name, prompt: str='', verbose=False):
+    async def get_mainkeys_tailkeys_prompt(self, product_name, prompt: str='',output_type='markdown', verbose=False):
         base_prompt = "{{product_info}}\n{{competitor_info}}\n"
         prompt_mainkyes = prompt or '''\
 你是日本站的亚马逊运营,请你根据产品信息为用户选出主要关键词和长尾关键词。
@@ -47,34 +55,52 @@ class MarketingAgent(BaseAgent):
 生成的内容满足以下要求:
 - reason 、 suggestions 必须写成中文
 - monthly_searches 必须是 int ,按照从大到小排序,别的字段按照源数据填写即可。
+- 内容格式必须是 {output_type} 
 '''
         variables = {'product_name': product_name}
         product_info = await self.template_manager.execute_template("product_info", variables)
         competitor_info = await self.template_manager.execute_template("competitor_for_llm", variables)
+        prompt_tmpl = PromptTemplate(template=base_prompt + prompt_mainkyes, )
+        return prompt_tmpl.partial_format(
+            product_info=product_info, 
+            competitor_info=competitor_info,
+            output_type=output_type,
+            )
+    def llm_mainkeys_tailkeys_to_model(self, prompt_template:PromptTemplate, prompt_kwargs:Dict={}, verbose=False):
         program = LLMTextCompletionProgram.from_defaults(
-            output_parser=PydanticOutputParser(output_cls=AICompetitorAnalyzeMainKeywordsResult),
+            output_parser=PydanticOutputParser(output_cls=MarketingInfo),
+            prompt_template_str=prompt_template.format(prompt_kwargs),
             llm=self.llm,
-            prompt_template_str=base_prompt + prompt_mainkyes,
             verbose=verbose,
         )
-        competitor = program(product_info=product_info, competitor_info=competitor_info)
+        competitor = program(**prompt_kwargs)
         logger.info(f"{competitor}")
         return competitor
-    async def gen_mainkeys_tailkeys(self, product_name, prompt: str='', verbose=False, overwrite=False):
-        agent_model = await AgentContent.find_one(AgentContent.model_name == self.llm._get_model_name() and AgentContent.product_name == product_name)
-        if not overwrite and agent_model:
-            logger.info(f"agent_model exist")
+    async def gen_mainkeys_tailkeys(self, product_name, prompt: str='',output_type='markdown', verbose=False, overwrite=False):
+        logger.info(f"start llm model: {self.llm._get_model_name()}, output_type: {output_type}")
+        agent_model = await AgentContent.find_one(AgentContent.model_name == self.llm._get_model_name(),AgentContent.product_name == product_name)
+        if agent_model and not overwrite:
+            logger.info(f"agent_model exist {agent_model}")
             return agent_model
-        competitor = await self.get_mainkeys_tailkeys_prompt(product_name, prompt, verbose)
-        product_model = await Product.find_one(Product.basic_info.name == product_name)
-        agent_model = AgentContent(
-            model_name=self.llm._get_model_name(), 
-            product=product_model,
-            product_name=product_name,
-            competitor=competitor,
-        )
-        await agent_model.save() 
-        return agent_model
+        elif not agent_model:
+            agent_model = AgentContent(model_name=self.llm._get_model_name(), product_name=product_name)
+
+        prompt_template = await self.get_mainkeys_tailkeys_prompt(product_name, prompt,output_type, verbose)
+        logger.info(f"{prompt_template.format()}")
+        if output_type == 'json':
+            competitor = self.llm_mainkeys_tailkeys_to_model(prompt_template, verbose)
+        elif output_type == 'markdown':
+            response = await self.llm.acomplete(prompt_template.format())
+            pattern = r'```markdown(.*?)```'
+            matches = re.findall(pattern, response.text, re.DOTALL)
+            if not matches:
+                competitor = response.text
+            else:
+                competitor = matches[0]
+        agent_model.product = await Product.find_one(Product.basic_info.name == product_name)
+        agent_model.competitor[output_type] = competitor
+        agent_model.update_time = datetime.now()
+        return await agent_model.save() 
 
     async def get_marketing_prompt(self, product_name, prompt: str='', verbose=False, llm=None):
         prompt_marketing = prompt or '''\
@@ -98,31 +124,104 @@ class MarketingAgent(BaseAgent):
         all_keywords = self.template_manager.execute_template('agent.mainkeys_tailkeys')
         variables = {'product_name': product_name}
         product_info = await self.template_manager.execute_template("product_info", variables)
-        program = LLMTextCompletionProgram.from_defaults(
-            output_parser=PydanticOutputParser(output_cls=MarketingInfo),
+        program = LMFormatEnforcerPydanticProgram(
+            output_cls=MarketingInfo,
             llm=llm or self.llm,
             prompt_template_str="{{product_info}}\n{{all_keywords}}\n" + prompt_marketing,
             verbose=verbose,
         )
         competitor = program(product_info=product_info, all_keywords=all_keywords)
-        logger.info(f"{competitor}")
         return competitor
 
-    async def gen_marketing_content(self, product_name, prompt: str='', llm=None, verbose=False):
-        pass
-
-async def task():
+    async def gen_marketing_file(self, product_name, output_path: str, llm_models: List[str] = []):
+        models = await AgentContent.find(AgentContent.product_name == product_name).to_list()
+        
+        # 创建一个字典来存储模型及其基础名称
+        model_dict = {}
+        unsorted_models = []
+        
+        for model in models:
+            # 提取基础名称(去掉前面的xxx/xxx)
+            base_name = model.model_name.split('/')[-1] if '/' in model.model_name else model.model_name
+            
+            # 检查是否在优先级列表中
+            found = False
+            for priority_model in llm_models:
+                priority_base = priority_model.split('/')[-1] if '/' in priority_model else priority_model
+                if base_name == priority_base:
+                    model_dict[priority_model] = (model, base_name)  # 存储模型和显示名称
+                    found = True
+                    break
+            
+            if not found:
+                # 对于不在列表中的模型,也存储基础名称
+                unsorted_models.append((model, base_name))
+        
+        # 按照llm_models顺序生成内容
+        sorted_content = ''
+        for priority_model in llm_models:
+            if priority_model in model_dict:
+                model, display_name = model_dict[priority_model]
+                markdown = model.competitor.get('markdown', '')
+                logger.info(f"llm_name: {model.model_name} , {markdown[:100]}")
+                sorted_content += f"# {display_name}\n{markdown}\n\n"  # 使用基础名称
+        
+        # 添加未排序的模型
+        unsorted_content = ''
+        for model, display_name in unsorted_models:
+            markdown = model.competitor.get('markdown', '')
+            logger.info(f"llm_name: {model.model_name} , {markdown[:100]}")
+            unsorted_content += f"# {display_name}\n{markdown}\n\n"  # 使用基础名称
+        
+        content = sorted_content + unsorted_content
+        return save_to_file(content, output_path)
+async def llm_task():
     m = TemplateManager(MONGO_URL, MONGO_DB_NAME)
     await m.initialize()
-    model = 'groq/groq/llama-3.1-8b-instant'
+    model = 'openai/groq/llama-3.1-8b-instant'
     # model = 'groq/groq/qwen-2.5-coder-32b'
-    model = 'groq/groq/qwen-2.5-32b'
-    llm = LiteLLM(model=model, api_key=LITELLM_API_KEY, api_base=LITELLM_API_BASE)
+    # model = 'openai/glm-4-flash'
+    # model = 'openai/deepseek-v3'
+    # model = 'openai/groq/qwen-2.5-32b'
+    # model = 'openai/deepseek-chat'
+    model = 'openai/deepseek-reasoner'
+    # model = 'openai/doubao-pro-32k-241215'
+    llm_models = [
+        # 'openai/deepseek-v3',
+        'openai/QwQ-32B',
+        'openai/deepseek-reasoner',
+        'openai/doubao-pro-32k-241215',
+    ]
+    task_list = []
+    for model in llm_models:
+        llm = LiteLLM(model=model, api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)
+        agent = MarketingAgent(llm=llm, template_manager=m)
+        agent_model = agent.gen_mainkeys_tailkeys(product_name='大尺寸厚款卸妆棉240片', verbose=True, overwrite=True)
+        task_list.append(agent_model)
+        # logger.info(f"{agent_model.competitor.items()}")
+    await asyncio.gather(*task_list)
+    # llm = LiteLLM(model=model, api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)
+    # agent = MarketingAgent(llm=llm, template_manager=m)
+    # agent_model = await agent.gen_mainkeys_tailkeys(product_name='大尺寸厚款卸妆棉240片', verbose=True, overwrite=True)
+    # logger.info(f"{agent_model.competitor.items()}")
+async def gen_marketing_file():
+    m = TemplateManager(MONGO_URL, MONGO_DB_NAME)
+    await m.initialize()
+    model = 'openai/deepseek-reasoner'
+    llm = LiteLLM(model=model, api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)
+    product_name = '大尺寸厚款卸妆棉240片'
     agent = MarketingAgent(llm=llm, template_manager=m)
-    agent_model = await agent.gen_mainkeys_tailkeys(product_name='电线保护套')
-    logger.info(f"{agent_model.competitor}")
+    output_path = r'G:\code\amazone\copywriting_production\output\temp' + f"\\{product_name}-营销文案.md"
+    llm_models = [
+'openai/doubao-pro-32k-241215',
+'openai/deepseek-reasoner',
+'openai/deepseek-v3', 
+'openai/QwQ-32B',
+ ]
+    await agent.gen_marketing_file(product_name=product_name, output_path=output_path, llm_models=llm_models)
+    logger.info(f"{output_path}")
 def main():
-    asyncio.run(task())
+    asyncio.run(gen_marketing_file())
 
 if __name__ == "__main__":
     main()

+ 2 - 1
config/settings.py

@@ -21,7 +21,8 @@ TEMP_PAGE_DIR.mkdir(parents=True, exist_ok=True)
 
 LITELLM_API_KEY = os.environ.get('LITELLM_API_KEY')
 LITELLM_API_BASE=os.environ.get('LITELLM_API_BASE')
-
+OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
+OPENAI_API_BASE=os.environ.get('OPENAI_API_BASE')
 class Config(BaseModel):
     storage: str = "local"
     s3_access_key: Optional[str] = os.environ.get("S3_ACCESS_KEY", 'bh9LbfsPHRJgQ44wXIlv')

+ 3 - 2
src/models/ai_execution_record.py

@@ -118,10 +118,11 @@ class AgentContent(Document):
         default=None,
         description="生成的营销内容结果"
     )
-    competitor:Optional[AICompetitorAnalyzeMainKeywordsResult] = Field(
-        default=None,
+    competitor:Optional[Dict[str,Union[AICompetitorAnalyzeMainKeywordsResult,str]]] = Field(
+        default={},
         description="竞品关键词分析结果" 
     )
+    update_time:Optional[datetime] = Field(default_factory=datetime.now)
     create_time:Optional[datetime] = Field(default_factory=datetime.now)
     class Settings:
         name = "agent.product"