|
|
@@ -279,7 +279,7 @@ class Paraformer(FunASRModel):
|
|
|
|
|
|
def encode(
|
|
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
|
|
|
- ) -> Tuple[Tuple[Any, Optional[Any]], Any]:
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
|
Args:
|
|
|
speech: (Batch, Length, ...)
|
|
|
@@ -649,7 +649,35 @@ class ParaformerOnline(Paraformer):
|
|
|
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
|
|
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
|
|
|
|
|
- super().__init__()
|
|
|
+ super().__init__(
|
|
|
+ vocab_size=vocab_size,
|
|
|
+ token_list=token_list,
|
|
|
+ frontend=frontend,
|
|
|
+ specaug=specaug,
|
|
|
+ normalize=normalize,
|
|
|
+ preencoder=preencoder,
|
|
|
+ encoder=encoder,
|
|
|
+ postencoder=postencoder,
|
|
|
+ decoder=decoder,
|
|
|
+ ctc=ctc,
|
|
|
+ ctc_weight=ctc_weight,
|
|
|
+ interctc_weight=interctc_weight,
|
|
|
+ ignore_id=ignore_id,
|
|
|
+ blank_id=blank_id,
|
|
|
+ sos=sos,
|
|
|
+ eos=eos,
|
|
|
+ lsm_weight=lsm_weight,
|
|
|
+ length_normalized_loss=length_normalized_loss,
|
|
|
+ report_cer=report_cer,
|
|
|
+ report_wer=report_wer,
|
|
|
+ sym_space=sym_space,
|
|
|
+ sym_blank=sym_blank,
|
|
|
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
|
|
|
+ predictor=predictor,
|
|
|
+ predictor_weight=predictor_weight,
|
|
|
+ predictor_bias=predictor_bias,
|
|
|
+ sampling_ratio=sampling_ratio,
|
|
|
+ )
|
|
|
# note that eos is the same as sos (equivalent ID)
|
|
|
self.blank_id = blank_id
|
|
|
self.sos = vocab_size - 1 if sos is None else sos
|
|
|
@@ -705,6 +733,7 @@ class ParaformerOnline(Paraformer):
|
|
|
self.sampling_ratio = sampling_ratio
|
|
|
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
|
|
|
self.step_cur = 0
|
|
|
+ self.scama_mask = None
|
|
|
if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
|
|
|
from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
|
|
|
self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
|
|
|
@@ -859,7 +888,7 @@ class ParaformerOnline(Paraformer):
|
|
|
# Pre-encoder, e.g. used for raw input data
|
|
|
if self.preencoder is not None:
|
|
|
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
|
-
|
|
|
+
|
|
|
# 4. Forward encoder
|
|
|
# feats: (Batch, Length, Dim)
|
|
|
# -> encoder_out: (Batch, Length2, Dim2)
|
|
|
@@ -1111,12 +1140,73 @@ class ParaformerOnline(Paraformer):
|
|
|
|
|
|
return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
|
|
|
|
|
|
+ def calc_predictor(self, encoder_out, encoder_out_lens):
|
|
|
+
|
|
|
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
|
|
+ encoder_out.device)
|
|
|
+ mask_chunk_predictor = None
|
|
|
+ if self.encoder.overlap_chunk_cls is not None:
|
|
|
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
|
|
|
+ device=encoder_out.device,
|
|
|
+ batch_size=encoder_out.size(
|
|
|
+ 0))
|
|
|
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
|
|
|
+ batch_size=encoder_out.size(0))
|
|
|
+ encoder_out = encoder_out * mask_shfit_chunk
|
|
|
+ pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
|
|
|
+ None,
|
|
|
+ encoder_out_mask,
|
|
|
+ ignore_id=self.ignore_id,
|
|
|
+ mask_chunk_predictor=mask_chunk_predictor,
|
|
|
+ target_label_length=None,
|
|
|
+ )
|
|
|
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1],
|
|
|
+ encoder_out_lens)
|
|
|
+
|
|
|
+ scama_mask = None
|
|
|
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
|
|
|
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
|
|
|
+ attention_chunk_center_bias = 0
|
|
|
+ attention_chunk_size = encoder_chunk_size
|
|
|
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
|
|
|
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
|
|
|
+ get_mask_shift_att_chunk_decoder(None,
|
|
|
+ device=encoder_out.device,
|
|
|
+ batch_size=encoder_out.size(0)
|
|
|
+ )
|
|
|
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
|
|
|
+ predictor_alignments=predictor_alignments,
|
|
|
+ encoder_sequence_length=encoder_out_lens,
|
|
|
+ chunk_size=1,
|
|
|
+ encoder_chunk_size=encoder_chunk_size,
|
|
|
+ attention_chunk_center_bias=attention_chunk_center_bias,
|
|
|
+ attention_chunk_size=attention_chunk_size,
|
|
|
+ attention_chunk_type=self.decoder_attention_chunk_type,
|
|
|
+ step=None,
|
|
|
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
|
|
|
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
|
|
|
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
|
|
|
+ target_length=None,
|
|
|
+ is_training=self.training,
|
|
|
+ )
|
|
|
+ self.scama_mask = scama_mask
|
|
|
+
|
|
|
+ return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
|
|
|
+
|
|
|
def calc_predictor_chunk(self, encoder_out, cache=None):
|
|
|
|
|
|
pre_acoustic_embeds, pre_token_length = \
|
|
|
self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
|
|
return pre_acoustic_embeds, pre_token_length
|
|
|
|
|
|
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
|
|
+ decoder_outs = self.decoder(
|
|
|
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
|
|
|
+ )
|
|
|
+ decoder_out = decoder_outs[0]
|
|
|
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
|
|
+ return decoder_out, ys_pad_lens
|
|
|
+
|
|
|
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
|
|
decoder_outs = self.decoder.forward_chunk(
|
|
|
encoder_out, sematic_embeds, cache["decoder"]
|