|
|
@@ -294,24 +294,29 @@ class LLMASRNAR(nn.Module):
|
|
|
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
|
|
|
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
|
|
|
|
|
|
- model_outputs = self.llm.generate(
|
|
|
- inputs_embeds=inputs_embeds,
|
|
|
- max_length=kwargs.get("max_length", 200),
|
|
|
- max_new_tokens=kwargs.get("max_new_tokens", 200),
|
|
|
- num_beams=kwargs.get("num_beams", 4),
|
|
|
- do_sample=kwargs.get("do_sample", False),
|
|
|
- min_length=kwargs.get("min_length", 1),
|
|
|
- top_p=kwargs.get("top_p", 1.0),
|
|
|
- repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
|
|
- length_penalty=kwargs.get("length_penalty", 1.0),
|
|
|
- temperature=kwargs.get("temperature", 1.0),
|
|
|
- attention_mask=attention_mask,
|
|
|
- bos_token_id=tokenizer.bos_token_id,
|
|
|
- eos_token_id=tokenizer.eos_token_id,
|
|
|
- pad_token_id=tokenizer.pad_token_id
|
|
|
- )
|
|
|
+ # model_outputs = self.llm.generate(
|
|
|
+ # inputs_embeds=inputs_embeds,
|
|
|
+ # max_length=kwargs.get("max_length", 200),
|
|
|
+ # max_new_tokens=kwargs.get("max_new_tokens", 200),
|
|
|
+ # num_beams=kwargs.get("num_beams", 4),
|
|
|
+ # do_sample=kwargs.get("do_sample", False),
|
|
|
+ # min_length=kwargs.get("min_length", 1),
|
|
|
+ # top_p=kwargs.get("top_p", 1.0),
|
|
|
+ # repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
|
|
+ # length_penalty=kwargs.get("length_penalty", 1.0),
|
|
|
+ # temperature=kwargs.get("temperature", 1.0),
|
|
|
+ # attention_mask=attention_mask,
|
|
|
+ # bos_token_id=tokenizer.bos_token_id,
|
|
|
+ # eos_token_id=tokenizer.eos_token_id,
|
|
|
+ # pad_token_id=tokenizer.pad_token_id
|
|
|
+ # )
|
|
|
+
|
|
|
|
|
|
- text = tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
|
|
|
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
|
|
|
+ preds = torch.argmax(model_outputs.logits, -1)
|
|
|
+ text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
|
|
|
+ text = text.split(': "\n')[-1]
|
|
|
+ # preds = torch.argmax(model_outputs.logits, -1)
|
|
|
|
|
|
ibest_writer = None
|
|
|
if kwargs.get("output_dir") is not None:
|