|
|
@@ -1,17 +1,28 @@
|
|
|
from datetime import datetime, timedelta
|
|
|
from enum import StrEnum
|
|
|
from pathlib import Path
|
|
|
-from typing import Optional
|
|
|
+from typing import Optional, Any, Union, List
|
|
|
from pydantic import BaseModel, Field
|
|
|
+import re
|
|
|
from prefect import flow, task
|
|
|
from prefect.states import Failed, Running, Completed
|
|
|
from prefect.cache_policies import INPUTS
|
|
|
+from prefect.futures import wait
|
|
|
from src.browser.crawl_asin import Crawler
|
|
|
from utils.drission_page import ChromeOptions
|
|
|
-from config.settings import CFG, read_config, get_config_path, TEMP_PAGE_DIR
|
|
|
+from config.settings import CFG, read_config, get_config_path, TEMP_PAGE_DIR, OPENAI_API_KEY, OPENAI_API_BASE
|
|
|
from utils.logu import get_logger
|
|
|
-from utils.file import save_to_file, check_exists
|
|
|
+from utils.file import save_to_file, check_exists, extract_excel_text_from_url
|
|
|
from utils.file import s3
|
|
|
+from utils.url_utils import extract_urls_from_text, extract_filename_from_url
|
|
|
+from llama_index.llms.litellm import LiteLLM
|
|
|
+from llama_index.core.program import LLMTextCompletionProgram
|
|
|
+from llama_index.core.output_parsers import PydanticOutputParser
|
|
|
+from llama_index.core.output_parsers.pydantic import extract_json_str
|
|
|
+from src.flow_task.db.product_import_db import product_import_manager
|
|
|
+from src.flow_task.db.models.product_models import ProductImport, ProductForExtraction
|
|
|
+from src.manager.core.db import DbManager, AsinSeed
|
|
|
+from markitdown import MarkItDown
|
|
|
import tempfile
|
|
|
import os
|
|
|
|
|
|
@@ -169,8 +180,6 @@ def task_save_page(crawler: Crawler, asin: str, asin_area: AsinAreaEnum,
|
|
|
)
|
|
|
def task_save_to_db(local_file_path: str, asin: str, mthml_type: bool, asin_area: str = 'JP'):
|
|
|
"""将temp目录文件上传到S3的task方法,先检查数据库是否存在记录"""
|
|
|
- from src.manager.core.db import DbManager, AsinSeed
|
|
|
-
|
|
|
logger.info(f"开始处理文件: {local_file_path}")
|
|
|
|
|
|
# 初始化数据库管理器
|
|
|
@@ -281,9 +290,6 @@ def parse_url_to_markdown_task(url: str):
|
|
|
if url.lower().endswith(('.xlsx', '.xls')):
|
|
|
logger.info(f"检测到Excel文件,使用pandas方法读取: {url}")
|
|
|
|
|
|
- # 导入Excel处理函数
|
|
|
- from utils.file import extract_excel_text_from_url
|
|
|
-
|
|
|
# 使用pandas方法读取Excel文件
|
|
|
all_cells_text_dict = extract_excel_text_from_url(url)
|
|
|
|
|
|
@@ -305,7 +311,6 @@ def parse_url_to_markdown_task(url: str):
|
|
|
else:
|
|
|
# 非Excel文件使用原来的markitdown方法
|
|
|
logger.info(f"检测到非Excel文件,使用markitdown方法读取: {url}")
|
|
|
- from markitdown import MarkItDown
|
|
|
|
|
|
# 创建MarkItDown实例
|
|
|
md = MarkItDown(enable_plugins=False)
|
|
|
@@ -323,3 +328,167 @@ def parse_url_to_markdown_task(url: str):
|
|
|
logger.error(f"解析URL表格文件时发生错误: {e}")
|
|
|
raise Exception(f"解析URL表格文件失败: {e}")
|
|
|
|
|
|
+
|
|
|
+class DebugPydanticOutputParser(PydanticOutputParser):
|
|
|
+ """继承自PydanticOutputParser的调试版本,打印LLM生成结果"""
|
|
|
+
|
|
|
+ def parse(self, text: str) -> Any:
|
|
|
+ """Parse, validate, and correct errors programmatically."""
|
|
|
+ logger.info("=== LLM生成结果 ===")
|
|
|
+ logger.info(text)
|
|
|
+ logger.info("=== LLM生成结果结束 ===")
|
|
|
+
|
|
|
+ # 清理markdown代码块格式
|
|
|
+ cleaned_text = text
|
|
|
+ if "```json" in text:
|
|
|
+ # 移除markdown代码块标记
|
|
|
+ cleaned_text = text.split("```json")[1].split("```")[0]
|
|
|
+ elif "```" in text:
|
|
|
+ # 移除通用markdown代码块标记
|
|
|
+ cleaned_text = text.split("```")[1].split("```")[0]
|
|
|
+
|
|
|
+ # 清理转义字符
|
|
|
+ cleaned_text = cleaned_text.replace("\\n", "\n").replace("\\\"", "\"")
|
|
|
+
|
|
|
+ json_str = extract_json_str(cleaned_text)
|
|
|
+ return self._output_cls.model_validate_json(json_str)
|
|
|
+
|
|
|
+
|
|
|
+def extract_product_from_text(text: str, uri: str = "", filename: str = "") -> ProductImport:
|
|
|
+ """使用LLMTextCompletionProgram从文本中提取产品信息"""
|
|
|
+ llm = LiteLLM(model='openai/GLM-4-Flash', api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)
|
|
|
+
|
|
|
+ # 使用自定义的DebugPydanticOutputParser
|
|
|
+ output_parser = DebugPydanticOutputParser(output_cls=ProductForExtraction)
|
|
|
+
|
|
|
+ program = LLMTextCompletionProgram.from_defaults(
|
|
|
+ prompt_template_str=f"请从以下文本中提取产品信息:\n\nurl: {uri} \n\n{{text}}",
|
|
|
+ llm=llm,
|
|
|
+ verbose=True,
|
|
|
+ output_parser=output_parser
|
|
|
+ )
|
|
|
+
|
|
|
+ extracted_product = program(text=text)
|
|
|
+
|
|
|
+ # 使用类方法创建Product实例
|
|
|
+ return ProductImport.from_product_extraction(
|
|
|
+ extracted_product=extracted_product,
|
|
|
+ markdown_content=text,
|
|
|
+ uri=uri,
|
|
|
+ filename=filename
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+@task(name="Excel处理",
|
|
|
+ persist_result=True,
|
|
|
+ cache_expiration=timedelta(days=31),
|
|
|
+ cache_policy=INPUTS
|
|
|
+)
|
|
|
+def get_or_create_product_import_by_url(file_url: str):
|
|
|
+ """根据文件URL获取数据库中的ProductImport记录,如果不存在则解析Excel并保存到数据库
|
|
|
+
|
|
|
+ Args:
|
|
|
+ file_url (str): 文件的URL或本地路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ ProductImport: 数据库中的ProductImport记录
|
|
|
+ """
|
|
|
+ # 从URL中提取文件名
|
|
|
+ file_name = extract_filename_from_url(file_url)
|
|
|
+
|
|
|
+ logger.info(f"开始处理文件: {file_name} (URL: {file_url})")
|
|
|
+
|
|
|
+ # 首先检查数据库中是否已存在该文件名的记录
|
|
|
+ existing_record = product_import_manager.get_product_import_by_filename(file_name)
|
|
|
+
|
|
|
+ if existing_record:
|
|
|
+ logger.info(f"数据库中已存在文件 {file_name} 的记录,直接返回")
|
|
|
+ return existing_record
|
|
|
+
|
|
|
+ logger.info(f"数据库中不存在文件 {file_name} 的记录,开始解析Excel并保存到数据库")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 解析Excel文件为Markdown格式
|
|
|
+ markdown_content = parse_url_to_markdown_task(file_url)
|
|
|
+
|
|
|
+ if not markdown_content:
|
|
|
+ logger.warning(f"Excel文件解析失败或为空: {file_url}")
|
|
|
+ raise Exception(f"Excel文件解析失败或为空: {file_url}")
|
|
|
+
|
|
|
+ # 使用LLM从Markdown内容中提取产品信息
|
|
|
+ product_import = extract_product_from_text(
|
|
|
+ text=markdown_content,
|
|
|
+ uri=file_url,
|
|
|
+ filename=file_name
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存到数据库
|
|
|
+ saved_record = product_import_manager.save_product_import(product_import)
|
|
|
+
|
|
|
+ logger.info(f"成功解析Excel并保存到数据库: {file_name}")
|
|
|
+ return saved_record
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理文件 {file_name} 时发生错误: {e}")
|
|
|
+ raise Exception(f"处理文件失败: {e}")
|
|
|
+
|
|
|
+
|
|
|
+class ProductImportInput(BaseModel):
|
|
|
+ """产品导入输入模型"""
|
|
|
+ file_url: Union[str, List[str]] = Field(description="文件的URL或本地路径,可以是字符串或列表")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+@flow(
|
|
|
+ name="产品导入流程",
|
|
|
+ persist_result=True,
|
|
|
+ result_serializer="json",
|
|
|
+)
|
|
|
+def product_import_flow(flow_input: ProductImportInput):
|
|
|
+ """产品导入Prefect流程,支持字符串或列表输入,并发执行解析"""
|
|
|
+ # 处理输入,统一转换为URL列表
|
|
|
+ if isinstance(flow_input.file_url, str):
|
|
|
+ logger.info(f"输入为字符串,尝试提取URL: {flow_input.file_url}")
|
|
|
+ # 如果是字符串,尝试提取URL
|
|
|
+ urls = extract_urls_from_text(flow_input.file_url)
|
|
|
+ if not urls:
|
|
|
+ # 如果没有提取到URL,假设整个字符串就是一个URL
|
|
|
+ urls = [flow_input.file_url]
|
|
|
+ logger.info(f"提取到 {len(urls)} 个URL: {urls}")
|
|
|
+ else:
|
|
|
+ # 如果是列表,直接使用
|
|
|
+ urls = flow_input.file_url
|
|
|
+ logger.info(f"输入为列表,共 {len(urls)} 个URL: {urls}")
|
|
|
+
|
|
|
+ # 并发执行所有URL的解析
|
|
|
+ all_futures = []
|
|
|
+ for url in urls:
|
|
|
+ future = get_or_create_product_import_by_url.with_options(
|
|
|
+ task_run_name=f"处理URL: {url}",
|
|
|
+ ).submit(url)
|
|
|
+ all_futures.append(future)
|
|
|
+
|
|
|
+ # 等待所有任务完成
|
|
|
+ logger.info(f"等待 {len(all_futures)} 个任务完成...")
|
|
|
+ results = [future.result() for future in wait(all_futures).done]
|
|
|
+
|
|
|
+ logger.info(f"所有任务完成,成功处理 {len(results)} 个文件")
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'status': 'success',
|
|
|
+ 'product_imports': results,
|
|
|
+ 'file_urls': urls,
|
|
|
+ 'total_count': len(results)
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@task
|
|
|
+def product_import_task(flow_input: ProductImportInput):
|
|
|
+ """产品导入任务"""
|
|
|
+ return product_import_flow(flow_input)
|
|
|
+
|