ast_eval_hf.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright 2023 https://github.com/ShishirPatil/gorilla
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # This file is modified from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py
  15. from tree_sitter import Language, Parser
  16. # Get all the subtrees given a root_node
  17. def get_all_sub_trees(root_node):
  18. node_stack = []
  19. sub_tree_sexp_list = []
  20. depth = 1
  21. # text = root_node.text
  22. node_stack.append([root_node, depth])
  23. while len(node_stack) != 0:
  24. cur_node, cur_depth = node_stack.pop()
  25. if cur_node.child_count > 0:
  26. sub_tree_sexp_list.append(
  27. [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
  28. )
  29. else:
  30. sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
  31. for child_node in cur_node.children:
  32. if len(child_node.children) != 0:
  33. depth = cur_depth + 1
  34. node_stack.append([child_node, depth])
  35. return sub_tree_sexp_list
  36. # Parse the program into AST trees
  37. def ast_parse(candidate, lang='python'):
  38. LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
  39. parser = Parser()
  40. parser.set_language(LANGUAGE)
  41. candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
  42. return candidate_tree
  43. # Get all the arguments in the ast tree
  44. def get_args(node):
  45. if node.child_count == 0:
  46. return []
  47. args_list = []
  48. for child in node.children[0].children[0].children[1].children:
  49. if '=' in child.text.decode():
  50. args_list.append(child.children[2].text)
  51. elif (
  52. child.text.decode() != '('
  53. and child.text.decode() != ')'
  54. and child.text.decode() != ','
  55. ):
  56. args_list.append(child.text)
  57. return args_list
  58. # Check if there is an api match
  59. def ast_check(candidate_subtree_list, base_tree_list):
  60. for idx, base_tree in enumerate(base_tree_list):
  61. if base_tree.children[0].children[0].child_count == 0:
  62. continue
  63. api_name = base_tree.children[0].children[0].children[0].text
  64. for candidate_tree in candidate_subtree_list:
  65. if candidate_tree[3] == api_name:
  66. break
  67. # Now we have a sub-tree
  68. candidate_tree = candidate_tree[2]
  69. args_list = get_args(base_tree)
  70. if len(args_list) == 0:
  71. continue
  72. ast_match = True
  73. for arg in args_list:
  74. if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
  75. ast_match = False
  76. break
  77. if ast_match:
  78. return idx
  79. return -1
  80. def ast_eval_hf(api_database, qa_pairs, ast_database, question_id, response):
  81. # Check correctness
  82. correct = False
  83. hallucination = False
  84. output = response
  85. # Index the "api_call" domain
  86. output = output.split('api_call')
  87. if len(output) == 1:
  88. api_call = output[0]
  89. else:
  90. # Parse the output
  91. output = output[1].split('api_provider')[0]
  92. if ':' not in output:
  93. start = 0
  94. else:
  95. start = output.index(':')
  96. if ')' not in output:
  97. end = -2
  98. else:
  99. end = output.rindex(')')
  100. api_call = output[start + 2 : end + 1]
  101. # Parse the api_call into AST tree
  102. ast_tree = ast_parse(api_call)
  103. # Search for a subtree
  104. ast_subtree_list = get_all_sub_trees(ast_tree)
  105. # Check which ast tree is matching
  106. database_index = ast_check(ast_subtree_list, ast_database)
  107. # We cannot index this ast in our database
  108. if database_index == -1:
  109. hallucination = True
  110. # We index our reference api_call
  111. ref_api_call = api_database[database_index]
  112. # Check for functionality
  113. if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
  114. correct = True
  115. return correct, hallucination