utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import json
  2. import os
  3. import re
  4. import string
  5. import zipfile
  6. import requests
  7. def download_data(dir):
  8. import gdown
  9. data_path = os.path.join(dir, 'data/external_corpus')
  10. if os.path.exists(data_path):
  11. return data_path
  12. url = 'https://drive.google.com/uc?id=1zRbHzPW2x4dDcfmphBWlan8cxUCRNmqk'
  13. zip_path = os.path.join(dir, 'data.zip')
  14. gdown.download(url, zip_path, quiet=False)
  15. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  16. zip_ref.extractall(os.path.join(dir, 'data'))
  17. if os.path.exists(zip_path):
  18. os.remove(zip_path)
  19. print(f'Data saved to {data_path}')
  20. return data_path
  21. def download_tools(dir, wolfram_alpha_appid='YOUR_WOLFRAMALPHA_APPID'):
  22. tool_path = os.path.join(dir, 'tools')
  23. if os.path.exists(tool_path):
  24. return tool_path
  25. os.mkdir(tool_path)
  26. tools = [
  27. 'code/sql_interpreter.py',
  28. 'graph/graphtools.py',
  29. 'math/calculator.py',
  30. 'table/mysql_db_create.py',
  31. 'table/tabtools.py',
  32. 'text/agenda_retriever.py',
  33. 'text/scirex_retriever.py',
  34. ]
  35. for tool in tools:
  36. url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/benchmark/ReAct/code/tools/{tool}'
  37. response = requests.get(url)
  38. output_file = os.path.join(tool_path, tool.split('/')[1])
  39. with open(output_file, 'wb') as f:
  40. f.write(response.content)
  41. print(f'Tool saved to {output_file}')
  42. with open(os.path.join(tool_path, 'calculator.py'), 'r') as f:
  43. content = f.read()
  44. new_content = content.replace('YOUR_WOLFRAMALPHA_APPID', wolfram_alpha_appid)
  45. with open(os.path.join(tool_path, 'calculator.py'), 'w') as f:
  46. f.write(new_content)
  47. with open(os.path.join(tool_path, 'agenda_retriever.py'), 'r') as f:
  48. content = f.read()
  49. new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
  50. with open(os.path.join(tool_path, 'agenda_retriever.py'), 'w') as f:
  51. f.write(new_content)
  52. with open(os.path.join(tool_path, 'mysql_db_create.py'), 'r') as f:
  53. content = f.read()
  54. new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
  55. with open(os.path.join(tool_path, 'mysql_db_create.py'), 'w') as f:
  56. f.write(new_content)
  57. with open(os.path.join(tool_path, 'scirex_retriever.py'), 'r') as f:
  58. content = f.read()
  59. new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
  60. with open(os.path.join(tool_path, 'scirex_retriever.py'), 'w') as f:
  61. f.write(new_content)
  62. LOCAL_DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
  63. def get_data(dataset, hardness):
  64. data_path = os.path.join(LOCAL_DATA_DIR, f'{dataset}-{hardness}.jsonl')
  65. if os.path.exists(data_path):
  66. print(f'Loading data from {data_path}')
  67. with open(data_path, 'r') as f:
  68. return json.load(f)
  69. else:
  70. print(
  71. f'Downloading data from https://raw.githubusercontent.com/night-chen/ToolQA/main/data/questions/{hardness}/{dataset}-{hardness}.jsonl'
  72. )
  73. data = []
  74. url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/data/questions/{hardness}/{dataset}-{hardness}.jsonl'
  75. url = requests.get(url)
  76. if url.status_code == 200:
  77. lines = url.text.splitlines()
  78. for line in lines:
  79. data.append(json.loads(line))
  80. with open(data_path, 'w') as f:
  81. json.dump(data, f)
  82. print(f'Data saved to {data_path}')
  83. return data
  84. REACT_INSTRUCTION = """Use tools in the tools directory to solve the task: {question}
  85. You could use all tools which are under the tools/ directory and all the data under the data/ directory.
  86. When you think you finished the task, respond with `Finish[answer]` where you include your answer in `[]`.
  87. IMPORTANT: Make sure that in your final answer, you should not print any additional text/instructions other than the actual answer, which should be a word or a simple phrase.
  88. """
  89. def encode_question(question):
  90. return REACT_INSTRUCTION.format(question=question)
  91. # imported from https://github.com/night-chen/ToolQA/tree/main/benchmark/ReAct/code/agents_chatgpt.py
  92. def normalize_answer(s):
  93. def remove_articles(text):
  94. return re.sub(r'\b(a|an|the|usd)\b', ' ', text)
  95. def white_space_fix(text):
  96. return ' '.join(text.split())
  97. def remove_punc(text):
  98. exclude = set(string.punctuation)
  99. return ''.join(ch for ch in text if ch not in exclude)
  100. def lower(text):
  101. return text.lower()
  102. return white_space_fix(remove_articles(remove_punc(lower(s))))
  103. def eval_answer(pred, answer):
  104. pattern = r'Finish\[(.*?)\]'
  105. match = re.search(pattern, pred)
  106. if match:
  107. pred = match.group(1)
  108. return normalize_answer(pred) == normalize_answer(answer)