datasets.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import torch
  2. import json
  3. import torch.distributed as dist
  4. import numpy as np
  5. import kaldiio
  6. import librosa
  7. import torchaudio
  8. import time
  9. import logging
  10. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  11. from funasr.register import tables
  12. @tables.register("dataset_classes", "AudioDataset")
  13. class AudioDataset(torch.utils.data.Dataset):
  14. """
  15. AudioDataset
  16. """
  17. def __init__(self,
  18. path,
  19. index_ds: str = None,
  20. frontend=None,
  21. tokenizer=None,
  22. int_pad_value: int = -1,
  23. float_pad_value: float = 0.0,
  24. **kwargs):
  25. super().__init__()
  26. index_ds_class = tables.index_ds_classes.get(index_ds)
  27. self.index_ds = index_ds_class(path)
  28. preprocessor_speech = kwargs.get("preprocessor_speech", None)
  29. if preprocessor_speech:
  30. preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
  31. preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
  32. self.preprocessor_speech = preprocessor_speech
  33. preprocessor_text = kwargs.get("preprocessor_text", None)
  34. if preprocessor_text:
  35. preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
  36. preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
  37. self.preprocessor_text = preprocessor_text
  38. self.frontend = frontend
  39. self.fs = 16000 if frontend is None else frontend.fs
  40. self.data_type = "sound"
  41. self.tokenizer = tokenizer
  42. self.int_pad_value = int_pad_value
  43. self.float_pad_value = float_pad_value
  44. def get_source_len(self, index):
  45. item = self.index_ds[index]
  46. return self.index_ds.get_source_len(item)
  47. def get_target_len(self, index):
  48. item = self.index_ds[index]
  49. return self.index_ds.get_target_len(item)
  50. def __len__(self):
  51. return len(self.index_ds)
  52. def __getitem__(self, index):
  53. item = self.index_ds[index]
  54. # import pdb;
  55. # pdb.set_trace()
  56. source = item["source"]
  57. data_src = load_audio(source, fs=self.fs)
  58. if self.preprocessor_speech:
  59. data_src = self.preprocessor_speech(data_src)
  60. speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
  61. target = item["target"]
  62. if self.preprocessor_text:
  63. target = self.preprocessor_text(target)
  64. ids = self.tokenizer.encode(target)
  65. ids_lengths = len(ids)
  66. text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
  67. return {"speech": speech[0, :, :],
  68. "speech_lengths": speech_lengths,
  69. "text": text,
  70. "text_lengths": text_lengths,
  71. }
  72. def collator(self, samples: list=None):
  73. outputs = {}
  74. for sample in samples:
  75. for key in sample.keys():
  76. if key not in outputs:
  77. outputs[key] = []
  78. outputs[key].append(sample[key])
  79. for key, data_list in outputs.items():
  80. if data_list[0].dtype == torch.int64:
  81. pad_value = self.int_pad_value
  82. else:
  83. pad_value = self.float_pad_value
  84. outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
  85. return outputs