browser_env.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import atexit
  2. import base64
  3. import io
  4. import json
  5. import multiprocessing
  6. import time
  7. import uuid
  8. import browsergym.core # noqa F401 (we register the openended task as a gym environment)
  9. import gymnasium as gym
  10. import html2text
  11. import numpy as np
  12. import tenacity
  13. from browsergym.utils.obs import flatten_dom_to_str
  14. from PIL import Image
  15. from openhands.core.exceptions import BrowserInitException
  16. from openhands.core.logger import openhands_logger as logger
  17. from openhands.runtime.utils.shutdown_listener import should_continue, should_exit
  18. from openhands.utils.tenacity_stop import stop_if_should_exit
  19. BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
  20. BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
  21. class BrowserEnv:
  22. def __init__(self, browsergym_eval_env: str | None = None):
  23. self.html_text_converter = self.get_html_text_converter()
  24. self.eval_mode = False
  25. self.eval_dir = ''
  26. # EVAL only: browsergym_eval_env must be provided for evaluation
  27. self.browsergym_eval_env = browsergym_eval_env
  28. self.eval_mode = bool(browsergym_eval_env)
  29. # Initialize browser environment process
  30. multiprocessing.set_start_method('spawn', force=True)
  31. self.browser_side, self.agent_side = multiprocessing.Pipe()
  32. self.init_browser()
  33. atexit.register(self.close)
  34. def get_html_text_converter(self):
  35. html_text_converter = html2text.HTML2Text()
  36. # ignore links and images
  37. html_text_converter.ignore_links = False
  38. html_text_converter.ignore_images = True
  39. # use alt text for images
  40. html_text_converter.images_to_alt = True
  41. # disable auto text wrapping
  42. html_text_converter.body_width = 0
  43. return html_text_converter
  44. @tenacity.retry(
  45. wait=tenacity.wait_fixed(1),
  46. stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
  47. retry=tenacity.retry_if_exception_type(BrowserInitException),
  48. )
  49. def init_browser(self):
  50. logger.info('Starting browser env...')
  51. try:
  52. self.process = multiprocessing.Process(target=self.browser_process)
  53. self.process.start()
  54. except Exception as e:
  55. logger.error(f'Failed to start browser process: {e}')
  56. raise
  57. if not self.check_alive():
  58. self.close()
  59. raise BrowserInitException('Failed to start browser environment.')
  60. def browser_process(self):
  61. if self.eval_mode:
  62. assert self.browsergym_eval_env is not None
  63. logger.info('Initializing browser env for web browsing evaluation.')
  64. if 'webarena' in self.browsergym_eval_env:
  65. import browsergym.webarena # noqa F401 register webarena tasks as gym environments
  66. elif 'miniwob' in self.browsergym_eval_env:
  67. import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments
  68. else:
  69. raise ValueError(
  70. f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
  71. )
  72. env = gym.make(self.browsergym_eval_env)
  73. else:
  74. env = gym.make(
  75. 'browsergym/openended',
  76. task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
  77. wait_for_user_message=False,
  78. headless=True,
  79. disable_env_checker=True,
  80. )
  81. obs, info = env.reset()
  82. # EVAL ONLY: save the goal into file for evaluation
  83. self.eval_goal = None
  84. self.eval_rewards: list[float] = []
  85. if self.eval_mode:
  86. logger.info(f"Browsing goal: {obs['goal']}")
  87. self.eval_goal = obs['goal']
  88. logger.info('Browser env started.')
  89. while should_continue():
  90. try:
  91. if self.browser_side.poll(timeout=0.01):
  92. unique_request_id, action_data = self.browser_side.recv()
  93. # shutdown the browser environment
  94. if unique_request_id == 'SHUTDOWN':
  95. logger.info('SHUTDOWN recv, shutting down browser env...')
  96. env.close()
  97. return
  98. elif unique_request_id == 'IS_ALIVE':
  99. self.browser_side.send(('ALIVE', None))
  100. continue
  101. # EVAL ONLY: Get evaluation info
  102. if action_data['action'] == BROWSER_EVAL_GET_GOAL_ACTION:
  103. self.browser_side.send(
  104. (unique_request_id, {'text_content': self.eval_goal})
  105. )
  106. continue
  107. elif action_data['action'] == BROWSER_EVAL_GET_REWARDS_ACTION:
  108. self.browser_side.send(
  109. (
  110. unique_request_id,
  111. {'text_content': json.dumps(self.eval_rewards)},
  112. )
  113. )
  114. continue
  115. action = action_data['action']
  116. obs, reward, terminated, truncated, info = env.step(action)
  117. # EVAL ONLY: Save the rewards into file for evaluation
  118. if self.eval_mode:
  119. self.eval_rewards.append(reward)
  120. # add text content of the page
  121. html_str = flatten_dom_to_str(obs['dom_object'])
  122. obs['text_content'] = self.html_text_converter.handle(html_str)
  123. # make observation serializable
  124. obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
  125. obs['active_page_index'] = obs['active_page_index'].item()
  126. obs['elapsed_time'] = obs['elapsed_time'].item()
  127. self.browser_side.send((unique_request_id, obs))
  128. except KeyboardInterrupt:
  129. logger.info('Browser env process interrupted by user.')
  130. try:
  131. env.close()
  132. except Exception:
  133. pass
  134. return
  135. def step(self, action_str: str, timeout: float = 30) -> dict:
  136. """Execute an action in the browser environment and return the observation."""
  137. unique_request_id = str(uuid.uuid4())
  138. self.agent_side.send((unique_request_id, {'action': action_str}))
  139. start_time = time.time()
  140. while True:
  141. if should_exit() or time.time() - start_time > timeout:
  142. raise TimeoutError('Browser environment took too long to respond.')
  143. if self.agent_side.poll(timeout=0.01):
  144. response_id, obs = self.agent_side.recv()
  145. if response_id == unique_request_id:
  146. return obs
  147. def check_alive(self, timeout: float = 60):
  148. self.agent_side.send(('IS_ALIVE', None))
  149. if self.agent_side.poll(timeout=timeout):
  150. response_id, _ = self.agent_side.recv()
  151. if response_id == 'ALIVE':
  152. return True
  153. logger.info(f'Browser env is not alive. Response ID: {response_id}')
  154. def close(self):
  155. if not self.process.is_alive():
  156. return
  157. try:
  158. self.agent_side.send(('SHUTDOWN', None))
  159. self.process.join(5) # Wait for the process to terminate
  160. if self.process.is_alive():
  161. logger.error(
  162. 'Browser process did not terminate, forcefully terminating...'
  163. )
  164. self.process.terminate()
  165. self.process.join(5) # Wait for the process to terminate
  166. if self.process.is_alive():
  167. self.process.kill()
  168. self.process.join(5) # Wait for the process to terminate
  169. self.agent_side.close()
  170. self.browser_side.close()
  171. except Exception:
  172. logger.error('Encountered an error when closing browser env', exc_info=True)
  173. @staticmethod
  174. def image_to_png_base64_url(
  175. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  176. ):
  177. """Convert a numpy array to a base64 encoded png image url."""
  178. if isinstance(image, np.ndarray):
  179. image = Image.fromarray(image)
  180. if image.mode in ('RGBA', 'LA'):
  181. image = image.convert('RGB')
  182. buffered = io.BytesIO()
  183. image.save(buffered, format='PNG')
  184. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  185. return (
  186. f'data:image/png;base64,{image_base64}'
  187. if add_data_prefix
  188. else f'{image_base64}'
  189. )
  190. @staticmethod
  191. def image_to_jpg_base64_url(
  192. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  193. ):
  194. """Convert a numpy array to a base64 encoded jpeg image url."""
  195. if isinstance(image, np.ndarray):
  196. image = Image.fromarray(image)
  197. if image.mode in ('RGBA', 'LA'):
  198. image = image.convert('RGB')
  199. buffered = io.BytesIO()
  200. image.save(buffered, format='JPEG')
  201. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  202. return (
  203. f'data:image/jpeg;base64,{image_base64}'
  204. if add_data_prefix
  205. else f'{image_base64}'
  206. )