|
@@ -23,7 +23,7 @@ from funasr.models.decoder.rnnt_decoder import RNNTDecoder
|
|
|
from funasr.models.joint_net.joint_network import JointNetwork
|
|
from funasr.models.joint_net.joint_network import JointNetwork
|
|
|
from funasr.models.e2e_asr import ASRModel
|
|
from funasr.models.e2e_asr import ASRModel
|
|
|
from funasr.models.e2e_asr_mfcca import MFCCA
|
|
from funasr.models.e2e_asr_mfcca import MFCCA
|
|
|
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
|
|
|
|
|
|
|
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
|
|
|
from funasr.models.e2e_tp import TimestampPredictor
|
|
from funasr.models.e2e_tp import TimestampPredictor
|
|
|
from funasr.models.e2e_uni_asr import UniASR
|
|
from funasr.models.e2e_uni_asr import UniASR
|
|
|
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
|
|
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
|
|
@@ -82,6 +82,7 @@ model_choices = ClassChoices(
|
|
|
asr=ASRModel,
|
|
asr=ASRModel,
|
|
|
uniasr=UniASR,
|
|
uniasr=UniASR,
|
|
|
paraformer=Paraformer,
|
|
paraformer=Paraformer,
|
|
|
|
|
+ paraformer_online=ParaformerOnline,
|
|
|
paraformer_bert=ParaformerBert,
|
|
paraformer_bert=ParaformerBert,
|
|
|
bicif_paraformer=BiCifParaformer,
|
|
bicif_paraformer=BiCifParaformer,
|
|
|
contextual_paraformer=ContextualParaformer,
|
|
contextual_paraformer=ContextualParaformer,
|
|
@@ -293,7 +294,7 @@ def build_asr_model(args):
|
|
|
token_list=token_list,
|
|
token_list=token_list,
|
|
|
**args.model_conf,
|
|
**args.model_conf,
|
|
|
)
|
|
)
|
|
|
- elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
|
|
|
|
|
|
|
+ elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
|
|
|
# predictor
|
|
# predictor
|
|
|
predictor_class = predictor_choices.get_class(args.predictor)
|
|
predictor_class = predictor_choices.get_class(args.predictor)
|
|
|
predictor = predictor_class(**args.predictor_conf)
|
|
predictor = predictor_class(**args.predictor_conf)
|