九耳 3 yıl önce
ebeveyn
işleme
7df8452a85

+ 2 - 3
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py

@@ -76,9 +76,8 @@ class TargetDelayTransformer():
             try:
                 outputs = self.infer(data['text'], data['text_lengths'])
                 y = outputs[0]
-                _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-                punctuations = indices
-                assert punctuations.size()[0] == len(mini_sentence)
+                punctuations = np.argmax(y,axis=-1)[0]
+                assert punctuations.size == len(mini_sentence)
             except ONNXRuntimeError:
                 logging.warning("error")