Browse Source

torchscripts

游雁 3 years ago
parent
commit
548153260b

+ 1 - 1
funasr/export/test_torchscripts.py

@@ -2,7 +2,7 @@ import torch
 import numpy as np
 import numpy as np
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-	onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
+	onnx_path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
 	loaded = torch.jit.load(onnx_path)
 	loaded = torch.jit.load(onnx_path)
 	
 	
 	x = torch.rand([2, 21, 560])
 	x = torch.rand([2, 21, 560])

+ 2 - 2
funasr/runtime/python/libtorch/setup.py

@@ -1,7 +1,7 @@
 # -*- encoding: utf-8 -*-
 # -*- encoding: utf-8 -*-
 from pathlib import Path
 from pathlib import Path
 import setuptools
 import setuptools
-
+from setuptools import find_packages
 
 
 def get_readme():
 def get_readme():
     root_dir = Path(__file__).resolve().parent
     root_dir = Path(__file__).resolve().parent
@@ -29,7 +29,7 @@ setuptools.setup(
                       "scipy", "numpy>=1.19.3",
                       "scipy", "numpy>=1.19.3",
                       "typeguard", "kaldi-native-fbank",
                       "typeguard", "kaldi-native-fbank",
                       "PyYAML>=5.1.2"],
                       "PyYAML>=5.1.2"],
-    packages=['torch_paraformer'],
+    packages=find_packages(include=["torch_paraformer*"]),
     keywords=[
     keywords=[
         'funasr,paraformer'
         'funasr,paraformer'
     ],
     ],

+ 6 - 5
funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py

@@ -27,7 +27,7 @@ class Paraformer():
         if not Path(model_dir).exists():
         if not Path(model_dir).exists():
             raise FileNotFoundError(f'{model_dir} does not exist.')
             raise FileNotFoundError(f'{model_dir} does not exist.')
 
 
-        model_file = os.path.join(model_dir, 'model.onnx')
+        model_file = os.path.join(model_dir, 'model.torchscripts')
         config_file = os.path.join(model_dir, 'config.yaml')
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
         config = read_yaml(config_file)
         config = read_yaml(config_file)
@@ -52,9 +52,8 @@ class Paraformer():
             feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
             feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
 
 
             try:
             try:
-                outputs = self.infer(feats, feats_len)
-                outs = outputs[0], outputs[1]
-                am_scores, valid_token_lens = outs[0], outs[1]
+                outputs = self.ort_infer(feats, feats_len)
+                am_scores, valid_token_lens = outputs[0], outputs[1]
                 if len(outputs) == 4:
                 if len(outputs) == 4:
                     # for BiCifParaformer Inference
                     # for BiCifParaformer Inference
                     us_alphas, us_cif_peak = outputs[2], outputs[3]
                     us_alphas, us_cif_peak = outputs[2], outputs[3]
@@ -65,7 +64,7 @@ class Paraformer():
                 logging.warning("input wav is silence or noise")
                 logging.warning("input wav is silence or noise")
                 preds = ['']
                 preds = ['']
             else:
             else:
-                am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy()
+                am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy()
                 preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
                 preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
                 res['preds'] = preds
                 res['preds'] = preds
                 if us_cif_peak is not None:
                 if us_cif_peak is not None:
@@ -105,6 +104,8 @@ class Paraformer():
 
 
         feats = self.pad_feats(feats, np.max(feats_len))
         feats = self.pad_feats(feats, np.max(feats_len))
         feats_len = np.array(feats_len).astype(np.int32)
         feats_len = np.array(feats_len).astype(np.int32)
+        feats = torch.from_numpy(feats).type(torch.float32)
+        feats_len = torch.from_numpy(feats_len).type(torch.int32)
         return feats, feats_len
         return feats, feats_len
 
 
     @staticmethod
     @staticmethod