|
|
@@ -14,7 +14,8 @@ import numpy as np
|
|
|
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
|
|
|
OrtInferSession, TokenIDConverter, get_logger,
|
|
|
read_yaml)
|
|
|
-from .utils.postprocess_utils import sentence_postprocess
|
|
|
+from .utils.postprocess_utils import (sentence_postprocess,
|
|
|
+ sentence_postprocess_sentencepiece)
|
|
|
from .utils.frontend import WavFrontend
|
|
|
from .utils.timestamp_utils import time_stamp_lfr6_onnx
|
|
|
from .utils.utils import pad_list, make_pad_mask
|
|
|
@@ -86,6 +87,10 @@ class Paraformer():
|
|
|
self.pred_bias = config['model_conf']['predictor_bias']
|
|
|
else:
|
|
|
self.pred_bias = 0
|
|
|
+ if "lang" in config:
|
|
|
+ self.language = config['lang']
|
|
|
+ else:
|
|
|
+ self.language = None
|
|
|
|
|
|
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
|
|
|
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
|
|
@@ -111,7 +116,10 @@ class Paraformer():
|
|
|
preds = self.decode(am_scores, valid_token_lens)
|
|
|
if us_peaks is None:
|
|
|
for pred in preds:
|
|
|
- pred = sentence_postprocess(pred)
|
|
|
+ if self.language == "en-bpe":
|
|
|
+ pred = sentence_postprocess_sentencepiece(pred)
|
|
|
+ else:
|
|
|
+ pred = sentence_postprocess(pred)
|
|
|
asr_res.append({'preds': pred})
|
|
|
else:
|
|
|
for pred, us_peaks_ in zip(preds, us_peaks):
|