utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import json
  2. import os
  3. from functools import partial
  4. import pandas as pd
  5. import requests
  6. from ast_eval_hf import ast_eval_hf, ast_parse
  7. from ast_eval_tf import ast_eval_tf
  8. from ast_eval_th import ast_eval_th
  9. # This function is modified from Gorilla's APIBench implementations (https://github.com/ShishirPatil/gorilla/blob/main/eval/get_llm_responses.py).
  10. def encode_question(question, api_name):
  11. """Encode multiple prompt instructions into a single string."""
  12. prompts = []
  13. if api_name == 'torch':
  14. api_name = 'torchhub'
  15. domains = '1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}.'
  16. elif api_name == 'hf':
  17. api_name = 'huggingface'
  18. domains = '1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
  19. Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
  20. Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
  21. Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
  22. Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
  23. Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
  24. Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
  25. Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
  26. Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
  27. Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }'
  28. elif api_name == 'tf':
  29. api_name = 'tensorhub'
  30. domains = '1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}'
  31. else:
  32. print('Error: API name is not supported.')
  33. prompt = (
  34. question
  35. + '\nWrite a python program in 1 to 2 lines to call API in '
  36. + api_name
  37. + '.\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. Here are the requirements:\n'
  38. + domains
  39. + '\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer.'
  40. )
  41. # prompts.append({"role": "system", "content": ""})
  42. prompts = (
  43. 'You are a helpful API writer who can write APIs based on requirements.\n'
  44. + prompt
  45. )
  46. return prompts
  47. DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
  48. os.makedirs(DATA_DIR, exist_ok=True)
  49. def fetch_data(url, filename):
  50. cache_path = os.path.join(DATA_DIR, filename)
  51. if os.path.exists(cache_path):
  52. with open(cache_path, 'r') as f:
  53. return f.read()
  54. else:
  55. response = requests.get(url)
  56. if response.status_code == 200:
  57. with open(cache_path, 'w') as f:
  58. f.write(response.text)
  59. return response.text
  60. else:
  61. raise Exception(f'Failed to fetch data from {url}')
  62. def get_data_for_hub(hub: str):
  63. if hub == 'hf':
  64. question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/huggingface/questions_huggingface_0_shot.jsonl'
  65. api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/huggingface_api.jsonl'
  66. apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/huggingface_eval.json'
  67. ast_eval = ast_eval_hf
  68. elif hub == 'torch':
  69. question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/torchhub/questions_torchhub_0_shot.jsonl'
  70. api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/torchhub_api.jsonl'
  71. apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/torchhub_eval.json'
  72. ast_eval = ast_eval_th
  73. elif hub == 'tf':
  74. question_data = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/eval/eval-data/questions/tensorflowhub/questions_tensorflowhub_0_shot.jsonl'
  75. api_dataset = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/api/tensorflowhub_api.jsonl'
  76. apibench = 'https://raw.githubusercontent.com/ShishirPatil/gorilla/main/data/apibench/tensorflow_eval.json'
  77. ast_eval = ast_eval_tf
  78. question_data = fetch_data(question_data, 'question_data.jsonl')
  79. api_dataset = fetch_data(api_dataset, 'api_dataset.jsonl')
  80. apibench = fetch_data(apibench, 'apibench.json')
  81. # Parse question data
  82. questions = []
  83. question_ids = []
  84. for line in question_data.splitlines():
  85. data = json.loads(line)
  86. questions.append(data['text'])
  87. question_ids.append(data['question_id'])
  88. # Parse API dataset
  89. api_database = [json.loads(line) for line in api_dataset.splitlines()]
  90. # Parse question-answer pairs
  91. qa_pairs = [json.loads(line)['api_data'] for line in apibench.splitlines()]
  92. # Parse all apis to ast trees
  93. ast_database = []
  94. for data in api_database:
  95. ast_tree = ast_parse(data['api_call'])
  96. ast_database.append(ast_tree)
  97. ast_eval = partial(ast_eval, api_database, qa_pairs, ast_database)
  98. return pd.DataFrame(
  99. {
  100. 'question_id': question_ids,
  101. 'question': questions,
  102. 'api_database': [api_database] * len(questions),
  103. 'qa_pairs': [qa_pairs] * len(questions),
  104. 'ast_database': [ast_database] * len(questions),
  105. 'ast_eval': [ast_eval] * len(questions),
  106. 'hub': [hub] * len(questions),
  107. }
  108. )