browser_env.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import atexit
  2. import base64
  3. import io
  4. import json
  5. import multiprocessing
  6. import time
  7. import uuid
  8. import browsergym.core # noqa F401 (we register the openended task as a gym environment)
  9. import gymnasium as gym
  10. import html2text
  11. import numpy as np
  12. import tenacity
  13. from browsergym.utils.obs import flatten_dom_to_str
  14. from PIL import Image
  15. from openhands.core.exceptions import BrowserInitException
  16. from openhands.core.logger import openhands_logger as logger
  17. from openhands.runtime.utils.shutdown_listener import should_continue, should_exit
  18. BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
  19. BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
  20. class BrowserEnv:
  21. def __init__(self, browsergym_eval_env: str | None = None):
  22. self.html_text_converter = self.get_html_text_converter()
  23. self.eval_mode = False
  24. self.eval_dir = ''
  25. # EVAL only: browsergym_eval_env must be provided for evaluation
  26. self.browsergym_eval_env = browsergym_eval_env
  27. self.eval_mode = bool(browsergym_eval_env)
  28. # Initialize browser environment process
  29. multiprocessing.set_start_method('spawn', force=True)
  30. self.browser_side, self.agent_side = multiprocessing.Pipe()
  31. self.init_browser()
  32. atexit.register(self.close)
  33. def get_html_text_converter(self):
  34. html_text_converter = html2text.HTML2Text()
  35. # ignore links and images
  36. html_text_converter.ignore_links = False
  37. html_text_converter.ignore_images = True
  38. # use alt text for images
  39. html_text_converter.images_to_alt = True
  40. # disable auto text wrapping
  41. html_text_converter.body_width = 0
  42. return html_text_converter
  43. @tenacity.retry(
  44. wait=tenacity.wait_fixed(1),
  45. stop=tenacity.stop_after_attempt(5),
  46. retry=tenacity.retry_if_exception_type(BrowserInitException),
  47. )
  48. def init_browser(self):
  49. logger.info('Starting browser env...')
  50. try:
  51. self.process = multiprocessing.Process(target=self.browser_process)
  52. self.process.start()
  53. except Exception as e:
  54. logger.error(f'Failed to start browser process: {e}')
  55. raise
  56. if not self.check_alive():
  57. self.close()
  58. raise BrowserInitException('Failed to start browser environment.')
  59. def browser_process(self):
  60. if self.eval_mode:
  61. assert self.browsergym_eval_env is not None
  62. logger.info('Initializing browser env for web browsing evaluation.')
  63. if 'webarena' in self.browsergym_eval_env:
  64. import browsergym.webarena # noqa F401 register webarena tasks as gym environments
  65. elif 'miniwob' in self.browsergym_eval_env:
  66. import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments
  67. else:
  68. raise ValueError(
  69. f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
  70. )
  71. env = gym.make(self.browsergym_eval_env)
  72. else:
  73. env = gym.make(
  74. 'browsergym/openended',
  75. task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
  76. wait_for_user_message=False,
  77. headless=True,
  78. disable_env_checker=True,
  79. )
  80. obs, info = env.reset()
  81. # EVAL ONLY: save the goal into file for evaluation
  82. self.eval_goal = None
  83. self.eval_rewards: list[float] = []
  84. if self.eval_mode:
  85. logger.info(f"Browsing goal: {obs['goal']}")
  86. self.eval_goal = obs['goal']
  87. logger.info('Browser env started.')
  88. while should_continue():
  89. try:
  90. if self.browser_side.poll(timeout=0.01):
  91. unique_request_id, action_data = self.browser_side.recv()
  92. # shutdown the browser environment
  93. if unique_request_id == 'SHUTDOWN':
  94. logger.info('SHUTDOWN recv, shutting down browser env...')
  95. env.close()
  96. return
  97. elif unique_request_id == 'IS_ALIVE':
  98. self.browser_side.send(('ALIVE', None))
  99. continue
  100. # EVAL ONLY: Get evaluation info
  101. if action_data['action'] == BROWSER_EVAL_GET_GOAL_ACTION:
  102. self.browser_side.send(
  103. (unique_request_id, {'text_content': self.eval_goal})
  104. )
  105. continue
  106. elif action_data['action'] == BROWSER_EVAL_GET_REWARDS_ACTION:
  107. self.browser_side.send(
  108. (
  109. unique_request_id,
  110. {'text_content': json.dumps(self.eval_rewards)},
  111. )
  112. )
  113. continue
  114. action = action_data['action']
  115. obs, reward, terminated, truncated, info = env.step(action)
  116. # EVAL ONLY: Save the rewards into file for evaluation
  117. if self.eval_mode:
  118. self.eval_rewards.append(reward)
  119. # add text content of the page
  120. html_str = flatten_dom_to_str(obs['dom_object'])
  121. obs['text_content'] = self.html_text_converter.handle(html_str)
  122. # make observation serializable
  123. obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
  124. obs['active_page_index'] = obs['active_page_index'].item()
  125. obs['elapsed_time'] = obs['elapsed_time'].item()
  126. self.browser_side.send((unique_request_id, obs))
  127. except KeyboardInterrupt:
  128. logger.info('Browser env process interrupted by user.')
  129. try:
  130. env.close()
  131. except Exception:
  132. pass
  133. return
  134. def step(self, action_str: str, timeout: float = 30) -> dict:
  135. """Execute an action in the browser environment and return the observation."""
  136. unique_request_id = str(uuid.uuid4())
  137. self.agent_side.send((unique_request_id, {'action': action_str}))
  138. start_time = time.time()
  139. while True:
  140. if should_exit() or time.time() - start_time > timeout:
  141. raise TimeoutError('Browser environment took too long to respond.')
  142. if self.agent_side.poll(timeout=0.01):
  143. response_id, obs = self.agent_side.recv()
  144. if response_id == unique_request_id:
  145. return obs
  146. def check_alive(self, timeout: float = 60):
  147. self.agent_side.send(('IS_ALIVE', None))
  148. if self.agent_side.poll(timeout=timeout):
  149. response_id, _ = self.agent_side.recv()
  150. if response_id == 'ALIVE':
  151. return True
  152. logger.info(f'Browser env is not alive. Response ID: {response_id}')
  153. def close(self):
  154. if not self.process.is_alive():
  155. return
  156. try:
  157. self.agent_side.send(('SHUTDOWN', None))
  158. self.process.join(5) # Wait for the process to terminate
  159. if self.process.is_alive():
  160. logger.error(
  161. 'Browser process did not terminate, forcefully terminating...'
  162. )
  163. self.process.terminate()
  164. self.process.join(5) # Wait for the process to terminate
  165. if self.process.is_alive():
  166. self.process.kill()
  167. self.process.join(5) # Wait for the process to terminate
  168. self.agent_side.close()
  169. self.browser_side.close()
  170. except Exception:
  171. logger.error('Encountered an error when closing browser env', exc_info=True)
  172. @staticmethod
  173. def image_to_png_base64_url(
  174. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  175. ):
  176. """Convert a numpy array to a base64 encoded png image url."""
  177. if isinstance(image, np.ndarray):
  178. image = Image.fromarray(image)
  179. if image.mode in ('RGBA', 'LA'):
  180. image = image.convert('RGB')
  181. buffered = io.BytesIO()
  182. image.save(buffered, format='PNG')
  183. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  184. return (
  185. f'data:image/png;base64,{image_base64}'
  186. if add_data_prefix
  187. else f'{image_base64}'
  188. )
  189. @staticmethod
  190. def image_to_jpg_base64_url(
  191. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  192. ):
  193. """Convert a numpy array to a base64 encoded jpeg image url."""
  194. if isinstance(image, np.ndarray):
  195. image = Image.fromarray(image)
  196. if image.mode in ('RGBA', 'LA'):
  197. image = image.convert('RGB')
  198. buffered = io.BytesIO()
  199. image.save(buffered, format='JPEG')
  200. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  201. return (
  202. f'data:image/jpeg;base64,{image_base64}'
  203. if add_data_prefix
  204. else f'{image_base64}'
  205. )