游雁 2 жил өмнө
parent
commit
1ef8117213

+ 2 - 0
funasr/models/llm_asr_nar/model.py

@@ -315,8 +315,10 @@ class LLMASRNAR(nn.Module):
         model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
         preds = torch.argmax(model_outputs.logits, -1)
         text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+
         text = text[0].split(': ')[-1]
         text = text.strip()
+        
         # preds = torch.argmax(model_outputs.logits, -1)
         
         ibest_writer = None