Browse Source

Merge pull request #475 from alibaba-damo-academy/dev_cmz

onnx runtime model int32
zhifu gao 2 years ago
parent
commit
98223a3b59
1 changed files with 2 additions and 2 deletions
  1. 2 2
      funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py

+ 2 - 2
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py

@@ -64,7 +64,7 @@ class CT_Transformer():
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
             mini_sentence = cache_sent + mini_sentence
-            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
+            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int32')
             data = {
                 "text": mini_sentence_id[None,:],
                 "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
@@ -166,7 +166,7 @@ class CT_Transformer_VadRealtime(CT_Transformer):
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
             mini_sentence = cache_sent + mini_sentence
-            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
             text_length = len(mini_sentence_id)
             data = {
                 "input": mini_sentence_id[None,:],