|
|
@@ -21,6 +21,7 @@ from funasr.models.decoder.transformer_decoder import (
|
|
|
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
|
|
|
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
|
|
from funasr.models.e2e_asr import ASRModel
|
|
|
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
|
|
from funasr.models.e2e_asr_mfcca import MFCCA
|
|
|
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \
|
|
|
ContextualParaformer
|
|
|
@@ -87,6 +88,7 @@ model_choices = ClassChoices(
|
|
|
paraformer_bert=ParaformerBert,
|
|
|
bicif_paraformer=BiCifParaformer,
|
|
|
contextual_paraformer=ContextualParaformer,
|
|
|
+ neatcontextual_paraformer=NeatContextualParaformer,
|
|
|
mfcca=MFCCA,
|
|
|
timestamp_prediction=TimestampPredictor,
|
|
|
rnnt=TransducerModel,
|
|
|
@@ -267,7 +269,7 @@ def build_asr_model(args):
|
|
|
if args.normalize is not None:
|
|
|
normalize_class = normalize_choices.get_class(args.normalize)
|
|
|
if args.model == "mfcca":
|
|
|
- normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
|
|
|
+ normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
|
|
|
else:
|
|
|
normalize = normalize_class(**args.normalize_conf)
|
|
|
else:
|