browser_env.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import atexit
  2. import base64
  3. import io
  4. import json
  5. import multiprocessing
  6. import os
  7. import threading
  8. import time
  9. import uuid
  10. import browsergym.core # noqa F401 (we register the openended task as a gym environment)
  11. import gymnasium as gym
  12. import html2text
  13. import numpy as np
  14. from browsergym.utils.obs import flatten_dom_to_str
  15. from PIL import Image
  16. from opendevin.core.exceptions import BrowserInitException
  17. from opendevin.core.logger import opendevin_logger as logger
  18. class BrowserEnv:
  19. def __init__(
  20. self,
  21. is_async: bool = True,
  22. browsergym_eval: str = '',
  23. browsergym_eval_save_dir: str = '',
  24. ):
  25. self.html_text_converter = self.get_html_text_converter()
  26. self.eval_mode = False
  27. self.eval_dir = ''
  28. # EVAL only: browsergym_eval and browsergym_eval_save_dir must be provided for evaluation
  29. self.browsergym_eval = browsergym_eval
  30. self.browsergym_eval_save_dir = browsergym_eval_save_dir
  31. if self.browsergym_eval:
  32. assert (
  33. self.browsergym_eval_save_dir
  34. ), 'browsergym_eval_save_dir must be provided for evaluation.'
  35. self.eval_mode = True
  36. self.eval_dir = os.path.join(
  37. self.browsergym_eval_save_dir, self.browsergym_eval.split('/')[1]
  38. )
  39. os.makedirs(self.eval_dir, exist_ok=True)
  40. # Initialize browser environment process
  41. multiprocessing.set_start_method('spawn', force=True)
  42. self.browser_side, self.agent_side = multiprocessing.Pipe()
  43. self.process = multiprocessing.Process(
  44. target=self.browser_process,
  45. )
  46. if is_async:
  47. threading.Thread(target=self.init_browser).start()
  48. else:
  49. self.init_browser()
  50. atexit.register(self.close)
  51. def get_html_text_converter(self):
  52. html_text_converter = html2text.HTML2Text()
  53. # ignore links and images
  54. html_text_converter.ignore_links = False
  55. html_text_converter.ignore_images = True
  56. # use alt text for images
  57. html_text_converter.images_to_alt = True
  58. # disable auto text wrapping
  59. html_text_converter.body_width = 0
  60. return html_text_converter
  61. def init_browser(self):
  62. logger.info('Starting browser env...')
  63. self.process.start()
  64. if not self.check_alive():
  65. self.close()
  66. raise BrowserInitException('Failed to start browser environment.')
  67. def browser_process(self):
  68. if self.eval_mode:
  69. logger.info('Creating browser env for evaluation purpose.')
  70. env = gym.make(self.browsergym_eval)
  71. else:
  72. env = gym.make(
  73. 'browsergym/openended',
  74. task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
  75. wait_for_user_message=False,
  76. headless=True,
  77. disable_env_checker=True,
  78. )
  79. obs, info = env.reset()
  80. # EVAL only: save the goal into file for evaluation
  81. if self.eval_mode:
  82. rewards = [] # store rewards if in eval mode
  83. logger.info(obs['goal'])
  84. with open(
  85. os.path.join(self.eval_dir, 'goal.txt'), 'w', encoding='utf-8'
  86. ) as f:
  87. f.write(obs['goal'])
  88. logger.info('Browser env started.')
  89. while True:
  90. try:
  91. if self.browser_side.poll(timeout=0.01):
  92. unique_request_id, action_data = self.browser_side.recv()
  93. # shutdown the browser environment
  94. if unique_request_id == 'SHUTDOWN':
  95. logger.info('SHUTDOWN recv, shutting down browser env...')
  96. env.close()
  97. return
  98. elif unique_request_id == 'IS_ALIVE':
  99. self.browser_side.send(('ALIVE', None))
  100. continue
  101. action = action_data['action']
  102. obs, reward, terminated, truncated, info = env.step(action)
  103. # EVAL only: save the rewards into file for evaluation
  104. if self.eval_mode:
  105. rewards.append(reward)
  106. with open(
  107. os.path.join(self.eval_dir, 'rewards.json'),
  108. 'w',
  109. encoding='utf-8',
  110. ) as f:
  111. f.write(json.dumps(rewards))
  112. # add text content of the page
  113. html_str = flatten_dom_to_str(obs['dom_object'])
  114. obs['text_content'] = self.html_text_converter.handle(html_str)
  115. # make observation serializable
  116. obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
  117. obs['active_page_index'] = obs['active_page_index'].item()
  118. obs['elapsed_time'] = obs['elapsed_time'].item()
  119. self.browser_side.send((unique_request_id, obs))
  120. except KeyboardInterrupt:
  121. logger.info('Browser env process interrupted by user.')
  122. try:
  123. env.close()
  124. except Exception:
  125. pass
  126. return
  127. def step(self, action_str: str, timeout: float = 30) -> dict:
  128. unique_request_id = str(uuid.uuid4())
  129. self.agent_side.send((unique_request_id, {'action': action_str}))
  130. start_time = time.time()
  131. while True:
  132. if time.time() - start_time > timeout:
  133. raise TimeoutError('Browser environment took too long to respond.')
  134. if self.agent_side.poll(timeout=0.01):
  135. response_id, obs = self.agent_side.recv()
  136. if response_id == unique_request_id:
  137. return obs
  138. def check_alive(self, timeout: float = 60):
  139. self.agent_side.send(('IS_ALIVE', None))
  140. if self.agent_side.poll(timeout=timeout):
  141. response_id, _ = self.agent_side.recv()
  142. if response_id == 'ALIVE':
  143. return True
  144. logger.info(f'Browser env is not alive. Response ID: {response_id}')
  145. def close(self):
  146. if not self.process.is_alive():
  147. return
  148. try:
  149. self.agent_side.send(('SHUTDOWN', None))
  150. self.process.join(5) # Wait for the process to terminate
  151. if self.process.is_alive():
  152. logger.error(
  153. 'Browser process did not terminate, forcefully terminating...'
  154. )
  155. self.process.terminate()
  156. self.process.join(5) # Wait for the process to terminate
  157. if self.process.is_alive():
  158. self.process.kill()
  159. self.process.join(5) # Wait for the process to terminate
  160. self.agent_side.close()
  161. self.browser_side.close()
  162. except Exception:
  163. logger.error('Encountered an error when closing browser env', exc_info=True)
  164. @staticmethod
  165. def image_to_png_base64_url(
  166. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  167. ):
  168. """Convert a numpy array to a base64 encoded png image url."""
  169. if isinstance(image, np.ndarray):
  170. image = Image.fromarray(image)
  171. if image.mode in ('RGBA', 'LA'):
  172. image = image.convert('RGB')
  173. buffered = io.BytesIO()
  174. image.save(buffered, format='PNG')
  175. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  176. return (
  177. f'data:image/png;base64,{image_base64}'
  178. if add_data_prefix
  179. else f'{image_base64}'
  180. )
  181. @staticmethod
  182. def image_to_jpg_base64_url(
  183. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  184. ):
  185. """Convert a numpy array to a base64 encoded jpeg image url."""
  186. if isinstance(image, np.ndarray):
  187. image = Image.fromarray(image)
  188. if image.mode in ('RGBA', 'LA'):
  189. image = image.convert('RGB')
  190. buffered = io.BytesIO()
  191. image.save(buffered, format='JPEG')
  192. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  193. return (
  194. f'data:image/jpeg;base64,{image_base64}'
  195. if add_data_prefix
  196. else f'{image_base64}'
  197. )