browser_env.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import atexit
  2. import base64
  3. import io
  4. import json
  5. import multiprocessing
  6. import time
  7. import uuid
  8. import browsergym.core # noqa F401 (we register the openended task as a gym environment)
  9. import gymnasium as gym
  10. import html2text
  11. import numpy as np
  12. import tenacity
  13. from browsergym.utils.obs import flatten_dom_to_str
  14. from PIL import Image
  15. from openhands.core.exceptions import BrowserInitException
  16. from openhands.core.logger import openhands_logger as logger
  17. from openhands.runtime.utils.shutdown_listener import should_continue, should_exit
  18. from openhands.utils.tenacity_stop import stop_if_should_exit
  19. BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
  20. BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
  21. class BrowserEnv:
  22. def __init__(self, browsergym_eval_env: str | None = None):
  23. self.html_text_converter = self.get_html_text_converter()
  24. self.eval_mode = False
  25. self.eval_dir = ''
  26. # EVAL only: browsergym_eval_env must be provided for evaluation
  27. self.browsergym_eval_env = browsergym_eval_env
  28. self.eval_mode = bool(browsergym_eval_env)
  29. # Initialize browser environment process
  30. multiprocessing.set_start_method('spawn', force=True)
  31. self.browser_side, self.agent_side = multiprocessing.Pipe()
  32. self.init_browser()
  33. atexit.register(self.close)
  34. def get_html_text_converter(self):
  35. html_text_converter = html2text.HTML2Text()
  36. # ignore links and images
  37. html_text_converter.ignore_links = False
  38. html_text_converter.ignore_images = True
  39. # use alt text for images
  40. html_text_converter.images_to_alt = True
  41. # disable auto text wrapping
  42. html_text_converter.body_width = 0
  43. return html_text_converter
  44. @tenacity.retry(
  45. wait=tenacity.wait_fixed(1),
  46. stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
  47. retry=tenacity.retry_if_exception_type(BrowserInitException),
  48. )
  49. def init_browser(self):
  50. logger.debug('Starting browser env...')
  51. try:
  52. self.process = multiprocessing.Process(target=self.browser_process)
  53. self.process.start()
  54. except Exception as e:
  55. logger.error(f'Failed to start browser process: {e}')
  56. raise
  57. if not self.check_alive():
  58. self.close()
  59. raise BrowserInitException('Failed to start browser environment.')
  60. def browser_process(self):
  61. if self.eval_mode:
  62. assert self.browsergym_eval_env is not None
  63. logger.debug('Initializing browser env for web browsing evaluation.')
  64. if 'webarena' in self.browsergym_eval_env:
  65. import browsergym.webarena # noqa F401 register webarena tasks as gym environments
  66. elif 'miniwob' in self.browsergym_eval_env:
  67. import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments
  68. else:
  69. raise ValueError(
  70. f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
  71. )
  72. env = gym.make(
  73. self.browsergym_eval_env,
  74. tags_to_mark='all',
  75. )
  76. else:
  77. env = gym.make(
  78. 'browsergym/openended',
  79. task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
  80. wait_for_user_message=False,
  81. headless=True,
  82. disable_env_checker=True,
  83. tags_to_mark='all',
  84. )
  85. obs, info = env.reset()
  86. # EVAL ONLY: save the goal into file for evaluation
  87. self.eval_goal = None
  88. self.eval_rewards: list[float] = []
  89. if self.eval_mode:
  90. logger.debug(f"Browsing goal: {obs['goal']}")
  91. self.eval_goal = obs['goal']
  92. logger.debug('Browser env started.')
  93. while should_continue():
  94. try:
  95. if self.browser_side.poll(timeout=0.01):
  96. unique_request_id, action_data = self.browser_side.recv()
  97. # shutdown the browser environment
  98. if unique_request_id == 'SHUTDOWN':
  99. logger.debug('SHUTDOWN recv, shutting down browser env...')
  100. env.close()
  101. return
  102. elif unique_request_id == 'IS_ALIVE':
  103. self.browser_side.send(('ALIVE', None))
  104. continue
  105. # EVAL ONLY: Get evaluation info
  106. if action_data['action'] == BROWSER_EVAL_GET_GOAL_ACTION:
  107. self.browser_side.send(
  108. (unique_request_id, {'text_content': self.eval_goal})
  109. )
  110. continue
  111. elif action_data['action'] == BROWSER_EVAL_GET_REWARDS_ACTION:
  112. self.browser_side.send(
  113. (
  114. unique_request_id,
  115. {'text_content': json.dumps(self.eval_rewards)},
  116. )
  117. )
  118. continue
  119. action = action_data['action']
  120. obs, reward, terminated, truncated, info = env.step(action)
  121. # EVAL ONLY: Save the rewards into file for evaluation
  122. if self.eval_mode:
  123. self.eval_rewards.append(reward)
  124. # add text content of the page
  125. html_str = flatten_dom_to_str(obs['dom_object'])
  126. obs['text_content'] = self.html_text_converter.handle(html_str)
  127. # make observation serializable
  128. obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
  129. obs['active_page_index'] = obs['active_page_index'].item()
  130. obs['elapsed_time'] = obs['elapsed_time'].item()
  131. self.browser_side.send((unique_request_id, obs))
  132. except KeyboardInterrupt:
  133. logger.debug('Browser env process interrupted by user.')
  134. try:
  135. env.close()
  136. except Exception:
  137. pass
  138. return
  139. def step(self, action_str: str, timeout: float = 30) -> dict:
  140. """Execute an action in the browser environment and return the observation."""
  141. unique_request_id = str(uuid.uuid4())
  142. self.agent_side.send((unique_request_id, {'action': action_str}))
  143. start_time = time.time()
  144. while True:
  145. if should_exit() or time.time() - start_time > timeout:
  146. raise TimeoutError('Browser environment took too long to respond.')
  147. if self.agent_side.poll(timeout=0.01):
  148. response_id, obs = self.agent_side.recv()
  149. if response_id == unique_request_id:
  150. return obs
  151. def check_alive(self, timeout: float = 60):
  152. self.agent_side.send(('IS_ALIVE', None))
  153. if self.agent_side.poll(timeout=timeout):
  154. response_id, _ = self.agent_side.recv()
  155. if response_id == 'ALIVE':
  156. return True
  157. logger.debug(f'Browser env is not alive. Response ID: {response_id}')
  158. def close(self):
  159. if not self.process.is_alive():
  160. return
  161. try:
  162. self.agent_side.send(('SHUTDOWN', None))
  163. self.process.join(5) # Wait for the process to terminate
  164. if self.process.is_alive():
  165. logger.error(
  166. 'Browser process did not terminate, forcefully terminating...'
  167. )
  168. self.process.terminate()
  169. self.process.join(5) # Wait for the process to terminate
  170. if self.process.is_alive():
  171. self.process.kill()
  172. self.process.join(5) # Wait for the process to terminate
  173. self.agent_side.close()
  174. self.browser_side.close()
  175. except Exception:
  176. logger.error('Encountered an error when closing browser env', exc_info=True)
  177. @staticmethod
  178. def image_to_png_base64_url(
  179. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  180. ):
  181. """Convert a numpy array to a base64 encoded png image url."""
  182. if isinstance(image, np.ndarray):
  183. image = Image.fromarray(image)
  184. if image.mode in ('RGBA', 'LA'):
  185. image = image.convert('RGB')
  186. buffered = io.BytesIO()
  187. image.save(buffered, format='PNG')
  188. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  189. return (
  190. f'data:image/png;base64,{image_base64}'
  191. if add_data_prefix
  192. else f'{image_base64}'
  193. )
  194. @staticmethod
  195. def image_to_jpg_base64_url(
  196. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  197. ):
  198. """Convert a numpy array to a base64 encoded jpeg image url."""
  199. if isinstance(image, np.ndarray):
  200. image = Image.fromarray(image)
  201. if image.mode in ('RGBA', 'LA'):
  202. image = image.convert('RGB')
  203. buffered = io.BytesIO()
  204. image.save(buffered, format='JPEG')
  205. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  206. return (
  207. f'data:image/jpeg;base64,{image_base64}'
  208. if add_data_prefix
  209. else f'{image_base64}'
  210. )