Quellcode durchsuchen

onnx export funasr_onnx

游雁 vor 2 Jahren
Ursprung
Commit
d25d0942f9

+ 7 - 6
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py

@@ -32,7 +32,7 @@ class Paraformer():
                  plot_timestamp_to: str = "",
                  quantize: bool = False,
                  intra_op_num_threads: int = 4,
-                 cache_dir=None
+                 cache_dir: str = None
                  ):
 
         if not Path(model_dir).exists():
@@ -41,6 +41,12 @@ class Paraformer():
                 model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
             except:
                 raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
+        
+        model_file = os.path.join(model_dir, 'model.onnx')
+        if quantize:
+            model_file = os.path.join(model_dir, 'model_quant.onnx')
+        if not os.path.exists(model_file):
+            print(".onnx is not exist, begin to export onnx")
             from funasr.export.export_model import ModelExport
             export_model = ModelExport(
                 cache_dir=cache_dir,
@@ -50,11 +56,6 @@ class Paraformer():
             )
             export_model.export(model_dir)
             
-            
-
-        model_file = os.path.join(model_dir, 'model.onnx')
-        if quantize:
-            model_file = os.path.join(model_dir, 'model_quant.onnx')
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
         config = read_yaml(config_file)

+ 24 - 6
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py

@@ -24,15 +24,32 @@ class CT_Transformer():
                  batch_size: int = 1,
                  device_id: Union[str, int] = "-1",
                  quantize: bool = False,
-                 intra_op_num_threads: int = 4
+                 intra_op_num_threads: int = 4,
+                 cache_dir: str = None,
                  ):
-
+    
         if not Path(model_dir).exists():
-            raise FileNotFoundError(f'{model_dir} does not exist.')
-
+            from modelscope.hub.snapshot_download import snapshot_download
+            try:
+                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+            except:
+                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+                    model_dir)
+    
         model_file = os.path.join(model_dir, 'model.onnx')
         if quantize:
             model_file = os.path.join(model_dir, 'model_quant.onnx')
+        if not os.path.exists(model_file):
+            print(".onnx is not exist, begin to export onnx")
+            from funasr.export.export_model import ModelExport
+            export_model = ModelExport(
+                cache_dir=cache_dir,
+                onnx=True,
+                device="cpu",
+                quant=quantize,
+            )
+            export_model.export(model_dir)
+            
         config_file = os.path.join(model_dir, 'punc.yaml')
         config = read_yaml(config_file)
 
@@ -135,9 +152,10 @@ class CT_Transformer_VadRealtime(CT_Transformer):
                  batch_size: int = 1,
                  device_id: Union[str, int] = "-1",
                  quantize: bool = False,
-                 intra_op_num_threads: int = 4
+                 intra_op_num_threads: int = 4,
+                 cache_dir: str = None
                  ):
-        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
+        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
 
     def __call__(self, text: str, param_dict: map, split_size=20):
         cache_key = "cache"

+ 1 - 0
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py

@@ -271,4 +271,5 @@ def get_logger(name='funasr_onnx'):
     logger.addHandler(sh)
     logger_initialized[name] = True
     logger.propagate = False
+    logging.basicConfig(level=logging.ERROR)
     return logger

+ 34 - 3
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py

@@ -31,14 +31,30 @@ class Fsmn_vad():
 	             quantize: bool = False,
 	             intra_op_num_threads: int = 4,
 	             max_end_sil: int = None,
+	             cache_dir: str = None
 	             ):
 		
 		if not Path(model_dir).exists():
-			raise FileNotFoundError(f'{model_dir} does not exist.')
+			from modelscope.hub.snapshot_download import snapshot_download
+			try:
+				model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+			except:
+				raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+					model_dir)
 		
 		model_file = os.path.join(model_dir, 'model.onnx')
 		if quantize:
 			model_file = os.path.join(model_dir, 'model_quant.onnx')
+		if not os.path.exists(model_file):
+			print(".onnx is not exist, begin to export onnx")
+			from funasr.export.export_model import ModelExport
+			export_model = ModelExport(
+				cache_dir=cache_dir,
+				onnx=True,
+				device="cpu",
+				quant=quantize,
+			)
+			export_model.export(model_dir)
 		config_file = os.path.join(model_dir, 'vad.yaml')
 		cmvn_file = os.path.join(model_dir, 'vad.mvn')
 		config = read_yaml(config_file)
@@ -172,14 +188,29 @@ class Fsmn_vad_online():
 	             quantize: bool = False,
 	             intra_op_num_threads: int = 4,
 	             max_end_sil: int = None,
+	             cache_dir: str = None
 	             ):
-		
 		if not Path(model_dir).exists():
-			raise FileNotFoundError(f'{model_dir} does not exist.')
+			from modelscope.hub.snapshot_download import snapshot_download
+			try:
+				model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+			except:
+				raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+					model_dir)
 		
 		model_file = os.path.join(model_dir, 'model.onnx')
 		if quantize:
 			model_file = os.path.join(model_dir, 'model_quant.onnx')
+		if not os.path.exists(model_file):
+			print(".onnx is not exist, begin to export onnx")
+			from funasr.export.export_model import ModelExport
+			export_model = ModelExport(
+				cache_dir=cache_dir,
+				onnx=True,
+				device="cpu",
+				quant=quantize,
+			)
+			export_model.export(model_dir)
 		config_file = os.path.join(model_dir, 'vad.yaml')
 		cmvn_file = os.path.join(model_dir, 'vad.mvn')
 		config = read_yaml(config_file)