Browse Source

Merge pull request #357 from alibaba-damo-academy/dev_lyb

modify paraformer onnx init
zhifu gao 2 năm trước cách đây
mục cha
commit
d7440147aa

+ 13 - 12
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp

@@ -4,7 +4,7 @@ using namespace std;
 using namespace paraformer;
 using namespace paraformer;
 
 
 ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
 ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
-{
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
     string model_path;
     string model_path;
     string cmvn_path;
     string cmvn_path;
     string config_path;
     string config_path;
@@ -29,20 +29,20 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
 
 
 #ifdef _WIN32
 #ifdef _WIN32
     wstring wstrPath = strToWstr(model_path);
     wstring wstrPath = strToWstr(model_path);
-    m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #else
 #else
-    m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #endif
 #endif
 
 
     string strName;
     string strName;
-    getInputName(m_session, strName);
+    getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
     m_strInputNames.push_back(strName.c_str());
-    getInputName(m_session, strName,1);
+    getInputName(m_session.get(), strName,1);
     m_strInputNames.push_back(strName);
     m_strInputNames.push_back(strName);
     
     
-    getOutputName(m_session, strName);
+    getOutputName(m_session.get(), strName);
     m_strOutputNames.push_back(strName);
     m_strOutputNames.push_back(strName);
-    getOutputName(m_session, strName,1);
+    getOutputName(m_session.get(), strName,1);
     m_strOutputNames.push_back(strName);
     m_strOutputNames.push_back(strName);
 
 
     for (auto& item : m_strInputNames)
     for (auto& item : m_strInputNames)
@@ -55,11 +55,6 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
 
 
 ModelImp::~ModelImp()
 ModelImp::~ModelImp()
 {
 {
-    if (m_session)
-    {
-        delete m_session;
-        m_session = nullptr;
-    }
     if(vocab)
     if(vocab)
         delete vocab;
         delete vocab;
     fftwf_free(fft_input);
     fftwf_free(fft_input);
@@ -172,6 +167,12 @@ string ModelImp::forward(float* din, int len, int flag)
     apply_cmvn(in);
     apply_cmvn(in);
     Ort::RunOptions run_option;
     Ort::RunOptions run_option;
 
 
+#ifdef _WIN_X86
+        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
+
     std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
     std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
     Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
     Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
         in->buff,
         in->buff,

+ 3 - 9
funasr/runtime/onnxruntime/src/paraformer_onnx.h

@@ -24,15 +24,9 @@ namespace paraformer {
 
 
         string greedy_search( float* in, int nLen);
         string greedy_search( float* in, int nLen);
 
 
-#ifdef _WIN_X86
-        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
-#else
-        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
-#endif
-
-        Ort::Session* m_session = nullptr;
-        Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
-        Ort::SessionOptions sessionOptions = Ort::SessionOptions();
+        std::unique_ptr<Ort::Session> m_session;
+        Ort::Env env_;
+        Ort::SessionOptions sessionOptions;
 
 
         vector<string> m_strInputNames, m_strOutputNames;
         vector<string> m_strInputNames, m_strOutputNames;
         vector<const char*> m_szInputNames;
         vector<const char*> m_szInputNames;