| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- import json
- import os
- from functools import partial
- import pandas as pd
- import requests
- from ast_eval_hf import ast_eval_hf, ast_parse
- from ast_eval_tf import ast_eval_tf
- from ast_eval_th import ast_eval_th
- # This function is modified from Gorilla's APIBench implementations (https://github.com/ShishirPatil/gorilla/blob/main/eval/get_llm_responses.py).
- def encode_question(question, api_name):
- """Encode multiple prompt instructions into a single string."""
- prompts = []
- if api_name == 'torch':
- api_name = 'torchhub'
- domains = '1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}.'
- elif api_name == 'hf':
- api_name = 'huggingface'
- domains = '1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
- Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
- Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
- Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
- Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
- Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
- Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
- Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
- Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
- Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }'
- elif api_name == 'tf':
- api_name = 'tensorhub'
- domains = '1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}'
- else:
- print('Error: API name is not supported.')
- prompt = (
- question
- + '\nWrite a python program in 1 to 2 lines to call API in '
- + api_name
- + '.\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. Here are the requirements:\n'
- + domains
- + '\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer.'
- )
- # prompts.append({"role": "system", "content": ""})
- prompts = (
- 'You are a helpful API writer who can write APIs based on requirements.\n'
- + prompt
- )
- return prompts
- DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
- os.makedirs(DATA_DIR, exist_ok=True)
- def fetch_data(url, filename):
- cache_path = os.path.join(DATA_DIR, filename)
- if os.path.exists(cache_path):
- with open(cache_path, 'r') as f:
- return f.read()
- else:
- response = requests.get(url)
- if response.status_code == 200:
- with open(cache_path, 'w') as f:
- f.write(response.text)
- return response.text
- else:
- raise Exception(f'Failed to fetch data from {url}')
- def get_data_for_hub(hub: str):
- if hub == 'hf':
- question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl'
- api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl'
- apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json'
- ast_eval = ast_eval_hf
- elif hub == 'torch':
- question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl'
- api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl'
- apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json'
- ast_eval = ast_eval_th
- elif hub == 'tf':
- question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl'
- api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl'
- apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json'
- ast_eval = ast_eval_tf
- question_data = fetch_data(question_data, 'question_data.jsonl')
- api_dataset = fetch_data(api_dataset, 'api_dataset.jsonl')
- apibench = fetch_data(apibench, 'apibench.json')
- # Parse question data
- questions = []
- question_ids = []
- for line in question_data.splitlines():
- data = json.loads(line)
- questions.append(data['text'])
- question_ids.append(data['question_id'])
- # Parse API dataset
- api_database = [json.loads(line) for line in api_dataset.splitlines()]
- # Parse question-answer pairs
- qa_pairs = [json.loads(line)['api_data'] for line in apibench.splitlines()]
- # Parse all apis to ast trees
- ast_database = []
- for data in api_database:
- ast_tree = ast_parse(data['api_call'])
- ast_database.append(ast_tree)
- ast_eval = partial(ast_eval, api_database, qa_pairs, ast_database)
- return pd.DataFrame(
- {
- 'question_id': question_ids,
- 'question': questions,
- 'api_database': [api_database] * len(questions),
- 'qa_pairs': [qa_pairs] * len(questions),
- 'ast_database': [ast_database] * len(questions),
- 'ast_eval': [ast_eval] * len(questions),
- 'hub': [hub] * len(questions),
- }
- )
|