|
@@ -24,7 +24,7 @@ import torch
|
|
|
from packaging.version import parse as V
|
|
from packaging.version import parse as V
|
|
|
from typeguard import check_argument_types
|
|
from typeguard import check_argument_types
|
|
|
from typeguard import check_return_type
|
|
from typeguard import check_return_type
|
|
|
-
|
|
|
|
|
|
|
+from funasr.build_utils.build_model_from_file import build_model_from_file
|
|
|
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
|
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
|
|
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
|
|
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
|
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
|
@@ -35,9 +35,7 @@ from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransduc
|
|
|
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
|
|
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
|
|
|
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
|
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
|
|
from funasr.modules.scorers.length_bonus import LengthBonus
|
|
from funasr.modules.scorers.length_bonus import LengthBonus
|
|
|
-from funasr.tasks.asr import ASRTask
|
|
|
|
|
-from funasr.tasks.asr import frontend_choices
|
|
|
|
|
-from funasr.tasks.lm import LMTask
|
|
|
|
|
|
|
+from funasr.build_utils.build_asr_model import frontend_choices
|
|
|
from funasr.text.build_tokenizer import build_tokenizer
|
|
from funasr.text.build_tokenizer import build_tokenizer
|
|
|
from funasr.text.token_id_converter import TokenIDConverter
|
|
from funasr.text.token_id_converter import TokenIDConverter
|
|
|
from funasr.torch_utils.device_funcs import to_device
|
|
from funasr.torch_utils.device_funcs import to_device
|
|
@@ -84,15 +82,14 @@ class Speech2Text:
|
|
|
|
|
|
|
|
# 1. Build ASR model
|
|
# 1. Build ASR model
|
|
|
scorers = {}
|
|
scorers = {}
|
|
|
- asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
- asr_train_config, asr_model_file, cmvn_file, device
|
|
|
|
|
|
|
+ asr_model, asr_train_args = build_model_from_file(
|
|
|
|
|
+ asr_train_config, asr_model_file, cmvn_file, device, mode="asr"
|
|
|
)
|
|
)
|
|
|
frontend = None
|
|
frontend = None
|
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
|
if asr_train_args.frontend == 'wav_frontend':
|
|
if asr_train_args.frontend == 'wav_frontend':
|
|
|
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
|
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
|
|
|
else:
|
|
else:
|
|
|
- from funasr.tasks.asr import frontend_choices
|
|
|
|
|
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
|
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
|
|
|
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
|
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
|
|
|
|
|
|
|
@@ -112,7 +109,7 @@ class Speech2Text:
|
|
|
|
|
|
|
|
# 2. Build Language model
|
|
# 2. Build Language model
|
|
|
if lm_train_config is not None:
|
|
if lm_train_config is not None:
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
|
|
+ lm, lm_train_args = build_model_from_file(
|
|
|
lm_train_config, lm_file, None, device
|
|
lm_train_config, lm_file, None, device
|
|
|
)
|
|
)
|
|
|
scorers["lm"] = lm.lm
|
|
scorers["lm"] = lm.lm
|
|
@@ -295,9 +292,8 @@ class Speech2TextParaformer:
|
|
|
|
|
|
|
|
# 1. Build ASR model
|
|
# 1. Build ASR model
|
|
|
scorers = {}
|
|
scorers = {}
|
|
|
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
|
|
|
|
- asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
- asr_train_config, asr_model_file, cmvn_file, device
|
|
|
|
|
|
|
+ asr_model, asr_train_args = build_model_from_file(
|
|
|
|
|
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
|
|
|
)
|
|
)
|
|
|
frontend = None
|
|
frontend = None
|
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
@@ -319,7 +315,7 @@ class Speech2TextParaformer:
|
|
|
|
|
|
|
|
# 2. Build Language model
|
|
# 2. Build Language model
|
|
|
if lm_train_config is not None:
|
|
if lm_train_config is not None:
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
|
|
+ lm, lm_train_args = build_model_from_file(
|
|
|
lm_train_config, lm_file, device
|
|
lm_train_config, lm_file, device
|
|
|
)
|
|
)
|
|
|
scorers["lm"] = lm.lm
|
|
scorers["lm"] = lm.lm
|
|
@@ -616,9 +612,8 @@ class Speech2TextParaformerOnline:
|
|
|
|
|
|
|
|
# 1. Build ASR model
|
|
# 1. Build ASR model
|
|
|
scorers = {}
|
|
scorers = {}
|
|
|
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
|
|
|
|
- asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
- asr_train_config, asr_model_file, cmvn_file, device
|
|
|
|
|
|
|
+ asr_model, asr_train_args = build_model_from_file(
|
|
|
|
|
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
|
|
|
)
|
|
)
|
|
|
frontend = None
|
|
frontend = None
|
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
@@ -640,7 +635,7 @@ class Speech2TextParaformerOnline:
|
|
|
|
|
|
|
|
# 2. Build Language model
|
|
# 2. Build Language model
|
|
|
if lm_train_config is not None:
|
|
if lm_train_config is not None:
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
|
|
+ lm, lm_train_args = build_model_from_file(
|
|
|
lm_train_config, lm_file, device
|
|
lm_train_config, lm_file, device
|
|
|
)
|
|
)
|
|
|
scorers["lm"] = lm.lm
|
|
scorers["lm"] = lm.lm
|
|
@@ -873,9 +868,8 @@ class Speech2TextUniASR:
|
|
|
|
|
|
|
|
# 1. Build ASR model
|
|
# 1. Build ASR model
|
|
|
scorers = {}
|
|
scorers = {}
|
|
|
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
|
|
|
|
- asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
- asr_train_config, asr_model_file, cmvn_file, device
|
|
|
|
|
|
|
+ asr_model, asr_train_args = build_model_from_file(
|
|
|
|
|
+ asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
|
|
|
)
|
|
)
|
|
|
frontend = None
|
|
frontend = None
|
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
|
|
@@ -901,8 +895,8 @@ class Speech2TextUniASR:
|
|
|
|
|
|
|
|
# 2. Build Language model
|
|
# 2. Build Language model
|
|
|
if lm_train_config is not None:
|
|
if lm_train_config is not None:
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
- lm_train_config, lm_file, device
|
|
|
|
|
|
|
+ lm, lm_train_args = build_model_from_file(
|
|
|
|
|
+ lm_train_config, lm_file, device, "lm"
|
|
|
)
|
|
)
|
|
|
scorers["lm"] = lm.lm
|
|
scorers["lm"] = lm.lm
|
|
|
|
|
|
|
@@ -1104,9 +1098,8 @@ class Speech2TextMFCCA:
|
|
|
assert check_argument_types()
|
|
assert check_argument_types()
|
|
|
|
|
|
|
|
# 1. Build ASR model
|
|
# 1. Build ASR model
|
|
|
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
|
|
|
|
scorers = {}
|
|
scorers = {}
|
|
|
- asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
|
|
+ asr_model, asr_train_args = build_model_from_file(
|
|
|
asr_train_config, asr_model_file, cmvn_file, device
|
|
asr_train_config, asr_model_file, cmvn_file, device
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -1126,7 +1119,7 @@ class Speech2TextMFCCA:
|
|
|
|
|
|
|
|
# 2. Build Language model
|
|
# 2. Build Language model
|
|
|
if lm_train_config is not None:
|
|
if lm_train_config is not None:
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
|
|
+ lm, lm_train_args = build_model_from_file(
|
|
|
lm_train_config, lm_file, device
|
|
lm_train_config, lm_file, device
|
|
|
)
|
|
)
|
|
|
lm.to(device)
|
|
lm.to(device)
|
|
@@ -1315,8 +1308,7 @@ class Speech2TextTransducer:
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert check_argument_types()
|
|
assert check_argument_types()
|
|
|
- from funasr.tasks.asr import ASRTransducerTask
|
|
|
|
|
- asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
|
|
|
|
|
|
|
+ asr_model, asr_train_args = build_model_from_file(
|
|
|
asr_train_config, asr_model_file, cmvn_file, device
|
|
asr_train_config, asr_model_file, cmvn_file, device
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -1350,7 +1342,7 @@ class Speech2TextTransducer:
|
|
|
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
|
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
|
|
|
|
|
|
|
if lm_train_config is not None:
|
|
if lm_train_config is not None:
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
|
|
+ lm, lm_train_args = build_model_from_file(
|
|
|
lm_train_config, lm_file, device
|
|
lm_train_config, lm_file, device
|
|
|
)
|
|
)
|
|
|
lm_scorer = lm.lm
|
|
lm_scorer = lm.lm
|
|
@@ -1638,9 +1630,8 @@ class Speech2TextSAASR:
|
|
|
assert check_argument_types()
|
|
assert check_argument_types()
|
|
|
|
|
|
|
|
# 1. Build ASR model
|
|
# 1. Build ASR model
|
|
|
- from funasr.tasks.sa_asr import ASRTask
|
|
|
|
|
scorers = {}
|
|
scorers = {}
|
|
|
- asr_model, asr_train_args = ASRTask.build_model_from_file(
|
|
|
|
|
|
|
+ asr_model, asr_train_args = build_model_from_file(
|
|
|
asr_train_config, asr_model_file, cmvn_file, device
|
|
asr_train_config, asr_model_file, cmvn_file, device
|
|
|
)
|
|
)
|
|
|
frontend = None
|
|
frontend = None
|
|
@@ -1667,7 +1658,7 @@ class Speech2TextSAASR:
|
|
|
|
|
|
|
|
# 2. Build Language model
|
|
# 2. Build Language model
|
|
|
if lm_train_config is not None:
|
|
if lm_train_config is not None:
|
|
|
- lm, lm_train_args = LMTask.build_model_from_file(
|
|
|
|
|
|
|
+ lm, lm_train_args = build_model_from_file(
|
|
|
lm_train_config, lm_file, None, device
|
|
lm_train_config, lm_file, None, device
|
|
|
)
|
|
)
|
|
|
scorers["lm"] = lm.lm
|
|
scorers["lm"] = lm.lm
|