t_pymongo_template.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from jinja2 import Environment, BaseLoader
  2. import json
  3. from bson import json_util
  4. from src.manager.core.db_mongo import BaseMongoManager
  5. import asyncio
  6. from utils.logu import get_logger
  7. from src.models.field_config import FieldConfig, get_field_descriptions
  8. logger = get_logger('test')
  9. class MongoAggregationTemplate:
  10. def __init__(self):
  11. self.env = Environment(loader=BaseLoader())
  12. # 添加自定义过滤器
  13. self.env.filters['tojson'] = lambda v: json.dumps(v, default=json_util.default)
  14. def to_template_string(self, pipeline):
  15. """将聚合管道转换为模板字符串"""
  16. return json.dumps(pipeline, default=json_util.default)
  17. def render(self, template_str, context):
  18. """渲染模板字符串为可执行的聚合管道"""
  19. template = self.env.from_string(template_str)
  20. rendered = template.render(**context)
  21. try:
  22. return json.loads(rendered, object_hook=json_util.object_hook)
  23. except json.JSONDecodeError as e:
  24. raise ValueError(f"Invalid JSON after rendering: {e}\nRendered content:\n{rendered}")
  25. async def filter_aggregate_demo(product_name="电线保护套"):
  26. # 初始化模板处理器
  27. template_processor = MongoAggregationTemplate()
  28. # 定义聚合管道模板
  29. filter_competior_by_name = [
  30. {
  31. '$match': {
  32. 'basic_info.name': "{{ product_name }}"
  33. }
  34. },
  35. {
  36. '$project': {
  37. 'competitor_crawl_data': 1
  38. }
  39. }, {
  40. '$addFields': {
  41. 'competitors': {
  42. '$objectToArray': '$competitor_crawl_data'
  43. }
  44. }
  45. }, {
  46. '$unwind': '$competitors'
  47. }, {
  48. '$project': {
  49. '_id': 0,
  50. 'asin': '$competitors.k',
  51. 'product_info': {
  52. 'main_text': '$competitors.v.extra_result.product_info.main_text'
  53. },
  54. 'result_table': {
  55. '$map': {
  56. 'input': '$competitors.v.extra_result.result_table',
  57. 'as': 'item',
  58. 'in': {
  59. 'traffic_keyword': '$$item.traffic_keyword',
  60. 'monthly_searches': '$$item.monthly_searches'
  61. }
  62. }
  63. }
  64. }
  65. }
  66. ]
  67. # 将聚合管道转换为模板字符串
  68. template_str = template_processor.to_template_string(filter_competior_by_name)
  69. logger.info(f"Template string: {template_str}")
  70. # 初始化数据库
  71. db_manager = BaseMongoManager(db_name='test')
  72. await db_manager.initialize()
  73. db = db_manager.db
  74. # 渲染模板
  75. try:
  76. pipeline = template_processor.render(template_str, {
  77. "product_name": product_name,
  78. # 可以添加更多变量
  79. })
  80. logger.info(f"Rendered pipeline: {json.dumps(pipeline, indent=2, ensure_ascii=False)}")
  81. except ValueError as e:
  82. logger.error(f"Template rendering failed: {e}")
  83. return
  84. # 执行聚合查询
  85. result = await db.Product.aggregate(pipeline).to_list()
  86. logger.info(json.dumps(result, ensure_ascii=False))
  87. # 测试代码
  88. if __name__ == "__main__":
  89. asyncio.run(filter_aggregate_demo())