dataset_jsonl.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import torch
  2. import json
  3. import torch.distributed as dist
  4. import numpy as np
  5. import kaldiio
  6. import librosa
  7. def load_audio(audio_path: str, fs: int=16000):
  8. audio = None
  9. if audio_path.startswith("oss:"):
  10. pass
  11. elif audio_path.startswith("odps:"):
  12. pass
  13. else:
  14. if ".ark:" in audio_path:
  15. audio = kaldiio.load_mat(audio_path)
  16. else:
  17. audio, fs = librosa.load(audio_path, sr=fs)
  18. return audio
  19. def extract_features(data, date_type: str="sound", frontend=None):
  20. if date_type == "sound":
  21. feat, feats_lens = frontend(data, len(data))
  22. feat = feat[0, :, :]
  23. else:
  24. feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
  25. return feat, feats_lens
  26. class IndexedDatasetJsonl(torch.utils.data.Dataset):
  27. def __init__(self, path):
  28. super().__init__()
  29. # data_parallel_size = dist.get_world_size()
  30. data_parallel_size = 1
  31. contents = []
  32. with open(path, encoding='utf-8') as fin:
  33. for line in fin:
  34. data = json.loads(line.strip())
  35. if "text" in data: # for sft
  36. self.contents.append(data['text'])
  37. if "source" in data: # for speech lab pretrain
  38. prompt = data["prompt"]
  39. source = data["source"]
  40. target = data["target"]
  41. source_len = data["source_len"]
  42. target_len = data["target_len"]
  43. contents.append({"source": source,
  44. "prompt": prompt,
  45. "target": target,
  46. "source_len": source_len,
  47. "target_len": target_len,
  48. }
  49. )
  50. self.contents = []
  51. total_num = len(contents)
  52. num_per_rank = total_num // data_parallel_size
  53. # rank = dist.get_rank()
  54. rank = 0
  55. # import ipdb; ipdb.set_trace()
  56. self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
  57. def __len__(self):
  58. return len(self.contents)
  59. def __getitem__(self, index):
  60. return self.contents[index]
  61. class AudioDataset(torch.utils.data.Dataset):
  62. def __init__(self, path, frontend=None, tokenizer=None):
  63. super().__init__()
  64. self.indexed_dataset = IndexedDatasetJsonl(path)
  65. self.frontend = frontend.forward
  66. self.fs = 16000 if frontend is None else frontend.fs
  67. self.data_type = "sound"
  68. self.tokenizer = tokenizer
  69. self.int_pad_value = -1
  70. self.float_pad_value = 0.0
  71. def __len__(self):
  72. return len(self.indexed_dataset)
  73. def __getitem__(self, index):
  74. item = self.indexed_dataset[index]
  75. source = item["source"]
  76. data_src = load_audio(source, fs=self.fs)
  77. speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
  78. target = item["target"]
  79. text = self.tokenizer.encode(target)
  80. text_lengths = len(text)
  81. text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
  82. return {"speech": speech,
  83. "speech_lengths": speech_lengths,
  84. "text": text,
  85. "text_lengths": text_lengths,
  86. }
  87. def collator(self, samples: list=None):
  88. outputs = {}
  89. for sample in samples:
  90. for key in sample.keys():
  91. if key not in outputs:
  92. outputs[key] = []
  93. outputs[key].append(sample[key])
  94. for key, data_list in outputs.items():
  95. if data_list[0].dtype.kind == "i":
  96. pad_value = self.int_pad_value
  97. else:
  98. pad_value = self.float_pad_value
  99. outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
  100. return samples