chunk_utilis.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import torch
  2. import numpy as np
  3. import math
  4. from funasr.modules.nets_utils import make_pad_mask
  5. import logging
  6. import torch.nn.functional as F
  7. from funasr.modules.streaming_utils.utils import sequence_mask
  8. class overlap_chunk():
  9. """
  10. author: Speech Lab, Alibaba Group, China
  11. San-m: Memory equipped self-attention for end-to-end speech recognition
  12. https://arxiv.org/abs/2006.01713
  13. """
  14. def __init__(self,
  15. chunk_size: tuple = (16,),
  16. stride: tuple = (10,),
  17. pad_left: tuple = (0,),
  18. encoder_att_look_back_factor: tuple = (1,),
  19. shfit_fsmn: int = 0,
  20. decoder_att_look_back_factor: tuple = (1,),
  21. ):
  22. pad_left = self.check_chunk_size_args(chunk_size, pad_left)
  23. encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
  24. decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
  25. self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
  26. = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
  27. self.shfit_fsmn = shfit_fsmn
  28. self.x_add_mask = None
  29. self.x_rm_mask = None
  30. self.x_len = None
  31. self.mask_shfit_chunk = None
  32. self.mask_chunk_predictor = None
  33. self.mask_att_chunk_encoder = None
  34. self.mask_shift_att_chunk_decoder = None
  35. self.chunk_outs = None
  36. self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
  37. = None, None, None, None, None
  38. def check_chunk_size_args(self, chunk_size, x):
  39. if len(x) < len(chunk_size):
  40. x = [x[0] for i in chunk_size]
  41. return x
  42. def get_chunk_size(self,
  43. ind: int = 0
  44. ):
  45. # with torch.no_grad:
  46. chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
  47. self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
  48. self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
  49. = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
  50. return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
  51. def random_choice(self, training=True, decoding_ind=None):
  52. chunk_num = len(self.chunk_size)
  53. ind = 0
  54. if training and chunk_num > 1:
  55. ind = torch.randint(0, chunk_num-1, ()).cpu().item()
  56. if not training and decoding_ind is not None:
  57. ind = int(decoding_ind)
  58. return ind
  59. def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
  60. with torch.no_grad():
  61. x_len = x_len.cpu().numpy()
  62. x_len_max = x_len.max()
  63. chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
  64. shfit_fsmn = self.shfit_fsmn
  65. pad_right = chunk_size - stride - pad_left
  66. chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
  67. x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
  68. x_len_chunk = x_len_chunk.astype(x_len.dtype)
  69. x_len_chunk_max = x_len_chunk.max()
  70. chunk_num = int(math.ceil(x_len_max/stride))
  71. dtype = np.int32
  72. max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
  73. x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
  74. x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
  75. mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
  76. mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
  77. mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
  78. mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
  79. for chunk_ids in range(chunk_num):
  80. # x_mask add
  81. fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
  82. x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
  83. x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
  84. x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
  85. x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
  86. x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
  87. x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
  88. x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
  89. # x_mask rm
  90. fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
  91. padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
  92. padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
  93. x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
  94. x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
  95. x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
  96. x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
  97. x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
  98. x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
  99. x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
  100. # fsmn_padding_mask
  101. pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
  102. ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
  103. mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
  104. mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
  105. # predictor mask
  106. zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
  107. ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
  108. zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
  109. ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
  110. mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
  111. mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
  112. # encoder att mask
  113. zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
  114. zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
  115. zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
  116. encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
  117. zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
  118. ones_2_mid = np.ones([stride, stride], dtype=dtype)
  119. zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
  120. zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
  121. ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
  122. ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
  123. ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
  124. zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
  125. ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
  126. ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
  127. zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
  128. zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
  129. ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
  130. mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
  131. mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
  132. # decoder fsmn_shift_att_mask
  133. zeros_1 = np.zeros([shfit_fsmn, 1])
  134. ones_1 = np.ones([chunk_size, 1])
  135. mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
  136. mask_shift_att_chunk_decoder = np.concatenate(
  137. [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
  138. self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
  139. self.x_len_chunk = x_len_chunk
  140. self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
  141. self.x_len = x_len
  142. self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
  143. self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
  144. self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
  145. self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
  146. self.chunk_outs = (self.x_add_mask,
  147. self.x_len_chunk,
  148. self.x_rm_mask,
  149. self.x_len,
  150. self.mask_shfit_chunk,
  151. self.mask_chunk_predictor,
  152. self.mask_att_chunk_encoder,
  153. self.mask_shift_att_chunk_decoder)
  154. return self.chunk_outs
  155. def split_chunk(self, x, x_len, chunk_outs):
  156. """
  157. :param x: (b, t, d)
  158. :param x_length: (b)
  159. :param ind: int
  160. :return:
  161. """
  162. x = x[:, :x_len.max(), :]
  163. b, t, d = x.size()
  164. x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
  165. x.device)
  166. x *= x_len_mask[:, :, None]
  167. x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
  168. x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
  169. pad = (0, 0, self.pad_left_cur, 0)
  170. x = F.pad(x, pad, "constant", 0.0)
  171. b, t, d = x.size()
  172. x = torch.transpose(x, 1, 0)
  173. x = torch.reshape(x, [t, -1])
  174. x_chunk = torch.mm(x_add_mask, x)
  175. x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
  176. return x_chunk, x_len_chunk
  177. def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
  178. x_chunk = x_chunk[:, :x_len_chunk.max(), :]
  179. b, t, d = x_chunk.size()
  180. x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
  181. x_chunk.device)
  182. x_chunk *= x_len_chunk_mask[:, :, None]
  183. x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
  184. x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
  185. x_chunk = torch.transpose(x_chunk, 1, 0)
  186. x_chunk = torch.reshape(x_chunk, [t, -1])
  187. x = torch.mm(x_rm_mask, x_chunk)
  188. x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
  189. return x, x_len
  190. def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
  191. with torch.no_grad():
  192. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  193. x = torch.from_numpy(x).type(dtype).to(device)
  194. return x
  195. def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
  196. with torch.no_grad():
  197. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  198. x = torch.from_numpy(x).type(dtype).to(device)
  199. return x
  200. def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
  201. with torch.no_grad():
  202. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  203. x = torch.from_numpy(x).type(dtype).to(device)
  204. return x
  205. def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
  206. with torch.no_grad():
  207. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  208. x = torch.from_numpy(x).type(dtype).to(device)
  209. return x
  210. def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
  211. with torch.no_grad():
  212. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  213. x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
  214. x = torch.from_numpy(x).type(dtype).to(device)
  215. return x
  216. def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
  217. with torch.no_grad():
  218. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  219. x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
  220. x = torch.from_numpy(x).type(dtype).to(device)
  221. return x
  222. def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
  223. with torch.no_grad():
  224. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  225. x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
  226. x = torch.from_numpy(x).type(dtype).to(device)
  227. return x
  228. def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
  229. with torch.no_grad():
  230. x = chunk_outs[idx] if chunk_outs is not None else self.chunk_outs[idx]
  231. x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
  232. x = torch.from_numpy(x).type(dtype).to(device)
  233. return x
  234. def build_scama_mask_for_cross_attention_decoder(
  235. predictor_alignments: torch.Tensor,
  236. encoder_sequence_length: torch.Tensor,
  237. chunk_size: int = 5,
  238. encoder_chunk_size: int = 5,
  239. attention_chunk_center_bias: int = 0,
  240. attention_chunk_size: int = 1,
  241. attention_chunk_type: str = 'chunk',
  242. step=None,
  243. predictor_mask_chunk_hopping: torch.Tensor = None,
  244. decoder_att_look_back_factor: int = 1,
  245. mask_shift_att_chunk_decoder: torch.Tensor = None,
  246. target_length: torch.Tensor = None,
  247. is_training=True,
  248. dtype: torch.dtype = torch.float32):
  249. with torch.no_grad():
  250. device = predictor_alignments.device
  251. batch_size, chunk_num = predictor_alignments.size()
  252. maximum_encoder_length = encoder_sequence_length.max().item()
  253. int_type = predictor_alignments.dtype
  254. if not is_training:
  255. target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
  256. maximum_target_length = target_length.max()
  257. predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
  258. predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
  259. index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
  260. index = torch.cumsum(index, dim=1)
  261. index = index[:, :, None].repeat(1, 1, chunk_num)
  262. index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
  263. index_div_bool_zeros = index_div == 0
  264. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
  265. index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
  266. index_div_bool_zeros_count *= chunk_size
  267. index_div_bool_zeros_count += attention_chunk_center_bias
  268. index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
  269. index_div_bool_zeros_count_ori = index_div_bool_zeros_count
  270. index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
  271. max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
  272. mask_flip, mask_flip2 = None, None
  273. if attention_chunk_size is not None:
  274. index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
  275. index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
  276. index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
  277. mask_flip = 1 - index_div_bool_zeros_count_beg_mask
  278. attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
  279. index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
  280. index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
  281. index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
  282. mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
  283. mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
  284. if predictor_mask_chunk_hopping is not None:
  285. b, k, t = mask.size()
  286. predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
  287. mask_mask_flip = mask
  288. if mask_flip is not None:
  289. mask_mask_flip = mask_flip * mask
  290. def _fn():
  291. mask_sliced = mask[:b, :k, encoder_chunk_size:t]
  292. zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
  293. mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
  294. _, _, tt = predictor_mask_chunk_hopping.size()
  295. pad_right_p = max_len_chunk - tt
  296. predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
  297. masked = mask_sliced * predictor_mask_chunk_hopping_pad
  298. mask_true = mask_mask_flip + masked
  299. return mask_true
  300. mask = _fn() if t > chunk_size else mask_mask_flip
  301. if mask_flip2 is not None:
  302. mask *= mask_flip2
  303. mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
  304. mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
  305. mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
  306. mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
  307. if attention_chunk_type == 'full':
  308. mask = torch.ones_like(mask).to(device)
  309. if mask_shift_att_chunk_decoder is not None:
  310. mask = mask * mask_shift_att_chunk_decoder
  311. mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
  312. return mask