Răsfoiți Sursa

[Quantization] automatic mixed precision quantization

wanchen.swc 3 ani în urmă
părinte
comite
8a620a5a36
2 a modificat fișierele cu 53 adăugiri și 30 ștergeri
  1. 17 9
      funasr/export/README.md
  2. 36 21
      funasr/export/export_model.py

+ 17 - 9
funasr/export/README.md

@@ -11,35 +11,43 @@ The installation is the same as [funasr](../../README.md)
    `Tips`: torch>=1.11.0
 
    ```shell
-   python -m funasr.export.export_model [model_name] [export_dir] [onnx] [quant]
+   python -m funasr.export.export_model \
+       --model-name [model_name] \
+       --export-dir [export_dir] \
+       --type [onnx, torch] \
+       --quantize \
+       --fallback-num [fallback_num]
    ```
-   `model_name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb). 
+   `model-name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb).
 
-   `export_dir`: the dir where the onnx is export.
+   `export-dir`: the dir where the onnx is export.
 
-   `onnx`: `true`, export onnx format model; `false`, export torchscripts format model.
+   `type`: `onnx` or `torch`, export onnx format model or torchscript format model.
+
+   `quantize`: `true`, export quantized model at the same time; `false`, export fp32 model only.
+
+   `fallback-num`: specify the number of fallback layers to perform automatic mixed precision quantization.
 
-   `quant`: `true`, export quantized model at the same time; `false`, export fp32 model only.
 
 ## For example
 ### Export onnx format model
 Export model from modelscope
 ```shell
-python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false
+python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx
 ```
 Export model from local path, the model'name must be `model.pb`.
 ```shell
-python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true false
+python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx
 ```
 
 ### Export torchscripts format model
 Export model from modelscope
 ```shell
-python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false
+python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch
 ```
 
 Export model from local path, the model'name must be `model.pb`.
 ```shell
-python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false false
+python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch
 ```
 

+ 36 - 21
funasr/export/export_model.py

@@ -16,7 +16,11 @@ import random
 
 class ASRModelExportParaformer:
     def __init__(
-        self, cache_dir: Union[Path, str] = None, onnx: bool = True, quant: bool = True
+        self,
+        cache_dir: Union[Path, str] = None,
+        onnx: bool = True,
+        quant: bool = True,
+        fallback_num: int = 0,
     ):
         assert check_argument_types()
         self.set_all_random_seed(0)
@@ -31,6 +35,7 @@ class ASRModelExportParaformer:
         print("output dir: {}".format(self.cache_dir))
         self.onnx = onnx
         self.quant = quant
+        self.fallback_num = fallback_num
         
 
     def _export(
@@ -60,8 +65,12 @@ class ASRModelExportParaformer:
 
 
     def _torch_quantize(self, model):
+        def _run_calibration_data(m):
+            # using dummy inputs for a example
+            dummy_input = model.get_dummy_inputs()
+            m(*dummy_input)
+
         from torch_quant.module import ModuleFilter
-        from torch_quant.observer import HistogramObserver
         from torch_quant.quantizer import Backend, Quantizer
         from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
         from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
@@ -70,17 +79,21 @@ class ASRModelExportParaformer:
         quantizer = Quantizer(
             module_filter=module_filter,
             backend=Backend.FBGEMM,
-            act_ob_ctr=HistogramObserver,
         )
         model.eval()
         calib_model = quantizer.calib(model)
-        # run calibration data
-        # using dummy inputs for a example
-        dummy_input = model.get_dummy_inputs()
-        _ = calib_model(*dummy_input)
+        _run_calibration_data(calib_model)
+        if self.fallback_num > 0:
+            # perform automatic mixed precision quantization
+            amp_model = quantizer.amp(model)
+            _run_calibration_data(amp_model)
+            quantizer.fallback(amp_model, num=self.fallback_num)
+            print('Fallback layers:')
+            print('\n'.join(quantizer.module_filter.exclude_names))
         quant_model = quantizer.quantize(model)
         return quant_model
 
+
     def _export_torchscripts(self, model, verbose, path, enc_size=None):
         if enc_size:
             dummy_input = model.get_dummy_inputs(enc_size)
@@ -170,17 +183,19 @@ class ASRModelExportParaformer:
 
 
 if __name__ == '__main__':
-    import sys
-    
-    model_path = sys.argv[1]
-    output_dir = sys.argv[2]
-    onnx = sys.argv[3]
-    quant = sys.argv[4]
-    onnx = onnx.lower()
-    onnx = onnx == 'true'
-    quant = quant == 'true'
-    # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
-    # output_dir = "../export"
-    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx, quant=quant)
-    export_model.export(model_path)
-    # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model-name', type=str, required=True)
+    parser.add_argument('--export-dir', type=str, required=True)
+    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+    parser.add_argument('--quantize', action='store_true', help='export quantized model')
+    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+    args = parser.parse_args()
+
+    export_model = ASRModelExportParaformer(
+        cache_dir=args.export_dir,
+        onnx=args.type == 'onnx',
+        quant=args.quantize,
+        fallback_num=args.fallback_num,
+    )
+    export_model.export(args.model_name)