|
|
@@ -37,13 +37,13 @@ class AudioLLMDataset(torch.utils.data.Dataset):
|
|
|
self.data_type = "sound"
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
- self.int_pad_value = int_pad_value
|
|
|
self.float_pad_value = float_pad_value
|
|
|
self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
|
|
|
self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
|
|
|
self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
|
|
|
self.prompt_af = ""
|
|
|
self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
|
|
|
+ self.int_pad_value = self.IGNORE_INDEX
|
|
|
|
|
|
def get_source_len(self, index):
|
|
|
item = self.index_ds[index]
|