datasets.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import torch
  2. import copy
  3. from funasr.register import tables
  4. from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
  5. @tables.register("dataset_classes", "AudioLLMDataset")
  6. class AudioLLMDataset(torch.utils.data.Dataset):
  7. """
  8. AudioLLMDataset
  9. """
  10. def __init__(self,
  11. path,
  12. index_ds: str = None,
  13. frontend=None,
  14. tokenizer=None,
  15. int_pad_value: int = -1,
  16. float_pad_value: float = 0.0,
  17. **kwargs):
  18. super().__init__()
  19. index_ds_class = tables.index_ds_classes.get(index_ds)
  20. self.index_ds = index_ds_class(path, **kwargs)
  21. preprocessor_speech = kwargs.get("preprocessor_speech", None)
  22. if preprocessor_speech:
  23. preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
  24. preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
  25. self.preprocessor_speech = preprocessor_speech
  26. preprocessor_text = kwargs.get("preprocessor_text", None)
  27. if preprocessor_text:
  28. preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
  29. preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
  30. self.preprocessor_text = preprocessor_text
  31. self.frontend = frontend
  32. self.fs = 16000 if frontend is None else frontend.fs
  33. self.data_type = "sound"
  34. self.tokenizer = tokenizer
  35. self.float_pad_value = float_pad_value
  36. self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
  37. self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
  38. self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
  39. self.prompt_af = ""
  40. self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
  41. self.int_pad_value = self.IGNORE_INDEX
  42. def get_source_len(self, index):
  43. item = self.index_ds[index]
  44. return self.index_ds.get_source_len(item)
  45. def get_target_len(self, index):
  46. item = self.index_ds[index]
  47. return self.index_ds.get_target_len(item)
  48. def __len__(self):
  49. return len(self.index_ds)
  50. def __getitem__(self, index):
  51. item = self.index_ds[index]
  52. # import pdb;
  53. # pdb.set_trace()
  54. source = item["source"]
  55. data_src = load_audio_text_image_video(source, fs=self.fs)
  56. if self.preprocessor_speech:
  57. data_src = self.preprocessor_speech(data_src, fs=self.fs)
  58. speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
  59. speech = speech.squeeze(0)
  60. target = item["target"]
  61. if self.preprocessor_text:
  62. target = self.preprocessor_text(target)
  63. prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
  64. prompt_pre_length = len(prompt_ids_pre)
  65. prompt_input = "{}{}".format(self.prompt_pre, target)
  66. prompt_input_ids = self.tokenizer.encode(prompt_input)
  67. audio_length = len(prompt_input_ids) - prompt_pre_length
  68. input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
  69. input_ids = torch.tensor(input_ids, dtype=torch.int64) #[bos, prompt, input, pad]
  70. input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1]
  71. attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
  72. prompt_answer = "{}{}".format(self.prompt_pre, target)
  73. prompt_answer_ids = self.tokenizer.encode(prompt_answer)
  74. answer_length = len(prompt_answer_ids) - prompt_pre_length
  75. labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
  76. labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, input, eos]
  77. labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos]
  78. label_mask = labels_ids.ge(0) # [False,False,True,True]
  79. labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
  80. audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
  81. audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
  82. ids = self.tokenizer.encode(target) # token ids is different from labels_ids
  83. text = torch.tensor(ids, dtype=torch.int64)
  84. text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
  85. return {"speech": speech,
  86. "speech_lengths": speech_lengths,
  87. "text": text,
  88. "text_lengths": text_lengths,
  89. "input_ids": input_ids,
  90. "attention_mask": attention_mask,
  91. "labels_ids": labels_ids,
  92. "label_mask": label_mask,
  93. "audio_mask": audio_mask,
  94. }
  95. def collator(self, samples: list=None):
  96. outputs = {}
  97. for sample in samples:
  98. for key in sample.keys():
  99. if key not in outputs:
  100. outputs[key] = []
  101. outputs[key].append(sample[key])
  102. for key, data_list in outputs.items():
  103. if isinstance(data_list[0], torch.Tensor):
  104. if data_list[0].dtype == torch.int64:
  105. pad_value = self.int_pad_value
  106. else:
  107. pad_value = self.float_pad_value
  108. outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
  109. return outputs
  110. @tables.register("dataset_classes", "AudioLLMARDataset")
  111. class AudioLLMARDataset(torch.utils.data.Dataset):
  112. """
  113. AudioLLMDataset
  114. """
  115. def __init__(self,
  116. path,
  117. index_ds: str = None,
  118. frontend=None,
  119. tokenizer=None,
  120. int_pad_value: int = -1,
  121. float_pad_value: float = 0.0,
  122. **kwargs):
  123. super().__init__()
  124. index_ds_class = tables.index_ds_classes.get(index_ds)
  125. self.index_ds = index_ds_class(path, **kwargs)
  126. preprocessor_speech = kwargs.get("preprocessor_speech", None)
  127. if preprocessor_speech:
  128. preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
  129. preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
  130. self.preprocessor_speech = preprocessor_speech
  131. preprocessor_text = kwargs.get("preprocessor_text", None)
  132. if preprocessor_text:
  133. preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
  134. preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
  135. self.preprocessor_text = preprocessor_text
  136. self.frontend = frontend
  137. self.fs = 16000 if frontend is None else frontend.fs
  138. self.data_type = "sound"
  139. self.tokenizer = tokenizer
  140. self.float_pad_value = float_pad_value
  141. self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
  142. self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
  143. self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
  144. self.prompt_af = ""
  145. self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
  146. self.int_pad_value = self.IGNORE_INDEX
  147. def get_source_len(self, index):
  148. item = self.index_ds[index]
  149. return self.index_ds.get_source_len(item)
  150. def get_target_len(self, index):
  151. item = self.index_ds[index]
  152. return self.index_ds.get_target_len(item)
  153. def __len__(self):
  154. return len(self.index_ds)
  155. def __getitem__(self, index):
  156. item = self.index_ds[index]
  157. # import pdb;
  158. # pdb.set_trace()
  159. source = item["source"]
  160. data_src = load_audio_text_image_video(source, fs=self.fs)
  161. if self.preprocessor_speech:
  162. data_src = self.preprocessor_speech(data_src, fs=self.fs)
  163. speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
  164. is_final=True) # speech: [b, T, d]
  165. speech = speech.squeeze(0)
  166. target = item["target"]
  167. if self.preprocessor_text:
  168. target = self.preprocessor_text(target)
  169. prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
  170. prompt_pre_length = len(prompt_ids_pre)
  171. prompt_input = "{}{}".format(self.prompt_pre, target)
  172. prompt_input_ids = self.tokenizer.encode(prompt_input)
  173. audio_length = len(prompt_input_ids) - prompt_pre_length
  174. input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
  175. input_ids = torch.tensor(input_ids, dtype=torch.int64) # [bos, prompt, input, pad]
  176. input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1]
  177. attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
  178. prompt_answer = "{}{}".format(self.prompt_pre, target)
  179. prompt_answer_ids = self.tokenizer.encode(prompt_answer)
  180. answer_length = len(prompt_answer_ids) - prompt_pre_length
  181. labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
  182. labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, input, eos]
  183. labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos]
  184. label_mask = labels_ids.ge(0) # [False,False,True,True]
  185. labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
  186. audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
  187. audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
  188. ids = self.tokenizer.encode(target) # token ids is different from labels_ids
  189. text = torch.tensor(ids, dtype=torch.int64)
  190. text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
  191. return {"speech": speech,
  192. "speech_lengths": speech_lengths,
  193. "text": text,
  194. "text_lengths": text_lengths,
  195. "input_ids": input_ids,
  196. "attention_mask": attention_mask,
  197. "labels_ids": labels_ids,
  198. "label_mask": label_mask,
  199. "audio_mask": audio_mask,
  200. }
  201. def collator(self, samples: list = None):
  202. outputs = {}
  203. for sample in samples:
  204. for key in sample.keys():
  205. if key not in outputs:
  206. outputs[key] = []
  207. outputs[key].append(sample[key])
  208. for key, data_list in outputs.items():
  209. if isinstance(data_list[0], torch.Tensor):
  210. if data_list[0].dtype == torch.int64:
  211. pad_value = self.int_pad_value
  212. else:
  213. pad_value = self.float_pad_value
  214. outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
  215. return outputs