utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import json
  2. import os
  3. import re
  4. import string
  5. import zipfile
  6. import gdown
  7. import requests
  8. def download_data(dir):
  9. data_path = os.path.join(dir, 'data/external_corpus')
  10. if os.path.exists(data_path):
  11. return data_path
  12. url = 'https://drive.google.com/uc?id=1zRbHzPW2x4dDcfmphBWlan8cxUCRNmqk'
  13. zip_path = os.path.join(dir, 'data.zip')
  14. gdown.download(url, zip_path, quiet=False)
  15. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  16. zip_ref.extractall(os.path.join(dir, 'data'))
  17. if os.path.exists(zip_path):
  18. os.remove(zip_path)
  19. return data_path
  20. def download_tools(dir, wolfram_alpha_appid='YOUR_WOLFRAMALPHA_APPID'):
  21. tool_path = os.path.join(dir, 'tools')
  22. if os.path.exists(tool_path):
  23. return tool_path
  24. os.mkdir(tool_path)
  25. tools = [
  26. 'code/sql_interpreter.py',
  27. 'graph/graphtools.py',
  28. 'math/calculator.py',
  29. 'table/mysql_db_create.py',
  30. 'table/tabtools.py',
  31. 'text/agenda_retriever.py',
  32. 'text/scirex_retriever.py',
  33. ]
  34. for tool in tools:
  35. url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/benchmark/ReAct/code/tools/{tool}'
  36. response = requests.get(url)
  37. output_file = os.path.join(tool_path, tool.split('/')[1])
  38. with open(output_file, 'wb') as f:
  39. f.write(response.content)
  40. with open(os.path.join(tool_path, 'calculator.py'), 'r') as f:
  41. content = f.read()
  42. new_content = content.replace('YOUR_WOLFRAMALPHA_APPID', wolfram_alpha_appid)
  43. with open(os.path.join(tool_path, 'calculator.py'), 'w') as f:
  44. f.write(new_content)
  45. with open(os.path.join(tool_path, 'agenda_retriever.py'), 'r') as f:
  46. content = f.read()
  47. new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
  48. with open(os.path.join(tool_path, 'agenda_retriever.py'), 'w') as f:
  49. f.write(new_content)
  50. with open(os.path.join(tool_path, 'mysql_db_create.py'), 'r') as f:
  51. content = f.read()
  52. new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
  53. with open(os.path.join(tool_path, 'mysql_db_create.py'), 'w') as f:
  54. f.write(new_content)
  55. with open(os.path.join(tool_path, 'scirex_retriever.py'), 'r') as f:
  56. content = f.read()
  57. new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
  58. with open(os.path.join(tool_path, 'scirex_retriever.py'), 'w') as f:
  59. f.write(new_content)
  60. def get_data(dataset, hardness):
  61. data = []
  62. url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/data/questions/{hardness}/{dataset}-{hardness}.jsonl'
  63. url = requests.get(url)
  64. if url.status_code == 200:
  65. lines = url.text.splitlines()
  66. for line in lines:
  67. data.append(json.loads(line))
  68. return data
  69. REACT_INSTRUCTION = """Use tools in the tools directory to solve the task: {question}
  70. You could use all tools which are under the tools/ directory and all the data under the data/ directory.
  71. When you think you finished the task, respond with `Finish[answer]` where you include your answer in `[]`.
  72. IMPORTANT: Make sure that in your final answer, you should not print any additional text/instructions other than the actual answer, which should be a word or a simple phrase.
  73. """
  74. def encode_question(question):
  75. return REACT_INSTRUCTION.format(question=question)
  76. # imported from https://github.com/night-chen/ToolQA/tree/main/benchmark/ReAct/code/agents_chatgpt.py
  77. def normalize_answer(s):
  78. def remove_articles(text):
  79. return re.sub(r'\b(a|an|the|usd)\b', ' ', text)
  80. def white_space_fix(text):
  81. return ' '.join(text.split())
  82. def remove_punc(text):
  83. exclude = set(string.punctuation)
  84. return ''.join(ch for ch in text if ch not in exclude)
  85. def lower(text):
  86. return text.lower()
  87. return white_space_fix(remove_articles(remove_punc(lower(s))))
  88. def eval_answer(pred, answer):
  89. pattern = r'Finish\[(.*?)\]'
  90. match = re.search(pattern, pred)
  91. if match:
  92. pred = match.group(1)
  93. return normalize_answer(pred) == normalize_answer(answer)