Răsfoiți Sursa

Merge branch 'dev_cmz2' of github.com:alibaba-damo-academy/FunASR into dev_cmz2
add

游雁 3 ani în urmă
părinte
comite
c5acc04e2d

+ 2 - 2
funasr/runtime/python/onnxruntime/demo_punc_offline.py

@@ -4,6 +4,6 @@ model_dir = "/disk1/mengzhe.cmz/workspace/FunASR/funasr/export/damo/punc_ct-tran
 model = TargetDelayTransformer(model_dir)
 model = TargetDelayTransformer(model_dir)
 
 
 text_in = "我们都是木头人不会讲话不会动"
 text_in = "我们都是木头人不会讲话不会动"
-
+text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"
 result = model(text_in)
 result = model(text_in)
-print(result)
+print(result[0])

+ 3 - 5
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py

@@ -76,9 +76,8 @@ class TargetDelayTransformer():
             try:
             try:
                 outputs = self.infer(data['text'], data['text_lengths'])
                 outputs = self.infer(data['text'], data['text_lengths'])
                 y = outputs[0]
                 y = outputs[0]
-                _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-                punctuations = indices
-                assert punctuations.size()[0] == len(mini_sentence)
+                punctuations = np.argmax(y,axis=-1)[0]
+                assert punctuations.size == len(mini_sentence)
             except ONNXRuntimeError:
             except ONNXRuntimeError:
                 logging.warning("error")
                 logging.warning("error")
 
 
@@ -102,8 +101,7 @@ class TargetDelayTransformer():
                 mini_sentence = mini_sentence[0:sentenceEnd + 1]
                 mini_sentence = mini_sentence[0:sentenceEnd + 1]
                 punctuations = punctuations[0:sentenceEnd + 1]
                 punctuations = punctuations[0:sentenceEnd + 1]
 
 
-            punctuations_np = punctuations.cpu().numpy()
-            new_mini_sentence_punc += [int(x) for x in punctuations_np]
+            new_mini_sentence_punc += [int(x) for x in punctuations]
             words_with_punc = []
             words_with_punc = []
             for i in range(len(mini_sentence)):
             for i in range(len(mini_sentence)):
                 if i > 0:
                 if i > 0: