prompt.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os
  2. from itertools import islice
  3. from jinja2 import Template
  4. from openhands.controller.state.state import State
  5. from openhands.core.message import Message, TextContent
  6. from openhands.utils.microagent import MicroAgent
  7. class PromptManager:
  8. """
  9. Manages prompt templates and micro-agents for AI interactions.
  10. This class handles loading and rendering of system and user prompt templates,
  11. as well as loading micro-agent specifications. It provides methods to access
  12. rendered system and initial user messages for AI interactions.
  13. Attributes:
  14. prompt_dir (str): Directory containing prompt templates.
  15. microagent_dir (str): Directory containing microagent specifications.
  16. disabled_microagents (list[str] | None): List of microagents to disable. If None, all microagents are enabled.
  17. """
  18. def __init__(
  19. self,
  20. prompt_dir: str,
  21. microagent_dir: str | None = None,
  22. disabled_microagents: list[str] | None = None,
  23. ):
  24. self.prompt_dir: str = prompt_dir
  25. self.system_template: Template = self._load_template('system_prompt')
  26. self.user_template: Template = self._load_template('user_prompt')
  27. self.microagents: dict = {}
  28. microagent_files = []
  29. if microagent_dir:
  30. microagent_files = [
  31. os.path.join(microagent_dir, f)
  32. for f in os.listdir(microagent_dir)
  33. if f.endswith('.md')
  34. ]
  35. for microagent_file in microagent_files:
  36. microagent = MicroAgent(path=microagent_file)
  37. if (
  38. disabled_microagents is None
  39. or microagent.name not in disabled_microagents
  40. ):
  41. self.microagents[microagent.name] = microagent
  42. def load_microagent_files(self, microagent_files: list[str]):
  43. for microagent_file in microagent_files:
  44. microagent = MicroAgent(content=microagent_file)
  45. self.microagents[microagent.name] = microagent
  46. def _load_template(self, template_name: str) -> Template:
  47. if self.prompt_dir is None:
  48. raise ValueError('Prompt directory is not set')
  49. template_path = os.path.join(self.prompt_dir, f'{template_name}.j2')
  50. if not os.path.exists(template_path):
  51. raise FileNotFoundError(f'Prompt file {template_path} not found')
  52. with open(template_path, 'r') as file:
  53. return Template(file.read())
  54. def get_system_message(self) -> str:
  55. return self.system_template.render().strip()
  56. def get_example_user_message(self) -> str:
  57. """This is the initial user message provided to the agent
  58. before *actual* user instructions are provided.
  59. It is used to provide a demonstration of how the agent
  60. should behave in order to solve the user's task. And it may
  61. optionally contain some additional context about the user's task.
  62. These additional context will convert the current generic agent
  63. into a more specialized agent that is tailored to the user's task.
  64. """
  65. return self.user_template.render().strip()
  66. def enhance_message(self, message: Message) -> None:
  67. """Enhance the user message with additional context.
  68. This method is used to enhance the user message with additional context
  69. about the user's task. The additional context will convert the current
  70. generic agent into a more specialized agent that is tailored to the user's task.
  71. """
  72. if not message.content:
  73. return
  74. message_content = message.content[0].text
  75. for microagent in self.microagents.values():
  76. trigger = microagent.get_trigger(message_content)
  77. if trigger:
  78. micro_text = f'<extra_info>\nThe following information has been included based on a keyword match for "{trigger}". It may or may not be relevant to the user\'s request.'
  79. micro_text += '\n\n' + microagent.content
  80. micro_text += '\n</extra_info>'
  81. message.content.append(TextContent(text=micro_text))
  82. def add_turns_left_reminder(self, messages: list[Message], state: State) -> None:
  83. latest_user_message = next(
  84. islice(
  85. (
  86. m
  87. for m in reversed(messages)
  88. if m.role == 'user'
  89. and any(isinstance(c, TextContent) for c in m.content)
  90. ),
  91. 1,
  92. ),
  93. None,
  94. )
  95. if latest_user_message:
  96. reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with <finish></finish>.'
  97. latest_user_message.content.append(TextContent(text=reminder_text))