speech_asr vor 3 Jahren
Ursprung
Commit
6165c13918
1 geänderte Dateien mit 23 neuen und 15 gelöschten Zeilen
  1. 23 15
      funasr/models/frontend/wav_frontend.py

+ 23 - 15
funasr/models/frontend/wav_frontend.py

@@ -1,15 +1,15 @@
 # Copyright (c) Alibaba, Inc. and its affiliates.
 # Copyright (c) Alibaba, Inc. and its affiliates.
 # Part of the implementation is borrowed from espnet/espnet.
 # Part of the implementation is borrowed from espnet/espnet.
-from abc import ABC
 from typing import Tuple
 from typing import Tuple
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 import torchaudio.compliance.kaldi as kaldi
 import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-import funasr.models.frontend.eend_ola_feature as eend_ola_feature
-from typeguard import check_argument_types
 from torch.nn.utils.rnn import pad_sequence
 from torch.nn.utils.rnn import pad_sequence
+from typeguard import check_argument_types
+
+import funasr.models.frontend.eend_ola_feature as eend_ola_feature
+from funasr.models.frontend.abs_frontend import AbsFrontend
 
 
 
 
 def load_cmvn(cmvn_file):
 def load_cmvn(cmvn_file):
@@ -276,7 +276,8 @@ class WavFrontendOnline(AbsFrontend):
     # inputs tensor has catted the cache tensor
     # inputs tensor has catted the cache tensor
     # def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
     # def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
     #               is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
     #               is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
-    def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
+    def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
+        torch.Tensor, torch.Tensor, int]:
         """
         """
         Apply lfr with data
         Apply lfr with data
         """
         """
@@ -377,7 +378,8 @@ class WavFrontendOnline(AbsFrontend):
             if self.lfr_m != 1 or self.lfr_n != 1:
             if self.lfr_m != 1 or self.lfr_n != 1:
                 # update self.lfr_splice_cache in self.apply_lfr
                 # update self.lfr_splice_cache in self.apply_lfr
                 # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
                 # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
-                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
+                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
+                                                                                     is_final)
             if self.cmvn_file is not None:
             if self.cmvn_file is not None:
                 mat = self.apply_cmvn(mat, self.cmvn)
                 mat = self.apply_cmvn(mat, self.cmvn)
             feat_length = mat.size(0)
             feat_length = mat.size(0)
@@ -399,9 +401,10 @@ class WavFrontendOnline(AbsFrontend):
         assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
         assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
         waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths)  # input shape: B T D
         waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths)  # input shape: B T D
         if feats.shape[0]:
         if feats.shape[0]:
-            #if self.reserve_waveforms is None and self.lfr_m > 1:
+            # if self.reserve_waveforms is None and self.lfr_m > 1:
             #    self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
             #    self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
-            self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1)
+            self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
+                (self.reserve_waveforms, waveforms), dim=1)
             if not self.lfr_splice_cache:  # 初始化splice_cache
             if not self.lfr_splice_cache:  # 初始化splice_cache
                 for i in range(batch_size):
                 for i in range(batch_size):
                     self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
                     self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
@@ -410,7 +413,8 @@ class WavFrontendOnline(AbsFrontend):
                 lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache)  # B T D
                 lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache)  # B T D
                 feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
                 feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
                 feats_lengths += lfr_splice_cache_tensor[0].shape[0]
                 feats_lengths += lfr_splice_cache_tensor[0].shape[0]
-                frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
+                frame_from_waveforms = int(
+                    (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
                 minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
                 minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
                 feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
                 feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
                 if self.lfr_m == 1:
                 if self.lfr_m == 1:
@@ -419,19 +423,22 @@ class WavFrontendOnline(AbsFrontend):
                     reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
                     reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
                     # print('reserve_frame_idx:  ' + str(reserve_frame_idx))
                     # print('reserve_frame_idx:  ' + str(reserve_frame_idx))
                     # print('frame_frame:  ' + str(frame_from_waveforms))
                     # print('frame_frame:  ' + str(frame_from_waveforms))
-                    self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
-                    sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
+                    self.reserve_waveforms = self.waveforms[:,
+                                             reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
+                    sample_length = (
+                                                frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
                     self.waveforms = self.waveforms[:, :sample_length]
                     self.waveforms = self.waveforms[:, :sample_length]
             else:
             else:
                 # update self.reserve_waveforms and self.lfr_splice_cache
                 # update self.reserve_waveforms and self.lfr_splice_cache
-                self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
+                self.reserve_waveforms = self.waveforms[:,
+                                         :-(self.frame_sample_length - self.frame_shift_sample_length)]
                 for i in range(batch_size):
                 for i in range(batch_size):
                     self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
                     self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
                 return torch.empty(0), feats_lengths
                 return torch.empty(0), feats_lengths
         else:
         else:
             if is_final:
             if is_final:
                 self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
                 self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
-                feats = torch.stack(self.lfr_splice_cache) 
+                feats = torch.stack(self.lfr_splice_cache)
                 feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
                 feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
                 feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
                 feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
         if is_final:
         if is_final:
@@ -466,9 +473,10 @@ class WavFrontendMel23(AbsFrontend):
         self.frame_shift = frame_shift
         self.frame_shift = frame_shift
         self.lfr_m = lfr_m
         self.lfr_m = lfr_m
         self.lfr_n = lfr_n
         self.lfr_n = lfr_n
+        self.n_mels = 23
 
 
     def output_size(self) -> int:
     def output_size(self) -> int:
-        return self.n_mels * self.lfr_m
+        return self.n_mels * (2 * self.lfr_m + 1)
 
 
     def forward(
     def forward(
             self,
             self,
@@ -494,4 +502,4 @@ class WavFrontendMel23(AbsFrontend):
         feats_pad = pad_sequence(feats,
         feats_pad = pad_sequence(feats,
                                  batch_first=True,
                                  batch_first=True,
                                  padding_value=0.0)
                                  padding_value=0.0)
-        return feats_pad, feats_lens
+        return feats_pad, feats_lens