cif.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. import torch
  2. from torch import nn
  3. import logging
  4. import numpy as np
  5. from funasr.torch_utils.device_funcs import to_device
  6. from funasr.modules.nets_utils import make_pad_mask
  7. from funasr.modules.streaming_utils.utils import sequence_mask
  8. class CifPredictor(nn.Module):
  9. def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
  10. super(CifPredictor, self).__init__()
  11. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  12. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
  13. self.cif_output = nn.Linear(idim, 1)
  14. self.dropout = torch.nn.Dropout(p=dropout)
  15. self.threshold = threshold
  16. self.smooth_factor = smooth_factor
  17. self.noise_threshold = noise_threshold
  18. self.tail_threshold = tail_threshold
  19. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  20. target_label_length=None):
  21. h = hidden
  22. context = h.transpose(1, 2)
  23. queries = self.pad(context)
  24. memory = self.cif_conv1d(queries)
  25. output = memory + context
  26. output = self.dropout(output)
  27. output = output.transpose(1, 2)
  28. output = torch.relu(output)
  29. output = self.cif_output(output)
  30. alphas = torch.sigmoid(output)
  31. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  32. if mask is not None:
  33. mask = mask.transpose(-1, -2).float()
  34. alphas = alphas * mask
  35. if mask_chunk_predictor is not None:
  36. alphas = alphas * mask_chunk_predictor
  37. alphas = alphas.squeeze(-1)
  38. mask = mask.squeeze(-1)
  39. if target_label_length is not None:
  40. target_length = target_label_length
  41. elif target_label is not None:
  42. target_length = (target_label != ignore_id).float().sum(-1)
  43. else:
  44. target_length = None
  45. token_num = alphas.sum(-1)
  46. if target_length is not None:
  47. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  48. elif self.tail_threshold > 0.0:
  49. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  50. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  51. if target_length is None and self.tail_threshold > 0.0:
  52. token_num_int = torch.max(token_num).type(torch.int32).item()
  53. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  54. return acoustic_embeds, token_num, alphas, cif_peak
  55. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  56. b, t, d = hidden.size()
  57. tail_threshold = self.tail_threshold
  58. if mask is not None:
  59. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  60. ones_t = torch.ones_like(zeros_t)
  61. mask_1 = torch.cat([mask, zeros_t], dim=1)
  62. mask_2 = torch.cat([ones_t, mask], dim=1)
  63. mask = mask_2 - mask_1
  64. tail_threshold = mask * tail_threshold
  65. alphas = torch.cat([alphas, zeros_t], dim=1)
  66. alphas = torch.add(alphas, tail_threshold)
  67. else:
  68. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  69. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  70. alphas = torch.cat([alphas, tail_threshold], dim=1)
  71. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  72. hidden = torch.cat([hidden, zeros], dim=1)
  73. token_num = alphas.sum(dim=-1)
  74. token_num_floor = torch.floor(token_num)
  75. return hidden, alphas, token_num_floor
  76. def gen_frame_alignments(self,
  77. alphas: torch.Tensor = None,
  78. encoder_sequence_length: torch.Tensor = None):
  79. batch_size, maximum_length = alphas.size()
  80. int_type = torch.int32
  81. is_training = self.training
  82. if is_training:
  83. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  84. else:
  85. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  86. max_token_num = torch.max(token_num).item()
  87. alphas_cumsum = torch.cumsum(alphas, dim=1)
  88. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  89. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  90. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  91. index = torch.cumsum(index, dim=1)
  92. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  93. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  94. index_div_bool_zeros = index_div.eq(0)
  95. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  96. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  97. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  98. index_div_bool_zeros_count *= token_num_mask
  99. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  100. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  101. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  102. ones = torch.cumsum(ones, dim=2)
  103. cond = index_div_bool_zeros_count_tile == ones
  104. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  105. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  106. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  107. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  108. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  109. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  110. int_type).to(encoder_sequence_length.device)
  111. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  112. predictor_alignments = index_div_bool_zeros_count_tile_out
  113. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  114. return predictor_alignments.detach(), predictor_alignments_length.detach()
  115. class CifPredictorV2(nn.Module):
  116. def __init__(self,
  117. idim,
  118. l_order,
  119. r_order,
  120. threshold=1.0,
  121. dropout=0.1,
  122. smooth_factor=1.0,
  123. noise_threshold=0,
  124. tail_threshold=0.0,
  125. tf2torch_tensor_name_prefix_torch="predictor",
  126. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  127. tail_mask=True,
  128. ):
  129. super(CifPredictorV2, self).__init__()
  130. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  131. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  132. self.cif_output = nn.Linear(idim, 1)
  133. self.dropout = torch.nn.Dropout(p=dropout)
  134. self.threshold = threshold
  135. self.smooth_factor = smooth_factor
  136. self.noise_threshold = noise_threshold
  137. self.tail_threshold = tail_threshold
  138. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  139. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  140. self.tail_mask = tail_mask
  141. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  142. target_label_length=None):
  143. h = hidden
  144. context = h.transpose(1, 2)
  145. queries = self.pad(context)
  146. output = torch.relu(self.cif_conv1d(queries))
  147. output = output.transpose(1, 2)
  148. output = self.cif_output(output)
  149. alphas = torch.sigmoid(output)
  150. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  151. if mask is not None:
  152. mask = mask.transpose(-1, -2).float()
  153. alphas = alphas * mask
  154. if mask_chunk_predictor is not None:
  155. alphas = alphas * mask_chunk_predictor
  156. alphas = alphas.squeeze(-1)
  157. mask = mask.squeeze(-1)
  158. if target_label_length is not None:
  159. target_length = target_label_length
  160. elif target_label is not None:
  161. target_length = (target_label != ignore_id).float().sum(-1)
  162. else:
  163. target_length = None
  164. token_num = alphas.sum(-1)
  165. if target_length is not None:
  166. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  167. elif self.tail_threshold > 0.0:
  168. if self.tail_mask:
  169. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  170. else:
  171. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
  172. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  173. if target_length is None and self.tail_threshold > 0.0:
  174. token_num_int = torch.max(token_num).type(torch.int32).item()
  175. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  176. return acoustic_embeds, token_num, alphas, cif_peak
  177. def forward_chunk(self, hidden, cache=None):
  178. batch_size, len_time, hidden_size = hidden.shape
  179. h = hidden
  180. context = h.transpose(1, 2)
  181. queries = self.pad(context)
  182. output = torch.relu(self.cif_conv1d(queries))
  183. output = output.transpose(1, 2)
  184. output = self.cif_output(output)
  185. alphas = torch.sigmoid(output)
  186. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  187. alphas = alphas.squeeze(-1)
  188. token_length = []
  189. list_fires = []
  190. list_frames = []
  191. cache_alphas = []
  192. cache_hiddens = []
  193. if cache is not None and "chunk_size" in cache:
  194. alphas[:, :cache["chunk_size"][0]] = 0.0
  195. alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
  196. if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
  197. cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
  198. cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
  199. hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
  200. alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
  201. if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
  202. tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
  203. tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
  204. tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
  205. hidden = torch.cat((hidden, tail_hidden), dim=1)
  206. alphas = torch.cat((alphas, tail_alphas), dim=1)
  207. len_time = alphas.shape[1]
  208. for b in range(batch_size):
  209. integrate = 0.0
  210. frames = torch.zeros((hidden_size), device=hidden.device)
  211. list_frame = []
  212. list_fire = []
  213. for t in range(len_time):
  214. alpha = alphas[b][t]
  215. if alpha + integrate < self.threshold:
  216. integrate += alpha
  217. list_fire.append(integrate)
  218. frames += alpha * hidden[b][t]
  219. else:
  220. frames += (self.threshold - integrate) * hidden[b][t]
  221. list_frame.append(frames)
  222. integrate += alpha
  223. list_fire.append(integrate)
  224. integrate -= self.threshold
  225. frames = integrate * hidden[b][t]
  226. cache_alphas.append(integrate)
  227. if integrate > 0.0:
  228. cache_hiddens.append(frames / integrate)
  229. else:
  230. cache_hiddens.append(frames)
  231. token_length.append(torch.tensor(len(list_frame), device=alphas.device))
  232. list_fires.append(list_fire)
  233. list_frames.append(list_frame)
  234. cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
  235. cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
  236. cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
  237. cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
  238. max_token_len = max(token_length)
  239. if max_token_len == 0:
  240. return hidden, torch.stack(token_length, 0)
  241. list_ls = []
  242. for b in range(batch_size):
  243. pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
  244. if token_length[b] == 0:
  245. list_ls.append(pad_frames)
  246. else:
  247. list_frames[b] = torch.stack(list_frames[b])
  248. list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
  249. cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
  250. cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
  251. cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
  252. cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
  253. return torch.stack(list_ls, 0), torch.stack(token_length, 0)
  254. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  255. b, t, d = hidden.size()
  256. tail_threshold = self.tail_threshold
  257. if mask is not None:
  258. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  259. ones_t = torch.ones_like(zeros_t)
  260. mask_1 = torch.cat([mask, zeros_t], dim=1)
  261. mask_2 = torch.cat([ones_t, mask], dim=1)
  262. mask = mask_2 - mask_1
  263. tail_threshold = mask * tail_threshold
  264. alphas = torch.cat([alphas, zeros_t], dim=1)
  265. alphas = torch.add(alphas, tail_threshold)
  266. else:
  267. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  268. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  269. if b > 1:
  270. alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
  271. else:
  272. alphas = torch.cat([alphas, tail_threshold], dim=1)
  273. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  274. hidden = torch.cat([hidden, zeros], dim=1)
  275. token_num = alphas.sum(dim=-1)
  276. token_num_floor = torch.floor(token_num)
  277. return hidden, alphas, token_num_floor
  278. def gen_frame_alignments(self,
  279. alphas: torch.Tensor = None,
  280. encoder_sequence_length: torch.Tensor = None):
  281. batch_size, maximum_length = alphas.size()
  282. int_type = torch.int32
  283. is_training = self.training
  284. if is_training:
  285. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  286. else:
  287. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  288. max_token_num = torch.max(token_num).item()
  289. alphas_cumsum = torch.cumsum(alphas, dim=1)
  290. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  291. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  292. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  293. index = torch.cumsum(index, dim=1)
  294. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  295. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  296. index_div_bool_zeros = index_div.eq(0)
  297. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  298. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  299. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  300. index_div_bool_zeros_count *= token_num_mask
  301. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  302. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  303. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  304. ones = torch.cumsum(ones, dim=2)
  305. cond = index_div_bool_zeros_count_tile == ones
  306. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  307. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  308. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  309. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  310. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  311. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  312. int_type).to(encoder_sequence_length.device)
  313. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  314. predictor_alignments = index_div_bool_zeros_count_tile_out
  315. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  316. return predictor_alignments.detach(), predictor_alignments_length.detach()
  317. def gen_tf2torch_map_dict(self):
  318. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  319. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  320. map_dict_local = {
  321. ## predictor
  322. "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
  323. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  324. "squeeze": None,
  325. "transpose": (2, 1, 0),
  326. }, # (256,256,3),(3,256,256)
  327. "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
  328. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  329. "squeeze": None,
  330. "transpose": None,
  331. }, # (256,),(256,)
  332. "{}.cif_output.weight".format(tensor_name_prefix_torch):
  333. {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
  334. "squeeze": 0,
  335. "transpose": (1, 0),
  336. }, # (1,256),(1,256,1)
  337. "{}.cif_output.bias".format(tensor_name_prefix_torch):
  338. {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
  339. "squeeze": None,
  340. "transpose": None,
  341. }, # (1,),(1,)
  342. }
  343. return map_dict_local
  344. def convert_tf2torch(self,
  345. var_dict_tf,
  346. var_dict_torch,
  347. ):
  348. map_dict = self.gen_tf2torch_map_dict()
  349. var_dict_torch_update = dict()
  350. for name in sorted(var_dict_torch.keys(), reverse=False):
  351. names = name.split('.')
  352. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  353. name_tf = map_dict[name]["name"]
  354. data_tf = var_dict_tf[name_tf]
  355. if map_dict[name]["squeeze"] is not None:
  356. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  357. if map_dict[name]["transpose"] is not None:
  358. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  359. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  360. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  361. var_dict_torch[
  362. name].size(),
  363. data_tf.size())
  364. var_dict_torch_update[name] = data_tf
  365. logging.info(
  366. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  367. var_dict_tf[name_tf].shape))
  368. return var_dict_torch_update
  369. class mae_loss(nn.Module):
  370. def __init__(self, normalize_length=False):
  371. super(mae_loss, self).__init__()
  372. self.normalize_length = normalize_length
  373. self.criterion = torch.nn.L1Loss(reduction='sum')
  374. def forward(self, token_length, pre_token_length):
  375. loss_token_normalizer = token_length.size(0)
  376. if self.normalize_length:
  377. loss_token_normalizer = token_length.sum().type(torch.float32)
  378. loss = self.criterion(token_length, pre_token_length)
  379. loss = loss / loss_token_normalizer
  380. return loss
  381. def cif(hidden, alphas, threshold):
  382. batch_size, len_time, hidden_size = hidden.size()
  383. # loop varss
  384. integrate = torch.zeros([batch_size], device=hidden.device)
  385. frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
  386. # intermediate vars along time
  387. list_fires = []
  388. list_frames = []
  389. for t in range(len_time):
  390. alpha = alphas[:, t]
  391. distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
  392. integrate += alpha
  393. list_fires.append(integrate)
  394. fire_place = integrate >= threshold
  395. integrate = torch.where(fire_place,
  396. integrate - torch.ones([batch_size], device=hidden.device),
  397. integrate)
  398. cur = torch.where(fire_place,
  399. distribution_completion,
  400. alpha)
  401. remainds = alpha - cur
  402. frame += cur[:, None] * hidden[:, t, :]
  403. list_frames.append(frame)
  404. frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
  405. remainds[:, None] * hidden[:, t, :],
  406. frame)
  407. fires = torch.stack(list_fires, 1)
  408. frames = torch.stack(list_frames, 1)
  409. list_ls = []
  410. len_labels = torch.round(alphas.sum(-1)).int()
  411. max_label_len = len_labels.max()
  412. for b in range(batch_size):
  413. fire = fires[b, :]
  414. l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
  415. pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
  416. list_ls.append(torch.cat([l, pad_l], 0))
  417. return torch.stack(list_ls, 0), fires
  418. def cif_wo_hidden(alphas, threshold):
  419. batch_size, len_time = alphas.size()
  420. # loop varss
  421. integrate = torch.zeros([batch_size], device=alphas.device)
  422. # intermediate vars along time
  423. list_fires = []
  424. for t in range(len_time):
  425. alpha = alphas[:, t]
  426. integrate += alpha
  427. list_fires.append(integrate)
  428. fire_place = integrate >= threshold
  429. integrate = torch.where(fire_place,
  430. integrate - torch.ones([batch_size], device=alphas.device),
  431. integrate)
  432. fires = torch.stack(list_fires, 1)
  433. return fires
  434. class CifPredictorV3(nn.Module):
  435. def __init__(self,
  436. idim,
  437. l_order,
  438. r_order,
  439. threshold=1.0,
  440. dropout=0.1,
  441. smooth_factor=1.0,
  442. noise_threshold=0,
  443. tail_threshold=0.0,
  444. tf2torch_tensor_name_prefix_torch="predictor",
  445. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  446. smooth_factor2=1.0,
  447. noise_threshold2=0,
  448. upsample_times=5,
  449. upsample_type="cnn",
  450. use_cif1_cnn=True,
  451. tail_mask=True,
  452. ):
  453. super(CifPredictorV3, self).__init__()
  454. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  455. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  456. self.cif_output = nn.Linear(idim, 1)
  457. self.dropout = torch.nn.Dropout(p=dropout)
  458. self.threshold = threshold
  459. self.smooth_factor = smooth_factor
  460. self.noise_threshold = noise_threshold
  461. self.tail_threshold = tail_threshold
  462. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  463. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  464. self.upsample_times = upsample_times
  465. self.upsample_type = upsample_type
  466. self.use_cif1_cnn = use_cif1_cnn
  467. if self.upsample_type == 'cnn':
  468. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  469. self.cif_output2 = nn.Linear(idim, 1)
  470. elif self.upsample_type == 'cnn_blstm':
  471. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  472. self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
  473. self.cif_output2 = nn.Linear(idim*2, 1)
  474. elif self.upsample_type == 'cnn_attn':
  475. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  476. from funasr.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
  477. from funasr.modules.attention import MultiHeadedAttention
  478. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
  479. positionwise_layer_args = (
  480. idim,
  481. idim*2,
  482. 0.1,
  483. )
  484. self.self_attn = TransformerEncoderLayer(
  485. idim,
  486. MultiHeadedAttention(
  487. 4, idim, 0.1
  488. ),
  489. PositionwiseFeedForward(*positionwise_layer_args),
  490. 0.1,
  491. True, #normalize_before,
  492. False, #concat_after,
  493. )
  494. self.cif_output2 = nn.Linear(idim, 1)
  495. self.smooth_factor2 = smooth_factor2
  496. self.noise_threshold2 = noise_threshold2
  497. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  498. target_label_length=None):
  499. h = hidden
  500. context = h.transpose(1, 2)
  501. queries = self.pad(context)
  502. output = torch.relu(self.cif_conv1d(queries))
  503. # alphas2 is an extra head for timestamp prediction
  504. if not self.use_cif1_cnn:
  505. _output = context
  506. else:
  507. _output = output
  508. if self.upsample_type == 'cnn':
  509. output2 = self.upsample_cnn(_output)
  510. output2 = output2.transpose(1,2)
  511. elif self.upsample_type == 'cnn_blstm':
  512. output2 = self.upsample_cnn(_output)
  513. output2 = output2.transpose(1,2)
  514. output2, (_, _) = self.blstm(output2)
  515. elif self.upsample_type == 'cnn_attn':
  516. output2 = self.upsample_cnn(_output)
  517. output2 = output2.transpose(1,2)
  518. output2, _ = self.self_attn(output2, mask)
  519. # import pdb; pdb.set_trace()
  520. alphas2 = torch.sigmoid(self.cif_output2(output2))
  521. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  522. # repeat the mask in T demension to match the upsampled length
  523. if mask is not None:
  524. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  525. mask2 = mask2.unsqueeze(-1)
  526. alphas2 = alphas2 * mask2
  527. alphas2 = alphas2.squeeze(-1)
  528. token_num2 = alphas2.sum(-1)
  529. output = output.transpose(1, 2)
  530. output = self.cif_output(output)
  531. alphas = torch.sigmoid(output)
  532. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  533. if mask is not None:
  534. mask = mask.transpose(-1, -2).float()
  535. alphas = alphas * mask
  536. if mask_chunk_predictor is not None:
  537. alphas = alphas * mask_chunk_predictor
  538. alphas = alphas.squeeze(-1)
  539. mask = mask.squeeze(-1)
  540. if target_label_length is not None:
  541. target_length = target_label_length
  542. elif target_label is not None:
  543. target_length = (target_label != ignore_id).float().sum(-1)
  544. else:
  545. target_length = None
  546. token_num = alphas.sum(-1)
  547. if target_length is not None:
  548. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  549. elif self.tail_threshold > 0.0:
  550. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  551. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  552. if target_length is None and self.tail_threshold > 0.0:
  553. token_num_int = torch.max(token_num).type(torch.int32).item()
  554. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  555. return acoustic_embeds, token_num, alphas, cif_peak, token_num2
  556. def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
  557. h = hidden
  558. b = hidden.shape[0]
  559. context = h.transpose(1, 2)
  560. queries = self.pad(context)
  561. output = torch.relu(self.cif_conv1d(queries))
  562. # alphas2 is an extra head for timestamp prediction
  563. if not self.use_cif1_cnn:
  564. _output = context
  565. else:
  566. _output = output
  567. if self.upsample_type == 'cnn':
  568. output2 = self.upsample_cnn(_output)
  569. output2 = output2.transpose(1,2)
  570. elif self.upsample_type == 'cnn_blstm':
  571. output2 = self.upsample_cnn(_output)
  572. output2 = output2.transpose(1,2)
  573. output2, (_, _) = self.blstm(output2)
  574. elif self.upsample_type == 'cnn_attn':
  575. output2 = self.upsample_cnn(_output)
  576. output2 = output2.transpose(1,2)
  577. output2, _ = self.self_attn(output2, mask)
  578. alphas2 = torch.sigmoid(self.cif_output2(output2))
  579. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  580. # repeat the mask in T demension to match the upsampled length
  581. if mask is not None:
  582. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  583. mask2 = mask2.unsqueeze(-1)
  584. alphas2 = alphas2 * mask2
  585. alphas2 = alphas2.squeeze(-1)
  586. _token_num = alphas2.sum(-1)
  587. if token_num is not None:
  588. alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
  589. # re-downsample
  590. ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
  591. ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
  592. # upsampled alphas and cif_peak
  593. us_alphas = alphas2
  594. us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
  595. return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
  596. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  597. b, t, d = hidden.size()
  598. tail_threshold = self.tail_threshold
  599. if mask is not None:
  600. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  601. ones_t = torch.ones_like(zeros_t)
  602. mask_1 = torch.cat([mask, zeros_t], dim=1)
  603. mask_2 = torch.cat([ones_t, mask], dim=1)
  604. mask = mask_2 - mask_1
  605. tail_threshold = mask * tail_threshold
  606. alphas = torch.cat([alphas, zeros_t], dim=1)
  607. alphas = torch.add(alphas, tail_threshold)
  608. else:
  609. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  610. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  611. alphas = torch.cat([alphas, tail_threshold], dim=1)
  612. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  613. hidden = torch.cat([hidden, zeros], dim=1)
  614. token_num = alphas.sum(dim=-1)
  615. token_num_floor = torch.floor(token_num)
  616. return hidden, alphas, token_num_floor
  617. def gen_frame_alignments(self,
  618. alphas: torch.Tensor = None,
  619. encoder_sequence_length: torch.Tensor = None):
  620. batch_size, maximum_length = alphas.size()
  621. int_type = torch.int32
  622. is_training = self.training
  623. if is_training:
  624. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  625. else:
  626. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  627. max_token_num = torch.max(token_num).item()
  628. alphas_cumsum = torch.cumsum(alphas, dim=1)
  629. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  630. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  631. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  632. index = torch.cumsum(index, dim=1)
  633. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  634. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  635. index_div_bool_zeros = index_div.eq(0)
  636. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  637. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  638. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  639. index_div_bool_zeros_count *= token_num_mask
  640. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  641. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  642. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  643. ones = torch.cumsum(ones, dim=2)
  644. cond = index_div_bool_zeros_count_tile == ones
  645. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  646. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  647. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  648. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  649. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  650. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  651. int_type).to(encoder_sequence_length.device)
  652. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  653. predictor_alignments = index_div_bool_zeros_count_tile_out
  654. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  655. return predictor_alignments.detach(), predictor_alignments_length.detach()