browser_env.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import atexit
  2. import base64
  3. import io
  4. import multiprocessing
  5. import time
  6. import uuid
  7. import browsergym.core # noqa F401 (we register the openended task as a gym environment)
  8. import gymnasium as gym
  9. import html2text
  10. import numpy as np
  11. from browsergym.utils.obs import flatten_dom_to_str
  12. from PIL import Image
  13. from opendevin.core.logger import opendevin_logger as logger
  14. class BrowserException(Exception):
  15. pass
  16. class BrowserEnv:
  17. def __init__(self):
  18. self.html_text_converter = html2text.HTML2Text()
  19. # ignore links and images
  20. self.html_text_converter.ignore_links = True
  21. self.html_text_converter.ignore_images = True
  22. # use alt text for images
  23. self.html_text_converter.images_to_alt = True
  24. # disable auto text wrapping
  25. self.html_text_converter.body_width = 0
  26. # Initialize browser environment process
  27. multiprocessing.set_start_method('spawn', force=True)
  28. self.browser_side, self.agent_side = multiprocessing.Pipe()
  29. self.process = multiprocessing.Process(
  30. target=self.browser_process,
  31. )
  32. logger.info('Starting browser env...')
  33. self.process.start()
  34. atexit.register(self.close)
  35. def browser_process(self):
  36. env = gym.make(
  37. 'browsergym/openended',
  38. start_url='about:blank',
  39. wait_for_user_message=False,
  40. headless=True,
  41. disable_env_checker=True,
  42. )
  43. obs, info = env.reset()
  44. logger.info('Browser env started.')
  45. while True:
  46. try:
  47. if self.browser_side.poll(timeout=0.01):
  48. unique_request_id, action_data = self.browser_side.recv()
  49. # shutdown the browser environment
  50. if unique_request_id == 'SHUTDOWN':
  51. logger.info('SHUTDOWN recv, shutting down browser env...')
  52. env.close()
  53. return
  54. action = action_data['action']
  55. obs, reward, terminated, truncated, info = env.step(action)
  56. # add text content of the page
  57. html_str = flatten_dom_to_str(obs['dom_object'])
  58. obs['text_content'] = self.html_text_converter.handle(html_str)
  59. # make observation serializable
  60. obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot'])
  61. obs['active_page_index'] = obs['active_page_index'].item()
  62. obs['elapsed_time'] = obs['elapsed_time'].item()
  63. self.browser_side.send((unique_request_id, obs))
  64. except KeyboardInterrupt:
  65. logger.info('Browser env process interrupted by user.')
  66. try:
  67. env.close()
  68. except Exception:
  69. pass
  70. return
  71. def step(self, action_str: str, timeout: float = 10) -> dict:
  72. unique_request_id = str(uuid.uuid4())
  73. self.agent_side.send((unique_request_id, {'action': action_str}))
  74. start_time = time.time()
  75. while True:
  76. if time.time() - start_time > timeout:
  77. raise TimeoutError('Browser environment took too long to respond.')
  78. if self.agent_side.poll(timeout=0.01):
  79. response_id, obs = self.agent_side.recv()
  80. if response_id == unique_request_id:
  81. if obs['last_action_error']:
  82. raise BrowserException(obs['last_action_error'])
  83. return obs
  84. def close(self):
  85. try:
  86. self.agent_side.send(('SHUTDOWN', None))
  87. self.process.join()
  88. self.agent_side.close()
  89. self.browser_side.close()
  90. except Exception:
  91. pass
  92. @staticmethod
  93. def image_to_png_base64_url(image: np.ndarray | Image.Image):
  94. """Convert a numpy array to a base64 encoded png image url."""
  95. if isinstance(image, np.ndarray):
  96. image = Image.fromarray(image)
  97. if image.mode in ('RGBA', 'LA'):
  98. image = image.convert('RGB')
  99. buffered = io.BytesIO()
  100. image.save(buffered, format='PNG')
  101. image_base64 = base64.b64encode(buffered.getvalue()).decode()
  102. return f'{image_base64}'