Explorar o código

Merge pull request #189 from yuekaizhang/token_list

[Triton] Read token list from config.yaml
zhifu gao %!s(int64=3) %!d(string=hai) anos
pai
achega
659ad8f48b

+ 3 - 4
funasr/runtime/triton_gpu/README.md

@@ -8,8 +8,8 @@ git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-1
 
 
 pretrained_model_dir=$(pwd)/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
 pretrained_model_dir=$(pwd)/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
 
 
-cp $pretrained_model_dir/tokens.txt ./model_repo_paraformer_large_offline/scoring/
 cp $pretrained_model_dir/am.mvn ./model_repo_paraformer_large_offline/feature_extractor/
 cp $pretrained_model_dir/am.mvn ./model_repo_paraformer_large_offline/feature_extractor/
+cp $pretrained_model_dir/config.yaml ./model_repo_paraformer_large_offline/feature_extractor/
 
 
 # Refer here to get model.onnx (https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/export/README.md)
 # Refer here to get model.onnx (https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/export/README.md)
 cp <exported_onnx_dir>/model.onnx ./model_repo_paraformer_large_offline/encoder/1/
 cp <exported_onnx_dir>/model.onnx ./model_repo_paraformer_large_offline/encoder/1/
@@ -33,10 +33,9 @@ model_repo_paraformer_large_offline/
 `-- scoring
 `-- scoring
     |-- 1
     |-- 1
     |   `-- model.py
     |   `-- model.py
-    |-- config.pbtxt
-    `-- tokens.txt
+    `-- config.pbtxt
 
 
-8 directories, 10 files
+8 directories, 9 files
 ```
 ```
 
 
 2. Follow below instructions to launch triton server
 2. Follow below instructions to launch triton server

+ 9 - 7
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py

@@ -229,22 +229,24 @@ class TritonPythonModel:
             if key == "config_path":
             if key == "config_path":
                 with open(str(value), 'rb') as f:
                 with open(str(value), 'rb') as f:
                     config = yaml.load(f, Loader=yaml.Loader)
                     config = yaml.load(f, Loader=yaml.Loader)
+            if key == "cmvn_path":
+                cmvn_path = str(value)
 
 
         opts = kaldifeat.FbankOptions()
         opts = kaldifeat.FbankOptions()
         opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
         opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
-        opts.frame_opts.window_type = config['WavFrontend']['frontend_conf']['window']
-        opts.mel_opts.num_bins = int(config['WavFrontend']['frontend_conf']['n_mels'])
-        opts.frame_opts.frame_shift_ms = float(config['WavFrontend']['frontend_conf']['frame_shift'])
-        opts.frame_opts.frame_length_ms = float(config['WavFrontend']['frontend_conf']['frame_length'])
-        opts.frame_opts.samp_freq = int(config['WavFrontend']['frontend_conf']['fs'])
+        opts.frame_opts.window_type = config['frontend_conf']['window']
+        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
+        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
+        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
+        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
         opts.device = torch.device(self.device)
         opts.device = torch.device(self.device)
         self.opts = opts
         self.opts = opts
         self.feature_extractor = Fbank(self.opts)
         self.feature_extractor = Fbank(self.opts)
         self.feature_size = opts.mel_opts.num_bins
         self.feature_size = opts.mel_opts.num_bins
 
 
         self.frontend = WavFrontend(
         self.frontend = WavFrontend(
-            cmvn_file=config['WavFrontend']['cmvn_file'],
-            **config['WavFrontend']['frontend_conf'])
+            cmvn_file=cmvn_path,
+            **config['frontend_conf'])
 
 
     def extract_feat(self,
     def extract_feat(self,
                      waveform_list: List[np.ndarray]
                      waveform_list: List[np.ndarray]

+ 4 - 0
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.pbtxt

@@ -33,6 +33,10 @@ parameters [
     key: "sample_rate"
     key: "sample_rate"
     value: { string_value: "16000"}
     value: { string_value: "16000"}
   },
   },
+  {
+    key: "cmvn_path"
+    value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/am.mvn"}
+  },
   {
   {
     key: "config_path"
     key: "config_path"
     value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}
     value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}

+ 0 - 11
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.yaml

@@ -1,11 +0,0 @@
-WavFrontend:
-  cmvn_file: ./model_repo_paraformer_large_offline/feature_extractor/am.mvn
-  frontend_conf:
-    fs: 16000
-    window: hamming
-    n_mels: 80
-    frame_length: 25
-    frame_shift: 10
-    lfr_m: 7
-    lfr_n: 6
-    filter_length_max: -.inf

+ 4 - 3
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/1/model.py

@@ -21,6 +21,7 @@ from torch.utils.dlpack import from_dlpack
 
 
 import json
 import json
 import os
 import os
+import yaml
 
 
 class TritonPythonModel:
 class TritonPythonModel:
     """Your Python model must use the same class name. Every Python model
     """Your Python model must use the same class name. Every Python model
@@ -73,9 +74,9 @@ class TritonPythonModel:
         """
         """
         load lang_char.txt
         load lang_char.txt
         """
         """
-        with open(str(vocab_file), 'r') as f:
-            token_list = [line.strip() for line in f]
-        return token_list
+        with open(str(vocab_file), 'rb') as f:
+            config = yaml.load(f, Loader=yaml.Loader)
+        return config['token_list']
 
 
     def execute(self, requests):
     def execute(self, requests):
         """`execute` must be implemented in every Python model. `execute`
         """`execute` must be implemented in every Python model. `execute`

+ 1 - 1
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt

@@ -23,7 +23,7 @@ parameters [
   },
   },
   {
   {
     key: "vocabulary",
     key: "vocabulary",
-    value: { string_value: "./model_repo_paraformer_large_offline/scoring/tokens.txt"}
+    value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}
   },
   },
   {
   {
     key: "lm_path"
     key: "lm_path"