嘉渊 2 سال پیش
والد
کامیت
94436935bd
1فایلهای تغییر یافته به همراه8 افزوده شده و 18 حذف شده
  1. 8 18
      funasr/bin/asr_inference_launch.py

+ 8 - 18
funasr/bin/asr_inference_launch.py

@@ -35,8 +35,6 @@ from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
 from funasr.fileio.datadir_writer import DatadirWriter
 from funasr.modules.beam_search.beam_search import Hypothesis
 from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.vad import VADTask
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import asr_utils, postprocess_utils
@@ -1360,20 +1358,14 @@ def inference_transducer(
                  **kwargs,
                  ):
         # 3. Build data-iterator
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(
-                speech2text.asr_train_args, False
-            ),
-            collate_fn=ASRTask.build_collate_fn(
-                speech2text.asr_train_args, False
-            ),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         # 4 .Start for-loop
@@ -1520,18 +1512,16 @@ def inference_sa_asr(
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
             mc=mc,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         finish_count = 0