| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- from datetime import datetime
- from paddleocr import PaddleOCR, draw_ocr
- from rapidfuzz import fuzz
- import json
- import time
- from typing import List, Optional,Coroutine, Dict, Tuple,Any
- import win32con
- import win32clipboard
- import win32gui
- from pathlib import Path
- from sqlmodel import SQLModel, Field, Column, JSON
- from DrissionPage import ChromiumPage
- from ai.driver.cv_common import ImageMatchResult
- from ai.driver import backend_win32com,bro_page_pyautogui,browser_win32,send_input,cv_common
- from ai.conf_ai.config import load_chrome_from_ini,get_logger,RESOURCE,get_browser,BRO_INI_FILE
- from ai.gpt_node.async_wraps import async_wrapper,thread_safe
- logger = get_logger(f'ai/gpt_node-driver-{BRO_INI_FILE}')
- class WindowsInfo(SQLModel, table=False):
- id: Optional[int] = Field(default=None, primary_key=True)
- win_rect: Optional[list] = Field(default=[],sa_column=Column(JSON))
- win_size: Optional[list] = Field(default=[],sa_column=Column(JSON))
- hwnd: Optional[int] = Field(default=None)
- window_title:Optional[str] = Field(default='')
- pid: Optional[int] = Field(default=None)
- tab_id: Optional[str] = Field(default='')
- class OCRMatch(SQLModel, table=False):
- top_left: Tuple[float, float] = Field(default=(0.0, 0.0))
- top_right: Tuple[float, float] = Field(default=(0.0, 0.0))
- bottom_right: Tuple[float, float] = Field(default=(0.0, 0.0))
- bottom_left: Tuple[float, float] = Field(default=(0.0, 0.0))
- find_txt: str = Field(default="")
- find_txt_similarity: int = Field(default=0)
- ocr_txt: str = Field(default="")
- ocr_confidence: float = Field(default=0.0)
- win_info: WindowsInfo = Field(default=None)
-
- def center(self):
- return (self.top_left[0] + self.bottom_right[0]) / 2, (self.top_left[1] + self.bottom_right[1]) / 2
- def is_match(self, similarity:int=90):
- return self.find_txt_similarity > similarity
- class GptImgMatch(SQLModel, table=False):
- img_path:Optional[str] = Field(default=None)
- init_pos:Optional[list] = Field(default=[0,0],sa_column=Column(JSON))
- match_res:Optional[ImageMatchResult] = Field(default=None)
-
- def img_match(self, img, thread:float=0.92) ->ImageMatchResult|None:
- self.match_res:ImageMatchResult = cv_common.CVPage.match_img_in_screen(img, self.img_path)
- self.init_pos = self.match_res.max_location
- if self.is_match(thread):
- return self.match_res
-
- def is_match(self, thread:float=0.92):
- if not self.match_res:
- return False
- return self.match_res.match_max_val > thread
- def abs_pos(self, x,y, win_rect:list=None):
- win_x,win_y,_ , _ = win_rect
- return win_x + x, win_y + y
-
- def click(self, pos:list=[], hwnd=None):
- if not pos:
- if self.match_res:
- pos = self.match_res.max_location_center
- else:
- raise Exception("未识别到目标,请先识别坐标以点击")
- backend_win32com.VirtualKeyboard(hwnd).mouse_move_press(*pos)
- logger.info(f"{Path(self.img_path).name} {pos}")
- class BaseOCRMatch(SQLModel, table=False):
- copy: Optional[OCRMatch] = OCRMatch(find_txt='copy')
- paste: Optional[OCRMatch] = OCRMatch(find_txt='paste')
- select_all: Optional[OCRMatch] = OCRMatch(find_txt='select all')
- # 浏览器窗口标签图标
- class CvWinBroModel(SQLModel, table=False):
- id: Optional[int] = Field(default=None, primary_key=True)
- chatgpt: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'openai' / 'tab-icon.png'))
- claude: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'claude' / 'tab-icon.png'))
- refresh_btn: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'windows' / 'refrsh-btn.png'))
- refresh_btn2: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'windows' / 'refrsh-btn2.png'))
- paste: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'windows' / 'paste.png'))
-
-
- class DriverBase:
- name = 'About Version'
- url = ''
- # 浏览器对象,根据主进程启动 Chrome 得到浏览器主页面。仅有一个主进程,其他子窗口或标签都是子进程 tab 页面。
- page:ChromiumPage = None
- # 根据 PID 获得所有的窗口句柄
- all_windows_info:Dict[str,WindowsInfo] = {}
- paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en",)
- _lock = False
-
- def __init__(self, ini_file=BRO_INI_FILE) -> None:
- self.tab = None
- self.bromodel = CvWinBroModel()
- self.windows_info:WindowsInfo = None
- self.vb:backend_win32com.VirtualKeyboard = None
- self.ocr_model:BaseOCRMatch = None
- self.ini_file = ini_file
- @thread_safe
- def init(self):
- if not DriverBase.page:
- DriverBase.page: ChromiumPage = get_browser(self.ini_file)
- # logger.info(f"tab.rect {tab.title} tab.rect {tab.tab_id} tab.window_location {tab.rect.window_location} tab.size {tab.rect.size}")
- # DriverBase.page.get('chrome://version/')
- DriverBase.page.wait.doc_loaded(timeout=10)
- self.init_all_windows_info()
- return DriverBase.page
- # 获取所有窗口句柄和PID
- results = backend_win32com.enum_chrome_windows_by_pid(DriverBase.page.process_id)
- while results:
- hwnd, pid, title = results.pop()
- logger.info(f"{DriverBase.name} init {hwnd, pid, title} {DriverBase.name in title}")
- if main_windows is None and DriverBase.name in title:
- main_windows = (hwnd, pid, title)
- else:
- logger.info(f"close other {hwnd, pid, title}")
- win32gui.PostMessage(hwnd, win32con.WM_CLOSE, 0, 0)
- hwnd, pid, title = main_windows
- rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
- _, _, win_w, win_h = client_rect
- window_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title, tab_id=DriverBase.page.tab_id)
- DriverBase.all_windows_info.update({DriverBase.page.tab_id: window_info})
- logger.info(f"init page browser_version {DriverBase.page.browser_version} get_tabs {DriverBase.page.get_tabs()}")
-
- def init_all_windows_info(self, process_id=None, cache=False):
- if cache:
- return self.all_windows_info
- results = backend_win32com.enum_chrome_windows_by_pid(DriverBase.page.process_id)
- tabs = DriverBase.page.get_tabs()
- for hwnd, pid, title in results:
- rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
- _, _, win_w, win_h = client_rect
- tab_id = None
- window_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title, tab_id=tab_id)
- DriverBase.all_windows_info.update({hwnd: window_info})
- return DriverBase.all_windows_info
-
- def update_all_windows_info(self, hwnd, window_info):
- DriverBase.all_windows_info.update({hwnd: window_info})
- return DriverBase.all_windows_info
-
-
- def new_windows(self, url:str=''):
- while DriverBase._lock:
- time.sleep(0.1)
- DriverBase._lock = True
- try:
- self.tab = DriverBase.page.new_tab(url=url,new_window=True)
- # 获取所有窗口句柄和PID
- results = backend_win32com.enum_chrome_windows_by_pid(DriverBase.page.process_id)
- hwnd, pid, title = results.pop(0)
- self.windows_info = self.get_windows_info(hwnd, tab=self.tab)
- self.update_all_windows_info(hwnd, self.windows_info)
- self.vb = backend_win32com.VirtualKeyboard(hwnd)
- logger.info(f"{self.name} 打开新窗口 {self.windows_info.model_dump()}")
- logger.debug(f"self.all_windows_info {self.all_windows_info}")
- return self.windows_info
- except Exception as e:
- logger.exception(f"new_windows {e}")
- finally:
- DriverBase._lock = False
-
- @thread_safe
- def find_or_new_windows(self, title:str, url:str='') -> list[WindowsInfo] | WindowsInfo:
- res = []
- win = None
- for windows_info in DriverBase.all_windows_info.values():
- if title.lower() in windows_info.window_title.lower():
- if not windows_info.tab_id:
- win = windows_info
- if win:
- self.tab = self.find_shifted_tab(win)
- win.tab_id = self.tab.tab_id
- backend_win32com.show_win(win.hwnd)
- self.update_all_windows_info(win.hwnd, win)
- logger.info(f"{self.name} 已存在窗口 {win.model_dump()}")
- else:
- win = self.new_windows(url=url)
- return win
-
- @thread_safe
- def find_shifted_tab(self, windows_info: WindowsInfo):
- # Get initial tab positions
- initial_tabs = DriverBase.page.get_tabs()
- initial_positions = {tab.title: tab.rect.window_location for tab in initial_tabs}
-
- # Log initial positions
- # for tab in initial_tabs:
- # logger.info(f"Initial: tab.title {tab.title} tab.rect.window_location {tab.rect.window_location}")
-
- # Move window to the right
- x, y, _, _ = windows_info.win_rect
- win32gui.SetWindowPos(windows_info.hwnd, win32con.HWND_TOP, x + 100, y, 0, 0, win32con.SWP_NOSIZE)
-
- # Get new tab positions
- shifted_tabs = DriverBase.page.get_tabs()
-
- # Find the tab that has shifted
- shifted_tab = None
- for tab in shifted_tabs:
- new_position = tab.rect.window_location
- initial_position = initial_positions.get(tab.title)
-
- if initial_position and new_position != initial_position:
- shifted_tab = tab
- break
-
- # Move window back to original position
- win32gui.SetWindowPos(windows_info.hwnd, win32con.HWND_TOP, x, y, 0, 0, win32con.SWP_NOSIZE)
-
- return shifted_tab
-
- def close_windows(self):
- if not self.windows_info:
- return
- win32gui.PostMessage(self.windows_info.hwnd, win32con.WM_CLOSE, 0, 0)
- self.all_windows_info.pop(self.windows_info.hwnd)
- self.tab = None
- self.windows_info = None
-
- def close_no_tab_id():
- keys_to_remove = []
-
- for key in list(DriverBase.all_windows_info.keys()):
- windows_info = DriverBase.all_windows_info[key]
- if not windows_info.tab_id:
- logger.info(f"close windows {windows_info.model_dump()}")
- win32gui.PostMessage(windows_info.hwnd, win32con.WM_CLOSE, 0, 0)
- time.sleep(0.1)
- keys_to_remove.append(key)
-
- for key in keys_to_remove:
- DriverBase.all_windows_info.pop(key)
-
- def get_windows_info(self, hwnd:int, tab:ChromiumPage=None):
- pid = None if not self.page else self.page.process_id
- tab_id = None if not tab else tab.tab_id
- rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
- _, _, win_w, win_h = client_rect
- windows_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title,tab_id=tab_id)
- return windows_info
-
- def find(self, locator: GptImgMatch, screenshot=None, thread:float=0.98, region=None, raise_error=False) -> ImageMatchResult:
- if screenshot is None or not screenshot.any():
- screenshot = self.screenshot(region=region)
- result = locator.img_match(screenshot, thread=thread)
- if result:
- log = f"识别到目标 {Path(locator.img_path).name}"
- logger.debug(f"{log}")
- else:
- log = f"未识别到目标 {Path(locator.img_path).name} {locator}"
- if raise_error:
- raise Exception(log)
- logger.debug(f"{log}")
- return result
- @async_wrapper
- def click(self, locator: GptImgMatch, screenshot=None, thread:float=0.92, raise_error=False, button: str = "L"):
- if not locator.match_res:
- locator.match_res = self.find(locator,screenshot,thread)
- if not locator.match_res:
- logger.error(f"Failed to find {locator.img_path}")
- return False
- if locator.is_match(thread=thread):
- logger.info(f"{locator}")
- self.vb.mouse_move_press(*locator.match_res.max_location_center, button=button)
- return True
- if raise_error:
- raise Exception(f"Failed to find {locator}")
-
- def get_click_right_windows(self, pos:list, hwnd=None, wait_pop_time:float=0.3):
- if not hwnd:
- hwnd = self.windows_info.hwnd
- vb = backend_win32com.VirtualKeyboard(hwnd)
- else:
- vb = self.vb
- logger.info(f"hwnd {hwnd} pos {pos}")
- before_click_windows = backend_win32com.get_child_windows(hwnd)
- vb.mouse_move_press(*pos, button= 'R')
- time.sleep(wait_pop_time)
- after_click_windows = backend_win32com.get_child_windows(hwnd)
- new_windows = [hwnd for hwnd in after_click_windows if hwnd not in before_click_windows]
- if not new_windows:
- logger.error(f"右键找到新窗口 click")
- return
- pop_hwnd = new_windows[0]
- return pop_hwnd
-
- @async_wrapper
- def click_paste(self, txt_area_locator: GptImgMatch, hwnd=None, thread:float=0.92, raise_error=False):
- if not hwnd:
- hwnd = self.windows_info.hwnd
- before_click_windows = backend_win32com.get_child_windows(hwnd)
- self.click(txt_area_locator, thread=thread, raise_error=True, button="R")
- time.sleep(0.3)
- after_click_windows = backend_win32com.get_child_windows(hwnd)
- new_windows = [hwnd for hwnd in after_click_windows if hwnd not in before_click_windows]
- if not new_windows:
- logger.error(f"右键找到新窗口 click {txt_area_locator} ")
- return
- pop_hwnd = new_windows[0]
- screenshot = backend_win32com.back_end_screenshot(pop_hwnd)
- child_windows = backend_win32com.get_window_info(pop_hwnd)
- logger.info(f"child_windows {child_windows}")
- paste = self.bromodel.paste.img_match(screenshot, thread=thread)
- logger.info(f"paste img_match {paste}")
- vb = backend_win32com.VirtualKeyboard(pop_hwnd)
- vb.mouse_move_press(*self.bromodel.paste.match_res.max_location_center)
- time.sleep(0.3)
-
- def select_all(self, txt_area_pos:list, hwnd=None):
- pop_hwnd = self.get_click_right_windows(txt_area_pos, hwnd=hwnd)
- vb = backend_win32com.VirtualKeyboard(pop_hwnd)
- self.ocr_model.select_all = self.ocr_find_txt(self.ocr_model.select_all, pop_hwnd)
- vb.mouse_move_press(*self.ocr_model.select_all.center(), button= 'R')
-
- def paste_str(self, text: str, txt_area_locator: GptImgMatch, hwnd=None):
- logger.info(f"send_str {text}")
- win32clipboard.OpenClipboard()
- win32clipboard.EmptyClipboard()
- win32clipboard.SetClipboardText(text)
- win32clipboard.CloseClipboard()
- try:
- self.click_paste(txt_area_locator)
- except Exception as e:
- logger.error(f"send_str {text} error {e}")
- # Clear clipboard after pasting
- win32clipboard.OpenClipboard()
- win32clipboard.EmptyClipboard()
- win32clipboard.CloseClipboard()
-
- @async_wrapper
- def wait_for(self, locator: GptImgMatch, timeout: float = 10, interval: float = 0.5, thread: float = 0.98) -> ImageMatchResult:
- start_time = time.time()
- while time.time() - start_time < timeout:
- result = self.find(locator, thread=thread, raise_error=False)
- if result:
- return result
- time.sleep(interval)
-
- # If the timeout is reached without finding the image
- log = f"Timeout reached. Unable to find {Path(locator.img_path).name} within {timeout} seconds"
- logger.warning(log)
- return None
-
- def find_or_new_tab(self, url:str=''):
- if not url:
- url = self.url
- for tab_id in self.page.get_tabs():
- tab:ChromiumPage = self.page.get_tab(tab_id)
- if self.name in tab.url:
- return tab
- return self.page.new_tab(url=self.url,new_window=True)
-
-
-
- @async_wrapper
- def click_refresh(self):
- screenshot = self.screenshot()
- return (self.click(self.bromodel.refresh_btn, screenshot=screenshot) or
- self.click(self.bromodel.refresh_btn2, screenshot=screenshot))
- # return self.click(self.bromodel.refresh_btn)
-
- def get_tabs_icon(self, hwnd_list:List[int]):
- for hwnd in hwnd_list:
- self.find(self.bromodel.chatgpt,screen=backend_win32com.back_end_screenshot(hwnd))
- if self.bromodel.chatgpt.is_match():
- return
- def screenshot(self, hwnd=None, filename:str=None, region=None):
- if not hwnd:
- hwnd = self.windows_info.hwnd
- if region:
- return backend_win32com.back_end_screenshot_region(hwnd,filename, region)
- return backend_win32com.back_end_screenshot(hwnd,filename)
- def get_all_windows_info(self):
- results = backend_win32com.enum_chrome_windows_by_pid(self.page.process_id)
- ret = {}
- hwnd_to_tab_ids: Dict[int, List[str]] = {}
- for hwnd, pid, title in results:
- rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
- _, _, win_w, win_h = client_rect
- # logger.info(f"window_location {rect}")
- # logger.info(f"window_size {client_rect}")
- tab = self.find_tab_title(window_title)
- logger.info(f"window_title {window_title} hwnd {hwnd} client_rect {client_rect}")
- logger.info(f"tab title {tab.title}")
- window_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title,tab_id=tab.tab_id)
- ret.update({tab.tab_id:window_info})
- return ret
- def find_tab_title(self, window_title:str):
- for tab in self.page.get_tabs():
- logger.debug(f"{tab.title.lower()}")
- if tab.title.lower() in window_title.lower():
- return tab
-
- def ocr(self, screen)->List[OCRMatch]:
- '''
- 例如`ch`, `en`, `fr`, `german`, `korean`, `japan`
- need to run only once to download and load model into memory
- '''
- result = self.paddle_ocr.ocr(screen, cls=True)
- lines_model = []
- for idx in range(len(result)):
- res = result[idx]
- for line in res:
- coordinates, (text, confidence) = line
- ocr_match = OCRMatch(
- top_left=tuple(coordinates[0]),
- top_right=tuple(coordinates[1]),
- bottom_right=tuple(coordinates[2]),
- bottom_left=tuple(coordinates[3]),
- ocr_txt=text,
- ocr_confidence=confidence
- )
- lines_model.append(ocr_match)
- return lines_model
-
- def ocr_find_txt(self, ocr_model: OCRMatch, hwnd=None, ocr_result=None, threshold: float = 0.90, similarity_threshold: float = 80):
- if not ocr_result:
- screen = self.screenshot(hwnd)
- ocr_result = self.ocr(screen)
- assert ocr_model is not None
- logger.debug(f"ocr_model {ocr_model}")
- for ocr_match in ocr_result:
- if not ocr_match.ocr_confidence > threshold:
- continue
- similarity = fuzz.ratio(ocr_model.find_txt.lower(), ocr_match.ocr_txt.lower())
- # logger.info(f"ocr_match.ocr_txt {ocr_match.ocr_txt} - ocr_model.find_txt {ocr_model.find_txt} - similarity {similarity}")
-
- if similarity > similarity_threshold:
- ocr_match.find_txt_similarity = similarity
- ocr_match.find_txt = ocr_model.find_txt
- return ocr_match
- return None
-
- @async_wrapper
- def quit():
- if DriverBase.page:
- DriverBase.page.quit()
|