|
|
@@ -1,3 +1,4 @@
|
|
|
+import asyncio
|
|
|
from datetime import datetime
|
|
|
from typing import Optional, List
|
|
|
from beanie import Document, Indexed, init_beanie
|
|
|
@@ -9,11 +10,44 @@ from src.models.product_model import Product
|
|
|
from beanie.operators import Set, Rename
|
|
|
|
|
|
class BaseMongoManager:
|
|
|
+ _instance = None
|
|
|
+ _init_lock = asyncio.Lock()
|
|
|
+ _is_initialized = False
|
|
|
+
|
|
|
+ def __new__(cls):
|
|
|
+ if not cls._instance:
|
|
|
+ cls._instance = super().__new__(cls)
|
|
|
+ cls._instance.client = None
|
|
|
+ cls._instance.db = None
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
def __init__(self, mongo_url: str = None, db_name: str = None):
|
|
|
- self.client = AsyncIOMotorClient(mongo_url or MONGO_URL)
|
|
|
- self.db = self.client[db_name or MONGO_DB_NAME]
|
|
|
+ if not hasattr(self, 'client') or self.client is None:
|
|
|
+ self.client = AsyncIOMotorClient(mongo_url or MONGO_URL)
|
|
|
+ self.db = self.client[db_name or MONGO_DB_NAME]
|
|
|
+
|
|
|
+ async def _ensure_initialized(self):
|
|
|
+ """确保数据库已初始化"""
|
|
|
+ if not self._is_initialized:
|
|
|
+ await self.initialize()
|
|
|
+
|
|
|
async def initialize(self):
|
|
|
- await init_beanie(database=self.db, document_models=[Product])
|
|
|
+ async with self._init_lock:
|
|
|
+ if not self._is_initialized:
|
|
|
+ if not hasattr(self, 'db') or self.db is None:
|
|
|
+ self.__init__() # 确保client和db已初始化
|
|
|
+ await init_beanie(database=self.db, document_models=[Product])
|
|
|
+ self._is_initialized = True
|
|
|
+
|
|
|
+ async def check_connection(self) -> bool:
|
|
|
+ """检查数据库连接是否健康"""
|
|
|
+ if not self._is_initialized:
|
|
|
+ await self.initialize()
|
|
|
+ try:
|
|
|
+ await self.client.admin.command('ping')
|
|
|
+ return True
|
|
|
+ except Exception:
|
|
|
+ return False
|
|
|
|
|
|
async def backup(self, source_collection_name: Document, backup_collection_name:str, backup_db_name: str = f"{MONGO_DB_NAME}_backup"):
|
|
|
backup_db = self.client[backup_db_name]
|
|
|
@@ -28,10 +62,12 @@ class ProductManagerMongo(BaseMongoManager):
|
|
|
return await product.insert()
|
|
|
|
|
|
async def get_product(self, product_id: int) -> Optional[Product]:
|
|
|
+ await self._ensure_initialized()
|
|
|
return await Product.find_one(Product.id == product_id)
|
|
|
|
|
|
async def get_product_by_name(self, name: str) -> Optional[Product]:
|
|
|
- return await Product.find_one(Product.basic_info.name == name)
|
|
|
+ await self._ensure_initialized()
|
|
|
+ return await Product.find_one(Product.basic_info["name"] == name)
|
|
|
async def migrate_field():
|
|
|
# 迁移字段名 competitor_analysis 到 competitor_crawl_data
|
|
|
await Product.find(
|