import abc import difflib import logging import platform from copy import deepcopy from dataclasses import asdict, dataclass from textwrap import dedent from typing import Literal, Union from warnings import warn from browsergym.core.action.base import AbstractActionSet from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.core.action.python import PythonActionSet from agenthub.browsing_agent.utils import ( ParseError, parse_html_tags_raise, ) from openhands.runtime.browser.browser_env import BrowserEnv @dataclass class Flags: use_html: bool = True use_ax_tree: bool = False drop_ax_tree_first: bool = True # This flag is no longer active TODO delete use_thinking: bool = False use_error_logs: bool = False use_past_error_logs: bool = False use_history: bool = False use_action_history: bool = False use_memory: bool = False use_diff: bool = False html_type: str = 'pruned_html' use_concrete_example: bool = True use_abstract_example: bool = False multi_actions: bool = False action_space: Literal[ 'python', 'bid', 'coord', 'bid+coord', 'bid+nav', 'coord+nav', 'bid+coord+nav' ] = 'bid' is_strict: bool = False # This flag will be automatically disabled `if not chat_model_args.has_vision()` use_screenshot: bool = True enable_chat: bool = False max_prompt_tokens: int = 100_000 extract_visible_tag: bool = False extract_coords: Literal['False', 'center', 'box'] = 'False' extract_visible_elements_only: bool = False demo_mode: Literal['off', 'default', 'only_visible_elements'] = 'off' def copy(self): return deepcopy(self) def asdict(self): """Helper for JSON serializble requirement.""" return asdict(self) @classmethod def from_dict(self, flags_dict): """Helper for JSON serializable requirement.""" if isinstance(flags_dict, Flags): return flags_dict if not isinstance(flags_dict, dict): raise ValueError( f'Unregcognized type for flags_dict of type {type(flags_dict)}.' ) return Flags(**flags_dict) class PromptElement: """Base class for all prompt elements. Prompt elements can be hidden. Prompt elements are used to build the prompt. Use flags to control which prompt elements are visible. We use class attributes as a convenient way to implement static prompts, but feel free to override them with instance attributes or @property decorator. """ _prompt = '' _abstract_ex = '' _concrete_ex = '' def __init__(self, visible: bool = True) -> None: """Prompt element that can be hidden. Parameters ---------- visible : bool, optional Whether the prompt element should be visible, by default True. Can be a callable that returns a bool. This is useful when a specific flag changes during a shrink iteration. """ self._visible = visible @property def prompt(self): """Avoid overriding this method. Override _prompt instead.""" return self._hide(self._prompt) @property def abstract_ex(self): """Useful when this prompt element is requesting an answer from the llm. Provide an abstract example of the answer here. See Memory for an example. Avoid overriding this method. Override _abstract_ex instead """ return self._hide(self._abstract_ex) @property def concrete_ex(self): """Useful when this prompt element is requesting an answer from the llm. Provide a concrete example of the answer here. See Memory for an example. Avoid overriding this method. Override _concrete_ex instead """ return self._hide(self._concrete_ex) @property def is_visible(self): """Handle the case where visible is a callable.""" visible = self._visible if callable(visible): visible = visible() return visible def _hide(self, value): """Return value if visible is True, else return empty string.""" if self.is_visible: return value else: return '' def _parse_answer(self, text_answer) -> dict: if self.is_visible: return self._parse_answer(text_answer) else: return {} class Shrinkable(PromptElement, abc.ABC): @abc.abstractmethod def shrink(self) -> None: """Implement shrinking of this prompt element. You need to recursively call all shrinkable elements that are part of this prompt. You can also implement a shrinking strategy for this prompt. Shrinking is can be called multiple times to progressively shrink the prompt until it fits max_tokens. Default max shrink iterations is 20. """ pass class Truncater(Shrinkable): """A prompt element that can be truncated to fit the context length of the LLM. Of course, it will be great that we never have to use the functionality here to `shrink()` the prompt. Extend this class for prompt elements that can be truncated. Usually long observations such as AxTree or HTML. """ def __init__(self, visible, shrink_speed=0.3, start_truncate_iteration=10): super().__init__(visible=visible) self.shrink_speed = shrink_speed # the percentage shrunk in each iteration self.start_truncate_iteration = ( start_truncate_iteration # the iteration to start truncating ) self.shrink_calls = 0 self.deleted_lines = 0 def shrink(self) -> None: if self.is_visible and self.shrink_calls >= self.start_truncate_iteration: # remove the fraction of _prompt lines = self._prompt.splitlines() new_line_count = int(len(lines) * (1 - self.shrink_speed)) self.deleted_lines += len(lines) - new_line_count self._prompt = '\n'.join(lines[:new_line_count]) self._prompt += ( f'\n... Deleted {self.deleted_lines} lines to reduce prompt size.' ) self.shrink_calls += 1 def fit_tokens( shrinkable: Shrinkable, max_prompt_chars=None, max_iterations=20, ): """Shrink a prompt element until it fits max_tokens. Parameters ---------- shrinkable : Shrinkable The prompt element to shrink. max_prompt_chars : int The maximum number of chars allowed. max_iterations : int, optional The maximum number of shrink iterations, by default 20. model_name : str, optional The name of the model used when tokenizing. Returns: ------- str : the prompt after shrinking. """ if max_prompt_chars is None: return shrinkable.prompt for _ in range(max_iterations): prompt = shrinkable.prompt if isinstance(prompt, str): prompt_str = prompt elif isinstance(prompt, list): prompt_str = '\n'.join([p['text'] for p in prompt if p['type'] == 'text']) else: raise ValueError(f'Unrecognized type for prompt: {type(prompt)}') n_chars = len(prompt_str) if n_chars <= max_prompt_chars: return prompt shrinkable.shrink() logging.info( dedent( f"""\ After {max_iterations} shrink iterations, the prompt is still {len(prompt_str)} chars (greater than {max_prompt_chars}). Returning the prompt as is.""" ) ) return prompt class HTML(Truncater): def __init__(self, html, visible: bool = True, prefix='') -> None: super().__init__(visible=visible, start_truncate_iteration=5) self._prompt = f'\n{prefix}HTML:\n{html}\n' class AXTree(Truncater): def __init__( self, ax_tree, visible: bool = True, coord_type=None, prefix='' ) -> None: super().__init__(visible=visible, start_truncate_iteration=10) if coord_type == 'center': coord_note = """\ Note: center coordinates are provided in parenthesis and are relative to the top left corner of the page.\n\n""" elif coord_type == 'box': coord_note = """\ Note: bounding box of each object are provided in parenthesis and are relative to the top left corner of the page.\n\n""" else: coord_note = '' self._prompt = f'\n{prefix}AXTree:\n{coord_note}{ax_tree}\n' class Error(PromptElement): def __init__(self, error, visible: bool = True, prefix='') -> None: super().__init__(visible=visible) self._prompt = f'\n{prefix}Error from previous action:\n{error}\n' class Observation(Shrinkable): """Observation of the current step. Contains the html, the accessibility tree and the error logs. """ def __init__(self, obs, flags: Flags) -> None: super().__init__() self.flags = flags self.obs = obs self.html = HTML(obs[flags.html_type], visible=flags.use_html, prefix='## ') self.ax_tree = AXTree( obs['axtree_txt'], visible=flags.use_ax_tree, coord_type=flags.extract_coords, prefix='## ', ) self.error = Error( obs['last_action_error'], visible=flags.use_error_logs and obs['last_action_error'], prefix='## ', ) def shrink(self): self.ax_tree.shrink() self.html.shrink() @property def _prompt(self) -> str: # type: ignore return f'\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n' def add_screenshot(self, prompt): if self.flags.use_screenshot: if isinstance(prompt, str): prompt = [{'type': 'text', 'text': prompt}] img_url = BrowserEnv.image_to_jpg_base64_url( self.obs['screenshot'], add_data_prefix=True ) prompt.append({'type': 'image_url', 'image_url': img_url}) return prompt class MacNote(PromptElement): def __init__(self) -> None: super().__init__(visible=platform.system() == 'Darwin') self._prompt = '\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n' class BeCautious(PromptElement): def __init__(self, visible: bool = True) -> None: super().__init__(visible=visible) self._prompt = """\ \nBe very cautious. Avoid submitting anything before verifying the effect of your actions. Take the time to explore the effect of safe actions first. For example you can fill a few elements of a form, but don't click submit before verifying that everything was filled correctly.\n""" class GoalInstructions(PromptElement): def __init__(self, goal, visible: bool = True) -> None: super().__init__(visible) self._prompt = f"""\ # Instructions Review the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. ## Goal: {goal} """ class ChatInstructions(PromptElement): def __init__(self, chat_messages, visible: bool = True) -> None: super().__init__(visible) self._prompt = """\ # Instructions You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can communicate with the user via a chat, in which the user gives you instructions and in which you can send back messages. You have access to a web browser that both you and the user can see, and with which only you can interact via specific commands. Review the instructions from the user, the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. ## Chat messages: """ self._prompt += '\n'.join( [ f"""\ - [{msg['role']}], {msg['message']}""" for msg in chat_messages ] ) class SystemPrompt(PromptElement): _prompt = """\ You are an agent trying to solve a web task based on the content of the page and a user instructions. You can interact with the page and explore. Each time you submit an action it will be sent to the browser and you will receive a new page.""" class MainPrompt(Shrinkable): def __init__( self, obs_history, actions, memories, thoughts, flags: Flags, ) -> None: super().__init__() self.flags = flags self.history = History(obs_history, actions, memories, thoughts, flags) if self.flags.enable_chat: self.instructions: Union[ChatInstructions, GoalInstructions] = ( ChatInstructions(obs_history[-1]['chat_messages']) ) else: if ( 'chat_messages' in obs_history[-1] and sum( [msg['role'] == 'user' for msg in obs_history[-1]['chat_messages']] ) > 1 ): logging.warning( 'Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`.' ) self.instructions = GoalInstructions(obs_history[-1]['goal']) self.obs = Observation(obs_history[-1], self.flags) self.action_space = ActionSpace(self.flags) self.think = Think(visible=flags.use_thinking) self.memory = Memory(visible=flags.use_memory) @property def _prompt(self) -> str: # type: ignore prompt = f"""\ {self.instructions.prompt}\ {self.obs.prompt}\ {self.history.prompt}\ {self.action_space.prompt}\ {self.think.prompt}\ {self.memory.prompt}\ """ if self.flags.use_abstract_example: prompt += f""" # Abstract Example Here is an abstract version of the answer with description of the content of each tag. Make sure you follow this structure, but replace the content with your answer: {self.think.abstract_ex}\ {self.memory.abstract_ex}\ {self.action_space.abstract_ex}\ """ if self.flags.use_concrete_example: prompt += f""" # Concrete Example Here is a concrete example of how to format your answer. Make sure to follow the template with proper tags: {self.think.concrete_ex}\ {self.memory.concrete_ex}\ {self.action_space.concrete_ex}\ """ return self.obs.add_screenshot(prompt) def shrink(self): self.history.shrink() self.obs.shrink() def _parse_answer(self, text_answer): ans_dict = {} ans_dict.update(self.think._parse_answer(text_answer)) ans_dict.update(self.memory._parse_answer(text_answer)) ans_dict.update(self.action_space._parse_answer(text_answer)) return ans_dict class ActionSpace(PromptElement): def __init__(self, flags: Flags) -> None: super().__init__() self.flags = flags self.action_space = _get_action_space(flags) self._prompt = ( f'# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n' ) self._abstract_ex = f""" {self.action_space.example_action(abstract=True)} """ self._concrete_ex = f""" {self.action_space.example_action(abstract=False)} """ def _parse_answer(self, text_answer): ans_dict = parse_html_tags_raise( text_answer, keys=['action'], merge_multiple=True ) try: # just check if action can be mapped to python code but keep action as is # the environment will be responsible for mapping it to python self.action_space.to_python_code(ans_dict['action']) except Exception as e: raise ParseError( f'Error while parsing action\n: {e}\n' 'Make sure your answer is restricted to the allowed actions.' ) return ans_dict def _get_action_space(flags: Flags) -> AbstractActionSet: match flags.action_space: case 'python': action_space = PythonActionSet(strict=flags.is_strict) if flags.multi_actions: warn( f'Flag action_space={repr(flags.action_space)} incompatible with multi_actions={repr(flags.multi_actions)}.', stacklevel=2, ) if flags.demo_mode != 'off': warn( f'Flag action_space={repr(flags.action_space)} incompatible with demo_mode={repr(flags.demo_mode)}.', stacklevel=2, ) return action_space case 'bid': action_subsets = ['chat', 'bid'] case 'coord': action_subsets = ['chat', 'coord'] case 'bid+coord': action_subsets = ['chat', 'bid', 'coord'] case 'bid+nav': action_subsets = ['chat', 'bid', 'nav'] case 'coord+nav': action_subsets = ['chat', 'coord', 'nav'] case 'bid+coord+nav': action_subsets = ['chat', 'bid', 'coord', 'nav'] case _: raise NotImplementedError( f'Unknown action_space {repr(flags.action_space)}' ) action_space = HighLevelActionSet( subsets=action_subsets, multiaction=flags.multi_actions, strict=flags.is_strict, demo_mode=flags.demo_mode, ) return action_space class Memory(PromptElement): _prompt = '' # provided in the abstract and concrete examples _abstract_ex = """ Write down anything you need to remember for next steps. You will be presented with the list of previous memories and past actions. """ _concrete_ex = """ I clicked on bid 32 to activate tab 2. The accessibility tree should mention focusable for elements of the form at next step. """ def _parse_answer(self, text_answer): return parse_html_tags_raise( text_answer, optional_keys=['memory'], merge_multiple=True ) class Think(PromptElement): _prompt = '' _abstract_ex = """ Think step by step. If you need to make calculations such as coordinates, write them here. Describe the effect that your previous action had on the current content of the page. """ _concrete_ex = """ My memory says that I filled the first name and last name, but I can't see any content in the form. I need to explore different ways to fill the form. Perhaps the form is not visible yet or some fields are disabled. I need to replan. """ def _parse_answer(self, text_answer): return parse_html_tags_raise( text_answer, optional_keys=['think'], merge_multiple=True ) def diff(previous, new): """Return a string showing the difference between original and new. If the difference is above diff_threshold, return the diff string. """ if previous == new: return 'Identical', [] if len(previous) == 0 or previous is None: return 'previous is empty', [] diff_gen = difflib.ndiff(previous.splitlines(), new.splitlines()) diff_lines = [] plus_count = 0 minus_count = 0 for line in diff_gen: if line.strip().startswith('+'): diff_lines.append(line) plus_count += 1 elif line.strip().startswith('-'): diff_lines.append(line) minus_count += 1 else: continue header = f'{plus_count} lines added and {minus_count} lines removed:' return header, diff_lines class Diff(Shrinkable): def __init__( self, previous, new, prefix='', max_line_diff=20, shrink_speed=2, visible=True ) -> None: super().__init__(visible=visible) self.max_line_diff = max_line_diff self.header, self.diff_lines = diff(previous, new) self.shrink_speed = shrink_speed self.prefix = prefix def shrink(self): self.max_line_diff -= self.shrink_speed self.max_line_diff = max(1, self.max_line_diff) @property def _prompt(self) -> str: # type: ignore diff_str = '\n'.join(self.diff_lines[: self.max_line_diff]) if len(self.diff_lines) > self.max_line_diff: original_count = len(self.diff_lines) diff_str = f'{diff_str}\nDiff truncated, {original_count - self.max_line_diff} changes now shown.' return f'{self.prefix}{self.header}\n{diff_str}\n' class HistoryStep(Shrinkable): def __init__( self, previous_obs, current_obs, action, memory, flags: Flags, shrink_speed=1 ) -> None: super().__init__() self.html_diff = Diff( previous_obs[flags.html_type], current_obs[flags.html_type], prefix='\n### HTML diff:\n', shrink_speed=shrink_speed, visible=lambda: flags.use_html and flags.use_diff, ) self.ax_tree_diff = Diff( previous_obs['axtree_txt'], current_obs['axtree_txt'], prefix='\n### Accessibility tree diff:\n', shrink_speed=shrink_speed, visible=lambda: flags.use_ax_tree and flags.use_diff, ) self.error = Error( current_obs['last_action_error'], visible=( flags.use_error_logs and current_obs['last_action_error'] and flags.use_past_error_logs ), prefix='### ', ) self.shrink_speed = shrink_speed self.action = action self.memory = memory self.flags = flags def shrink(self): super().shrink() self.html_diff.shrink() self.ax_tree_diff.shrink() @property def _prompt(self) -> str: # type: ignore prompt = '' if self.flags.use_action_history: prompt += f'\n### Action:\n{self.action}\n' prompt += ( f'{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}' ) if self.flags.use_memory and self.memory is not None: prompt += f'\n### Memory:\n{self.memory}\n' return prompt class History(Shrinkable): def __init__( self, history_obs, actions, memories, thoughts, flags: Flags, shrink_speed=1 ) -> None: super().__init__(visible=flags.use_history) assert len(history_obs) == len(actions) + 1 assert len(history_obs) == len(memories) + 1 self.shrink_speed = shrink_speed self.history_steps: list[HistoryStep] = [] for i in range(1, len(history_obs)): self.history_steps.append( HistoryStep( history_obs[i - 1], history_obs[i], actions[i - 1], memories[i - 1], flags, ) ) def shrink(self): """Shrink individual steps""" # TODO set the shrink speed of older steps to be higher super().shrink() for step in self.history_steps: step.shrink() @property def _prompt(self): prompts = ['# History of interaction with the task:\n'] for i, step in enumerate(self.history_steps): prompts.append(f'## step {i}') prompts.append(step.prompt) return '\n'.join(prompts) + '\n' if __name__ == '__main__': html_template = """
Hello World. Step {}.
""" OBS_HISTORY = [ { 'goal': 'do this and that', 'pruned_html': html_template.format(1), 'axtree_txt': '[1] Click me', 'last_action_error': '', }, { 'goal': 'do this and that', 'pruned_html': html_template.format(2), 'axtree_txt': '[1] Click me', 'last_action_error': '', }, { 'goal': 'do this and that', 'pruned_html': html_template.format(3), 'axtree_txt': '[1] Click me', 'last_action_error': 'Hey, there is an error now', }, ] ACTIONS = ["click('41')", "click('42')"] MEMORIES = ['memory A', 'memory B'] THOUGHTS = ['thought A', 'thought B'] flags = Flags( use_html=True, use_ax_tree=True, use_thinking=True, use_error_logs=True, use_past_error_logs=True, use_history=True, use_action_history=True, use_memory=True, use_diff=True, html_type='pruned_html', use_concrete_example=True, use_abstract_example=True, use_screenshot=False, multi_actions=True, ) print( MainPrompt( obs_history=OBS_HISTORY, actions=ACTIONS, memories=MEMORIES, thoughts=THOUGHTS, flags=flags, ).prompt )