Explorar el Código

[Optimization] fix bug

wanchen.swc hace 2 años
padre
commit
aafbf57cf9

+ 7 - 3
funasr/runtime/triton_gpu/launch_service.sh

@@ -1,3 +1,7 @@
+if [ $# != 2 ]; then
+    echo "./launch_service.sh [model_type] [instance_num]"
+    exit
+fi
 model_type=$1
 instance_num=$2
 
@@ -18,9 +22,9 @@ fi
 rm -f $model_repo/encoder/1/$model_name
 rm -f $model_repo/feature_extractor/am.mvn
 rm -f $model_repo/feature_extractor/config.yaml
-ln -s `realpath export_dir/$model_type/$model_name` $model_repo/encoder/1/
-ln -s `realpath export_dir/$model_type/am.mvn` $model_repo/feature_extractor/
-ln -s `realpath export_dir/$model_type/config.yaml` $model_repo/feature_extractor/
+ln -s `realpath export_dir/$model_type/$model_name` $model_repo/encoder/1/$model_name
+ln -s `realpath export_dir/$model_type/am.mvn` $model_repo/feature_extractor/am.mvn
+ln -s `realpath export_dir/$model_type/config.yaml` $model_repo/feature_extractor/config.yaml
 
 config_file=$model_repo/encoder/config.pbtxt
 cp $config_file config.pbtxt

+ 3 - 1
funasr/runtime/triton_gpu/model_repo_paraformer_torchscritpts/encoder/1/model.py

@@ -99,7 +99,8 @@ class TritonPythonModel:
         feats_len = torch.tensor(speech_len, dtype=torch.int32).to(self.device)
 
         with torch.no_grad():
-            logits = self.model(feats, feats_len)[0]
+            outputs = self.model(feats, feats_len)
+        logits, token_num = outputs[0], outputs[1]
 
         def replace_space(tokens):
             return [i if i != '<space>' else ' ' for i in tokens]
@@ -107,6 +108,7 @@ class TritonPythonModel:
         yseq = logits.argmax(axis=-1).tolist()
         token_int = [list(filter(lambda x: x not in (0, 2), y)) for y in yseq]
         tokens = [[self.vocab_dict[i] for i in t] for t in token_int]
+        tokens = [t[:int(token_num[i]) - 1] for i, t in enumerate(tokens)]
         hyps = [''.join(replace_space(t)).encode('utf-8') for t in tokens]
         responses = []
         for i in range(len(requests)):