browser_env.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import atexit
  2. import base64
  3. import io
  4. import json
  5. import multiprocessing
  6. import os
  7. import threading
  8. import time
  9. import uuid
  10. import browsergym.core # noqa F401 (we register the openended task as a gym environment)
  11. import gymnasium as gym
  12. import html2text
  13. import numpy as np
  14. from browsergym.utils.obs import flatten_dom_to_str
  15. from PIL import Image
  16. from opendevin.core.exceptions import BrowserInitException
  17. from opendevin.core.logger import opendevin_logger as logger
  18. class BrowserEnv:
  19. def __init__(
  20. self,
  21. is_async: bool = True,
  22. browsergym_eval: str = '',
  23. browsergym_eval_save_dir: str = '',
  24. ):
  25. self.html_text_converter = self.get_html_text_converter()
  26. self.eval_mode = False
  27. self.eval_dir = ''
  28. # EVAL only: browsergym_eval and browsergym_eval_save_dir must be provided for evaluation
  29. self.browsergym_eval = browsergym_eval
  30. self.browsergym_eval_save_dir = browsergym_eval_save_dir
  31. if self.browsergym_eval:
  32. assert (
  33. self.browsergym_eval_save_dir
  34. ), 'browsergym_eval_save_dir must be provided for evaluation.'
  35. self.eval_mode = True
  36. self.eval_dir = os.path.join(
  37. self.browsergym_eval_save_dir, self.browsergym_eval.split('/')[1]
  38. )
  39. os.makedirs(self.eval_dir, exist_ok=True)
  40. # Initialize browser environment process
  41. multiprocessing.set_start_method('spawn', force=True)
  42. self.browser_side, self.agent_side = multiprocessing.Pipe()
  43. self.process = multiprocessing.Process(
  44. target=self.browser_process,
  45. )
  46. try:
  47. self.original_cwd = os.getcwd()
  48. except FileNotFoundError:
  49. logger.warning(
  50. 'Current working directory does not exist. Using /tmp as fallback.'
  51. )
  52. self.original_cwd = '/tmp'
  53. os.chdir('/tmp')
  54. if is_async:
  55. threading.Thread(target=self.init_browser).start()
  56. else:
  57. self.init_browser()
  58. atexit.register(self.close)
  59. def get_html_text_converter(self):
  60. html_text_converter = html2text.HTML2Text()
  61. # ignore links and images
  62. html_text_converter.ignore_links = False
  63. html_text_converter.ignore_images = True
  64. # use alt text for images
  65. html_text_converter.images_to_alt = True
  66. # disable auto text wrapping
  67. html_text_converter.body_width = 0
  68. return html_text_converter
  69. def init_browser(self):
  70. logger.info('Starting browser env...')
  71. # Ensure we're in a valid directory before starting the process
  72. try:
  73. os.chdir(self.original_cwd)
  74. logger.debug(f'Changed back to original directory: {self.original_cwd}')
  75. except Exception as e:
  76. logger.error(f'Failed to change to original directory: {e}')
  77. # If we can't change to the original directory, try to use a known valid directory
  78. os.chdir('/tmp')
  79. logger.debug('Changed to /tmp directory as fallback')
  80. try:
  81. self.process.start()
  82. except Exception as e:
  83. logger.error(f'Failed to start browser process: {e}')
  84. raise
  85. if not self.check_alive():
  86. self.close()
  87. raise BrowserInitException('Failed to start browser environment.')
  88. def browser_process(self):
  89. if self.eval_mode:
  90. logger.info('Creating browser env for evaluation purpose.')
  91. env = gym.make(self.browsergym_eval)
  92. else:
  93. env = gym.make(
  94. 'browsergym/openended',
  95. task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'},
  96. wait_for_user_message=False,
  97. headless=True,
  98. disable_env_checker=True,
  99. )
  100. obs, info = env.reset()
  101. # EVAL only: save the goal into file for evaluation
  102. if self.eval_mode:
  103. rewards = [] # store rewards if in eval mode
  104. logger.info(obs['goal'])
  105. with open(
  106. os.path.join(self.eval_dir, 'goal.txt'), 'w', encoding='utf-8'
  107. ) as f:
  108. f.write(obs['goal'])
  109. logger.info('Browser env started.')
  110. while True:
  111. try:
  112. if self.browser_side.poll(timeout=0.01):
  113. unique_request_id, action_data = self.browser_side.recv()
  114. # shutdown the browser environment
  115. if unique_request_id == 'SHUTDOWN':
  116. logger.info('SHUTDOWN recv, shutting down browser env...')
  117. env.close()
  118. return
  119. elif unique_request_id == 'IS_ALIVE':
  120. self.browser_side.send(('ALIVE', None))
  121. continue
  122. action = action_data['action']
  123. obs, reward, terminated, truncated, info = env.step(action)
  124. # EVAL only: save the rewards into file for evaluation
  125. if self.eval_mode:
  126. rewards.append(reward)
  127. with open(
  128. os.path.join(self.eval_dir, 'rewards.json'),
  129. 'w',
  130. encoding='utf-8',
  131. ) as f:
  132. f.write(json.dumps(rewards))
  133. # add text content of the page
  134. html_str = flatten_dom_to_str(obs['dom_object'])
  135. obs['text_content'] = self.html_text_converter.handle(html_str)
  136. # make observation serializable
  137. obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
  138. obs['active_page_index'] = obs['active_page_index'].item()
  139. obs['elapsed_time'] = obs['elapsed_time'].item()
  140. self.browser_side.send((unique_request_id, obs))
  141. except KeyboardInterrupt:
  142. logger.info('Browser env process interrupted by user.')
  143. try:
  144. env.close()
  145. except Exception:
  146. pass
  147. return
  148. def step(self, action_str: str, timeout: float = 30) -> dict:
  149. unique_request_id = str(uuid.uuid4())
  150. self.agent_side.send((unique_request_id, {'action': action_str}))
  151. start_time = time.time()
  152. while True:
  153. if time.time() - start_time > timeout:
  154. raise TimeoutError('Browser environment took too long to respond.')
  155. if self.agent_side.poll(timeout=0.01):
  156. response_id, obs = self.agent_side.recv()
  157. if response_id == unique_request_id:
  158. return obs
  159. def check_alive(self, timeout: float = 60):
  160. self.agent_side.send(('IS_ALIVE', None))
  161. if self.agent_side.poll(timeout=timeout):
  162. response_id, _ = self.agent_side.recv()
  163. if response_id == 'ALIVE':
  164. return True
  165. logger.info(f'Browser env is not alive. Response ID: {response_id}')
  166. def close(self):
  167. if not self.process.is_alive():
  168. return
  169. try:
  170. self.agent_side.send(('SHUTDOWN', None))
  171. self.process.join(5) # Wait for the process to terminate
  172. if self.process.is_alive():
  173. logger.error(
  174. 'Browser process did not terminate, forcefully terminating...'
  175. )
  176. self.process.terminate()
  177. self.process.join(5) # Wait for the process to terminate
  178. if self.process.is_alive():
  179. self.process.kill()
  180. self.process.join(5) # Wait for the process to terminate
  181. self.agent_side.close()
  182. self.browser_side.close()
  183. except Exception:
  184. logger.error('Encountered an error when closing browser env', exc_info=True)
  185. @staticmethod
  186. def image_to_png_base64_url(
  187. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  188. ):
  189. """Convert a numpy array to a base64 encoded png image url."""
  190. if isinstance(image, np.ndarray):
  191. image = Image.fromarray(image)
  192. if image.mode in ('RGBA', 'LA'):
  193. image = image.convert('RGB')
  194. buffered = io.BytesIO()
  195. image.save(buffered, format='PNG')
  196. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  197. return (
  198. f'data:image/png;base64,{image_base64}'
  199. if add_data_prefix
  200. else f'{image_base64}'
  201. )
  202. @staticmethod
  203. def image_to_jpg_base64_url(
  204. image: np.ndarray | Image.Image, add_data_prefix: bool = False
  205. ):
  206. """Convert a numpy array to a base64 encoded jpeg image url."""
  207. if isinstance(image, np.ndarray):
  208. image = Image.fromarray(image)
  209. if image.mode in ('RGBA', 'LA'):
  210. image = image.convert('RGB')
  211. buffered = io.BytesIO()
  212. image.save(buffered, format='JPEG')
  213. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  214. return (
  215. f'data:image/jpeg;base64,{image_base64}'
  216. if add_data_prefix
  217. else f'{image_base64}'
  218. )