| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- # Copyright 2023 https://github.com/ShishirPatil/gorilla
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py
- from tree_sitter import Language, Parser
- # Get all the subtrees given a root_node
- def get_all_sub_trees(root_node):
- node_stack = []
- sub_tree_sexp_list = []
- depth = 1
- # text = root_node.text
- node_stack.append([root_node, depth])
- while len(node_stack) != 0:
- cur_node, cur_depth = node_stack.pop()
- if cur_node.child_count > 0:
- sub_tree_sexp_list.append(
- [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
- )
- else:
- sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
- for child_node in cur_node.children:
- if len(child_node.children) != 0:
- depth = cur_depth + 1
- node_stack.append([child_node, depth])
- return sub_tree_sexp_list
- # Parse the program into AST trees
- def ast_parse(candidate, lang='python'):
- LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
- parser = Parser()
- parser.set_language(LANGUAGE)
- candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
- return candidate_tree
- # Get all the arguments in the ast tree
- def get_args(node):
- if node.child_count == 0:
- return []
- args_list = []
- for child in node.children[0].children[0].children[1].children:
- if 'model=' in child.text.decode() or 'model =' in child.text.decode():
- args_list.append(child.children[2].text)
- elif (
- child.text.decode() != '('
- and child.text.decode() != ')'
- and child.text.decode() != ','
- ):
- args_list.append(child.text)
- return args_list
- # Check if there is an api match
- def ast_check(candidate_subtree_list, base_tree_list):
- for idx, base_tree in enumerate(base_tree_list):
- if base_tree.children[0].children[0].child_count == 0:
- continue
- api_name = base_tree.children[0].children[0].children[0].text
- for candidate_tree in candidate_subtree_list:
- if candidate_tree[3] == api_name:
- break
- # Now we have a sub-tree
- candidate_tree = candidate_tree[2]
- args_list = get_args(base_tree)
- if len(args_list) == 0:
- continue
- ast_match = True
- for arg in args_list:
- if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
- ast_match = False
- break
- if ast_match:
- return idx
- return -1
- def ast_eval_tf(api_database, qa_pairs, ast_database, question_id, response):
- # Check correctness
- correct = False
- hallucination = False
- output = response
- # Index the "api_call" domain
- output = output.split('api_call')
- if len(output) == 1:
- api_call = output[0]
- else:
- # Parse the output
- output = output[1].split('api_provider')[0]
- if ':' not in output:
- start = 0
- else:
- start = output.index(':')
- if ')' not in output:
- end = -2
- else:
- end = output.rindex(')')
- api_call = output[start + 2 : end + 1]
- # Parse the api_call into AST tree
- ast_tree = ast_parse(api_call)
- # Search for a subtree
- ast_subtree_list = get_all_sub_trees(ast_tree)
- # Check which ast tree is matching
- database_index = ast_check(ast_subtree_list, ast_database)
- # We cannot index this ast in our database
- if database_index == -1:
- hallucination = True
- # We index our reference api_call
- ref_api_call = api_database[database_index]
- # Check for functionality
- if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
- correct = True
- return correct, hallucination
|