|
@@ -27,7 +27,7 @@ class Paraformer():
|
|
|
if not Path(model_dir).exists():
|
|
if not Path(model_dir).exists():
|
|
|
raise FileNotFoundError(f'{model_dir} does not exist.')
|
|
raise FileNotFoundError(f'{model_dir} does not exist.')
|
|
|
|
|
|
|
|
- model_file = os.path.join(model_dir, 'model.onnx')
|
|
|
|
|
|
|
+ model_file = os.path.join(model_dir, 'model.torchscripts')
|
|
|
config_file = os.path.join(model_dir, 'config.yaml')
|
|
config_file = os.path.join(model_dir, 'config.yaml')
|
|
|
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
|
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
|
|
config = read_yaml(config_file)
|
|
config = read_yaml(config_file)
|
|
@@ -52,9 +52,8 @@ class Paraformer():
|
|
|
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
|
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
- outputs = self.infer(feats, feats_len)
|
|
|
|
|
- outs = outputs[0], outputs[1]
|
|
|
|
|
- am_scores, valid_token_lens = outs[0], outs[1]
|
|
|
|
|
|
|
+ outputs = self.ort_infer(feats, feats_len)
|
|
|
|
|
+ am_scores, valid_token_lens = outputs[0], outputs[1]
|
|
|
if len(outputs) == 4:
|
|
if len(outputs) == 4:
|
|
|
# for BiCifParaformer Inference
|
|
# for BiCifParaformer Inference
|
|
|
us_alphas, us_cif_peak = outputs[2], outputs[3]
|
|
us_alphas, us_cif_peak = outputs[2], outputs[3]
|
|
@@ -65,7 +64,7 @@ class Paraformer():
|
|
|
logging.warning("input wav is silence or noise")
|
|
logging.warning("input wav is silence or noise")
|
|
|
preds = ['']
|
|
preds = ['']
|
|
|
else:
|
|
else:
|
|
|
- am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy()
|
|
|
|
|
|
|
+ am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy()
|
|
|
preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
|
|
preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
|
|
|
res['preds'] = preds
|
|
res['preds'] = preds
|
|
|
if us_cif_peak is not None:
|
|
if us_cif_peak is not None:
|
|
@@ -105,6 +104,8 @@ class Paraformer():
|
|
|
|
|
|
|
|
feats = self.pad_feats(feats, np.max(feats_len))
|
|
feats = self.pad_feats(feats, np.max(feats_len))
|
|
|
feats_len = np.array(feats_len).astype(np.int32)
|
|
feats_len = np.array(feats_len).astype(np.int32)
|
|
|
|
|
+ feats = torch.from_numpy(feats).type(torch.float32)
|
|
|
|
|
+ feats_len = torch.from_numpy(feats_len).type(torch.int32)
|
|
|
return feats, feats_len
|
|
return feats, feats_len
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|