browser_env.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import atexit
  2. import base64
  3. import io
  4. import json
  5. import multiprocessing
  6. import time
  7. import uuid
  8. import browsergym.core # noqa F401 (we register the openended task as a gym environment)
  9. import gymnasium as gym
  10. import html2text
  11. import numpy as np
  12. import tenacity
  13. from browsergym.utils.obs import flatten_dom_to_str
  14. from PIL import Image
  15. from opendevin.core.exceptions import BrowserInitException
  16. from opendevin.core.logger import opendevin_logger as logger
  17. BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL'
  18. BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS'
  19. class BrowserEnv:
  20. def __init__(self, browsergym_eval_env: str | None = None):
  21. self.html_text_converter = self.get_html_text_converter()
  22. self.eval_mode = False
  23. self.eval_dir = ''
  24. # EVAL only: browsergym_eval_env must be provided for evaluation
  25. self.browsergym_eval_env = browsergym_eval_env
  26. self.eval_mode = bool(browsergym_eval_env)
  27. # Initialize browser environment process
  28. multiprocessing.set_start_method('spawn', force=True)
  29. self.browser_side, self.agent_side = multiprocessing.Pipe()
  30. self.init_browser()
  31. atexit.register(self.close)
  32. def get_html_text_converter(self):
  33. html_text_converter = html2text.HTML2Text()
  34. # ignore links and images
  35. html_text_converter.ignore_links = False
  36. html_text_converter.ignore_images = True
  37. # use alt text for images
  38. html_text_converter.images_to_alt = True
  39. # disable auto text wrapping
  40. html_text_converter.body_width = 0
  41. return html_text_converter
  42. @tenacity.retry(
  43. wait=tenacity.wait_fixed(1),
  44. stop=tenacity.stop_after_attempt(5),
  45. retry=tenacity.retry_if_exception_type(BrowserInitException),
  46. )
  47. def init_browser(self):
  48. logger.info('Starting browser env...')
  49. try:
  50. self.process = multiprocessing.Process(target=self.browser_process)
  51. self.process.start()
  52. except Exception as e:
  53. logger.error(f'Failed to start browser process: {e}')
  54. raise
  55. if not self.check_alive():
  56. self.close()
  57. raise BrowserInitException('Failed to start browser environment.')
  58. def browser_process(self):
  59. if self.eval_mode:
  60. assert self.browsergym_eval_env is not None
  61. logger.info('Initializing browser env for web browsing evaluation.')
  62. if 'webarena' in self.browsergym_eval_env:
  63. import browsergym.webarena # noqa F401 register webarena tasks as gym environments
  64. elif 'miniwob' in self.browsergym_eval_env:
  65. import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments
  66. else:
  67. raise ValueError(
  68. f'Unsupported browsergym eval env: {self.browsergym_eval_env}'
  69. )
  70. env = gym.make(self.browsergym_eval_env)
  71. else:
  72. env = gym.make(
  73. 'browsergym/openended',
  74. task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
  75. wait_for_user_message=False,
  76. headless=True,
  77. disable_env_checker=True,
  78. )
  79. obs, info = env.reset()
  80. # EVAL ONLY: save the goal into file for evaluation
  81. self.eval_goal = None
  82. self.eval_rewards: list[float] = []
  83. if self.eval_mode:
  84. logger.info(f"Browsing goal: {obs['goal']}")
  85. self.eval_goal = obs['goal']
  86. logger.info('Browser env started.')
  87. while True:
  88. try:
  89. if self.browser_side.poll(timeout=0.01):
  90. unique_request_id, action_data = self.browser_side.recv()
  91. # shutdown the browser environment
  92. if unique_request_id == 'SHUTDOWN':
  93. logger.info('SHUTDOWN recv, shutting down browser env...')
  94. env.close()
  95. return
  96. elif unique_request_id == 'IS_ALIVE':
  97. self.browser_side.send(('ALIVE', None))
  98. continue
  99. # EVAL ONLY: Get evaluation info
  100. if action_data['action'] == BROWSER_EVAL_GET_GOAL_ACTION:
  101. self.browser_side.send(
  102. (unique_request_id, {'text_content': self.eval_goal})
  103. )
  104. continue
  105. elif action_data['action'] == BROWSER_EVAL_GET_REWARDS_ACTION:
  106. self.browser_side.send(
  107. (
  108. unique_request_id,
  109. {'text_content': json.dumps(self.eval_rewards)},
  110. )
  111. )
  112. continue
  113. action = action_data['action']
  114. obs, reward, terminated, truncated, info = env.step(action)
  115. # EVAL ONLY: Save the rewards into file for evaluation
  116. if self.eval_mode:
  117. self.eval_rewards.append(reward)
  118. # add text content of the page
  119. html_str = flatten_dom_to_str(obs['dom_object'])
  120. obs['text_content'] = self.html_text_converter.handle(html_str)
  121. # make observation serializable
  122. obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
  123. obs['active_page_index'] = obs['active_page_index'].item()
  124. obs['elapsed_time'] = obs['elapsed_time'].item()
  125. self.browser_side.send((unique_request_id, obs))
  126. except KeyboardInterrupt:
  127. logger.info('Browser env process interrupted by user.')
  128. try:
  129. env.close()
  130. except Exception:
  131. pass
  132. return
  133. def step(self, action_str: str, timeout: float = 30) -> dict:
  134. """Execute an action in the browser environment and return the observation."""
  135. unique_request_id = str(uuid.uuid4())
  136. self.agent_side.send((unique_request_id, {'action': action_str}))
  137. start_time = time.time()
  138. while True:
  139. if time.time() - start_time > timeout:
  140. raise TimeoutError('Browser environment took too long to respond.')
  141. if self.agent_side.poll(timeout=0.01):
  142. response_id, obs = self.agent_side.recv()
  143. if response_id == unique_request_id:
  144. return obs
  145. def check_alive(self, timeout: float = 60):
  146. self.agent_side.send(('IS_ALIVE', None))
  147. if self.agent_side.poll(timeout=timeout):
  148. response_id, _ = self.agent_side.recv()
  149. if response_id == 'ALIVE':
  150. return True
  151. logger.info(f'Browser env is not alive. Response ID: {response_id}')
  152. def close(self):
  153. if not self.process.is_alive():
  154. return
  155. try:
  156. self.agent_side.send(('SHUTDOWN', None))
  157. self.process.join(5) # Wait for the process to terminate
  158. if self.process.is_alive():
  159. logger.error(
  160. 'Browser process did not terminate, forcefully terminating...'
  161. )
  162. self.process.terminate()
  163. self.process.join(5) # Wait for the process to terminate
  164. if self.process.is_alive():
  165. self.process.kill()
  166. self.process.join(5) # Wait for the process to terminate
  167. self.agent_side.close()
  168. self.browser_side.close()
  169. except Exception:
  170. logger.error('Encountered an error when closing browser env', exc_info=True)
  171. @staticmethod
  172. def image_to_png_base64_url(
  173. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  174. ):
  175. """Convert a numpy array to a base64 encoded png image url."""
  176. if isinstance(image, np.ndarray):
  177. image = Image.fromarray(image)
  178. if image.mode in ('RGBA', 'LA'):
  179. image = image.convert('RGB')
  180. buffered = io.BytesIO()
  181. image.save(buffered, format='PNG')
  182. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  183. return (
  184. f'data:image/png;base64,{image_base64}'
  185. if add_data_prefix
  186. else f'{image_base64}'
  187. )
  188. @staticmethod
  189. def image_to_jpg_base64_url(
  190. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  191. ):
  192. """Convert a numpy array to a base64 encoded jpeg image url."""
  193. if isinstance(image, np.ndarray):
  194. image = Image.fromarray(image)
  195. if image.mode in ('RGBA', 'LA'):
  196. image = image.convert('RGB')
  197. buffered = io.BytesIO()
  198. image.save(buffered, format='JPEG')
  199. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  200. return (
  201. f'data:image/jpeg;base64,{image_base64}'
  202. if add_data_prefix
  203. else f'{image_base64}'
  204. )