|
|
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.cuda.amp import autocast
|
|
|
|
|
|
+from funasr.models.scama.utils import sequence_mask
|
|
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
|
|
from funasr.models.ctc.ctc import CTC
|
|
|
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
|
|
@@ -19,8 +20,8 @@ from funasr.utils.datadir_writer import DatadirWriter
|
|
|
from funasr.register import tables
|
|
|
|
|
|
|
|
|
-@tables.register("model_classes", "LLMASR")
|
|
|
-class LLMASR(nn.Module):
|
|
|
+@tables.register("model_classes", "LLMASRNAR")
|
|
|
+class LLMASRNAR(nn.Module):
|
|
|
""" """
|
|
|
|
|
|
def __init__(
|
|
|
@@ -72,15 +73,13 @@ class LLMASR(nn.Module):
|
|
|
hub = encoder_conf.get("hub", None)
|
|
|
if hub == "funasr":
|
|
|
from funasr import AutoModel
|
|
|
- from funasr.models.scama.utils import sequence_mask
|
|
|
init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
|
|
model = AutoModel(model=init_param_path, model_revision="v2.0.4")
|
|
|
- frontend = model.kwargs.get("frontend")
|
|
|
+ # frontend = model.kwargs.get("frontend")
|
|
|
model.model.decoder = None
|
|
|
|
|
|
- self.model = model.model
|
|
|
- self.frontend = frontend
|
|
|
- self.mask_fn = sequence_mask
|
|
|
+ self.audio_encoder = model.model
|
|
|
+ # self.frontend = frontend
|
|
|
|
|
|
elif hub == "hf":
|
|
|
pass
|
|
|
@@ -102,8 +101,8 @@ class LLMASR(nn.Module):
|
|
|
device_map=None,
|
|
|
use_cache=None,
|
|
|
)
|
|
|
- freeze_llm = llm_conf.get("freeze_llm", True)
|
|
|
- if freeze_llm:
|
|
|
+ freeze = llm_conf.get("freeze", True)
|
|
|
+ if freeze:
|
|
|
for name, param in model.named_parameters():
|
|
|
param.requires_grad = False
|
|
|
model.eval()
|
|
|
@@ -151,9 +150,9 @@ class LLMASR(nn.Module):
|
|
|
text_lengths: torch.Tensor,
|
|
|
input_ids: torch.Tensor,
|
|
|
attention_mask:torch.Tensor,
|
|
|
- labels_ids:torch.Tensor,
|
|
|
+ labels_ids: torch.Tensor,
|
|
|
label_mask: torch.Tensor,
|
|
|
- audio_mask:torch.Tensor,
|
|
|
+ audio_mask: torch.Tensor,
|
|
|
**kwargs,
|
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
|
"""Encoder + Decoder + Calc loss
|
|
|
@@ -173,7 +172,7 @@ class LLMASR(nn.Module):
|
|
|
batch_size = speech.shape[0]
|
|
|
|
|
|
# audio encoder
|
|
|
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask)
|
|
|
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
|
|
|
|
|
|
# adaptor
|
|
|
encoder_out = self.adaptor(encoder_out)
|
|
|
@@ -194,18 +193,18 @@ class LLMASR(nn.Module):
|
|
|
inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
|
|
|
inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
|
|
|
|
|
|
- model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
|
|
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
|
|
|
loss = model_outputs.loss
|
|
|
|
|
|
- acc_att = -1
|
|
|
+
|
|
|
+ stats = {}
|
|
|
if self.metric:
|
|
|
with torch.no_grad():
|
|
|
preds = torch.argmax(model_outputs.logits, -1)
|
|
|
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
|
|
|
+ stats["acc"] = acc_att
|
|
|
|
|
|
- stats = {}
|
|
|
- # Collect Attn branch stats
|
|
|
- stats["acc"] = acc_att.detach()
|
|
|
+ stats["loss"] = torch.clone(loss.detach())
|
|
|
|
|
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
if self.length_normalized_loss:
|
|
|
@@ -221,47 +220,15 @@ class LLMASR(nn.Module):
|
|
|
audio_token_lengths = audio_mask.sum(-1)
|
|
|
|
|
|
batch = {"speech": speech, "speech_lengths": speech_lengths}
|
|
|
- enc, enc_lens = self.model.encode(**batch)
|
|
|
- enc_mask = self.mask_fn(enc_lens, enc.size(1), device=enc.device)[:, None, :]
|
|
|
- pre_acoustic_embeds, pre_token_length, _, _ = self.model.predictor(enc,
|
|
|
+ enc, enc_lens = self.audio_encoder.encode(**batch)
|
|
|
+ enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
|
|
|
+ pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
|
|
|
mask=enc_mask,
|
|
|
target_label_length=audio_token_lengths,
|
|
|
)
|
|
|
|
|
|
return pre_acoustic_embeds, pre_token_length
|
|
|
-
|
|
|
- def _calc_att_loss(
|
|
|
- self,
|
|
|
- encoder_out: torch.Tensor,
|
|
|
- encoder_out_lens: torch.Tensor,
|
|
|
- ys_pad: torch.Tensor,
|
|
|
- ys_pad_lens: torch.Tensor,
|
|
|
- ):
|
|
|
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
|
|
- ys_in_lens = ys_pad_lens + 1
|
|
|
-
|
|
|
- # 1. Forward decoder
|
|
|
- decoder_out, _ = self.decoder(
|
|
|
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
|
|
- )
|
|
|
-
|
|
|
- # 2. Compute attention loss
|
|
|
- loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
- acc_att = th_accuracy(
|
|
|
- decoder_out.view(-1, self.vocab_size),
|
|
|
- ys_out_pad,
|
|
|
- ignore_label=self.ignore_id,
|
|
|
- )
|
|
|
-
|
|
|
- # Compute cer/wer using attention-decoder
|
|
|
- if self.training or self.error_calculator is None:
|
|
|
- cer_att, wer_att = None, None
|
|
|
- else:
|
|
|
- ys_hat = decoder_out.argmax(dim=-1)
|
|
|
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
|
|
-
|
|
|
- return loss_att, acc_att, cer_att, wer_att
|
|
|
-
|
|
|
+
|
|
|
|
|
|
def inference(self,
|
|
|
data_in,
|