|
|
@@ -14,14 +14,13 @@ from funasr.models.ctc.ctc import CTC
|
|
|
from funasr.utils import postprocess_utils
|
|
|
from funasr.metrics.compute_acc import th_accuracy
|
|
|
from funasr.utils.datadir_writer import DatadirWriter
|
|
|
-from funasr.models.paraformer.search import Hypothesis
|
|
|
from funasr.models.paraformer.cif_predictor import mae_loss
|
|
|
from funasr.train_utils.device_funcs import force_gatherable
|
|
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
|
|
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
|
|
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
|
|
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
|
|
-
|
|
|
+from funasr.models.scama.utils import sequence_mask
|
|
|
|
|
|
@tables.register("model_classes", "UniASR")
|
|
|
class UniASR(torch.nn.Module):
|
|
|
@@ -31,19 +30,37 @@ class UniASR(torch.nn.Module):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- specaug: Optional[str] = None,
|
|
|
- specaug_conf: Optional[Dict] = None,
|
|
|
+ specaug: str = None,
|
|
|
+ specaug_conf: dict = None,
|
|
|
normalize: str = None,
|
|
|
- normalize_conf: Optional[Dict] = None,
|
|
|
+ normalize_conf: dict = None,
|
|
|
encoder: str = None,
|
|
|
- encoder_conf: Optional[Dict] = None,
|
|
|
+ encoder_conf: dict = None,
|
|
|
+ encoder2: str = None,
|
|
|
+ encoder2_conf: dict = None,
|
|
|
decoder: str = None,
|
|
|
- decoder_conf: Optional[Dict] = None,
|
|
|
- ctc: str = None,
|
|
|
- ctc_conf: Optional[Dict] = None,
|
|
|
+ decoder_conf: dict = None,
|
|
|
+ decoder2: str = None,
|
|
|
+ decoder2_conf: dict = None,
|
|
|
predictor: str = None,
|
|
|
- predictor_conf: Optional[Dict] = None,
|
|
|
+ predictor_conf: dict = None,
|
|
|
+ predictor_bias: int = 0,
|
|
|
+ predictor_weight: float = 0.0,
|
|
|
+ predictor2: str = None,
|
|
|
+ predictor2_conf: dict = None,
|
|
|
+ predictor2_bias: int = 0,
|
|
|
+ predictor2_weight: float = 0.0,
|
|
|
+ ctc: str = None,
|
|
|
+ ctc_conf: dict = None,
|
|
|
ctc_weight: float = 0.5,
|
|
|
+ ctc2: str = None,
|
|
|
+ ctc2_conf: dict = None,
|
|
|
+ ctc2_weight: float = 0.5,
|
|
|
+ decoder_attention_chunk_type: str = 'chunk',
|
|
|
+ decoder_attention_chunk_type2: str = 'chunk',
|
|
|
+ stride_conv=None,
|
|
|
+ stride_conv_conf: dict = None,
|
|
|
+ loss_weight_model1: float = 0.5,
|
|
|
input_size: int = 80,
|
|
|
vocab_size: int = -1,
|
|
|
ignore_id: int = -1,
|
|
|
@@ -52,60 +69,72 @@ class UniASR(torch.nn.Module):
|
|
|
eos: int = 2,
|
|
|
lsm_weight: float = 0.0,
|
|
|
length_normalized_loss: bool = False,
|
|
|
- # report_cer: bool = True,
|
|
|
- # report_wer: bool = True,
|
|
|
- # sym_space: str = "<space>",
|
|
|
- # sym_blank: str = "<blank>",
|
|
|
- # extract_feats_in_collect_stats: bool = True,
|
|
|
- # predictor=None,
|
|
|
- predictor_weight: float = 0.0,
|
|
|
- predictor_bias: int = 0,
|
|
|
- sampling_ratio: float = 0.2,
|
|
|
share_embedding: bool = False,
|
|
|
- # preencoder: Optional[AbsPreEncoder] = None,
|
|
|
- # postencoder: Optional[AbsPostEncoder] = None,
|
|
|
- use_1st_decoder_loss: bool = False,
|
|
|
- encoder1_encoder2_joint_training: bool = True,
|
|
|
**kwargs,
|
|
|
|
|
|
):
|
|
|
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
|
|
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
|
|
-
|
|
|
super().__init__()
|
|
|
- self.blank_id = 0
|
|
|
- self.sos = 1
|
|
|
- self.eos = 2
|
|
|
+
|
|
|
+ if specaug is not None:
|
|
|
+ specaug_class = tables.specaug_classes.get(specaug)
|
|
|
+ specaug = specaug_class(**specaug_conf)
|
|
|
+ if normalize is not None:
|
|
|
+ normalize_class = tables.normalize_classes.get(normalize)
|
|
|
+ normalize = normalize_class(**normalize_conf)
|
|
|
+
|
|
|
+ encoder_class = tables.encoder_classes.get(encoder)
|
|
|
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
|
|
|
+ encoder_output_size = encoder.output_size()
|
|
|
+
|
|
|
+ decoder_class = tables.decoder_classes.get(decoder)
|
|
|
+ decoder = decoder_class(
|
|
|
+ vocab_size=vocab_size,
|
|
|
+ encoder_output_size=encoder_output_size,
|
|
|
+ **decoder_conf,
|
|
|
+ )
|
|
|
+ predictor_class = tables.predictor_classes.get(predictor)
|
|
|
+ predictor = predictor_class(**predictor_conf)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ from funasr.models.transformer.utils.subsampling import Conv1dSubsampling
|
|
|
+ stride_conv = Conv1dSubsampling(**stride_conv_conf, idim=input_size + encoder_output_size,
|
|
|
+ odim=input_size + encoder_output_size)
|
|
|
+ stride_conv_output_size = stride_conv.output_size()
|
|
|
+
|
|
|
+ encoder_class = tables.encoder_classes.get(encoder2)
|
|
|
+ encoder2 = encoder_class(input_size=stride_conv_output_size, **encoder2_conf)
|
|
|
+ encoder2_output_size = encoder2.output_size()
|
|
|
+
|
|
|
+ decoder_class = tables.decoder_classes.get(decoder2)
|
|
|
+ decoder2 = decoder_class(
|
|
|
+ vocab_size=vocab_size,
|
|
|
+ encoder_output_size=encoder2_output_size,
|
|
|
+ **decoder2_conf,
|
|
|
+ )
|
|
|
+ predictor_class = tables.predictor_classes.get(predictor2)
|
|
|
+ predictor2 = predictor_class(**predictor2_conf)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ self.blank_id = blank_id
|
|
|
+ self.sos = sos
|
|
|
+ self.eos = eos
|
|
|
self.vocab_size = vocab_size
|
|
|
self.ignore_id = ignore_id
|
|
|
self.ctc_weight = ctc_weight
|
|
|
- self.interctc_weight = interctc_weight
|
|
|
- self.token_list = token_list.copy()
|
|
|
+ self.ctc2_weight = ctc2_weight
|
|
|
|
|
|
- self.frontend = frontend
|
|
|
self.specaug = specaug
|
|
|
self.normalize = normalize
|
|
|
- self.preencoder = preencoder
|
|
|
- self.postencoder = postencoder
|
|
|
+
|
|
|
self.encoder = encoder
|
|
|
|
|
|
- if not hasattr(self.encoder, "interctc_use_conditioning"):
|
|
|
- self.encoder.interctc_use_conditioning = False
|
|
|
- if self.encoder.interctc_use_conditioning:
|
|
|
- self.encoder.conditioning_layer = torch.nn.Linear(
|
|
|
- vocab_size, self.encoder.output_size()
|
|
|
- )
|
|
|
-
|
|
|
self.error_calculator = None
|
|
|
|
|
|
- # we set self.decoder = None in the CTC mode since
|
|
|
- # self.decoder parameters were never used and PyTorch complained
|
|
|
- # and threw an Exception in the multi-GPU experiment.
|
|
|
- # thanks Jeff Farris for pointing out the issue.
|
|
|
- if ctc_weight == 1.0:
|
|
|
- self.decoder = None
|
|
|
- else:
|
|
|
- self.decoder = decoder
|
|
|
+ self.decoder = decoder
|
|
|
+ self.ctc = None
|
|
|
+ self.ctc2 = None
|
|
|
|
|
|
self.criterion_att = LabelSmoothingLoss(
|
|
|
size=vocab_size,
|
|
|
@@ -113,22 +142,13 @@ class UniASR(torch.nn.Module):
|
|
|
smoothing=lsm_weight,
|
|
|
normalize_length=length_normalized_loss,
|
|
|
)
|
|
|
-
|
|
|
- if report_cer or report_wer:
|
|
|
- self.error_calculator = ErrorCalculator(
|
|
|
- token_list, sym_space, sym_blank, report_cer, report_wer
|
|
|
- )
|
|
|
-
|
|
|
- if ctc_weight == 0.0:
|
|
|
- self.ctc = None
|
|
|
- else:
|
|
|
- self.ctc = ctc
|
|
|
-
|
|
|
- self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
|
|
+
|
|
|
self.predictor = predictor
|
|
|
self.predictor_weight = predictor_weight
|
|
|
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
|
|
|
- self.step_cur = 0
|
|
|
+ self.encoder1_encoder2_joint_training = kwargs.get("encoder1_encoder2_joint_training", True)
|
|
|
+
|
|
|
+
|
|
|
if self.encoder.overlap_chunk_cls is not None:
|
|
|
from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
|
|
|
self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
|
|
|
@@ -136,14 +156,10 @@ class UniASR(torch.nn.Module):
|
|
|
|
|
|
self.encoder2 = encoder2
|
|
|
self.decoder2 = decoder2
|
|
|
- self.ctc_weight2 = ctc_weight2
|
|
|
- if ctc_weight2 == 0.0:
|
|
|
- self.ctc2 = None
|
|
|
- else:
|
|
|
- self.ctc2 = ctc2
|
|
|
- self.interctc_weight2 = interctc_weight2
|
|
|
+ self.ctc2_weight = ctc2_weight
|
|
|
+
|
|
|
self.predictor2 = predictor2
|
|
|
- self.predictor_weight2 = predictor_weight2
|
|
|
+ self.predictor2_weight = predictor2_weight
|
|
|
self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
|
|
|
self.stride_conv = stride_conv
|
|
|
self.loss_weight_model1 = loss_weight_model1
|
|
|
@@ -152,10 +168,10 @@ class UniASR(torch.nn.Module):
|
|
|
self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
|
|
|
self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
|
|
|
|
|
|
- self.enable_maas_finetune = enable_maas_finetune
|
|
|
- self.freeze_encoder2 = freeze_encoder2
|
|
|
- self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
|
|
|
self.length_normalized_loss = length_normalized_loss
|
|
|
+ self.enable_maas_finetune = kwargs.get("enable_maas_finetune", False)
|
|
|
+ self.freeze_encoder2 = kwargs.get("freeze_encoder2", False)
|
|
|
+ self.beam_search = None
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
@@ -163,7 +179,7 @@ class UniASR(torch.nn.Module):
|
|
|
speech_lengths: torch.Tensor,
|
|
|
text: torch.Tensor,
|
|
|
text_lengths: torch.Tensor,
|
|
|
- decoding_ind: int = None,
|
|
|
+ **kwargs,
|
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
|
"""Frontend + Encoder + Decoder + Calc loss
|
|
|
Args:
|
|
|
@@ -172,19 +188,14 @@ class UniASR(torch.nn.Module):
|
|
|
text: (Batch, Length)
|
|
|
text_lengths: (Batch,)
|
|
|
"""
|
|
|
- assert text_lengths.dim() == 1, text_lengths.shape
|
|
|
- # Check that batch_size is unified
|
|
|
- assert (
|
|
|
- speech.shape[0]
|
|
|
- == speech_lengths.shape[0]
|
|
|
- == text.shape[0]
|
|
|
- == text_lengths.shape[0]
|
|
|
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
|
|
+ decoding_ind = kwargs.get("decoding_ind", None)
|
|
|
+ if len(text_lengths.size()) > 1:
|
|
|
+ text_lengths = text_lengths[:, 0]
|
|
|
+ if len(speech_lengths.size()) > 1:
|
|
|
+ speech_lengths = speech_lengths[:, 0]
|
|
|
+
|
|
|
batch_size = speech.shape[0]
|
|
|
|
|
|
- # for data-parallel
|
|
|
- text = text[:, : text_lengths.max()]
|
|
|
- speech = speech[:, :speech_lengths.max()]
|
|
|
|
|
|
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
|
|
|
# 1. Encoder
|
|
|
@@ -194,10 +205,6 @@ class UniASR(torch.nn.Module):
|
|
|
else:
|
|
|
speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
|
|
|
|
|
|
- intermediate_outs = None
|
|
|
- if isinstance(encoder_out, tuple):
|
|
|
- intermediate_outs = encoder_out[1]
|
|
|
- encoder_out = encoder_out[0]
|
|
|
|
|
|
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
|
|
loss_ctc, cer_ctc = None, None
|
|
|
@@ -210,62 +217,12 @@ class UniASR(torch.nn.Module):
|
|
|
# 1. CTC branch
|
|
|
if self.enable_maas_finetune:
|
|
|
with torch.no_grad():
|
|
|
- if self.ctc_weight != 0.0:
|
|
|
- if self.encoder.overlap_chunk_cls is not None:
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
|
|
|
- encoder_out_lens,
|
|
|
- chunk_outs=None)
|
|
|
- loss_ctc, cer_ctc = self._calc_ctc_loss(
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
|
|
- )
|
|
|
-
|
|
|
- # Collect CTC branch stats
|
|
|
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
|
|
- stats["cer_ctc"] = cer_ctc
|
|
|
-
|
|
|
- # Intermediate CTC (optional)
|
|
|
- loss_interctc = 0.0
|
|
|
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
|
|
- for layer_idx, intermediate_out in intermediate_outs:
|
|
|
- # we assume intermediate_out has the same length & padding
|
|
|
- # as those of encoder_out
|
|
|
- if self.encoder.overlap_chunk_cls is not None:
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc = \
|
|
|
- self.encoder.overlap_chunk_cls.remove_chunk(
|
|
|
- intermediate_out,
|
|
|
- encoder_out_lens,
|
|
|
- chunk_outs=None)
|
|
|
- loss_ic, cer_ic = self._calc_ctc_loss(
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
|
|
- )
|
|
|
- loss_interctc = loss_interctc + loss_ic
|
|
|
-
|
|
|
- # Collect Intermedaite CTC stats
|
|
|
- stats["loss_interctc_layer{}".format(layer_idx)] = (
|
|
|
- loss_ic.detach() if loss_ic is not None else None
|
|
|
- )
|
|
|
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
|
|
-
|
|
|
- loss_interctc = loss_interctc / len(intermediate_outs)
|
|
|
-
|
|
|
- # calculate whole encoder loss
|
|
|
- loss_ctc = (
|
|
|
- 1 - self.interctc_weight
|
|
|
- ) * loss_ctc + self.interctc_weight * loss_interctc
|
|
|
-
|
|
|
- # 2b. Attention decoder branch
|
|
|
- if self.ctc_weight != 1.0:
|
|
|
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
|
|
- encoder_out, encoder_out_lens, text, text_lengths
|
|
|
- )
|
|
|
-
|
|
|
- # 3. CTC-Att loss definition
|
|
|
- if self.ctc_weight == 0.0:
|
|
|
- loss = loss_att + loss_pre * self.predictor_weight
|
|
|
- elif self.ctc_weight == 1.0:
|
|
|
- loss = loss_ctc
|
|
|
- else:
|
|
|
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
|
|
+
|
|
|
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
|
|
+ encoder_out, encoder_out_lens, text, text_lengths
|
|
|
+ )
|
|
|
+
|
|
|
+ loss = loss_att + loss_pre * self.predictor_weight
|
|
|
|
|
|
# Collect Attn branch stats
|
|
|
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
|
|
@@ -274,62 +231,13 @@ class UniASR(torch.nn.Module):
|
|
|
stats["wer"] = wer_att
|
|
|
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
|
|
else:
|
|
|
- if self.ctc_weight != 0.0:
|
|
|
- if self.encoder.overlap_chunk_cls is not None:
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
|
|
|
- encoder_out_lens,
|
|
|
- chunk_outs=None)
|
|
|
- loss_ctc, cer_ctc = self._calc_ctc_loss(
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
|
|
- )
|
|
|
+
|
|
|
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
|
|
+ encoder_out, encoder_out_lens, text, text_lengths
|
|
|
+ )
|
|
|
|
|
|
- # Collect CTC branch stats
|
|
|
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
|
|
- stats["cer_ctc"] = cer_ctc
|
|
|
-
|
|
|
- # Intermediate CTC (optional)
|
|
|
- loss_interctc = 0.0
|
|
|
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
|
|
- for layer_idx, intermediate_out in intermediate_outs:
|
|
|
- # we assume intermediate_out has the same length & padding
|
|
|
- # as those of encoder_out
|
|
|
- if self.encoder.overlap_chunk_cls is not None:
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc = \
|
|
|
- self.encoder.overlap_chunk_cls.remove_chunk(
|
|
|
- intermediate_out,
|
|
|
- encoder_out_lens,
|
|
|
- chunk_outs=None)
|
|
|
- loss_ic, cer_ic = self._calc_ctc_loss(
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
|
|
- )
|
|
|
- loss_interctc = loss_interctc + loss_ic
|
|
|
-
|
|
|
- # Collect Intermedaite CTC stats
|
|
|
- stats["loss_interctc_layer{}".format(layer_idx)] = (
|
|
|
- loss_ic.detach() if loss_ic is not None else None
|
|
|
- )
|
|
|
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
|
|
-
|
|
|
- loss_interctc = loss_interctc / len(intermediate_outs)
|
|
|
-
|
|
|
- # calculate whole encoder loss
|
|
|
- loss_ctc = (
|
|
|
- 1 - self.interctc_weight
|
|
|
- ) * loss_ctc + self.interctc_weight * loss_interctc
|
|
|
-
|
|
|
- # 2b. Attention decoder branch
|
|
|
- if self.ctc_weight != 1.0:
|
|
|
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
|
|
|
- encoder_out, encoder_out_lens, text, text_lengths
|
|
|
- )
|
|
|
|
|
|
- # 3. CTC-Att loss definition
|
|
|
- if self.ctc_weight == 0.0:
|
|
|
- loss = loss_att + loss_pre * self.predictor_weight
|
|
|
- elif self.ctc_weight == 1.0:
|
|
|
- loss = loss_ctc
|
|
|
- else:
|
|
|
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
|
|
+ loss = loss_att + loss_pre * self.predictor_weight
|
|
|
|
|
|
# Collect Attn branch stats
|
|
|
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
|
|
@@ -354,67 +262,14 @@ class UniASR(torch.nn.Module):
|
|
|
if isinstance(encoder_out, tuple):
|
|
|
intermediate_outs = encoder_out[1]
|
|
|
encoder_out = encoder_out[0]
|
|
|
- # CTC2
|
|
|
- if self.ctc_weight2 != 0.0:
|
|
|
- if self.encoder2.overlap_chunk_cls is not None:
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc = \
|
|
|
- self.encoder2.overlap_chunk_cls.remove_chunk(
|
|
|
- encoder_out,
|
|
|
- encoder_out_lens,
|
|
|
- chunk_outs=None,
|
|
|
- )
|
|
|
- loss_ctc, cer_ctc = self._calc_ctc_loss2(
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
|
|
- )
|
|
|
-
|
|
|
- # Collect CTC branch stats
|
|
|
- stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None
|
|
|
- stats["cer_ctc2"] = cer_ctc
|
|
|
-
|
|
|
- # Intermediate CTC (optional)
|
|
|
- loss_interctc = 0.0
|
|
|
- if self.interctc_weight2 != 0.0 and intermediate_outs is not None:
|
|
|
- for layer_idx, intermediate_out in intermediate_outs:
|
|
|
- # we assume intermediate_out has the same length & padding
|
|
|
- # as those of encoder_out
|
|
|
- if self.encoder2.overlap_chunk_cls is not None:
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc = \
|
|
|
- self.encoder2.overlap_chunk_cls.remove_chunk(
|
|
|
- intermediate_out,
|
|
|
- encoder_out_lens,
|
|
|
- chunk_outs=None)
|
|
|
- loss_ic, cer_ic = self._calc_ctc_loss2(
|
|
|
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
|
|
|
- )
|
|
|
- loss_interctc = loss_interctc + loss_ic
|
|
|
-
|
|
|
- # Collect Intermedaite CTC stats
|
|
|
- stats["loss_interctc_layer{}2".format(layer_idx)] = (
|
|
|
- loss_ic.detach() if loss_ic is not None else None
|
|
|
- )
|
|
|
- stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic
|
|
|
|
|
|
- loss_interctc = loss_interctc / len(intermediate_outs)
|
|
|
|
|
|
- # calculate whole encoder loss
|
|
|
- loss_ctc = (
|
|
|
- 1 - self.interctc_weight2
|
|
|
- ) * loss_ctc + self.interctc_weight2 * loss_interctc
|
|
|
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
|
|
|
+ encoder_out, encoder_out_lens, text, text_lengths
|
|
|
+ )
|
|
|
|
|
|
- # 2b. Attention decoder branch
|
|
|
- if self.ctc_weight2 != 1.0:
|
|
|
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
|
|
|
- encoder_out, encoder_out_lens, text, text_lengths
|
|
|
- )
|
|
|
|
|
|
- # 3. CTC-Att loss definition
|
|
|
- if self.ctc_weight2 == 0.0:
|
|
|
- loss = loss_att + loss_pre * self.predictor_weight2
|
|
|
- elif self.ctc_weight2 == 1.0:
|
|
|
- loss = loss_ctc
|
|
|
- else:
|
|
|
- loss = self.ctc_weight2 * loss_ctc + (
|
|
|
- 1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2
|
|
|
+ loss = loss_att + loss_pre * self.predictor2_weight
|
|
|
|
|
|
# Collect Attn branch stats
|
|
|
stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
|
|
|
@@ -422,6 +277,7 @@ class UniASR(torch.nn.Module):
|
|
|
stats["cer2"] = cer_att
|
|
|
stats["wer2"] = wer_att
|
|
|
stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
|
|
+
|
|
|
loss2 = loss
|
|
|
|
|
|
loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
|
|
|
@@ -456,61 +312,31 @@ class UniASR(torch.nn.Module):
|
|
|
return {"feats": feats, "feats_lengths": feats_lengths}
|
|
|
|
|
|
def encode(
|
|
|
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
|
|
|
+ ):
|
|
|
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
|
Args:
|
|
|
speech: (Batch, Length, ...)
|
|
|
speech_lengths: (Batch, )
|
|
|
"""
|
|
|
+ ind = kwargs.get("ind", 0)
|
|
|
with autocast(False):
|
|
|
- # 1. Extract feats
|
|
|
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
|
-
|
|
|
- # 2. Data augmentation
|
|
|
+ # Data augmentation
|
|
|
if self.specaug is not None and self.training:
|
|
|
- feats, feats_lengths = self.specaug(feats, feats_lengths)
|
|
|
-
|
|
|
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
|
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
|
|
|
+
|
|
|
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
|
if self.normalize is not None:
|
|
|
- feats, feats_lengths = self.normalize(feats, feats_lengths)
|
|
|
- speech_raw = feats.clone().to(feats.device)
|
|
|
- # Pre-encoder, e.g. used for raw input data
|
|
|
- if self.preencoder is not None:
|
|
|
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
|
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
|
|
|
+
|
|
|
+ speech_raw = speech.clone().to(speech.device)
|
|
|
+
|
|
|
|
|
|
# 4. Forward encoder
|
|
|
- # feats: (Batch, Length, Dim)
|
|
|
- # -> encoder_out: (Batch, Length2, Dim2)
|
|
|
- if self.encoder.interctc_use_conditioning:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder(
|
|
|
- feats, feats_lengths, ctc=self.ctc, ind=ind
|
|
|
- )
|
|
|
- else:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
|
|
|
- intermediate_outs = None
|
|
|
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
|
|
|
if isinstance(encoder_out, tuple):
|
|
|
- intermediate_outs = encoder_out[1]
|
|
|
encoder_out = encoder_out[0]
|
|
|
|
|
|
- # Post-encoder, e.g. NLU
|
|
|
- if self.postencoder is not None:
|
|
|
- encoder_out, encoder_out_lens = self.postencoder(
|
|
|
- encoder_out, encoder_out_lens
|
|
|
- )
|
|
|
-
|
|
|
- assert encoder_out.size(0) == speech.size(0), (
|
|
|
- encoder_out.size(),
|
|
|
- speech.size(0),
|
|
|
- )
|
|
|
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
|
|
- encoder_out.size(),
|
|
|
- encoder_out_lens.max(),
|
|
|
- )
|
|
|
-
|
|
|
- if intermediate_outs is not None:
|
|
|
- return (encoder_out, intermediate_outs), encoder_out_lens
|
|
|
-
|
|
|
return speech_raw, encoder_out, encoder_out_lens
|
|
|
|
|
|
def encode2(
|
|
|
@@ -519,28 +345,15 @@ class UniASR(torch.nn.Module):
|
|
|
encoder_out_lens: torch.Tensor,
|
|
|
speech: torch.Tensor,
|
|
|
speech_lengths: torch.Tensor,
|
|
|
- ind: int = 0,
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
|
Args:
|
|
|
speech: (Batch, Length, ...)
|
|
|
speech_lengths: (Batch, )
|
|
|
"""
|
|
|
- # with autocast(False):
|
|
|
- # # 1. Extract feats
|
|
|
- # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
|
- #
|
|
|
- # # 2. Data augmentation
|
|
|
- # if self.specaug is not None and self.training:
|
|
|
- # feats, feats_lengths = self.specaug(feats, feats_lengths)
|
|
|
- #
|
|
|
- # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
|
- # if self.normalize is not None:
|
|
|
- # feats, feats_lengths = self.normalize(feats, feats_lengths)
|
|
|
-
|
|
|
- # Pre-encoder, e.g. used for raw input data
|
|
|
- # if self.preencoder is not None:
|
|
|
- # feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
|
+
|
|
|
+ ind = kwargs.get("ind", 0)
|
|
|
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
|
|
|
encoder_out,
|
|
|
encoder_out_lens,
|
|
|
@@ -557,55 +370,14 @@ class UniASR(torch.nn.Module):
|
|
|
# 4. Forward encoder
|
|
|
# feats: (Batch, Length, Dim)
|
|
|
# -> encoder_out: (Batch, Length2, Dim2)
|
|
|
- if self.encoder2.interctc_use_conditioning:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder2(
|
|
|
- speech, speech_lengths, ctc=self.ctc2, ind=ind
|
|
|
- )
|
|
|
- else:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
|
|
|
- intermediate_outs = None
|
|
|
+
|
|
|
+ encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
|
|
|
if isinstance(encoder_out, tuple):
|
|
|
- intermediate_outs = encoder_out[1]
|
|
|
encoder_out = encoder_out[0]
|
|
|
|
|
|
- # # Post-encoder, e.g. NLU
|
|
|
- # if self.postencoder is not None:
|
|
|
- # encoder_out, encoder_out_lens = self.postencoder(
|
|
|
- # encoder_out, encoder_out_lens
|
|
|
- # )
|
|
|
-
|
|
|
- assert encoder_out.size(0) == speech.size(0), (
|
|
|
- encoder_out.size(),
|
|
|
- speech.size(0),
|
|
|
- )
|
|
|
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
|
|
- encoder_out.size(),
|
|
|
- encoder_out_lens.max(),
|
|
|
- )
|
|
|
-
|
|
|
- if intermediate_outs is not None:
|
|
|
- return (encoder_out, intermediate_outs), encoder_out_lens
|
|
|
|
|
|
return encoder_out, encoder_out_lens
|
|
|
|
|
|
- def _extract_feats(
|
|
|
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- assert speech_lengths.dim() == 1, speech_lengths.shape
|
|
|
-
|
|
|
- # for data-parallel
|
|
|
- speech = speech[:, : speech_lengths.max()]
|
|
|
-
|
|
|
- if self.frontend is not None:
|
|
|
- # Frontend
|
|
|
- # e.g. STFT and Feature extract
|
|
|
- # data_loader may send time-domain signal in this case
|
|
|
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
|
|
- feats, feats_lengths = self.frontend(speech, speech_lengths)
|
|
|
- else:
|
|
|
- # No frontend and no feature extract
|
|
|
- feats, feats_lengths = speech, speech_lengths
|
|
|
- return feats, feats_lengths
|
|
|
|
|
|
def nll(
|
|
|
self,
|
|
|
@@ -1024,36 +796,152 @@ class UniASR(torch.nn.Module):
|
|
|
|
|
|
return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
|
|
|
|
|
|
- def _calc_ctc_loss(
|
|
|
- self,
|
|
|
- encoder_out: torch.Tensor,
|
|
|
- encoder_out_lens: torch.Tensor,
|
|
|
- ys_pad: torch.Tensor,
|
|
|
- ys_pad_lens: torch.Tensor,
|
|
|
- ):
|
|
|
- # Calc CTC loss
|
|
|
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
|
|
+ def init_beam_search(self,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+ from funasr.models.uniasr.beam_search import BeamSearchScama
|
|
|
+ from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
|
|
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
|
|
|
|
|
- # Calc CER using CTC
|
|
|
- cer_ctc = None
|
|
|
- if not self.training and self.error_calculator is not None:
|
|
|
- ys_hat = self.ctc.argmax(encoder_out).data
|
|
|
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
|
|
- return loss_ctc, cer_ctc
|
|
|
+ decoding_mode = kwargs.get("decoding_mode", "model1")
|
|
|
+ if decoding_mode == "model1":
|
|
|
+ decoder = self.decoder
|
|
|
+ else:
|
|
|
+ decoder = self.decoder2
|
|
|
+ # 1. Build ASR model
|
|
|
+ scorers = {}
|
|
|
+
|
|
|
+ if self.ctc != None:
|
|
|
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
|
|
|
+ scorers.update(
|
|
|
+ ctc=ctc
|
|
|
+ )
|
|
|
+ token_list = kwargs.get("token_list")
|
|
|
+ scorers.update(
|
|
|
+ decoder=decoder,
|
|
|
+ length_bonus=LengthBonus(len(token_list)),
|
|
|
+ )
|
|
|
+
|
|
|
+ # 3. Build ngram model
|
|
|
+ # ngram is not supported now
|
|
|
+ ngram = None
|
|
|
+ scorers["ngram"] = ngram
|
|
|
+
|
|
|
+ weights = dict(
|
|
|
+ decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
|
|
|
+ ctc=kwargs.get("decoding_ctc_weight", 0.0),
|
|
|
+ lm=kwargs.get("lm_weight", 0.0),
|
|
|
+ ngram=kwargs.get("ngram_weight", 0.0),
|
|
|
+ length_bonus=kwargs.get("penalty", 0.0),
|
|
|
+ )
|
|
|
+ beam_search = BeamSearchScama(
|
|
|
+ beam_size=kwargs.get("beam_size", 5),
|
|
|
+ weights=weights,
|
|
|
+ scorers=scorers,
|
|
|
+ sos=self.sos,
|
|
|
+ eos=self.eos,
|
|
|
+ vocab_size=len(token_list),
|
|
|
+ token_list=token_list,
|
|
|
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
|
|
|
+ )
|
|
|
+
|
|
|
+ self.beam_search = beam_search
|
|
|
+
|
|
|
+ def inference(self,
|
|
|
+ data_in,
|
|
|
+ data_lengths=None,
|
|
|
+ key: list = None,
|
|
|
+ tokenizer=None,
|
|
|
+ frontend=None,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+
|
|
|
+ decoding_model = kwargs.get("decoding_model", "normal")
|
|
|
+ token_num_relax = kwargs.get("token_num_relax", 5)
|
|
|
+ if decoding_model == "fast":
|
|
|
+ decoding_ind = 0
|
|
|
+ decoding_mode = "model1"
|
|
|
+ elif decoding_model == "offline":
|
|
|
+ decoding_ind = 1
|
|
|
+ decoding_mode = "model2"
|
|
|
+ else:
|
|
|
+ decoding_ind = 0
|
|
|
+ decoding_mode = "model2"
|
|
|
+ # init beamsearch
|
|
|
+
|
|
|
+ if self.beam_search is None:
|
|
|
+ logging.info("enable beam_search")
|
|
|
+ self.init_beam_search(decoding_mode=decoding_mode, **kwargs)
|
|
|
+ self.nbest = kwargs.get("nbest", 1)
|
|
|
+
|
|
|
+ meta_data = {}
|
|
|
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
|
|
+ speech, speech_lengths = data_in, data_lengths
|
|
|
+ if len(speech.shape) < 3:
|
|
|
+ speech = speech[None, :, :]
|
|
|
+ if speech_lengths is None:
|
|
|
+ speech_lengths = speech.shape[1]
|
|
|
+ else:
|
|
|
+ # extract fbank feats
|
|
|
+ time1 = time.perf_counter()
|
|
|
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
|
|
|
+ data_type=kwargs.get("data_type", "sound"),
|
|
|
+ tokenizer=tokenizer)
|
|
|
+ time2 = time.perf_counter()
|
|
|
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
|
|
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
|
|
+ frontend=frontend)
|
|
|
+ time3 = time.perf_counter()
|
|
|
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
|
|
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
|
|
+
|
|
|
+ speech = speech.to(device=kwargs["device"])
|
|
|
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
|
|
|
+ speech_raw = speech.clone().to(device=kwargs["device"])
|
|
|
+ # Encoder
|
|
|
+ _, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=decoding_ind)
|
|
|
+ if decoding_mode == "model1":
|
|
|
+ predictor_outs = self.calc_predictor_mask(encoder_out, encoder_out_lens)
|
|
|
+ else:
|
|
|
+ encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind)
|
|
|
+ predictor_outs = self.calc_predictor_mask2(encoder_out, encoder_out_lens)
|
|
|
+
|
|
|
+
|
|
|
+ scama_mask = predictor_outs[4]
|
|
|
+ pre_token_length = predictor_outs[1]
|
|
|
+ pre_acoustic_embeds = predictor_outs[0]
|
|
|
+ maxlen = pre_token_length.sum().item() + token_num_relax
|
|
|
+ minlen = max(0, pre_token_length.sum().item() - token_num_relax)
|
|
|
+ # c. Passed the encoder result and the beam search
|
|
|
+ nbest_hyps = self.beam_search(
|
|
|
+ x=encoder_out[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=0.0,
|
|
|
+ minlenratio=0.0, maxlen=int(maxlen), minlen=int(minlen),
|
|
|
+ )
|
|
|
|
|
|
- def _calc_ctc_loss2(
|
|
|
- self,
|
|
|
- encoder_out: torch.Tensor,
|
|
|
- encoder_out_lens: torch.Tensor,
|
|
|
- ys_pad: torch.Tensor,
|
|
|
- ys_pad_lens: torch.Tensor,
|
|
|
- ):
|
|
|
- # Calc CTC loss
|
|
|
- loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
|
|
-
|
|
|
- # Calc CER using CTC
|
|
|
- cer_ctc = None
|
|
|
- if not self.training and self.error_calculator is not None:
|
|
|
- ys_hat = self.ctc2.argmax(encoder_out).data
|
|
|
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
|
|
- return loss_ctc, cer_ctc
|
|
|
+ nbest_hyps = nbest_hyps[: self.nbest]
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for hyp in nbest_hyps:
|
|
|
+
|
|
|
+ # remove sos/eos and get results
|
|
|
+ last_pos = -1
|
|
|
+ if isinstance(hyp.yseq, list):
|
|
|
+ token_int = hyp.yseq[1:last_pos]
|
|
|
+ else:
|
|
|
+ token_int = hyp.yseq[1:last_pos].tolist()
|
|
|
+
|
|
|
+ # remove blank symbol id, which is assumed to be 0
|
|
|
+ token_int = list(filter(lambda x: x != 0, token_int))
|
|
|
+
|
|
|
+
|
|
|
+ # Change integer-ids to tokens
|
|
|
+ token = tokenizer.ids2tokens(token_int)
|
|
|
+ text_postprocessed = tokenizer.tokens2text(token)
|
|
|
+ if not hasattr(tokenizer, "bpemodel"):
|
|
|
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
|
|
+
|
|
|
+
|
|
|
+ result_i = {"key": key[0], "text": text_postprocessed}
|
|
|
+ results.append(result_i)
|
|
|
+
|
|
|
+ return results, meta_data
|