prompt.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786
  1. import abc
  2. import difflib
  3. import logging
  4. import platform
  5. from copy import deepcopy
  6. from dataclasses import asdict, dataclass
  7. from textwrap import dedent
  8. from typing import Literal, Union
  9. from warnings import warn
  10. from browsergym.core.action.base import AbstractActionSet
  11. from browsergym.core.action.highlevel import HighLevelActionSet
  12. from browsergym.core.action.python import PythonActionSet
  13. from openhands.agenthub.browsing_agent.utils import (
  14. ParseError,
  15. parse_html_tags_raise,
  16. )
  17. from openhands.runtime.browser.browser_env import BrowserEnv
  18. @dataclass
  19. class Flags:
  20. use_html: bool = True
  21. use_ax_tree: bool = False
  22. drop_ax_tree_first: bool = True # This flag is no longer active TODO delete
  23. use_thinking: bool = False
  24. use_error_logs: bool = False
  25. use_past_error_logs: bool = False
  26. use_history: bool = False
  27. use_action_history: bool = False
  28. use_memory: bool = False
  29. use_diff: bool = False
  30. html_type: str = 'pruned_html'
  31. use_concrete_example: bool = True
  32. use_abstract_example: bool = False
  33. multi_actions: bool = False
  34. action_space: Literal[
  35. 'python', 'bid', 'coord', 'bid+coord', 'bid+nav', 'coord+nav', 'bid+coord+nav'
  36. ] = 'bid'
  37. is_strict: bool = False
  38. # This flag will be automatically disabled `if not chat_model_args.has_vision()`
  39. use_screenshot: bool = True
  40. enable_chat: bool = False
  41. max_prompt_tokens: int = 100_000
  42. extract_visible_tag: bool = False
  43. extract_coords: Literal['False', 'center', 'box'] = 'False'
  44. extract_visible_elements_only: bool = False
  45. demo_mode: Literal['off', 'default', 'only_visible_elements'] = 'off'
  46. def copy(self):
  47. return deepcopy(self)
  48. def asdict(self):
  49. """Helper for JSON serializble requirement."""
  50. return asdict(self)
  51. @classmethod
  52. def from_dict(self, flags_dict):
  53. """Helper for JSON serializable requirement."""
  54. if isinstance(flags_dict, Flags):
  55. return flags_dict
  56. if not isinstance(flags_dict, dict):
  57. raise ValueError(
  58. f'Unregcognized type for flags_dict of type {type(flags_dict)}.'
  59. )
  60. return Flags(**flags_dict)
  61. class PromptElement:
  62. """Base class for all prompt elements. Prompt elements can be hidden.
  63. Prompt elements are used to build the prompt. Use flags to control which
  64. prompt elements are visible. We use class attributes as a convenient way
  65. to implement static prompts, but feel free to override them with instance
  66. attributes or @property decorator.
  67. """
  68. _prompt = ''
  69. _abstract_ex = ''
  70. _concrete_ex = ''
  71. def __init__(self, visible: bool = True) -> None:
  72. """Prompt element that can be hidden.
  73. Parameters
  74. ----------
  75. visible : bool, optional
  76. Whether the prompt element should be visible, by default True. Can
  77. be a callable that returns a bool. This is useful when a specific
  78. flag changes during a shrink iteration.
  79. """
  80. self._visible = visible
  81. @property
  82. def prompt(self):
  83. """Avoid overriding this method. Override _prompt instead."""
  84. return self._hide(self._prompt)
  85. @property
  86. def abstract_ex(self):
  87. """Useful when this prompt element is requesting an answer from the llm.
  88. Provide an abstract example of the answer here. See Memory for an
  89. example.
  90. Avoid overriding this method. Override _abstract_ex instead
  91. """
  92. return self._hide(self._abstract_ex)
  93. @property
  94. def concrete_ex(self):
  95. """Useful when this prompt element is requesting an answer from the llm.
  96. Provide a concrete example of the answer here. See Memory for an
  97. example.
  98. Avoid overriding this method. Override _concrete_ex instead
  99. """
  100. return self._hide(self._concrete_ex)
  101. @property
  102. def is_visible(self):
  103. """Handle the case where visible is a callable."""
  104. visible = self._visible
  105. if callable(visible):
  106. visible = visible()
  107. return visible
  108. def _hide(self, value):
  109. """Return value if visible is True, else return empty string."""
  110. if self.is_visible:
  111. return value
  112. else:
  113. return ''
  114. def _parse_answer(self, text_answer) -> dict:
  115. if self.is_visible:
  116. return self._parse_answer(text_answer)
  117. else:
  118. return {}
  119. class Shrinkable(PromptElement, abc.ABC):
  120. @abc.abstractmethod
  121. def shrink(self) -> None:
  122. """Implement shrinking of this prompt element.
  123. You need to recursively call all shrinkable elements that are part of
  124. this prompt. You can also implement a shrinking strategy for this prompt.
  125. Shrinking is can be called multiple times to progressively shrink the
  126. prompt until it fits max_tokens. Default max shrink iterations is 20.
  127. """
  128. pass
  129. class Truncater(Shrinkable):
  130. """A prompt element that can be truncated to fit the context length of the LLM.
  131. Of course, it will be great that we never have to use the functionality here to `shrink()` the prompt.
  132. Extend this class for prompt elements that can be truncated. Usually long observations such as AxTree or HTML.
  133. """
  134. def __init__(self, visible, shrink_speed=0.3, start_truncate_iteration=10):
  135. super().__init__(visible=visible)
  136. self.shrink_speed = shrink_speed # the percentage shrunk in each iteration
  137. self.start_truncate_iteration = (
  138. start_truncate_iteration # the iteration to start truncating
  139. )
  140. self.shrink_calls = 0
  141. self.deleted_lines = 0
  142. def shrink(self) -> None:
  143. if self.is_visible and self.shrink_calls >= self.start_truncate_iteration:
  144. # remove the fraction of _prompt
  145. lines = self._prompt.splitlines()
  146. new_line_count = int(len(lines) * (1 - self.shrink_speed))
  147. self.deleted_lines += len(lines) - new_line_count
  148. self._prompt = '\n'.join(lines[:new_line_count])
  149. self._prompt += (
  150. f'\n... Deleted {self.deleted_lines} lines to reduce prompt size.'
  151. )
  152. self.shrink_calls += 1
  153. def fit_tokens(
  154. shrinkable: Shrinkable,
  155. max_prompt_chars=None,
  156. max_iterations=20,
  157. ):
  158. """Shrink a prompt element until it fits max_tokens.
  159. Parameters
  160. ----------
  161. shrinkable : Shrinkable
  162. The prompt element to shrink.
  163. max_prompt_chars : int
  164. The maximum number of chars allowed.
  165. max_iterations : int, optional
  166. The maximum number of shrink iterations, by default 20.
  167. model_name : str, optional
  168. The name of the model used when tokenizing.
  169. Returns:
  170. -------
  171. str : the prompt after shrinking.
  172. """
  173. if max_prompt_chars is None:
  174. return shrinkable.prompt
  175. for _ in range(max_iterations):
  176. prompt = shrinkable.prompt
  177. if isinstance(prompt, str):
  178. prompt_str = prompt
  179. elif isinstance(prompt, list):
  180. prompt_str = '\n'.join([p['text'] for p in prompt if p['type'] == 'text'])
  181. else:
  182. raise ValueError(f'Unrecognized type for prompt: {type(prompt)}')
  183. n_chars = len(prompt_str)
  184. if n_chars <= max_prompt_chars:
  185. return prompt
  186. shrinkable.shrink()
  187. logging.info(
  188. dedent(
  189. f"""\
  190. After {max_iterations} shrink iterations, the prompt is still
  191. {len(prompt_str)} chars (greater than {max_prompt_chars}). Returning the prompt as is."""
  192. )
  193. )
  194. return prompt
  195. class HTML(Truncater):
  196. def __init__(self, html, visible: bool = True, prefix='') -> None:
  197. super().__init__(visible=visible, start_truncate_iteration=5)
  198. self._prompt = f'\n{prefix}HTML:\n{html}\n'
  199. class AXTree(Truncater):
  200. def __init__(
  201. self, ax_tree, visible: bool = True, coord_type=None, prefix=''
  202. ) -> None:
  203. super().__init__(visible=visible, start_truncate_iteration=10)
  204. if coord_type == 'center':
  205. coord_note = """\
  206. Note: center coordinates are provided in parenthesis and are
  207. relative to the top left corner of the page.\n\n"""
  208. elif coord_type == 'box':
  209. coord_note = """\
  210. Note: bounding box of each object are provided in parenthesis and are
  211. relative to the top left corner of the page.\n\n"""
  212. else:
  213. coord_note = ''
  214. self._prompt = f'\n{prefix}AXTree:\n{coord_note}{ax_tree}\n'
  215. class Error(PromptElement):
  216. def __init__(self, error, visible: bool = True, prefix='') -> None:
  217. super().__init__(visible=visible)
  218. self._prompt = f'\n{prefix}Error from previous action:\n{error}\n'
  219. class Observation(Shrinkable):
  220. """Observation of the current step.
  221. Contains the html, the accessibility tree and the error logs.
  222. """
  223. def __init__(self, obs, flags: Flags) -> None:
  224. super().__init__()
  225. self.flags = flags
  226. self.obs = obs
  227. self.html = HTML(obs[flags.html_type], visible=flags.use_html, prefix='## ')
  228. self.ax_tree = AXTree(
  229. obs['axtree_txt'],
  230. visible=flags.use_ax_tree,
  231. coord_type=flags.extract_coords,
  232. prefix='## ',
  233. )
  234. self.error = Error(
  235. obs['last_action_error'],
  236. visible=flags.use_error_logs and obs['last_action_error'],
  237. prefix='## ',
  238. )
  239. def shrink(self):
  240. self.ax_tree.shrink()
  241. self.html.shrink()
  242. @property
  243. def _prompt(self) -> str: # type: ignore
  244. return f'\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n'
  245. def add_screenshot(self, prompt):
  246. if self.flags.use_screenshot:
  247. if isinstance(prompt, str):
  248. prompt = [{'type': 'text', 'text': prompt}]
  249. img_url = BrowserEnv.image_to_jpg_base64_url(
  250. self.obs['screenshot'], add_data_prefix=True
  251. )
  252. prompt.append({'type': 'image_url', 'image_url': img_url})
  253. return prompt
  254. class MacNote(PromptElement):
  255. def __init__(self) -> None:
  256. super().__init__(visible=platform.system() == 'Darwin')
  257. self._prompt = '\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n'
  258. class BeCautious(PromptElement):
  259. def __init__(self, visible: bool = True) -> None:
  260. super().__init__(visible=visible)
  261. self._prompt = """\
  262. \nBe very cautious. Avoid submitting anything before verifying the effect of your
  263. actions. Take the time to explore the effect of safe actions first. For example
  264. you can fill a few elements of a form, but don't click submit before verifying
  265. that everything was filled correctly.\n"""
  266. class GoalInstructions(PromptElement):
  267. def __init__(self, goal, visible: bool = True) -> None:
  268. super().__init__(visible)
  269. self._prompt = f"""\
  270. # Instructions
  271. Review the current state of the page and all other information to find the best
  272. possible next action to accomplish your goal. Your answer will be interpreted
  273. and executed by a program, make sure to follow the formatting instructions.
  274. ## Goal:
  275. {goal}
  276. """
  277. class ChatInstructions(PromptElement):
  278. def __init__(self, chat_messages, visible: bool = True) -> None:
  279. super().__init__(visible)
  280. self._prompt = """\
  281. # Instructions
  282. You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can
  283. communicate with the user via a chat, in which the user gives you instructions and in which you
  284. can send back messages. You have access to a web browser that both you and the user can see,
  285. and with which only you can interact via specific commands.
  286. Review the instructions from the user, the current state of the page and all other information
  287. to find the best possible next action to accomplish your goal. Your answer will be interpreted
  288. and executed by a program, make sure to follow the formatting instructions.
  289. ## Chat messages:
  290. """
  291. self._prompt += '\n'.join(
  292. [
  293. f"""\
  294. - [{msg['role']}], {msg['message']}"""
  295. for msg in chat_messages
  296. ]
  297. )
  298. class SystemPrompt(PromptElement):
  299. _prompt = """\
  300. You are an agent trying to solve a web task based on the content of the page and
  301. a user instructions. You can interact with the page and explore. Each time you
  302. submit an action it will be sent to the browser and you will receive a new page."""
  303. class MainPrompt(Shrinkable):
  304. def __init__(
  305. self,
  306. obs_history,
  307. actions,
  308. memories,
  309. thoughts,
  310. flags: Flags,
  311. ) -> None:
  312. super().__init__()
  313. self.flags = flags
  314. self.history = History(obs_history, actions, memories, thoughts, flags)
  315. if self.flags.enable_chat:
  316. self.instructions: Union[ChatInstructions, GoalInstructions] = (
  317. ChatInstructions(obs_history[-1]['chat_messages'])
  318. )
  319. else:
  320. if (
  321. 'chat_messages' in obs_history[-1]
  322. and sum(
  323. [msg['role'] == 'user' for msg in obs_history[-1]['chat_messages']]
  324. )
  325. > 1
  326. ):
  327. logging.warning(
  328. 'Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`.'
  329. )
  330. self.instructions = GoalInstructions(obs_history[-1]['goal'])
  331. self.obs = Observation(obs_history[-1], self.flags)
  332. self.action_space = ActionSpace(self.flags)
  333. self.think = Think(visible=flags.use_thinking)
  334. self.memory = Memory(visible=flags.use_memory)
  335. @property
  336. def _prompt(self) -> str: # type: ignore
  337. prompt = f"""\
  338. {self.instructions.prompt}\
  339. {self.obs.prompt}\
  340. {self.history.prompt}\
  341. {self.action_space.prompt}\
  342. {self.think.prompt}\
  343. {self.memory.prompt}\
  344. """
  345. if self.flags.use_abstract_example:
  346. prompt += f"""
  347. # Abstract Example
  348. Here is an abstract version of the answer with description of the content of
  349. each tag. Make sure you follow this structure, but replace the content with your
  350. answer:
  351. {self.think.abstract_ex}\
  352. {self.memory.abstract_ex}\
  353. {self.action_space.abstract_ex}\
  354. """
  355. if self.flags.use_concrete_example:
  356. prompt += f"""
  357. # Concrete Example
  358. Here is a concrete example of how to format your answer.
  359. Make sure to follow the template with proper tags:
  360. {self.think.concrete_ex}\
  361. {self.memory.concrete_ex}\
  362. {self.action_space.concrete_ex}\
  363. """
  364. return self.obs.add_screenshot(prompt)
  365. def shrink(self):
  366. self.history.shrink()
  367. self.obs.shrink()
  368. def _parse_answer(self, text_answer):
  369. ans_dict = {}
  370. ans_dict.update(self.think._parse_answer(text_answer))
  371. ans_dict.update(self.memory._parse_answer(text_answer))
  372. ans_dict.update(self.action_space._parse_answer(text_answer))
  373. return ans_dict
  374. class ActionSpace(PromptElement):
  375. def __init__(self, flags: Flags) -> None:
  376. super().__init__()
  377. self.flags = flags
  378. self.action_space = _get_action_space(flags)
  379. self._prompt = (
  380. f'# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n'
  381. )
  382. self._abstract_ex = f"""
  383. <action>
  384. {self.action_space.example_action(abstract=True)}
  385. </action>
  386. """
  387. self._concrete_ex = f"""
  388. <action>
  389. {self.action_space.example_action(abstract=False)}
  390. </action>
  391. """
  392. def _parse_answer(self, text_answer):
  393. ans_dict = parse_html_tags_raise(
  394. text_answer, keys=['action'], merge_multiple=True
  395. )
  396. try:
  397. # just check if action can be mapped to python code but keep action as is
  398. # the environment will be responsible for mapping it to python
  399. self.action_space.to_python_code(ans_dict['action'])
  400. except Exception as e:
  401. raise ParseError(
  402. f'Error while parsing action\n: {e}\n'
  403. 'Make sure your answer is restricted to the allowed actions.'
  404. )
  405. return ans_dict
  406. def _get_action_space(flags: Flags) -> AbstractActionSet:
  407. match flags.action_space:
  408. case 'python':
  409. action_space = PythonActionSet(strict=flags.is_strict)
  410. if flags.multi_actions:
  411. warn(
  412. f'Flag action_space={repr(flags.action_space)} incompatible with multi_actions={repr(flags.multi_actions)}.',
  413. stacklevel=2,
  414. )
  415. if flags.demo_mode != 'off':
  416. warn(
  417. f'Flag action_space={repr(flags.action_space)} incompatible with demo_mode={repr(flags.demo_mode)}.',
  418. stacklevel=2,
  419. )
  420. return action_space
  421. case 'bid':
  422. action_subsets = ['chat', 'bid']
  423. case 'coord':
  424. action_subsets = ['chat', 'coord']
  425. case 'bid+coord':
  426. action_subsets = ['chat', 'bid', 'coord']
  427. case 'bid+nav':
  428. action_subsets = ['chat', 'bid', 'nav']
  429. case 'coord+nav':
  430. action_subsets = ['chat', 'coord', 'nav']
  431. case 'bid+coord+nav':
  432. action_subsets = ['chat', 'bid', 'coord', 'nav']
  433. case _:
  434. raise NotImplementedError(
  435. f'Unknown action_space {repr(flags.action_space)}'
  436. )
  437. action_space = HighLevelActionSet(
  438. subsets=action_subsets,
  439. multiaction=flags.multi_actions,
  440. strict=flags.is_strict,
  441. demo_mode=flags.demo_mode,
  442. )
  443. return action_space
  444. class Memory(PromptElement):
  445. _prompt = '' # provided in the abstract and concrete examples
  446. _abstract_ex = """
  447. <memory>
  448. Write down anything you need to remember for next steps. You will be presented
  449. with the list of previous memories and past actions.
  450. </memory>
  451. """
  452. _concrete_ex = """
  453. <memory>
  454. I clicked on bid 32 to activate tab 2. The accessibility tree should mention
  455. focusable for elements of the form at next step.
  456. </memory>
  457. """
  458. def _parse_answer(self, text_answer):
  459. return parse_html_tags_raise(
  460. text_answer, optional_keys=['memory'], merge_multiple=True
  461. )
  462. class Think(PromptElement):
  463. _prompt = ''
  464. _abstract_ex = """
  465. <think>
  466. Think step by step. If you need to make calculations such as coordinates, write them here. Describe the effect
  467. that your previous action had on the current content of the page.
  468. </think>
  469. """
  470. _concrete_ex = """
  471. <think>
  472. My memory says that I filled the first name and last name, but I can't see any
  473. content in the form. I need to explore different ways to fill the form. Perhaps
  474. the form is not visible yet or some fields are disabled. I need to replan.
  475. </think>
  476. """
  477. def _parse_answer(self, text_answer):
  478. return parse_html_tags_raise(
  479. text_answer, optional_keys=['think'], merge_multiple=True
  480. )
  481. def diff(previous, new):
  482. """Return a string showing the difference between original and new.
  483. If the difference is above diff_threshold, return the diff string.
  484. """
  485. if previous == new:
  486. return 'Identical', []
  487. if len(previous) == 0 or previous is None:
  488. return 'previous is empty', []
  489. diff_gen = difflib.ndiff(previous.splitlines(), new.splitlines())
  490. diff_lines = []
  491. plus_count = 0
  492. minus_count = 0
  493. for line in diff_gen:
  494. if line.strip().startswith('+'):
  495. diff_lines.append(line)
  496. plus_count += 1
  497. elif line.strip().startswith('-'):
  498. diff_lines.append(line)
  499. minus_count += 1
  500. else:
  501. continue
  502. header = f'{plus_count} lines added and {minus_count} lines removed:'
  503. return header, diff_lines
  504. class Diff(Shrinkable):
  505. def __init__(
  506. self, previous, new, prefix='', max_line_diff=20, shrink_speed=2, visible=True
  507. ) -> None:
  508. super().__init__(visible=visible)
  509. self.max_line_diff = max_line_diff
  510. self.header, self.diff_lines = diff(previous, new)
  511. self.shrink_speed = shrink_speed
  512. self.prefix = prefix
  513. def shrink(self):
  514. self.max_line_diff -= self.shrink_speed
  515. self.max_line_diff = max(1, self.max_line_diff)
  516. @property
  517. def _prompt(self) -> str: # type: ignore
  518. diff_str = '\n'.join(self.diff_lines[: self.max_line_diff])
  519. if len(self.diff_lines) > self.max_line_diff:
  520. original_count = len(self.diff_lines)
  521. diff_str = f'{diff_str}\nDiff truncated, {original_count - self.max_line_diff} changes now shown.'
  522. return f'{self.prefix}{self.header}\n{diff_str}\n'
  523. class HistoryStep(Shrinkable):
  524. def __init__(
  525. self, previous_obs, current_obs, action, memory, flags: Flags, shrink_speed=1
  526. ) -> None:
  527. super().__init__()
  528. self.html_diff = Diff(
  529. previous_obs[flags.html_type],
  530. current_obs[flags.html_type],
  531. prefix='\n### HTML diff:\n',
  532. shrink_speed=shrink_speed,
  533. visible=lambda: flags.use_html and flags.use_diff,
  534. )
  535. self.ax_tree_diff = Diff(
  536. previous_obs['axtree_txt'],
  537. current_obs['axtree_txt'],
  538. prefix='\n### Accessibility tree diff:\n',
  539. shrink_speed=shrink_speed,
  540. visible=lambda: flags.use_ax_tree and flags.use_diff,
  541. )
  542. self.error = Error(
  543. current_obs['last_action_error'],
  544. visible=(
  545. flags.use_error_logs
  546. and current_obs['last_action_error']
  547. and flags.use_past_error_logs
  548. ),
  549. prefix='### ',
  550. )
  551. self.shrink_speed = shrink_speed
  552. self.action = action
  553. self.memory = memory
  554. self.flags = flags
  555. def shrink(self):
  556. super().shrink()
  557. self.html_diff.shrink()
  558. self.ax_tree_diff.shrink()
  559. @property
  560. def _prompt(self) -> str: # type: ignore
  561. prompt = ''
  562. if self.flags.use_action_history:
  563. prompt += f'\n### Action:\n{self.action}\n'
  564. prompt += (
  565. f'{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}'
  566. )
  567. if self.flags.use_memory and self.memory is not None:
  568. prompt += f'\n### Memory:\n{self.memory}\n'
  569. return prompt
  570. class History(Shrinkable):
  571. def __init__(
  572. self, history_obs, actions, memories, thoughts, flags: Flags, shrink_speed=1
  573. ) -> None:
  574. super().__init__(visible=flags.use_history)
  575. assert len(history_obs) == len(actions) + 1
  576. assert len(history_obs) == len(memories) + 1
  577. self.shrink_speed = shrink_speed
  578. self.history_steps: list[HistoryStep] = []
  579. for i in range(1, len(history_obs)):
  580. self.history_steps.append(
  581. HistoryStep(
  582. history_obs[i - 1],
  583. history_obs[i],
  584. actions[i - 1],
  585. memories[i - 1],
  586. flags,
  587. )
  588. )
  589. def shrink(self):
  590. """Shrink individual steps"""
  591. # TODO set the shrink speed of older steps to be higher
  592. super().shrink()
  593. for step in self.history_steps:
  594. step.shrink()
  595. @property
  596. def _prompt(self):
  597. prompts = ['# History of interaction with the task:\n']
  598. for i, step in enumerate(self.history_steps):
  599. prompts.append(f'## step {i}')
  600. prompts.append(step.prompt)
  601. return '\n'.join(prompts) + '\n'
  602. if __name__ == '__main__':
  603. html_template = """
  604. <html>
  605. <body>
  606. <div>
  607. Hello World.
  608. Step {}.
  609. </div>
  610. </body>
  611. </html>
  612. """
  613. OBS_HISTORY = [
  614. {
  615. 'goal': 'do this and that',
  616. 'pruned_html': html_template.format(1),
  617. 'axtree_txt': '[1] Click me',
  618. 'last_action_error': '',
  619. },
  620. {
  621. 'goal': 'do this and that',
  622. 'pruned_html': html_template.format(2),
  623. 'axtree_txt': '[1] Click me',
  624. 'last_action_error': '',
  625. },
  626. {
  627. 'goal': 'do this and that',
  628. 'pruned_html': html_template.format(3),
  629. 'axtree_txt': '[1] Click me',
  630. 'last_action_error': 'Hey, there is an error now',
  631. },
  632. ]
  633. ACTIONS = ["click('41')", "click('42')"]
  634. MEMORIES = ['memory A', 'memory B']
  635. THOUGHTS = ['thought A', 'thought B']
  636. flags = Flags(
  637. use_html=True,
  638. use_ax_tree=True,
  639. use_thinking=True,
  640. use_error_logs=True,
  641. use_past_error_logs=True,
  642. use_history=True,
  643. use_action_history=True,
  644. use_memory=True,
  645. use_diff=True,
  646. html_type='pruned_html',
  647. use_concrete_example=True,
  648. use_abstract_example=True,
  649. use_screenshot=False,
  650. multi_actions=True,
  651. )
  652. print(
  653. MainPrompt(
  654. obs_history=OBS_HISTORY,
  655. actions=ACTIONS,
  656. memories=MEMORIES,
  657. thoughts=THOUGHTS,
  658. flags=flags,
  659. ).prompt
  660. )