Преглед изворни кода

Feat: Support Gorilla APIBench (#2081)

* removed unused files from gorilla

* Update run_infer.py, removed unused imports

* Update utils.py

* Update ast_eval_hf.py

* Update ast_eval_tf.py

* Update ast_eval_th.py

* Create README.md

* Update run_infer.py

* make lint

* Update run_infer.py

* fix lint

---------

Co-authored-by: yufansong <yufan@risingwave-labs.com>
yueqis пре 1 година
родитељ
комит
68d9ad61cf

+ 41 - 0
evaluation/gorilla/README.md

@@ -0,0 +1,41 @@
+# Gorilla APIBench Evaluation with OpenDevin
+
+This folder contains evaluation harness we built on top of the original [Gorilla APIBench](https://github.com/ShishirPatil/gorilla) ([paper](https://arxiv.org/pdf/2305.15334)).
+
+## Setup Environment
+
+Please follow [this document](https://github.com/OpenDevin/OpenDevin/blob/main/Development.md) to setup local development environment for OpenDevin.
+
+## Configure OpenDevin and your LLM
+
+Run `make setup-config` to set up the `config.toml` file if it does not exist at the root of the workspace.
+
+## Run Inference on APIBench Instances
+
+Make sure your Docker daemon is running, then run this bash script:
+
+```bash
+bash evaluation/gorilla/scripts/run_infer.sh [model_config] [agent] [eval_limit] [hubs]
+```
+
+where `model_config` is mandatory, while all other arguments are optional.
+
+`model_config`, e.g. `llm`, is the config group name for your
+LLM settings, as defined in your `config.toml`.
+
+`agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting
+to `CodeActAgent`.
+
+`eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances.
+By default, the script evaluates 1 instance.
+
+`hubs`, the hub from APIBench to evaluate from. You could choose one or more from `torch` or `th` (which is abbreviation of torch), `hf` (which is abbreviation of huggingface), and `tf` (which is abbreviation of tensorflow),  for `hubs`. The default is `hf,torch,tf`.
+
+Note: in order to use `eval_limit`, you must also set `agent`; in order to use `hubs`, you must also set `eval_limit`.
+
+Let's say you'd like to run 10 instances using `llm` and CodeActAgent on `th` test,
+then your command would be:
+
+```bash
+bash evaluation/gorilla/scripts/run_infer.sh llm CodeActAgent 10 th
+```

+ 127 - 0
evaluation/gorilla/ast_eval_hf.py

@@ -0,0 +1,127 @@
+# Copyright 2023 https://github.com/ShishirPatil/gorilla
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py
+
+from tree_sitter import Language, Parser
+
+
+# Get all the subtrees given a root_node
+def get_all_sub_trees(root_node):
+    node_stack = []
+    sub_tree_sexp_list = []
+    depth = 1
+    # text = root_node.text
+    node_stack.append([root_node, depth])
+    while len(node_stack) != 0:
+        cur_node, cur_depth = node_stack.pop()
+        if cur_node.child_count > 0:
+            sub_tree_sexp_list.append(
+                [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
+            )
+        else:
+            sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
+        for child_node in cur_node.children:
+            if len(child_node.children) != 0:
+                depth = cur_depth + 1
+                node_stack.append([child_node, depth])
+    return sub_tree_sexp_list
+
+
+# Parse the program into AST trees
+def ast_parse(candidate, lang='python'):
+    LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
+    parser = Parser()
+    parser.set_language(LANGUAGE)
+
+    candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
+    return candidate_tree
+
+
+# Get all the arguments in the ast tree
+def get_args(node):
+    if node.child_count == 0:
+        return []
+    args_list = []
+    for child in node.children[0].children[0].children[1].children:
+        if '=' in child.text.decode():
+            args_list.append(child.children[2].text)
+        elif (
+            child.text.decode() != '('
+            and child.text.decode() != ')'
+            and child.text.decode() != ','
+        ):
+            args_list.append(child.text)
+    return args_list
+
+
+# Check if there is an api match
+def ast_check(candidate_subtree_list, base_tree_list):
+    for idx, base_tree in enumerate(base_tree_list):
+        if base_tree.children[0].children[0].child_count == 0:
+            continue
+        api_name = base_tree.children[0].children[0].children[0].text
+        for candidate_tree in candidate_subtree_list:
+            if candidate_tree[3] == api_name:
+                break
+        # Now we have a sub-tree
+        candidate_tree = candidate_tree[2]
+        args_list = get_args(base_tree)
+        if len(args_list) == 0:
+            continue
+        ast_match = True
+        for arg in args_list:
+            if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
+                ast_match = False
+                break
+        if ast_match:
+            return idx
+    return -1
+
+
+def ast_eval_hf(api_database, qa_pairs, ast_database, question_id, response):
+    # Check correctness
+    correct = False
+    hallucination = False
+    output = response
+    # Index the "api_call" domain
+    output = output.split('api_call')
+    if len(output) == 1:
+        api_call = output[0]
+    else:
+        # Parse the output
+        output = output[1].split('api_provider')[0]
+        if ':' not in output:
+            start = 0
+        else:
+            start = output.index(':')
+        if ')' not in output:
+            end = -2
+        else:
+            end = output.rindex(')')
+        api_call = output[start + 2 : end + 1]
+    # Parse the api_call into AST tree
+    ast_tree = ast_parse(api_call)
+    # Search for a subtree
+    ast_subtree_list = get_all_sub_trees(ast_tree)
+    # Check which ast tree is matching
+    database_index = ast_check(ast_subtree_list, ast_database)
+    # We cannot index this ast in our database
+    if database_index == -1:
+        hallucination = True
+    # We index our reference api_call
+    ref_api_call = api_database[database_index]
+    # Check for functionality
+    if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
+        correct = True
+    return correct, hallucination

+ 127 - 0
evaluation/gorilla/ast_eval_tf.py

@@ -0,0 +1,127 @@
+# Copyright 2023 https://github.com/ShishirPatil/gorilla
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py
+
+from tree_sitter import Language, Parser
+
+
+# Get all the subtrees given a root_node
+def get_all_sub_trees(root_node):
+    node_stack = []
+    sub_tree_sexp_list = []
+    depth = 1
+    # text = root_node.text
+    node_stack.append([root_node, depth])
+    while len(node_stack) != 0:
+        cur_node, cur_depth = node_stack.pop()
+        if cur_node.child_count > 0:
+            sub_tree_sexp_list.append(
+                [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
+            )
+        else:
+            sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
+        for child_node in cur_node.children:
+            if len(child_node.children) != 0:
+                depth = cur_depth + 1
+                node_stack.append([child_node, depth])
+    return sub_tree_sexp_list
+
+
+# Parse the program into AST trees
+def ast_parse(candidate, lang='python'):
+    LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
+    parser = Parser()
+    parser.set_language(LANGUAGE)
+
+    candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
+    return candidate_tree
+
+
+# Get all the arguments in the ast tree
+def get_args(node):
+    if node.child_count == 0:
+        return []
+    args_list = []
+    for child in node.children[0].children[0].children[1].children:
+        if 'model=' in child.text.decode() or 'model =' in child.text.decode():
+            args_list.append(child.children[2].text)
+        elif (
+            child.text.decode() != '('
+            and child.text.decode() != ')'
+            and child.text.decode() != ','
+        ):
+            args_list.append(child.text)
+    return args_list
+
+
+# Check if there is an api match
+def ast_check(candidate_subtree_list, base_tree_list):
+    for idx, base_tree in enumerate(base_tree_list):
+        if base_tree.children[0].children[0].child_count == 0:
+            continue
+        api_name = base_tree.children[0].children[0].children[0].text
+        for candidate_tree in candidate_subtree_list:
+            if candidate_tree[3] == api_name:
+                break
+        # Now we have a sub-tree
+        candidate_tree = candidate_tree[2]
+        args_list = get_args(base_tree)
+        if len(args_list) == 0:
+            continue
+        ast_match = True
+        for arg in args_list:
+            if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
+                ast_match = False
+                break
+        if ast_match:
+            return idx
+    return -1
+
+
+def ast_eval_tf(api_database, qa_pairs, ast_database, question_id, response):
+    # Check correctness
+    correct = False
+    hallucination = False
+    output = response
+    # Index the "api_call" domain
+    output = output.split('api_call')
+    if len(output) == 1:
+        api_call = output[0]
+    else:
+        # Parse the output
+        output = output[1].split('api_provider')[0]
+        if ':' not in output:
+            start = 0
+        else:
+            start = output.index(':')
+        if ')' not in output:
+            end = -2
+        else:
+            end = output.rindex(')')
+        api_call = output[start + 2 : end + 1]
+    # Parse the api_call into AST tree
+    ast_tree = ast_parse(api_call)
+    # Search for a subtree
+    ast_subtree_list = get_all_sub_trees(ast_tree)
+    # Check which ast tree is matching
+    database_index = ast_check(ast_subtree_list, ast_database)
+    # We cannot index this ast in our database
+    if database_index == -1:
+        hallucination = True
+    # We index our reference api_call
+    ref_api_call = api_database[database_index]
+    # Check for functionality
+    if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
+        correct = True
+    return correct, hallucination

+ 123 - 0
evaluation/gorilla/ast_eval_th.py

@@ -0,0 +1,123 @@
+# Copyright 2023 https://github.com/ShishirPatil/gorilla
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is modifed from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_th.py
+
+from tree_sitter import Language, Parser
+
+
+# Get all the subtrees given a root_node
+def get_all_sub_trees(root_node):
+    node_stack = []
+    sub_tree_sexp_list = []
+    depth = 1
+    # text = root_node.text
+    node_stack.append([root_node, depth])
+    while len(node_stack) != 0:
+        cur_node, cur_depth = node_stack.pop()
+        if cur_node.child_count > 0:
+            sub_tree_sexp_list.append(
+                [cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
+            )
+        else:
+            sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
+        for child_node in cur_node.children:
+            if len(child_node.children) != 0:
+                depth = cur_depth + 1
+                node_stack.append([child_node, depth])
+    return sub_tree_sexp_list
+
+
+# Parse the program into AST trees
+def ast_parse(candidate, lang='python'):
+    LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
+    parser = Parser()
+    parser.set_language(LANGUAGE)
+
+    candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
+    return candidate_tree
+
+
+# Get all the arguments in the ast tree
+def get_args(node):
+    if node.child_count == 0:
+        return []
+    args_list = []
+    for child in node.children[0].children[0].children[1].children:
+        if 'repo_or_dir' in child.text.decode() or 'model' in child.text.decode():
+            args_list.append(child.children[2].text)
+    return args_list
+
+
+# Check if there is an api match
+def ast_check(candidate_subtree_list, base_tree_list):
+    for idx, base_tree in enumerate(base_tree_list):
+        if base_tree.children[0].children[0].child_count == 0:
+            continue
+        api_name = base_tree.children[0].children[0].children[0].text
+        for candidate_tree in candidate_subtree_list:
+            if candidate_tree[3] == api_name:
+                break
+        # Now we have a sub-tree
+        candidate_tree = candidate_tree[2]
+        args_list = get_args(base_tree)
+        if len(args_list) == 0:
+            continue
+        ast_match = True
+        for arg in args_list:
+            if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
+                ast_match = False
+                break
+        if ast_match:
+            return idx
+    return -1
+
+
+def process_response(question_id, output, api_database, qa_pairs, ast_database):
+    # Index the "api_call" domain
+    output = output.split('api_call')
+    if len(output) == 1:
+        return False, False
+    else:
+        output = output[1].split('api_provider')[0]
+    if ':' not in output:
+        start = 0
+    else:
+        start = output.index(':')
+    if ')' not in output:
+        end = -2
+    else:
+        end = output.rindex(')')
+    api_call = output[start + 2 : end + 1]
+
+    # Parse the api_call into AST tree
+    ast_tree = ast_parse(api_call)
+    # Search for a subtree
+    ast_subtree_list = get_all_sub_trees(ast_tree)
+    # Check which ast tree is matching
+    database_index = ast_check(ast_subtree_list, ast_database)
+    # We cannot index this ast in our database
+    if database_index == -1:
+        return False, True
+    # We index our reference api_call
+    ref_api_call = api_database[database_index]
+    # Check for functionality
+    if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
+        return True, False
+    else:
+        return False, False
+
+
+def ast_eval_th(api_database, qa_pairs, ast_database, question_id, response):
+    # Check correctness
+    return process_response(question_id, response, api_database, qa_pairs, ast_database)

+ 355 - 0
evaluation/gorilla/run_infer.py

@@ -0,0 +1,355 @@
+import asyncio
+import json
+import logging
+import multiprocessing as mp
+import os
+import pathlib
+import subprocess
+import time
+from concurrent.futures import ProcessPoolExecutor
+
+from tqdm import tqdm
+from utils import encode_question, get_data
+
+from opendevin.controller.state.state import State
+from opendevin.core.config import config, get_llm_config_arg, get_parser
+from opendevin.core.logger import get_console_handler
+from opendevin.core.logger import opendevin_logger as logger
+from opendevin.core.main import main
+from opendevin.events.action import MessageAction
+from opendevin.events.serialization.event import event_to_dict
+
+
+def cleanup():
+    print('Cleaning up child processes...')
+    for process in mp.active_children():
+        print(f'Terminating child process: {process.name}')
+        process.terminate()
+        process.join()
+
+
+def codeact_user_response(state: State) -> str:
+    msg = (
+        #'Please continue working on the task on whatever approach you think is suitable.\n'
+        'Please run the following command: <execute_bash> exit </execute_bash>.\n'
+        #'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP OR USE THE INTERNET TO SOLVE THIS TASK.\n'
+    )
+    if state.history:
+        user_msgs = [
+            action
+            for action, _ in state.history
+            if isinstance(action, MessageAction) and action.source == 'user'
+        ]
+        if len(user_msgs) >= 2:
+            # let the agent know that it can give up when it has tried 3 times
+            return (
+                msg
+                + 'If you want to give up, run: <execute_bash> exit </execute_bash>.\n'
+            )
+    return msg
+
+
+def monologue_user_response(state: State) -> str:
+    raise NotImplementedError('MonologueAgent should never ask for user responses.')
+
+
+AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
+    'CodeActAgent': codeact_user_response,
+    'MonologueAgent': monologue_user_response,
+}
+
+AGENT_CLS_TO_INST_SUFFIX = {
+    'CodeActAgent': 'When you think you have completed the request, please run the following command: <execute_bash> exit </execute_bash>.\n'
+}
+
+
+def process_instance(
+    question_id, question, agent_class, metadata, reset_logger: bool = True
+):
+    # create process-specific workspace dir
+    # we will create a workspace directory for EACH process
+    # so that different agent don't interfere with each other.
+    old_workspace_mount_path = config.workspace_mount_path
+    try:
+        workspace_mount_path = os.path.join(
+            config.workspace_mount_path, '_eval_workspace'
+        )
+        workspace_mount_path = os.path.join(workspace_mount_path, str(os.getpid()))
+        pathlib.Path(workspace_mount_path).mkdir(parents=True, exist_ok=True)
+        config.workspace_mount_path = workspace_mount_path
+
+        # Setup the logger properly, so you can run multi-processing to parallize the evaluation
+        eval_output_dir = metadata['eval_output_dir']
+        if reset_logger:
+            # Set up logger
+            log_file = os.path.join(
+                eval_output_dir, 'logs', f'instance_{question_id}.log'
+            )
+            # Remove all existing handlers from logger
+            for handler in logger.handlers[:]:
+                logger.removeHandler(handler)
+            # add back the console handler to print ONE line
+            logger.addHandler(get_console_handler())
+            logger.info(
+                f'Starting evaluation for instance {question_id}.\nLOG:   tail -f {log_file}'
+            )
+            # Remove all existing handlers from logger
+            for handler in logger.handlers[:]:
+                logger.removeHandler(handler)
+            file_handler = logging.FileHandler(log_file)
+            file_handler.setFormatter(
+                logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+            )
+            logger.addHandler(file_handler)
+        logger.info(f'Process-specific workspace mounted at {workspace_mount_path}')
+
+        # Prepare instruction
+        instruction = encode_question(question, metadata['hub'])
+        instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
+        # NOTE: You can actually set slightly different instruction for different agents
+        instruction += AGENT_CLS_TO_INST_SUFFIX.get(agent_class, '')
+        # logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
+
+        # Here's how you can run the agent (similar to the `main` function) and get the final task state
+        state: State = asyncio.run(
+            main(
+                instruction,
+                fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
+                    agent_class
+                ),
+            )
+        )
+        # ======= Attempt to evaluate the agent's edits =======
+        # If you are working on simplier benchmark that only evaluates the final model output (e.g., in a MessageAction)
+        # You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
+
+        if state is None:
+            raise ValueError('State should not be None.')
+
+        model_answer_raw = ''
+        for act, _ in reversed(state.history):
+            if isinstance(act, MessageAction) and act.source == 'agent':
+                model_answer_raw = act.content
+                break
+        # attempt to parse model_answer
+        _, _, ast_eval = get_data(metadata['hub'])
+        correct, hallucination = ast_eval(question_id, model_answer_raw)
+        metrics = state.metrics.get() if state.metrics else None
+        logger.info(
+            f'Final message: {model_answer_raw} | Correctness: {correct} | Hallucination: {hallucination}'
+        )
+        # Save the output
+        output = {
+            'question_id': question_id,
+            'text': model_answer_raw,
+            'correct': correct,
+            'hallucination': hallucination,
+            'answer_id': 'None',
+            'model_id': metadata['model_name'],
+            'metadata': metadata,
+            'history': [
+                (event_to_dict(action), event_to_dict(obs))
+                for action, obs in state.history
+            ],
+            'metrics': metrics,
+            'error': state.error if state and state.error else None,
+        }
+    except Exception:
+        logger.error('Process instance failed')
+        raise
+    finally:
+        config.workspace_mount_path = old_workspace_mount_path
+    return output
+
+
+if __name__ == '__main__':
+    parser = get_parser()
+    parser.add_argument(
+        '--hubs',
+        type=str,
+        help='Which hubs to evaluate from APIBench. APIBench contains 3 hubs, namely huggingface, torch, and tensorflow. You could choose one or more from hf, torch, or tf, seperated by commas. For example, the default is --hub hf,torch,tf.',
+        default='hf,torch,tf',
+    )
+    args, _ = parser.parse_known_args()
+    if args.directory:
+        config.workspace_base = os.path.abspath(args.directory)
+        print(f'Setting workspace base to {config.workspace_base}')
+
+    # Check https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/swe_bench/README.md#configure-opendevin-and-your-llm
+    # for details of how to set `llm_config`
+    if args.llm_config:
+        specified_llm_config = get_llm_config_arg(args.llm_config)
+        if specified_llm_config:
+            config.llm = specified_llm_config
+    logger.info(f'Config for evaluation: {config}')
+    agent_class = args.agent_cls
+    assert (
+        agent_class in AGENT_CLS_TO_FAKE_USER_RESPONSE_FN
+    ), f'Unsupported agent class: {agent_class}'
+    model_name = config.llm.model.split('/')[-1]
+    max_iterations = args.max_iterations
+    eval_note = ''
+    if args.eval_note is not None:
+        eval_note += '_N_' + args.eval_note
+    eval_output_dir = os.path.join(
+        args.eval_output_dir,
+        'gorilla',
+        agent_class,
+        model_name + '_maxiter_' + str(max_iterations) + eval_note,
+    )
+    pathlib.Path(eval_output_dir).mkdir(parents=True, exist_ok=True)
+    pathlib.Path(os.path.join(eval_output_dir, 'logs')).mkdir(
+        parents=True, exist_ok=True
+    )
+    logger.info(f'Using evaluation output directory: {eval_output_dir}')
+
+    hubs = []
+    if 'hf' in args.hubs:
+        hubs.append('hf')
+    if 'torch' in args.hubs or 'th' in args.hubs:
+        hubs.append('torch')
+    if 'tf' in args.hubs:
+        hubs.append('tf')
+    if hubs == []:
+        raise ValueError('Please choose at least one from hf, torch, and tf for hubs.')
+
+    for hub in hubs:
+        logger.info(f'Evaluating APIBench {hub} test')
+        questions, question_ids, ast_eval = get_data(hub)
+
+        # TEST METADATA
+        metadata = {
+            'hub': hub,
+            'agent_class': agent_class,
+            'model_name': model_name,
+            'max_iterations': max_iterations,
+            'eval_output_dir': eval_output_dir,
+            'start_time': time.strftime('%Y-%m-%d %H:%M:%S'),
+            # get the commit id of current repo for reproduciblity
+            'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD'])
+            .decode('utf-8')
+            .strip(),
+        }
+        logger.info(f'Metadata: {metadata}')
+        with open(os.path.join(eval_output_dir, f'metadata_{hub}.json'), 'w') as f:
+            json.dump(metadata, f)
+
+        # LIMIT EVALUATION
+        eval_n_limit = args.eval_n_limit
+        if eval_n_limit:
+            questions = questions[: (eval_n_limit // len(hubs))]
+            question_ids = question_ids[: (eval_n_limit // len(hubs))]
+            logger.info(
+                f'Limiting evaluation to a total of first {eval_n_limit} instances -> first {eval_n_limit//len(hubs)} instances per hub.'
+            )
+        output_file = os.path.join(eval_output_dir, f'output_{model_name}_{hub}.jsonl')
+        logger.info(f'Writing evaluation output to {output_file}')
+        finished_task_ids = set()
+        if os.path.exists(output_file):
+            with open(output_file, 'r') as f:
+                for line in f:
+                    data = json.loads(line)
+                    for i in range(len(question_ids)):
+                        if question_ids[i] == int(data['question_id']):
+                            finished_task_ids.add(data['question_id'])
+            logger.warning(
+                f'Output file {output_file} already exists. Loaded {len(finished_task_ids)} finished instances.'
+            )
+        output_fp = open(output_file, 'a')
+        logger.info(
+            f'Evaluation started with Agent {agent_class}, model {model_name}, max iterations {max_iterations}.'
+        )
+        # =============================================
+        # filter out finished instances
+        new_questions = []
+        new_question_ids = []
+        for i in range(len(question_ids)):
+            if question_ids[i] in finished_task_ids:
+                logger.info(
+                    f'Skipping instance {question_ids[i]} as it is already finished.'
+                )
+                continue
+            new_questions.append(questions[i])
+            new_question_ids.append(question_ids[i])
+
+        finished_task_number = len(finished_task_ids)
+        questions = new_questions
+        question_ids = new_question_ids
+        logger.info(
+            f'Finished instances: {finished_task_number}, Remaining instances: {len(question_ids)}'
+        )
+        # =============================================
+        pbar = tqdm(total=len(question_ids))
+
+        # This function tracks the progress AND write the output to a JSONL file
+        def update_progress(future, pbar, output_fp, finished_task_ids):
+            pbar.update(1)
+            output = future.result()
+            pbar.set_description(f'Instance {output["question_id"]}')
+            pbar.set_postfix_str(f'Test Result: {output["correct"]}')
+            logger.info(
+                f'Finished evaluation for instance {output["question_id"]}: {output["correct"]}'
+            )
+            output_fp.write(json.dumps(output) + '\n')
+            output_fp.flush()
+            finished_task_ids.add(output['question_id'])
+
+        # This sets the multi-processing
+        num_workers = args.eval_num_workers
+        logger.info(f'Using {num_workers} workers for evaluation.')
+        try:
+            with ProcessPoolExecutor(num_workers) as executor:
+                futures = []
+                # This is how we perform multi-processing
+                for i in range(len(question_ids)):
+                    try:
+                        question_id = question_ids[i]
+                        question = questions[i]
+                        future = executor.submit(
+                            process_instance,
+                            question_id,
+                            question,
+                            agent_class,
+                            metadata,
+                            reset_logger=bool(num_workers > 1),
+                        )
+                        future.add_done_callback(
+                            update_progress, pbar, output_fp, finished_task_ids
+                        )
+                        futures.append(future)
+                    except Exception:
+                        continue
+
+                # Wait for all futures to complete
+                for future in futures:
+                    try:
+                        future.result()
+                    except Exception:
+                        continue
+        except KeyboardInterrupt:
+            logger.info('KeyboardInterrupt received. Cleaning up...')
+            cleanup()
+
+        output_fp.close()
+        total_correct = 0
+        total_hallucination = 0
+        output = []
+        with open(output_file, 'r') as f:
+            for line in f:
+                data = json.loads(line)
+                output.append(data)
+                if int(data['question_id']) in finished_task_ids:
+                    if str(data['correct']).lower() == 'true':
+                        total_correct += 1
+                    if str(data['hallucination']).lower() == 'true':
+                        total_hallucination += 1
+        # sort all output by question_id
+        output = sorted(output, key=lambda x: x['question_id'])
+        with open(output_file, 'w') as f:
+            for dat in output:
+                f.write(json.dumps(dat) + '\n')
+                f.flush()
+
+        logger.info(
+            f'Evaluation finished for {hub}. Total: {len(question_ids)+finished_task_number}; Correct: {total_correct}; Hallucination: {total_hallucination}. Accuracy: {total_correct / (len(question_ids)+finished_task_number)}'
+        )

+ 42 - 0
evaluation/gorilla/scripts/run_infer.sh

@@ -0,0 +1,42 @@
+#!/bin/bash
+MODEL_CONFIG=$1
+AGENT=$2
+EVAL_LIMIT=$3
+HUBS=$4
+
+if [ -z "$AGENT" ]; then
+  echo "Agent not specified, use default CodeActAgent"
+  AGENT="CodeActAgent"
+fi
+
+if [ -z "$HUBS" ]; then
+  HUBS="hf,torch,tf"
+  echo "Hubs not specified, use default $HUBS"
+fi
+
+# IMPORTANT: Because Agent's prompt changes fairly often in the rapidly evolving codebase of OpenDevin
+# We need to track the version of Agent in the evaluation to make sure results are comparable
+AGENT_VERSION=v$(poetry run python -c "import agenthub; from opendevin.controller.agent import Agent; print(Agent.get_cls('$AGENT').VERSION)")
+
+echo "AGENT: $AGENT"
+echo "AGENT_VERSION: $AGENT_VERSION"
+echo "MODEL_CONFIG: $MODEL_CONFIG"
+echo "HUBS: $HUBS"
+
+COMMAND="poetry run python evaluation/gorilla/run_infer.py \
+  --agent-cls $AGENT \
+  --llm-config $MODEL_CONFIG \
+  --max-iterations 30 \
+  --hubs $HUBS \
+  --data-split validation \
+  --max-chars 10000000 \
+  --eval-num-workers 1 \
+  --eval-note ${AGENT_VERSION}_${LEVELS}"
+
+if [ -n "$EVAL_LIMIT" ]; then
+  echo "EVAL_LIMIT: $EVAL_LIMIT"
+  COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
+fi
+
+# Run the command
+eval $COMMAND

+ 101 - 0
evaluation/gorilla/utils.py

@@ -0,0 +1,101 @@
+import json
+from functools import partial
+
+import requests
+from ast_eval_hf import ast_eval_hf, ast_parse
+from ast_eval_tf import ast_eval_tf
+from ast_eval_th import ast_eval_th
+
+
+# This function is modified from Gorilla's APIBench implementations (https://github.com/ShishirPatil/gorilla/blob/main/eval/get_llm_responses.py).
+def encode_question(question, api_name):
+    """Encode multiple prompt instructions into a single string."""
+
+    prompts = []
+    if api_name == 'torch':
+        api_name = 'torchhub'
+        domains = '1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}.'
+    elif api_name == 'hf':
+        api_name = 'huggingface'
+        domains = '1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
+        Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
+        Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
+        Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
+        Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
+        Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
+        Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
+        Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
+        Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
+        Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }'
+    elif api_name == 'tf':
+        api_name = 'tensorhub'
+        domains = '1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}'
+    else:
+        print('Error: API name is not supported.')
+
+    prompt = (
+        question
+        + '\nWrite a python program in 1 to 2 lines to call API in '
+        + api_name
+        + '.\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. Here are the requirements:\n'
+        + domains
+        + '\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer.'
+    )
+    # prompts.append({"role": "system", "content": ""})
+    prompts = (
+        'You are a helpful API writer who can write APIs based on requirements.\n'
+        + prompt
+    )
+    return prompts
+
+
+def get_data(hub):
+    if hub == 'hf':
+        question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl'
+        api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl'
+        apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json'
+        ast_eval = ast_eval_hf
+    if hub == 'torch':
+        question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl'
+        api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl'
+        apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json'
+        ast_eval = ast_eval_th
+    if hub == 'tf':
+        question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl'
+        api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl'
+        apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json'
+        ast_eval = ast_eval_tf
+
+    # get questions and question_ids
+    questions = []
+    question_ids = []
+    question_data = requests.get(question_data)
+    if question_data.status_code == 200:
+        lines = question_data.text.splitlines()
+        for line in lines:
+            questions.append(json.loads(line)['text'])
+            question_ids.append(json.loads(line)['question_id'])
+
+    # get the api datasest
+    api_database = []
+    api_dataset = requests.get(api_dataset)
+    if api_dataset.status_code == 200:
+        lines = api_dataset.text.splitlines()
+        for line in lines:
+            api_database.append(json.loads(line))
+
+    # get the question answer pair datasest
+    qa_pairs = []
+    apibench = requests.get(apibench)
+    if apibench.status_code == 200:
+        lines = apibench.text.splitlines()
+        for line in lines:
+            qa_pairs.append(json.loads(line)['api_data'])
+
+    # Parse all apis to ast trees
+    ast_database = []
+    for data in api_database:
+        ast_tree = ast_parse(data['api_call'])
+        ast_database.append(ast_tree)
+    ast_eval = partial(ast_eval, api_database, qa_pairs, ast_database)
+    return questions, question_ids, ast_eval