scp2jsonl.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import json
  3. import torch
  4. import logging
  5. import hydra
  6. from omegaconf import DictConfig, OmegaConf
  7. import concurrent.futures
  8. import librosa
  9. import torch.distributed as dist
  10. def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
  11. try:
  12. rank = dist.get_rank()
  13. world_size = dist.get_world_size()
  14. except:
  15. rank = 0
  16. world_size = 1
  17. cpu_cores = os.cpu_count() or 1
  18. print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
  19. if rank == 0:
  20. json_dict = {}
  21. for data_type, data_file in zip(data_type_list, path):
  22. json_dict[data_type] = {}
  23. with open(data_file, "r") as f:
  24. data_file_lists = f.readlines()
  25. lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
  26. task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
  27. with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
  28. futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
  29. for future in concurrent.futures.as_completed(futures):
  30. json_dict[data_type].update(future.result())
  31. # print(json_dict)
  32. with open(jsonl_file_out, "w") as f:
  33. for key in json_dict[data_type_list[0]].keys():
  34. jsonl_line = {"key": key}
  35. for data_file in data_type_list:
  36. jsonl_line.update(json_dict[data_file][key])
  37. jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
  38. f.write(jsonl_line+"\n")
  39. f.flush()
  40. else:
  41. pass
  42. if world_size > 1:
  43. dist.barrier()
  44. def parse_context_length(data_list: list, data_type: str):
  45. res = {}
  46. for i, line in enumerate(data_list):
  47. key, line = line.strip().split(maxsplit=1)
  48. line = line.strip()
  49. if os.path.exists(line):
  50. waveform, _ = librosa.load(line, sr=16000)
  51. sample_num = len(waveform)
  52. context_len = int(sample_num//16000*1000/10)
  53. else:
  54. context_len = len(line.split()) if " " in line else len(line)
  55. res[key] = {data_type: line, f"{data_type}_len": context_len}
  56. return res
  57. @hydra.main(config_name=None, version_base=None)
  58. def main_hydra(cfg: DictConfig):
  59. kwargs = OmegaConf.to_container(cfg, resolve=True)
  60. scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
  61. if isinstance(scp_file_list, str):
  62. scp_file_list = eval(scp_file_list)
  63. data_type_list = kwargs.get("data_type_list", ("source", "target"))
  64. jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
  65. gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
  66. """
  67. python -m funasr.datasets.audio_datasets.scp2jsonl \
  68. ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
  69. ++data_type_list='["source", "target"]' \
  70. ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
  71. """
  72. if __name__ == "__main__":
  73. main_hydra()