| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786 |
- import abc
- import difflib
- import logging
- import platform
- from copy import deepcopy
- from dataclasses import asdict, dataclass
- from textwrap import dedent
- from typing import Literal, Union
- from warnings import warn
- from browsergym.core.action.base import AbstractActionSet
- from browsergym.core.action.highlevel import HighLevelActionSet
- from browsergym.core.action.python import PythonActionSet
- from openhands.agenthub.browsing_agent.utils import (
- ParseError,
- parse_html_tags_raise,
- )
- from openhands.runtime.browser.browser_env import BrowserEnv
- @dataclass
- class Flags:
- use_html: bool = True
- use_ax_tree: bool = False
- drop_ax_tree_first: bool = True # This flag is no longer active TODO delete
- use_thinking: bool = False
- use_error_logs: bool = False
- use_past_error_logs: bool = False
- use_history: bool = False
- use_action_history: bool = False
- use_memory: bool = False
- use_diff: bool = False
- html_type: str = 'pruned_html'
- use_concrete_example: bool = True
- use_abstract_example: bool = False
- multi_actions: bool = False
- action_space: Literal[
- 'python', 'bid', 'coord', 'bid+coord', 'bid+nav', 'coord+nav', 'bid+coord+nav'
- ] = 'bid'
- is_strict: bool = False
- # This flag will be automatically disabled `if not chat_model_args.has_vision()`
- use_screenshot: bool = True
- enable_chat: bool = False
- max_prompt_tokens: int = 100_000
- extract_visible_tag: bool = False
- extract_coords: Literal['False', 'center', 'box'] = 'False'
- extract_visible_elements_only: bool = False
- demo_mode: Literal['off', 'default', 'only_visible_elements'] = 'off'
- def copy(self):
- return deepcopy(self)
- def asdict(self):
- """Helper for JSON serializble requirement."""
- return asdict(self)
- @classmethod
- def from_dict(self, flags_dict):
- """Helper for JSON serializable requirement."""
- if isinstance(flags_dict, Flags):
- return flags_dict
- if not isinstance(flags_dict, dict):
- raise ValueError(
- f'Unregcognized type for flags_dict of type {type(flags_dict)}.'
- )
- return Flags(**flags_dict)
- class PromptElement:
- """Base class for all prompt elements. Prompt elements can be hidden.
- Prompt elements are used to build the prompt. Use flags to control which
- prompt elements are visible. We use class attributes as a convenient way
- to implement static prompts, but feel free to override them with instance
- attributes or @property decorator.
- """
- _prompt = ''
- _abstract_ex = ''
- _concrete_ex = ''
- def __init__(self, visible: bool = True) -> None:
- """Prompt element that can be hidden.
- Parameters
- ----------
- visible : bool, optional
- Whether the prompt element should be visible, by default True. Can
- be a callable that returns a bool. This is useful when a specific
- flag changes during a shrink iteration.
- """
- self._visible = visible
- @property
- def prompt(self):
- """Avoid overriding this method. Override _prompt instead."""
- return self._hide(self._prompt)
- @property
- def abstract_ex(self):
- """Useful when this prompt element is requesting an answer from the llm.
- Provide an abstract example of the answer here. See Memory for an
- example.
- Avoid overriding this method. Override _abstract_ex instead
- """
- return self._hide(self._abstract_ex)
- @property
- def concrete_ex(self):
- """Useful when this prompt element is requesting an answer from the llm.
- Provide a concrete example of the answer here. See Memory for an
- example.
- Avoid overriding this method. Override _concrete_ex instead
- """
- return self._hide(self._concrete_ex)
- @property
- def is_visible(self):
- """Handle the case where visible is a callable."""
- visible = self._visible
- if callable(visible):
- visible = visible()
- return visible
- def _hide(self, value):
- """Return value if visible is True, else return empty string."""
- if self.is_visible:
- return value
- else:
- return ''
- def _parse_answer(self, text_answer) -> dict:
- if self.is_visible:
- return self._parse_answer(text_answer)
- else:
- return {}
- class Shrinkable(PromptElement, abc.ABC):
- @abc.abstractmethod
- def shrink(self) -> None:
- """Implement shrinking of this prompt element.
- You need to recursively call all shrinkable elements that are part of
- this prompt. You can also implement a shrinking strategy for this prompt.
- Shrinking is can be called multiple times to progressively shrink the
- prompt until it fits max_tokens. Default max shrink iterations is 20.
- """
- pass
- class Truncater(Shrinkable):
- """A prompt element that can be truncated to fit the context length of the LLM.
- Of course, it will be great that we never have to use the functionality here to `shrink()` the prompt.
- Extend this class for prompt elements that can be truncated. Usually long observations such as AxTree or HTML.
- """
- def __init__(self, visible, shrink_speed=0.3, start_truncate_iteration=10):
- super().__init__(visible=visible)
- self.shrink_speed = shrink_speed # the percentage shrunk in each iteration
- self.start_truncate_iteration = (
- start_truncate_iteration # the iteration to start truncating
- )
- self.shrink_calls = 0
- self.deleted_lines = 0
- def shrink(self) -> None:
- if self.is_visible and self.shrink_calls >= self.start_truncate_iteration:
- # remove the fraction of _prompt
- lines = self._prompt.splitlines()
- new_line_count = int(len(lines) * (1 - self.shrink_speed))
- self.deleted_lines += len(lines) - new_line_count
- self._prompt = '\n'.join(lines[:new_line_count])
- self._prompt += (
- f'\n... Deleted {self.deleted_lines} lines to reduce prompt size.'
- )
- self.shrink_calls += 1
- def fit_tokens(
- shrinkable: Shrinkable,
- max_prompt_chars=None,
- max_iterations=20,
- ):
- """Shrink a prompt element until it fits max_tokens.
- Parameters
- ----------
- shrinkable : Shrinkable
- The prompt element to shrink.
- max_prompt_chars : int
- The maximum number of chars allowed.
- max_iterations : int, optional
- The maximum number of shrink iterations, by default 20.
- model_name : str, optional
- The name of the model used when tokenizing.
- Returns:
- -------
- str : the prompt after shrinking.
- """
- if max_prompt_chars is None:
- return shrinkable.prompt
- for _ in range(max_iterations):
- prompt = shrinkable.prompt
- if isinstance(prompt, str):
- prompt_str = prompt
- elif isinstance(prompt, list):
- prompt_str = '\n'.join([p['text'] for p in prompt if p['type'] == 'text'])
- else:
- raise ValueError(f'Unrecognized type for prompt: {type(prompt)}')
- n_chars = len(prompt_str)
- if n_chars <= max_prompt_chars:
- return prompt
- shrinkable.shrink()
- logging.info(
- dedent(
- f"""\
- After {max_iterations} shrink iterations, the prompt is still
- {len(prompt_str)} chars (greater than {max_prompt_chars}). Returning the prompt as is."""
- )
- )
- return prompt
- class HTML(Truncater):
- def __init__(self, html, visible: bool = True, prefix='') -> None:
- super().__init__(visible=visible, start_truncate_iteration=5)
- self._prompt = f'\n{prefix}HTML:\n{html}\n'
- class AXTree(Truncater):
- def __init__(
- self, ax_tree, visible: bool = True, coord_type=None, prefix=''
- ) -> None:
- super().__init__(visible=visible, start_truncate_iteration=10)
- if coord_type == 'center':
- coord_note = """\
- Note: center coordinates are provided in parenthesis and are
- relative to the top left corner of the page.\n\n"""
- elif coord_type == 'box':
- coord_note = """\
- Note: bounding box of each object are provided in parenthesis and are
- relative to the top left corner of the page.\n\n"""
- else:
- coord_note = ''
- self._prompt = f'\n{prefix}AXTree:\n{coord_note}{ax_tree}\n'
- class Error(PromptElement):
- def __init__(self, error, visible: bool = True, prefix='') -> None:
- super().__init__(visible=visible)
- self._prompt = f'\n{prefix}Error from previous action:\n{error}\n'
- class Observation(Shrinkable):
- """Observation of the current step.
- Contains the html, the accessibility tree and the error logs.
- """
- def __init__(self, obs, flags: Flags) -> None:
- super().__init__()
- self.flags = flags
- self.obs = obs
- self.html = HTML(obs[flags.html_type], visible=flags.use_html, prefix='## ')
- self.ax_tree = AXTree(
- obs['axtree_txt'],
- visible=flags.use_ax_tree,
- coord_type=flags.extract_coords,
- prefix='## ',
- )
- self.error = Error(
- obs['last_action_error'],
- visible=flags.use_error_logs and obs['last_action_error'],
- prefix='## ',
- )
- def shrink(self):
- self.ax_tree.shrink()
- self.html.shrink()
- @property
- def _prompt(self) -> str: # type: ignore
- return f'\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n'
- def add_screenshot(self, prompt):
- if self.flags.use_screenshot:
- if isinstance(prompt, str):
- prompt = [{'type': 'text', 'text': prompt}]
- img_url = BrowserEnv.image_to_jpg_base64_url(
- self.obs['screenshot'], add_data_prefix=True
- )
- prompt.append({'type': 'image_url', 'image_url': img_url})
- return prompt
- class MacNote(PromptElement):
- def __init__(self) -> None:
- super().__init__(visible=platform.system() == 'Darwin')
- self._prompt = '\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n'
- class BeCautious(PromptElement):
- def __init__(self, visible: bool = True) -> None:
- super().__init__(visible=visible)
- self._prompt = """\
- \nBe very cautious. Avoid submitting anything before verifying the effect of your
- actions. Take the time to explore the effect of safe actions first. For example
- you can fill a few elements of a form, but don't click submit before verifying
- that everything was filled correctly.\n"""
- class GoalInstructions(PromptElement):
- def __init__(self, goal, visible: bool = True) -> None:
- super().__init__(visible)
- self._prompt = f"""\
- # Instructions
- Review the current state of the page and all other information to find the best
- possible next action to accomplish your goal. Your answer will be interpreted
- and executed by a program, make sure to follow the formatting instructions.
- ## Goal:
- {goal}
- """
- class ChatInstructions(PromptElement):
- def __init__(self, chat_messages, visible: bool = True) -> None:
- super().__init__(visible)
- self._prompt = """\
- # Instructions
- You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can
- communicate with the user via a chat, in which the user gives you instructions and in which you
- can send back messages. You have access to a web browser that both you and the user can see,
- and with which only you can interact via specific commands.
- Review the instructions from the user, the current state of the page and all other information
- to find the best possible next action to accomplish your goal. Your answer will be interpreted
- and executed by a program, make sure to follow the formatting instructions.
- ## Chat messages:
- """
- self._prompt += '\n'.join(
- [
- f"""\
- - [{msg['role']}], {msg['message']}"""
- for msg in chat_messages
- ]
- )
- class SystemPrompt(PromptElement):
- _prompt = """\
- You are an agent trying to solve a web task based on the content of the page and
- a user instructions. You can interact with the page and explore. Each time you
- submit an action it will be sent to the browser and you will receive a new page."""
- class MainPrompt(Shrinkable):
- def __init__(
- self,
- obs_history,
- actions,
- memories,
- thoughts,
- flags: Flags,
- ) -> None:
- super().__init__()
- self.flags = flags
- self.history = History(obs_history, actions, memories, thoughts, flags)
- if self.flags.enable_chat:
- self.instructions: Union[ChatInstructions, GoalInstructions] = (
- ChatInstructions(obs_history[-1]['chat_messages'])
- )
- else:
- if (
- 'chat_messages' in obs_history[-1]
- and sum(
- [msg['role'] == 'user' for msg in obs_history[-1]['chat_messages']]
- )
- > 1
- ):
- logging.warning(
- 'Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`.'
- )
- self.instructions = GoalInstructions(obs_history[-1]['goal'])
- self.obs = Observation(obs_history[-1], self.flags)
- self.action_space = ActionSpace(self.flags)
- self.think = Think(visible=flags.use_thinking)
- self.memory = Memory(visible=flags.use_memory)
- @property
- def _prompt(self) -> str: # type: ignore
- prompt = f"""\
- {self.instructions.prompt}\
- {self.obs.prompt}\
- {self.history.prompt}\
- {self.action_space.prompt}\
- {self.think.prompt}\
- {self.memory.prompt}\
- """
- if self.flags.use_abstract_example:
- prompt += f"""
- # Abstract Example
- Here is an abstract version of the answer with description of the content of
- each tag. Make sure you follow this structure, but replace the content with your
- answer:
- {self.think.abstract_ex}\
- {self.memory.abstract_ex}\
- {self.action_space.abstract_ex}\
- """
- if self.flags.use_concrete_example:
- prompt += f"""
- # Concrete Example
- Here is a concrete example of how to format your answer.
- Make sure to follow the template with proper tags:
- {self.think.concrete_ex}\
- {self.memory.concrete_ex}\
- {self.action_space.concrete_ex}\
- """
- return self.obs.add_screenshot(prompt)
- def shrink(self):
- self.history.shrink()
- self.obs.shrink()
- def _parse_answer(self, text_answer):
- ans_dict = {}
- ans_dict.update(self.think._parse_answer(text_answer))
- ans_dict.update(self.memory._parse_answer(text_answer))
- ans_dict.update(self.action_space._parse_answer(text_answer))
- return ans_dict
- class ActionSpace(PromptElement):
- def __init__(self, flags: Flags) -> None:
- super().__init__()
- self.flags = flags
- self.action_space = _get_action_space(flags)
- self._prompt = (
- f'# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n'
- )
- self._abstract_ex = f"""
- <action>
- {self.action_space.example_action(abstract=True)}
- </action>
- """
- self._concrete_ex = f"""
- <action>
- {self.action_space.example_action(abstract=False)}
- </action>
- """
- def _parse_answer(self, text_answer):
- ans_dict = parse_html_tags_raise(
- text_answer, keys=['action'], merge_multiple=True
- )
- try:
- # just check if action can be mapped to python code but keep action as is
- # the environment will be responsible for mapping it to python
- self.action_space.to_python_code(ans_dict['action'])
- except Exception as e:
- raise ParseError(
- f'Error while parsing action\n: {e}\n'
- 'Make sure your answer is restricted to the allowed actions.'
- )
- return ans_dict
- def _get_action_space(flags: Flags) -> AbstractActionSet:
- match flags.action_space:
- case 'python':
- action_space = PythonActionSet(strict=flags.is_strict)
- if flags.multi_actions:
- warn(
- f'Flag action_space={repr(flags.action_space)} incompatible with multi_actions={repr(flags.multi_actions)}.',
- stacklevel=2,
- )
- if flags.demo_mode != 'off':
- warn(
- f'Flag action_space={repr(flags.action_space)} incompatible with demo_mode={repr(flags.demo_mode)}.',
- stacklevel=2,
- )
- return action_space
- case 'bid':
- action_subsets = ['chat', 'bid']
- case 'coord':
- action_subsets = ['chat', 'coord']
- case 'bid+coord':
- action_subsets = ['chat', 'bid', 'coord']
- case 'bid+nav':
- action_subsets = ['chat', 'bid', 'nav']
- case 'coord+nav':
- action_subsets = ['chat', 'coord', 'nav']
- case 'bid+coord+nav':
- action_subsets = ['chat', 'bid', 'coord', 'nav']
- case _:
- raise NotImplementedError(
- f'Unknown action_space {repr(flags.action_space)}'
- )
- action_space = HighLevelActionSet(
- subsets=action_subsets,
- multiaction=flags.multi_actions,
- strict=flags.is_strict,
- demo_mode=flags.demo_mode,
- )
- return action_space
- class Memory(PromptElement):
- _prompt = '' # provided in the abstract and concrete examples
- _abstract_ex = """
- <memory>
- Write down anything you need to remember for next steps. You will be presented
- with the list of previous memories and past actions.
- </memory>
- """
- _concrete_ex = """
- <memory>
- I clicked on bid 32 to activate tab 2. The accessibility tree should mention
- focusable for elements of the form at next step.
- </memory>
- """
- def _parse_answer(self, text_answer):
- return parse_html_tags_raise(
- text_answer, optional_keys=['memory'], merge_multiple=True
- )
- class Think(PromptElement):
- _prompt = ''
- _abstract_ex = """
- <think>
- Think step by step. If you need to make calculations such as coordinates, write them here. Describe the effect
- that your previous action had on the current content of the page.
- </think>
- """
- _concrete_ex = """
- <think>
- My memory says that I filled the first name and last name, but I can't see any
- content in the form. I need to explore different ways to fill the form. Perhaps
- the form is not visible yet or some fields are disabled. I need to replan.
- </think>
- """
- def _parse_answer(self, text_answer):
- return parse_html_tags_raise(
- text_answer, optional_keys=['think'], merge_multiple=True
- )
- def diff(previous, new):
- """Return a string showing the difference between original and new.
- If the difference is above diff_threshold, return the diff string.
- """
- if previous == new:
- return 'Identical', []
- if len(previous) == 0 or previous is None:
- return 'previous is empty', []
- diff_gen = difflib.ndiff(previous.splitlines(), new.splitlines())
- diff_lines = []
- plus_count = 0
- minus_count = 0
- for line in diff_gen:
- if line.strip().startswith('+'):
- diff_lines.append(line)
- plus_count += 1
- elif line.strip().startswith('-'):
- diff_lines.append(line)
- minus_count += 1
- else:
- continue
- header = f'{plus_count} lines added and {minus_count} lines removed:'
- return header, diff_lines
- class Diff(Shrinkable):
- def __init__(
- self, previous, new, prefix='', max_line_diff=20, shrink_speed=2, visible=True
- ) -> None:
- super().__init__(visible=visible)
- self.max_line_diff = max_line_diff
- self.header, self.diff_lines = diff(previous, new)
- self.shrink_speed = shrink_speed
- self.prefix = prefix
- def shrink(self):
- self.max_line_diff -= self.shrink_speed
- self.max_line_diff = max(1, self.max_line_diff)
- @property
- def _prompt(self) -> str: # type: ignore
- diff_str = '\n'.join(self.diff_lines[: self.max_line_diff])
- if len(self.diff_lines) > self.max_line_diff:
- original_count = len(self.diff_lines)
- diff_str = f'{diff_str}\nDiff truncated, {original_count - self.max_line_diff} changes now shown.'
- return f'{self.prefix}{self.header}\n{diff_str}\n'
- class HistoryStep(Shrinkable):
- def __init__(
- self, previous_obs, current_obs, action, memory, flags: Flags, shrink_speed=1
- ) -> None:
- super().__init__()
- self.html_diff = Diff(
- previous_obs[flags.html_type],
- current_obs[flags.html_type],
- prefix='\n### HTML diff:\n',
- shrink_speed=shrink_speed,
- visible=lambda: flags.use_html and flags.use_diff,
- )
- self.ax_tree_diff = Diff(
- previous_obs['axtree_txt'],
- current_obs['axtree_txt'],
- prefix='\n### Accessibility tree diff:\n',
- shrink_speed=shrink_speed,
- visible=lambda: flags.use_ax_tree and flags.use_diff,
- )
- self.error = Error(
- current_obs['last_action_error'],
- visible=(
- flags.use_error_logs
- and current_obs['last_action_error']
- and flags.use_past_error_logs
- ),
- prefix='### ',
- )
- self.shrink_speed = shrink_speed
- self.action = action
- self.memory = memory
- self.flags = flags
- def shrink(self):
- super().shrink()
- self.html_diff.shrink()
- self.ax_tree_diff.shrink()
- @property
- def _prompt(self) -> str: # type: ignore
- prompt = ''
- if self.flags.use_action_history:
- prompt += f'\n### Action:\n{self.action}\n'
- prompt += (
- f'{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}'
- )
- if self.flags.use_memory and self.memory is not None:
- prompt += f'\n### Memory:\n{self.memory}\n'
- return prompt
- class History(Shrinkable):
- def __init__(
- self, history_obs, actions, memories, thoughts, flags: Flags, shrink_speed=1
- ) -> None:
- super().__init__(visible=flags.use_history)
- assert len(history_obs) == len(actions) + 1
- assert len(history_obs) == len(memories) + 1
- self.shrink_speed = shrink_speed
- self.history_steps: list[HistoryStep] = []
- for i in range(1, len(history_obs)):
- self.history_steps.append(
- HistoryStep(
- history_obs[i - 1],
- history_obs[i],
- actions[i - 1],
- memories[i - 1],
- flags,
- )
- )
- def shrink(self):
- """Shrink individual steps"""
- # TODO set the shrink speed of older steps to be higher
- super().shrink()
- for step in self.history_steps:
- step.shrink()
- @property
- def _prompt(self):
- prompts = ['# History of interaction with the task:\n']
- for i, step in enumerate(self.history_steps):
- prompts.append(f'## step {i}')
- prompts.append(step.prompt)
- return '\n'.join(prompts) + '\n'
- if __name__ == '__main__':
- html_template = """
- <html>
- <body>
- <div>
- Hello World.
- Step {}.
- </div>
- </body>
- </html>
- """
- OBS_HISTORY = [
- {
- 'goal': 'do this and that',
- 'pruned_html': html_template.format(1),
- 'axtree_txt': '[1] Click me',
- 'last_action_error': '',
- },
- {
- 'goal': 'do this and that',
- 'pruned_html': html_template.format(2),
- 'axtree_txt': '[1] Click me',
- 'last_action_error': '',
- },
- {
- 'goal': 'do this and that',
- 'pruned_html': html_template.format(3),
- 'axtree_txt': '[1] Click me',
- 'last_action_error': 'Hey, there is an error now',
- },
- ]
- ACTIONS = ["click('41')", "click('42')"]
- MEMORIES = ['memory A', 'memory B']
- THOUGHTS = ['thought A', 'thought B']
- flags = Flags(
- use_html=True,
- use_ax_tree=True,
- use_thinking=True,
- use_error_logs=True,
- use_past_error_logs=True,
- use_history=True,
- use_action_history=True,
- use_memory=True,
- use_diff=True,
- html_type='pruned_html',
- use_concrete_example=True,
- use_abstract_example=True,
- use_screenshot=False,
- multi_actions=True,
- )
- print(
- MainPrompt(
- obs_history=OBS_HISTORY,
- actions=ACTIONS,
- memories=MEMORIES,
- thoughts=THOUGHTS,
- flags=flags,
- ).prompt
- )
|