pyauto_windows.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. from datetime import datetime
  2. from paddleocr import PaddleOCR, draw_ocr
  3. from rapidfuzz import fuzz
  4. import json
  5. import time
  6. from typing import List, Optional,Coroutine, Dict, Tuple,Any
  7. import win32con
  8. import win32clipboard
  9. import win32gui
  10. from pathlib import Path
  11. from sqlmodel import SQLModel, Field, Column, JSON
  12. from DrissionPage import ChromiumPage
  13. from ai.driver.cv_common import ImageMatchResult
  14. from ai.driver import backend_win32com,bro_page_pyautogui,browser_win32,send_input,cv_common
  15. from ai.conf_ai.config import load_chrome_from_ini,get_logger,RESOURCE,get_browser,BRO_INI_FILE
  16. from ai.gpt_node.async_wraps import async_wrapper,thread_safe
  17. logger = get_logger(f'ai/gpt_node-driver-{BRO_INI_FILE}')
  18. class WindowsInfo(SQLModel, table=False):
  19. id: Optional[int] = Field(default=None, primary_key=True)
  20. win_rect: Optional[list] = Field(default=[],sa_column=Column(JSON))
  21. win_size: Optional[list] = Field(default=[],sa_column=Column(JSON))
  22. hwnd: Optional[int] = Field(default=None)
  23. window_title:Optional[str] = Field(default='')
  24. pid: Optional[int] = Field(default=None)
  25. tab_id: Optional[str] = Field(default='')
  26. class OCRMatch(SQLModel, table=False):
  27. top_left: Tuple[float, float] = Field(default=(0.0, 0.0))
  28. top_right: Tuple[float, float] = Field(default=(0.0, 0.0))
  29. bottom_right: Tuple[float, float] = Field(default=(0.0, 0.0))
  30. bottom_left: Tuple[float, float] = Field(default=(0.0, 0.0))
  31. find_txt: str = Field(default="")
  32. find_txt_similarity: int = Field(default=0)
  33. ocr_txt: str = Field(default="")
  34. ocr_confidence: float = Field(default=0.0)
  35. win_info: WindowsInfo = Field(default=None)
  36. def center(self):
  37. return (self.top_left[0] + self.bottom_right[0]) / 2, (self.top_left[1] + self.bottom_right[1]) / 2
  38. def is_match(self, similarity:int=90):
  39. return self.find_txt_similarity > similarity
  40. class GptImgMatch(SQLModel, table=False):
  41. img_path:Optional[str] = Field(default=None)
  42. init_pos:Optional[list] = Field(default=[0,0],sa_column=Column(JSON))
  43. match_res:Optional[ImageMatchResult] = Field(default=None)
  44. def img_match(self, img, thread:float=0.92) ->ImageMatchResult|None:
  45. self.match_res:ImageMatchResult = cv_common.CVPage.match_img_in_screen(img, self.img_path)
  46. self.init_pos = self.match_res.max_location
  47. if self.is_match(thread):
  48. return self.match_res
  49. def is_match(self, thread:float=0.92):
  50. if not self.match_res:
  51. return False
  52. return self.match_res.match_max_val > thread
  53. def abs_pos(self, x,y, win_rect:list=None):
  54. win_x,win_y,_ , _ = win_rect
  55. return win_x + x, win_y + y
  56. def click(self, pos:list=[], hwnd=None):
  57. if not pos:
  58. if self.match_res:
  59. pos = self.match_res.max_location_center
  60. else:
  61. raise Exception("未识别到目标,请先识别坐标以点击")
  62. backend_win32com.VirtualKeyboard(hwnd).mouse_move_press(*pos)
  63. logger.info(f"{Path(self.img_path).name} {pos}")
  64. class BaseOCRMatch(SQLModel, table=False):
  65. copy: Optional[OCRMatch] = OCRMatch(find_txt='copy')
  66. paste: Optional[OCRMatch] = OCRMatch(find_txt='paste')
  67. select_all: Optional[OCRMatch] = OCRMatch(find_txt='select all')
  68. # 浏览器窗口标签图标
  69. class CvWinBroModel(SQLModel, table=False):
  70. id: Optional[int] = Field(default=None, primary_key=True)
  71. chatgpt: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'openai' / 'tab-icon.png'))
  72. claude: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'claude' / 'tab-icon.png'))
  73. refresh_btn: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'windows' / 'refrsh-btn.png'))
  74. refresh_btn2: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'windows' / 'refrsh-btn2.png'))
  75. paste: Optional[GptImgMatch] = GptImgMatch(img_path=str(RESOURCE / 'windows' / 'paste.png'))
  76. class DriverBase:
  77. name = 'About Version'
  78. url = ''
  79. # 浏览器对象,根据主进程启动 Chrome 得到浏览器主页面。仅有一个主进程,其他子窗口或标签都是子进程 tab 页面。
  80. page:ChromiumPage = None
  81. # 根据 PID 获得所有的窗口句柄
  82. all_windows_info:Dict[str,WindowsInfo] = {}
  83. paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en",)
  84. _lock = False
  85. def __init__(self, ini_file=BRO_INI_FILE) -> None:
  86. self.tab = None
  87. self.bromodel = CvWinBroModel()
  88. self.windows_info:WindowsInfo = None
  89. self.vb:backend_win32com.VirtualKeyboard = None
  90. self.ocr_model:BaseOCRMatch = None
  91. self.ini_file = ini_file
  92. @thread_safe
  93. def init(self):
  94. if not DriverBase.page:
  95. DriverBase.page: ChromiumPage = get_browser(self.ini_file)
  96. # logger.info(f"tab.rect {tab.title} tab.rect {tab.tab_id} tab.window_location {tab.rect.window_location} tab.size {tab.rect.size}")
  97. # DriverBase.page.get('chrome://version/')
  98. DriverBase.page.wait.doc_loaded(timeout=10)
  99. self.init_all_windows_info()
  100. return DriverBase.page
  101. # 获取所有窗口句柄和PID
  102. results = backend_win32com.enum_chrome_windows_by_pid(DriverBase.page.process_id)
  103. while results:
  104. hwnd, pid, title = results.pop()
  105. logger.info(f"{DriverBase.name} init {hwnd, pid, title} {DriverBase.name in title}")
  106. if main_windows is None and DriverBase.name in title:
  107. main_windows = (hwnd, pid, title)
  108. else:
  109. logger.info(f"close other {hwnd, pid, title}")
  110. win32gui.PostMessage(hwnd, win32con.WM_CLOSE, 0, 0)
  111. hwnd, pid, title = main_windows
  112. rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
  113. _, _, win_w, win_h = client_rect
  114. window_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title, tab_id=DriverBase.page.tab_id)
  115. DriverBase.all_windows_info.update({DriverBase.page.tab_id: window_info})
  116. logger.info(f"init page browser_version {DriverBase.page.browser_version} get_tabs {DriverBase.page.get_tabs()}")
  117. def init_all_windows_info(self, process_id=None, cache=False):
  118. if cache:
  119. return self.all_windows_info
  120. results = backend_win32com.enum_chrome_windows_by_pid(DriverBase.page.process_id)
  121. tabs = DriverBase.page.get_tabs()
  122. for hwnd, pid, title in results:
  123. rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
  124. _, _, win_w, win_h = client_rect
  125. tab_id = None
  126. window_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title, tab_id=tab_id)
  127. DriverBase.all_windows_info.update({hwnd: window_info})
  128. return DriverBase.all_windows_info
  129. def update_all_windows_info(self, hwnd, window_info):
  130. DriverBase.all_windows_info.update({hwnd: window_info})
  131. return DriverBase.all_windows_info
  132. def new_windows(self, url:str=''):
  133. while DriverBase._lock:
  134. time.sleep(0.1)
  135. DriverBase._lock = True
  136. try:
  137. self.tab = DriverBase.page.new_tab(url=url,new_window=True)
  138. # 获取所有窗口句柄和PID
  139. results = backend_win32com.enum_chrome_windows_by_pid(DriverBase.page.process_id)
  140. hwnd, pid, title = results.pop(0)
  141. self.windows_info = self.get_windows_info(hwnd, tab=self.tab)
  142. self.update_all_windows_info(hwnd, self.windows_info)
  143. self.vb = backend_win32com.VirtualKeyboard(hwnd)
  144. logger.info(f"{self.name} 打开新窗口 {self.windows_info.model_dump()}")
  145. logger.debug(f"self.all_windows_info {self.all_windows_info}")
  146. return self.windows_info
  147. except Exception as e:
  148. logger.exception(f"new_windows {e}")
  149. finally:
  150. DriverBase._lock = False
  151. @thread_safe
  152. def find_or_new_windows(self, title:str, url:str='') -> list[WindowsInfo] | WindowsInfo:
  153. res = []
  154. win = None
  155. for windows_info in DriverBase.all_windows_info.values():
  156. if title.lower() in windows_info.window_title.lower():
  157. if not windows_info.tab_id:
  158. win = windows_info
  159. if win:
  160. self.tab = self.find_shifted_tab(win)
  161. win.tab_id = self.tab.tab_id
  162. backend_win32com.show_win(win.hwnd)
  163. self.update_all_windows_info(win.hwnd, win)
  164. logger.info(f"{self.name} 已存在窗口 {win.model_dump()}")
  165. else:
  166. win = self.new_windows(url=url)
  167. return win
  168. @thread_safe
  169. def find_shifted_tab(self, windows_info: WindowsInfo):
  170. # Get initial tab positions
  171. initial_tabs = DriverBase.page.get_tabs()
  172. initial_positions = {tab.title: tab.rect.window_location for tab in initial_tabs}
  173. # Log initial positions
  174. # for tab in initial_tabs:
  175. # logger.info(f"Initial: tab.title {tab.title} tab.rect.window_location {tab.rect.window_location}")
  176. # Move window to the right
  177. x, y, _, _ = windows_info.win_rect
  178. win32gui.SetWindowPos(windows_info.hwnd, win32con.HWND_TOP, x + 100, y, 0, 0, win32con.SWP_NOSIZE)
  179. # Get new tab positions
  180. shifted_tabs = DriverBase.page.get_tabs()
  181. # Find the tab that has shifted
  182. shifted_tab = None
  183. for tab in shifted_tabs:
  184. new_position = tab.rect.window_location
  185. initial_position = initial_positions.get(tab.title)
  186. if initial_position and new_position != initial_position:
  187. shifted_tab = tab
  188. break
  189. # Move window back to original position
  190. win32gui.SetWindowPos(windows_info.hwnd, win32con.HWND_TOP, x, y, 0, 0, win32con.SWP_NOSIZE)
  191. return shifted_tab
  192. def close_windows(self):
  193. if not self.windows_info:
  194. return
  195. win32gui.PostMessage(self.windows_info.hwnd, win32con.WM_CLOSE, 0, 0)
  196. self.all_windows_info.pop(self.windows_info.hwnd)
  197. self.tab = None
  198. self.windows_info = None
  199. def close_no_tab_id():
  200. keys_to_remove = []
  201. for key in list(DriverBase.all_windows_info.keys()):
  202. windows_info = DriverBase.all_windows_info[key]
  203. if not windows_info.tab_id:
  204. logger.info(f"close windows {windows_info.model_dump()}")
  205. win32gui.PostMessage(windows_info.hwnd, win32con.WM_CLOSE, 0, 0)
  206. time.sleep(0.1)
  207. keys_to_remove.append(key)
  208. for key in keys_to_remove:
  209. DriverBase.all_windows_info.pop(key)
  210. def get_windows_info(self, hwnd:int, tab:ChromiumPage=None):
  211. pid = None if not self.page else self.page.process_id
  212. tab_id = None if not tab else tab.tab_id
  213. rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
  214. _, _, win_w, win_h = client_rect
  215. windows_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title,tab_id=tab_id)
  216. return windows_info
  217. def find(self, locator: GptImgMatch, screenshot=None, thread:float=0.98, region=None, raise_error=False) -> ImageMatchResult:
  218. if screenshot is None or not screenshot.any():
  219. screenshot = self.screenshot(region=region)
  220. result = locator.img_match(screenshot, thread=thread)
  221. if result:
  222. log = f"识别到目标 {Path(locator.img_path).name}"
  223. logger.debug(f"{log}")
  224. else:
  225. log = f"未识别到目标 {Path(locator.img_path).name} {locator}"
  226. if raise_error:
  227. raise Exception(log)
  228. logger.debug(f"{log}")
  229. return result
  230. @async_wrapper
  231. def click(self, locator: GptImgMatch, screenshot=None, thread:float=0.92, raise_error=False, button: str = "L"):
  232. if not locator.match_res:
  233. locator.match_res = self.find(locator,screenshot,thread)
  234. if not locator.match_res:
  235. logger.error(f"Failed to find {locator.img_path}")
  236. return False
  237. if locator.is_match(thread=thread):
  238. logger.info(f"{locator}")
  239. self.vb.mouse_move_press(*locator.match_res.max_location_center, button=button)
  240. return True
  241. if raise_error:
  242. raise Exception(f"Failed to find {locator}")
  243. def get_click_right_windows(self, pos:list, hwnd=None, wait_pop_time:float=0.3):
  244. if not hwnd:
  245. hwnd = self.windows_info.hwnd
  246. vb = backend_win32com.VirtualKeyboard(hwnd)
  247. else:
  248. vb = self.vb
  249. logger.info(f"hwnd {hwnd} pos {pos}")
  250. before_click_windows = backend_win32com.get_child_windows(hwnd)
  251. vb.mouse_move_press(*pos, button= 'R')
  252. time.sleep(wait_pop_time)
  253. after_click_windows = backend_win32com.get_child_windows(hwnd)
  254. new_windows = [hwnd for hwnd in after_click_windows if hwnd not in before_click_windows]
  255. if not new_windows:
  256. logger.error(f"右键找到新窗口 click")
  257. return
  258. pop_hwnd = new_windows[0]
  259. return pop_hwnd
  260. @async_wrapper
  261. def click_paste(self, txt_area_locator: GptImgMatch, hwnd=None, thread:float=0.92, raise_error=False):
  262. if not hwnd:
  263. hwnd = self.windows_info.hwnd
  264. before_click_windows = backend_win32com.get_child_windows(hwnd)
  265. self.click(txt_area_locator, thread=thread, raise_error=True, button="R")
  266. time.sleep(0.3)
  267. after_click_windows = backend_win32com.get_child_windows(hwnd)
  268. new_windows = [hwnd for hwnd in after_click_windows if hwnd not in before_click_windows]
  269. if not new_windows:
  270. logger.error(f"右键找到新窗口 click {txt_area_locator} ")
  271. return
  272. pop_hwnd = new_windows[0]
  273. screenshot = backend_win32com.back_end_screenshot(pop_hwnd)
  274. child_windows = backend_win32com.get_window_info(pop_hwnd)
  275. logger.info(f"child_windows {child_windows}")
  276. paste = self.bromodel.paste.img_match(screenshot, thread=thread)
  277. logger.info(f"paste img_match {paste}")
  278. vb = backend_win32com.VirtualKeyboard(pop_hwnd)
  279. vb.mouse_move_press(*self.bromodel.paste.match_res.max_location_center)
  280. time.sleep(0.3)
  281. def select_all(self, txt_area_pos:list, hwnd=None):
  282. pop_hwnd = self.get_click_right_windows(txt_area_pos, hwnd=hwnd)
  283. vb = backend_win32com.VirtualKeyboard(pop_hwnd)
  284. self.ocr_model.select_all = self.ocr_find_txt(self.ocr_model.select_all, pop_hwnd)
  285. vb.mouse_move_press(*self.ocr_model.select_all.center(), button= 'R')
  286. def paste_str(self, text: str, txt_area_locator: GptImgMatch, hwnd=None):
  287. logger.info(f"send_str {text}")
  288. win32clipboard.OpenClipboard()
  289. win32clipboard.EmptyClipboard()
  290. win32clipboard.SetClipboardText(text)
  291. win32clipboard.CloseClipboard()
  292. try:
  293. self.click_paste(txt_area_locator)
  294. except Exception as e:
  295. logger.error(f"send_str {text} error {e}")
  296. # Clear clipboard after pasting
  297. win32clipboard.OpenClipboard()
  298. win32clipboard.EmptyClipboard()
  299. win32clipboard.CloseClipboard()
  300. @async_wrapper
  301. def wait_for(self, locator: GptImgMatch, timeout: float = 10, interval: float = 0.5, thread: float = 0.98) -> ImageMatchResult:
  302. start_time = time.time()
  303. while time.time() - start_time < timeout:
  304. result = self.find(locator, thread=thread, raise_error=False)
  305. if result:
  306. return result
  307. time.sleep(interval)
  308. # If the timeout is reached without finding the image
  309. log = f"Timeout reached. Unable to find {Path(locator.img_path).name} within {timeout} seconds"
  310. logger.warning(log)
  311. return None
  312. def find_or_new_tab(self, url:str=''):
  313. if not url:
  314. url = self.url
  315. for tab_id in self.page.get_tabs():
  316. tab:ChromiumPage = self.page.get_tab(tab_id)
  317. if self.name in tab.url:
  318. return tab
  319. return self.page.new_tab(url=self.url,new_window=True)
  320. @async_wrapper
  321. def click_refresh(self):
  322. screenshot = self.screenshot()
  323. return (self.click(self.bromodel.refresh_btn, screenshot=screenshot) or
  324. self.click(self.bromodel.refresh_btn2, screenshot=screenshot))
  325. # return self.click(self.bromodel.refresh_btn)
  326. def get_tabs_icon(self, hwnd_list:List[int]):
  327. for hwnd in hwnd_list:
  328. self.find(self.bromodel.chatgpt,screen=backend_win32com.back_end_screenshot(hwnd))
  329. if self.bromodel.chatgpt.is_match():
  330. return
  331. def screenshot(self, hwnd=None, filename:str=None, region=None):
  332. if not hwnd:
  333. hwnd = self.windows_info.hwnd
  334. if region:
  335. return backend_win32com.back_end_screenshot_region(hwnd,filename, region)
  336. return backend_win32com.back_end_screenshot(hwnd,filename)
  337. def get_all_windows_info(self):
  338. results = backend_win32com.enum_chrome_windows_by_pid(self.page.process_id)
  339. ret = {}
  340. hwnd_to_tab_ids: Dict[int, List[str]] = {}
  341. for hwnd, pid, title in results:
  342. rect, client_rect, window_title = backend_win32com.get_window_info(hwnd)
  343. _, _, win_w, win_h = client_rect
  344. # logger.info(f"window_location {rect}")
  345. # logger.info(f"window_size {client_rect}")
  346. tab = self.find_tab_title(window_title)
  347. logger.info(f"window_title {window_title} hwnd {hwnd} client_rect {client_rect}")
  348. logger.info(f"tab title {tab.title}")
  349. window_info = WindowsInfo(pid=pid, hwnd=hwnd, win_rect=list(rect), win_size=[win_w, win_h], window_title=window_title,tab_id=tab.tab_id)
  350. ret.update({tab.tab_id:window_info})
  351. return ret
  352. def find_tab_title(self, window_title:str):
  353. for tab in self.page.get_tabs():
  354. logger.debug(f"{tab.title.lower()}")
  355. if tab.title.lower() in window_title.lower():
  356. return tab
  357. def ocr(self, screen)->List[OCRMatch]:
  358. '''
  359. 例如`ch`, `en`, `fr`, `german`, `korean`, `japan`
  360. need to run only once to download and load model into memory
  361. '''
  362. result = self.paddle_ocr.ocr(screen, cls=True)
  363. lines_model = []
  364. for idx in range(len(result)):
  365. res = result[idx]
  366. for line in res:
  367. coordinates, (text, confidence) = line
  368. ocr_match = OCRMatch(
  369. top_left=tuple(coordinates[0]),
  370. top_right=tuple(coordinates[1]),
  371. bottom_right=tuple(coordinates[2]),
  372. bottom_left=tuple(coordinates[3]),
  373. ocr_txt=text,
  374. ocr_confidence=confidence
  375. )
  376. lines_model.append(ocr_match)
  377. return lines_model
  378. def ocr_find_txt(self, ocr_model: OCRMatch, hwnd=None, ocr_result=None, threshold: float = 0.90, similarity_threshold: float = 80):
  379. if not ocr_result:
  380. screen = self.screenshot(hwnd)
  381. ocr_result = self.ocr(screen)
  382. assert ocr_model is not None
  383. logger.debug(f"ocr_model {ocr_model}")
  384. for ocr_match in ocr_result:
  385. if not ocr_match.ocr_confidence > threshold:
  386. continue
  387. similarity = fuzz.ratio(ocr_model.find_txt.lower(), ocr_match.ocr_txt.lower())
  388. # logger.info(f"ocr_match.ocr_txt {ocr_match.ocr_txt} - ocr_model.find_txt {ocr_model.find_txt} - similarity {similarity}")
  389. if similarity > similarity_threshold:
  390. ocr_match.find_txt_similarity = similarity
  391. ocr_match.find_txt = ocr_model.find_txt
  392. return ocr_match
  393. return None
  394. @async_wrapper
  395. def quit():
  396. if DriverBase.page:
  397. DriverBase.page.quit()