split_data.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import sys
  3. import random
  4. in_dir = sys.argv[1]
  5. out_dir = sys.argv[2]
  6. num_split = sys.argv[3]
  7. def split_scp(scp, num):
  8. assert len(scp) >= num
  9. avg = len(scp) // num
  10. out = []
  11. begin = 0
  12. for i in range(num):
  13. if i == num - 1:
  14. out.append(scp[begin:])
  15. else:
  16. out.append(scp[begin:begin+avg])
  17. begin += avg
  18. return out
  19. os.path.exists("{}/wav.scp".format(in_dir))
  20. os.path.exists("{}/text".format(in_dir))
  21. with open("{}/wav.scp".format(in_dir), 'r') as infile:
  22. wav_list = infile.readlines()
  23. with open("{}/text".format(in_dir), 'r') as infile:
  24. text_list = infile.readlines()
  25. assert len(wav_list) == len(text_list)
  26. x = list(zip(wav_list, text_list))
  27. random.shuffle(x)
  28. wav_shuffle_list, text_shuffle_list = zip(*x)
  29. num_split = int(num_split)
  30. wav_split_list = split_scp(wav_shuffle_list, num_split)
  31. text_split_list = split_scp(text_shuffle_list, num_split)
  32. for idx, wav_list in enumerate(wav_split_list, 1):
  33. path = out_dir + "/split" + str(num_split) + "/" + str(idx)
  34. if not os.path.exists(path):
  35. os.makedirs(path)
  36. with open("{}/wav.scp".format(path), 'w') as wav_writer:
  37. for line in wav_list:
  38. wav_writer.write(line)
  39. for idx, text_list in enumerate(text_split_list, 1):
  40. path = out_dir + "/split" + str(num_split) + "/" + str(idx)
  41. if not os.path.exists(path):
  42. os.makedirs(path)
  43. with open("{}/text".format(path), 'w') as text_writer:
  44. for line in text_list:
  45. text_writer.write(line)