|
|
@@ -21,6 +21,7 @@ from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
|
|
from funasr.utils import postprocess_utils
|
|
|
from funasr.utils.datadir_writer import DatadirWriter
|
|
|
from funasr.register import tables
|
|
|
+
|
|
|
import pdb
|
|
|
@tables.register("model_classes", "LCBNet")
|
|
|
class LCBNet(nn.Module):
|
|
|
@@ -92,6 +93,7 @@ class LCBNet(nn.Module):
|
|
|
bias_predictor_class = tables.encoder_classes.get(bias_predictor)
|
|
|
bias_predictor = bias_predictor_class(**bias_predictor_conf)
|
|
|
|
|
|
+
|
|
|
if decoder is not None:
|
|
|
decoder_class = tables.decoder_classes.get(decoder)
|
|
|
decoder = decoder_class(
|
|
|
@@ -272,15 +274,15 @@ class LCBNet(nn.Module):
|
|
|
ind: int
|
|
|
"""
|
|
|
with autocast(False):
|
|
|
-
|
|
|
+ pdb.set_trace()
|
|
|
# Data augmentation
|
|
|
if self.specaug is not None and self.training:
|
|
|
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
|
|
-
|
|
|
+ pdb.set_trace()
|
|
|
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
|
if self.normalize is not None:
|
|
|
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
|
|
-
|
|
|
+ pdb.set_trace()
|
|
|
# Forward encoder
|
|
|
# feats: (Batch, Length, Dim)
|
|
|
# -> encoder_out: (Batch, Length2, Dim2)
|
|
|
@@ -297,7 +299,7 @@ class LCBNet(nn.Module):
|
|
|
|
|
|
if intermediate_outs is not None:
|
|
|
return (encoder_out, intermediate_outs), encoder_out_lens
|
|
|
-
|
|
|
+ pdb.set_trace()
|
|
|
return encoder_out, encoder_out_lens
|
|
|
|
|
|
def _calc_att_loss(
|
|
|
@@ -442,6 +444,7 @@ class LCBNet(nn.Module):
|
|
|
|
|
|
speech = speech.to(device=kwargs["device"])
|
|
|
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
|
|
+ pdb.set_trace()
|
|
|
# Encoder
|
|
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
|
|
if isinstance(encoder_out, tuple):
|