utils.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import collections
  2. import re
  3. from warnings import warn
  4. import yaml
  5. def yaml_parser(message):
  6. """Parse a yaml message for the retry function."""
  7. # saves gpt-3.5 from some yaml parsing errors
  8. message = re.sub(r':\s*\n(?=\S|\n)', ': ', message)
  9. try:
  10. value = yaml.safe_load(message)
  11. valid = True
  12. retry_message = ''
  13. except yaml.YAMLError as e:
  14. warn(str(e), stacklevel=2)
  15. value = {}
  16. valid = False
  17. retry_message = "Your response is not a valid yaml. Please try again and be careful to the format. Don't add any apology or comment, just the answer."
  18. return value, valid, retry_message
  19. def _compress_chunks(text, identifier, skip_list, split_regex='\n\n+'):
  20. """Compress a string by replacing redundant chunks by identifiers. Chunks are defined by the split_regex."""
  21. text_list = re.split(split_regex, text)
  22. text_list = [chunk.strip() for chunk in text_list]
  23. counter = collections.Counter(text_list)
  24. def_dict = {}
  25. id = 0
  26. # Store items that occur more than once in a dictionary
  27. for item, count in counter.items():
  28. if count > 1 and item not in skip_list and len(item) > 10:
  29. def_dict[f'{identifier}-{id}'] = item
  30. id += 1
  31. # Replace redundant items with their identifiers in the text
  32. compressed_text = '\n'.join(text_list)
  33. for key, value in def_dict.items():
  34. compressed_text = compressed_text.replace(value, key)
  35. return def_dict, compressed_text
  36. def compress_string(text):
  37. """Compress a string by replacing redundant paragraphs and lines with identifiers."""
  38. # Perform paragraph-level compression
  39. def_dict, compressed_text = _compress_chunks(
  40. text, identifier='§', skip_list=[], split_regex='\n\n+'
  41. )
  42. # Perform line-level compression, skipping any paragraph identifiers
  43. line_dict, compressed_text = _compress_chunks(
  44. compressed_text, '¶', list(def_dict.keys()), split_regex='\n+'
  45. )
  46. def_dict.update(line_dict)
  47. # Create a definitions section
  48. def_lines = ['<definitions>']
  49. for key, value in def_dict.items():
  50. def_lines.append(f'{key}:\n{value}')
  51. def_lines.append('</definitions>')
  52. definitions = '\n'.join(def_lines)
  53. return definitions + '\n' + compressed_text
  54. def extract_html_tags(text, keys):
  55. """Extract the content within HTML tags for a list of keys.
  56. Parameters
  57. ----------
  58. text : str
  59. The input string containing the HTML tags.
  60. keys : list of str
  61. The HTML tags to extract the content from.
  62. Returns:
  63. -------
  64. dict
  65. A dictionary mapping each key to a list of subset in `text` that match the key.
  66. Notes:
  67. -----
  68. All text and keys will be converted to lowercase before matching.
  69. """
  70. content_dict = {}
  71. # text = text.lower()
  72. # keys = set([k.lower() for k in keys])
  73. for key in keys:
  74. pattern = f'<{key}>(.*?)</{key}>'
  75. matches = re.findall(pattern, text, re.DOTALL)
  76. if matches:
  77. content_dict[key] = [match.strip() for match in matches]
  78. return content_dict
  79. class ParseError(Exception):
  80. pass
  81. def parse_html_tags_raise(text, keys=(), optional_keys=(), merge_multiple=False):
  82. """A version of parse_html_tags that raises an exception if the parsing is not successful."""
  83. content_dict, valid, retry_message = parse_html_tags(
  84. text, keys, optional_keys, merge_multiple=merge_multiple
  85. )
  86. if not valid:
  87. raise ParseError(retry_message)
  88. return content_dict
  89. def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False):
  90. """Satisfy the parse api, extracts 1 match per key and validates that all keys are present
  91. Parameters
  92. ----------
  93. text : str
  94. The input string containing the HTML tags.
  95. keys : list of str
  96. The HTML tags to extract the content from.
  97. optional_keys : list of str
  98. The HTML tags to extract the content from, but are optional.
  99. Returns:
  100. -------
  101. dict
  102. A dictionary mapping each key to subset of `text` that match the key.
  103. bool
  104. Whether the parsing was successful.
  105. str
  106. A message to be displayed to the agent if the parsing was not successful.
  107. """
  108. all_keys = tuple(keys) + tuple(optional_keys)
  109. content_dict = extract_html_tags(text, all_keys)
  110. retry_messages = []
  111. for key in all_keys:
  112. if key not in content_dict:
  113. if key not in optional_keys:
  114. retry_messages.append(f'Missing the key <{key}> in the answer.')
  115. else:
  116. val = content_dict[key]
  117. content_dict[key] = val[0]
  118. if len(val) > 1:
  119. if not merge_multiple:
  120. retry_messages.append(
  121. f'Found multiple instances of the key {key}. You should have only one of them.'
  122. )
  123. else:
  124. # merge the multiple instances
  125. content_dict[key] = '\n'.join(val)
  126. valid = len(retry_messages) == 0
  127. retry_message = '\n'.join(retry_messages)
  128. return content_dict, valid, retry_message