|
|
@@ -315,8 +315,10 @@ class LLMASRNAR(nn.Module):
|
|
|
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
|
|
|
preds = torch.argmax(model_outputs.logits, -1)
|
|
|
text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
|
|
|
+
|
|
|
text = text[0].split(': ')[-1]
|
|
|
text = text.strip()
|
|
|
+
|
|
|
# preds = torch.argmax(model_outputs.logits, -1)
|
|
|
|
|
|
ibest_writer = None
|