|
|
@@ -236,6 +236,7 @@ def inference_paraformer(
|
|
|
timestamp_infer_config: Union[Path, str] = None,
|
|
|
timestamp_model_file: Union[Path, str] = None,
|
|
|
param_dict: dict = None,
|
|
|
+ decoding_ind: int = 0,
|
|
|
**kwargs,
|
|
|
):
|
|
|
ncpu = kwargs.get("ncpu", 1)
|
|
|
@@ -290,6 +291,7 @@ def inference_paraformer(
|
|
|
nbest=nbest,
|
|
|
hotword_list_or_file=hotword_list_or_file,
|
|
|
clas_scale=clas_scale,
|
|
|
+ decoding_ind=decoding_ind,
|
|
|
)
|
|
|
|
|
|
speech2text = Speech2TextParaformer(**speech2text_kwargs)
|
|
|
@@ -312,6 +314,7 @@ def inference_paraformer(
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
+ decoding_ind = None
|
|
|
hotword_list_or_file = None
|
|
|
if param_dict is not None:
|
|
|
hotword_list_or_file = param_dict.get('hotword')
|
|
|
@@ -319,6 +322,8 @@ def inference_paraformer(
|
|
|
hotword_list_or_file = kwargs['hotword']
|
|
|
if hotword_list_or_file is not None or 'hotword' in kwargs:
|
|
|
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
|
|
+ if param_dict is not None and "decoding_ind" in param_dict:
|
|
|
+ decoding_ind = param_dict["decoding_ind"]
|
|
|
|
|
|
# 3. Build data-iterator
|
|
|
if data_path_and_name_and_type is None and raw_inputs is not None:
|
|
|
@@ -365,6 +370,7 @@ def inference_paraformer(
|
|
|
# N-best list of (text, token, token_int, hyp_object)
|
|
|
|
|
|
time_beg = time.time()
|
|
|
+ batch["decoding_ind"] = decoding_ind
|
|
|
results = speech2text(**batch)
|
|
|
if len(results) < 1:
|
|
|
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
|
|
@@ -1786,6 +1792,12 @@ def get_parser():
|
|
|
default=1,
|
|
|
help="The batch size for inference",
|
|
|
)
|
|
|
+ group.add_argument(
|
|
|
+ "--decoding_ind",
|
|
|
+ type=int,
|
|
|
+ default=0,
|
|
|
+ help="chunk select for chunk encoder",
|
|
|
+ )
|
|
|
group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
|
|
|
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
|
|
|
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
|