cif.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. import torch
  2. from torch import nn
  3. from torch import Tensor
  4. import logging
  5. import numpy as np
  6. from funasr.torch_utils.device_funcs import to_device
  7. from funasr.modules.nets_utils import make_pad_mask
  8. from funasr.modules.streaming_utils.utils import sequence_mask
  9. from typing import Optional, Tuple
  10. class CifPredictor(nn.Module):
  11. def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
  12. super(CifPredictor, self).__init__()
  13. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  14. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
  15. self.cif_output = nn.Linear(idim, 1)
  16. self.dropout = torch.nn.Dropout(p=dropout)
  17. self.threshold = threshold
  18. self.smooth_factor = smooth_factor
  19. self.noise_threshold = noise_threshold
  20. self.tail_threshold = tail_threshold
  21. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  22. target_label_length=None):
  23. h = hidden
  24. context = h.transpose(1, 2)
  25. queries = self.pad(context)
  26. memory = self.cif_conv1d(queries)
  27. output = memory + context
  28. output = self.dropout(output)
  29. output = output.transpose(1, 2)
  30. output = torch.relu(output)
  31. output = self.cif_output(output)
  32. alphas = torch.sigmoid(output)
  33. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  34. if mask is not None:
  35. mask = mask.transpose(-1, -2).float()
  36. alphas = alphas * mask
  37. if mask_chunk_predictor is not None:
  38. alphas = alphas * mask_chunk_predictor
  39. alphas = alphas.squeeze(-1)
  40. mask = mask.squeeze(-1)
  41. if target_label_length is not None:
  42. target_length = target_label_length
  43. elif target_label is not None:
  44. target_length = (target_label != ignore_id).float().sum(-1)
  45. else:
  46. target_length = None
  47. token_num = alphas.sum(-1)
  48. if target_length is not None:
  49. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  50. elif self.tail_threshold > 0.0:
  51. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  52. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  53. if target_length is None and self.tail_threshold > 0.0:
  54. token_num_int = torch.max(token_num).type(torch.int32).item()
  55. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  56. return acoustic_embeds, token_num, alphas, cif_peak
  57. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  58. b, t, d = hidden.size()
  59. tail_threshold = self.tail_threshold
  60. if mask is not None:
  61. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  62. ones_t = torch.ones_like(zeros_t)
  63. mask_1 = torch.cat([mask, zeros_t], dim=1)
  64. mask_2 = torch.cat([ones_t, mask], dim=1)
  65. mask = mask_2 - mask_1
  66. tail_threshold = mask * tail_threshold
  67. alphas = torch.cat([alphas, zeros_t], dim=1)
  68. alphas = torch.add(alphas, tail_threshold)
  69. else:
  70. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  71. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  72. alphas = torch.cat([alphas, tail_threshold], dim=1)
  73. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  74. hidden = torch.cat([hidden, zeros], dim=1)
  75. token_num = alphas.sum(dim=-1)
  76. token_num_floor = torch.floor(token_num)
  77. return hidden, alphas, token_num_floor
  78. def gen_frame_alignments(self,
  79. alphas: torch.Tensor = None,
  80. encoder_sequence_length: torch.Tensor = None):
  81. batch_size, maximum_length = alphas.size()
  82. int_type = torch.int32
  83. is_training = self.training
  84. if is_training:
  85. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  86. else:
  87. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  88. max_token_num = torch.max(token_num).item()
  89. alphas_cumsum = torch.cumsum(alphas, dim=1)
  90. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  91. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  92. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  93. index = torch.cumsum(index, dim=1)
  94. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  95. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  96. index_div_bool_zeros = index_div.eq(0)
  97. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  98. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  99. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  100. index_div_bool_zeros_count *= token_num_mask
  101. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  102. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  103. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  104. ones = torch.cumsum(ones, dim=2)
  105. cond = index_div_bool_zeros_count_tile == ones
  106. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  107. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  108. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  109. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  110. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  111. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  112. int_type).to(encoder_sequence_length.device)
  113. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  114. predictor_alignments = index_div_bool_zeros_count_tile_out
  115. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  116. return predictor_alignments.detach(), predictor_alignments_length.detach()
  117. class CifPredictorV2(nn.Module):
  118. def __init__(self,
  119. idim,
  120. l_order,
  121. r_order,
  122. threshold=1.0,
  123. dropout=0.1,
  124. smooth_factor=1.0,
  125. noise_threshold=0,
  126. tail_threshold=0.0,
  127. tf2torch_tensor_name_prefix_torch="predictor",
  128. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  129. tail_mask=True,
  130. ):
  131. super(CifPredictorV2, self).__init__()
  132. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  133. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  134. self.cif_output = nn.Linear(idim, 1)
  135. self.dropout = torch.nn.Dropout(p=dropout)
  136. self.threshold = threshold
  137. self.smooth_factor = smooth_factor
  138. self.noise_threshold = noise_threshold
  139. self.tail_threshold = tail_threshold
  140. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  141. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  142. self.tail_mask = tail_mask
  143. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  144. target_label_length=None):
  145. h = hidden
  146. context = h.transpose(1, 2)
  147. queries = self.pad(context)
  148. output = torch.relu(self.cif_conv1d(queries))
  149. output = output.transpose(1, 2)
  150. output = self.cif_output(output)
  151. alphas = torch.sigmoid(output)
  152. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  153. if mask is not None:
  154. mask = mask.transpose(-1, -2).float()
  155. alphas = alphas * mask
  156. if mask_chunk_predictor is not None:
  157. alphas = alphas * mask_chunk_predictor
  158. alphas = alphas.squeeze(-1)
  159. mask = mask.squeeze(-1)
  160. if target_label_length is not None:
  161. target_length = target_label_length
  162. elif target_label is not None:
  163. target_length = (target_label != ignore_id).float().sum(-1)
  164. else:
  165. target_length = None
  166. token_num = alphas.sum(-1)
  167. if target_length is not None:
  168. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  169. elif self.tail_threshold > 0.0:
  170. if self.tail_mask:
  171. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  172. else:
  173. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
  174. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  175. if target_length is None and self.tail_threshold > 0.0:
  176. token_num_int = torch.max(token_num).type(torch.int32).item()
  177. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  178. return acoustic_embeds, token_num, alphas, cif_peak
  179. def forward_chunk(self, hidden, cache=None):
  180. batch_size, len_time, hidden_size = hidden.shape
  181. h = hidden
  182. context = h.transpose(1, 2)
  183. queries = self.pad(context)
  184. output = torch.relu(self.cif_conv1d(queries))
  185. output = output.transpose(1, 2)
  186. output = self.cif_output(output)
  187. alphas = torch.sigmoid(output)
  188. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  189. alphas = alphas.squeeze(-1)
  190. token_length = []
  191. list_fires = []
  192. list_frames = []
  193. cache_alphas = []
  194. cache_hiddens = []
  195. if cache is not None and "chunk_size" in cache:
  196. alphas[:, :cache["chunk_size"][0]] = 0.0
  197. if "is_final" in cache and not cache["is_final"]:
  198. alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
  199. if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
  200. cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
  201. cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
  202. hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
  203. alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
  204. if cache is not None and "is_final" in cache and cache["is_final"]:
  205. tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
  206. tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
  207. tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
  208. hidden = torch.cat((hidden, tail_hidden), dim=1)
  209. alphas = torch.cat((alphas, tail_alphas), dim=1)
  210. len_time = alphas.shape[1]
  211. for b in range(batch_size):
  212. integrate = 0.0
  213. frames = torch.zeros((hidden_size), device=hidden.device)
  214. list_frame = []
  215. list_fire = []
  216. for t in range(len_time):
  217. alpha = alphas[b][t]
  218. if alpha + integrate < self.threshold:
  219. integrate += alpha
  220. list_fire.append(integrate)
  221. frames += alpha * hidden[b][t]
  222. else:
  223. frames += (self.threshold - integrate) * hidden[b][t]
  224. list_frame.append(frames)
  225. integrate += alpha
  226. list_fire.append(integrate)
  227. integrate -= self.threshold
  228. frames = integrate * hidden[b][t]
  229. cache_alphas.append(integrate)
  230. if integrate > 0.0:
  231. cache_hiddens.append(frames / integrate)
  232. else:
  233. cache_hiddens.append(frames)
  234. token_length.append(torch.tensor(len(list_frame), device=alphas.device))
  235. list_fires.append(list_fire)
  236. list_frames.append(list_frame)
  237. cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
  238. cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
  239. cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
  240. cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
  241. max_token_len = max(token_length)
  242. if max_token_len == 0:
  243. return hidden, torch.stack(token_length, 0)
  244. list_ls = []
  245. for b in range(batch_size):
  246. pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
  247. if token_length[b] == 0:
  248. list_ls.append(pad_frames)
  249. else:
  250. list_frames[b] = torch.stack(list_frames[b])
  251. list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
  252. cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
  253. cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
  254. cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
  255. cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
  256. return torch.stack(list_ls, 0), torch.stack(token_length, 0)
  257. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  258. b, t, d = hidden.size()
  259. tail_threshold = self.tail_threshold
  260. if mask is not None:
  261. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  262. ones_t = torch.ones_like(zeros_t)
  263. mask_1 = torch.cat([mask, zeros_t], dim=1)
  264. mask_2 = torch.cat([ones_t, mask], dim=1)
  265. mask = mask_2 - mask_1
  266. tail_threshold = mask * tail_threshold
  267. alphas = torch.cat([alphas, zeros_t], dim=1)
  268. alphas = torch.add(alphas, tail_threshold)
  269. else:
  270. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  271. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  272. if b > 1:
  273. alphas = torch.cat([alphas, tail_threshold.repeat(b, 1)], dim=1)
  274. else:
  275. alphas = torch.cat([alphas, tail_threshold], dim=1)
  276. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  277. hidden = torch.cat([hidden, zeros], dim=1)
  278. token_num = alphas.sum(dim=-1)
  279. token_num_floor = torch.floor(token_num)
  280. return hidden, alphas, token_num_floor
  281. def gen_frame_alignments(self,
  282. alphas: torch.Tensor = None,
  283. encoder_sequence_length: torch.Tensor = None):
  284. batch_size, maximum_length = alphas.size()
  285. int_type = torch.int32
  286. is_training = self.training
  287. if is_training:
  288. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  289. else:
  290. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  291. max_token_num = torch.max(token_num).item()
  292. alphas_cumsum = torch.cumsum(alphas, dim=1)
  293. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  294. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  295. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  296. index = torch.cumsum(index, dim=1)
  297. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  298. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  299. index_div_bool_zeros = index_div.eq(0)
  300. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  301. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  302. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  303. index_div_bool_zeros_count *= token_num_mask
  304. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  305. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  306. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  307. ones = torch.cumsum(ones, dim=2)
  308. cond = index_div_bool_zeros_count_tile == ones
  309. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  310. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  311. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  312. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  313. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  314. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  315. int_type).to(encoder_sequence_length.device)
  316. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  317. predictor_alignments = index_div_bool_zeros_count_tile_out
  318. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  319. return predictor_alignments.detach(), predictor_alignments_length.detach()
  320. def gen_tf2torch_map_dict(self):
  321. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  322. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  323. map_dict_local = {
  324. ## predictor
  325. "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
  326. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  327. "squeeze": None,
  328. "transpose": (2, 1, 0),
  329. }, # (256,256,3),(3,256,256)
  330. "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
  331. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  332. "squeeze": None,
  333. "transpose": None,
  334. }, # (256,),(256,)
  335. "{}.cif_output.weight".format(tensor_name_prefix_torch):
  336. {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
  337. "squeeze": 0,
  338. "transpose": (1, 0),
  339. }, # (1,256),(1,256,1)
  340. "{}.cif_output.bias".format(tensor_name_prefix_torch):
  341. {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
  342. "squeeze": None,
  343. "transpose": None,
  344. }, # (1,),(1,)
  345. }
  346. return map_dict_local
  347. def convert_tf2torch(self,
  348. var_dict_tf,
  349. var_dict_torch,
  350. ):
  351. map_dict = self.gen_tf2torch_map_dict()
  352. var_dict_torch_update = dict()
  353. for name in sorted(var_dict_torch.keys(), reverse=False):
  354. names = name.split('.')
  355. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  356. name_tf = map_dict[name]["name"]
  357. data_tf = var_dict_tf[name_tf]
  358. if map_dict[name]["squeeze"] is not None:
  359. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  360. if map_dict[name]["transpose"] is not None:
  361. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  362. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  363. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  364. var_dict_torch[
  365. name].size(),
  366. data_tf.size())
  367. var_dict_torch_update[name] = data_tf
  368. logging.info(
  369. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  370. var_dict_tf[name_tf].shape))
  371. return var_dict_torch_update
  372. class mae_loss(nn.Module):
  373. def __init__(self, normalize_length=False):
  374. super(mae_loss, self).__init__()
  375. self.normalize_length = normalize_length
  376. self.criterion = torch.nn.L1Loss(reduction='sum')
  377. def forward(self, token_length, pre_token_length):
  378. loss_token_normalizer = token_length.size(0)
  379. if self.normalize_length:
  380. loss_token_normalizer = token_length.sum().type(torch.float32)
  381. loss = self.criterion(token_length, pre_token_length)
  382. loss = loss / loss_token_normalizer
  383. return loss
  384. def cif(hidden, alphas, threshold):
  385. batch_size, len_time, hidden_size = hidden.size()
  386. # loop varss
  387. integrate = torch.zeros([batch_size], device=hidden.device)
  388. frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
  389. # intermediate vars along time
  390. list_fires = []
  391. list_frames = []
  392. for t in range(len_time):
  393. alpha = alphas[:, t]
  394. distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
  395. integrate += alpha
  396. list_fires.append(integrate)
  397. fire_place = integrate >= threshold
  398. integrate = torch.where(fire_place,
  399. integrate - torch.ones([batch_size], device=hidden.device),
  400. integrate)
  401. cur = torch.where(fire_place,
  402. distribution_completion,
  403. alpha)
  404. remainds = alpha - cur
  405. frame += cur[:, None] * hidden[:, t, :]
  406. list_frames.append(frame)
  407. frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
  408. remainds[:, None] * hidden[:, t, :],
  409. frame)
  410. fires = torch.stack(list_fires, 1)
  411. frames = torch.stack(list_frames, 1)
  412. list_ls = []
  413. len_labels = torch.round(alphas.sum(-1)).int()
  414. max_label_len = len_labels.max()
  415. for b in range(batch_size):
  416. fire = fires[b, :]
  417. l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
  418. pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
  419. list_ls.append(torch.cat([l, pad_l], 0))
  420. return torch.stack(list_ls, 0), fires
  421. def cif_wo_hidden(alphas, threshold):
  422. batch_size, len_time = alphas.size()
  423. # loop varss
  424. integrate = torch.zeros([batch_size], device=alphas.device)
  425. # intermediate vars along time
  426. list_fires = []
  427. for t in range(len_time):
  428. alpha = alphas[:, t]
  429. integrate += alpha
  430. list_fires.append(integrate)
  431. fire_place = integrate >= threshold
  432. integrate = torch.where(fire_place,
  433. integrate - torch.ones([batch_size], device=alphas.device)*threshold,
  434. integrate)
  435. fires = torch.stack(list_fires, 1)
  436. return fires
  437. class CifPredictorV3(nn.Module):
  438. def __init__(self,
  439. idim,
  440. l_order,
  441. r_order,
  442. threshold=1.0,
  443. dropout=0.1,
  444. smooth_factor=1.0,
  445. noise_threshold=0,
  446. tail_threshold=0.0,
  447. tf2torch_tensor_name_prefix_torch="predictor",
  448. tf2torch_tensor_name_prefix_tf="seq2seq/cif",
  449. smooth_factor2=1.0,
  450. noise_threshold2=0,
  451. upsample_times=5,
  452. upsample_type="cnn",
  453. use_cif1_cnn=True,
  454. tail_mask=True,
  455. ):
  456. super(CifPredictorV3, self).__init__()
  457. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  458. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
  459. self.cif_output = nn.Linear(idim, 1)
  460. self.dropout = torch.nn.Dropout(p=dropout)
  461. self.threshold = threshold
  462. self.smooth_factor = smooth_factor
  463. self.noise_threshold = noise_threshold
  464. self.tail_threshold = tail_threshold
  465. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  466. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  467. self.upsample_times = upsample_times
  468. self.upsample_type = upsample_type
  469. self.use_cif1_cnn = use_cif1_cnn
  470. if self.upsample_type == 'cnn':
  471. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  472. self.cif_output2 = nn.Linear(idim, 1)
  473. elif self.upsample_type == 'cnn_blstm':
  474. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  475. self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
  476. self.cif_output2 = nn.Linear(idim*2, 1)
  477. elif self.upsample_type == 'cnn_attn':
  478. self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
  479. from funasr.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
  480. from funasr.modules.attention import MultiHeadedAttention
  481. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
  482. positionwise_layer_args = (
  483. idim,
  484. idim*2,
  485. 0.1,
  486. )
  487. self.self_attn = TransformerEncoderLayer(
  488. idim,
  489. MultiHeadedAttention(
  490. 4, idim, 0.1
  491. ),
  492. PositionwiseFeedForward(*positionwise_layer_args),
  493. 0.1,
  494. True, #normalize_before,
  495. False, #concat_after,
  496. )
  497. self.cif_output2 = nn.Linear(idim, 1)
  498. self.smooth_factor2 = smooth_factor2
  499. self.noise_threshold2 = noise_threshold2
  500. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
  501. target_label_length=None):
  502. h = hidden
  503. context = h.transpose(1, 2)
  504. queries = self.pad(context)
  505. output = torch.relu(self.cif_conv1d(queries))
  506. # alphas2 is an extra head for timestamp prediction
  507. if not self.use_cif1_cnn:
  508. _output = context
  509. else:
  510. _output = output
  511. if self.upsample_type == 'cnn':
  512. output2 = self.upsample_cnn(_output)
  513. output2 = output2.transpose(1,2)
  514. elif self.upsample_type == 'cnn_blstm':
  515. output2 = self.upsample_cnn(_output)
  516. output2 = output2.transpose(1,2)
  517. output2, (_, _) = self.blstm(output2)
  518. elif self.upsample_type == 'cnn_attn':
  519. output2 = self.upsample_cnn(_output)
  520. output2 = output2.transpose(1,2)
  521. output2, _ = self.self_attn(output2, mask)
  522. # import pdb; pdb.set_trace()
  523. alphas2 = torch.sigmoid(self.cif_output2(output2))
  524. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  525. # repeat the mask in T demension to match the upsampled length
  526. if mask is not None:
  527. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  528. mask2 = mask2.unsqueeze(-1)
  529. alphas2 = alphas2 * mask2
  530. alphas2 = alphas2.squeeze(-1)
  531. token_num2 = alphas2.sum(-1)
  532. output = output.transpose(1, 2)
  533. output = self.cif_output(output)
  534. alphas = torch.sigmoid(output)
  535. alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
  536. if mask is not None:
  537. mask = mask.transpose(-1, -2).float()
  538. alphas = alphas * mask
  539. if mask_chunk_predictor is not None:
  540. alphas = alphas * mask_chunk_predictor
  541. alphas = alphas.squeeze(-1)
  542. mask = mask.squeeze(-1)
  543. if target_label_length is not None:
  544. target_length = target_label_length
  545. elif target_label is not None:
  546. target_length = (target_label != ignore_id).float().sum(-1)
  547. else:
  548. target_length = None
  549. token_num = alphas.sum(-1)
  550. if target_length is not None:
  551. alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
  552. elif self.tail_threshold > 0.0:
  553. hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
  554. acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
  555. if target_length is None and self.tail_threshold > 0.0:
  556. token_num_int = torch.max(token_num).type(torch.int32).item()
  557. acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
  558. return acoustic_embeds, token_num, alphas, cif_peak, token_num2
  559. def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
  560. h = hidden
  561. b = hidden.shape[0]
  562. context = h.transpose(1, 2)
  563. queries = self.pad(context)
  564. output = torch.relu(self.cif_conv1d(queries))
  565. # alphas2 is an extra head for timestamp prediction
  566. if not self.use_cif1_cnn:
  567. _output = context
  568. else:
  569. _output = output
  570. if self.upsample_type == 'cnn':
  571. output2 = self.upsample_cnn(_output)
  572. output2 = output2.transpose(1,2)
  573. elif self.upsample_type == 'cnn_blstm':
  574. output2 = self.upsample_cnn(_output)
  575. output2 = output2.transpose(1,2)
  576. output2, (_, _) = self.blstm(output2)
  577. elif self.upsample_type == 'cnn_attn':
  578. output2 = self.upsample_cnn(_output)
  579. output2 = output2.transpose(1,2)
  580. output2, _ = self.self_attn(output2, mask)
  581. alphas2 = torch.sigmoid(self.cif_output2(output2))
  582. alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
  583. # repeat the mask in T demension to match the upsampled length
  584. if mask is not None:
  585. mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
  586. mask2 = mask2.unsqueeze(-1)
  587. alphas2 = alphas2 * mask2
  588. alphas2 = alphas2.squeeze(-1)
  589. _token_num = alphas2.sum(-1)
  590. if token_num is not None:
  591. alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
  592. # re-downsample
  593. ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
  594. ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
  595. # upsampled alphas and cif_peak
  596. us_alphas = alphas2
  597. us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
  598. return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
  599. def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
  600. b, t, d = hidden.size()
  601. tail_threshold = self.tail_threshold
  602. if mask is not None:
  603. zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
  604. ones_t = torch.ones_like(zeros_t)
  605. mask_1 = torch.cat([mask, zeros_t], dim=1)
  606. mask_2 = torch.cat([ones_t, mask], dim=1)
  607. mask = mask_2 - mask_1
  608. tail_threshold = mask * tail_threshold
  609. alphas = torch.cat([alphas, zeros_t], dim=1)
  610. alphas = torch.add(alphas, tail_threshold)
  611. else:
  612. tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
  613. tail_threshold = torch.reshape(tail_threshold, (1, 1))
  614. alphas = torch.cat([alphas, tail_threshold], dim=1)
  615. zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
  616. hidden = torch.cat([hidden, zeros], dim=1)
  617. token_num = alphas.sum(dim=-1)
  618. token_num_floor = torch.floor(token_num)
  619. return hidden, alphas, token_num_floor
  620. def gen_frame_alignments(self,
  621. alphas: torch.Tensor = None,
  622. encoder_sequence_length: torch.Tensor = None):
  623. batch_size, maximum_length = alphas.size()
  624. int_type = torch.int32
  625. is_training = self.training
  626. if is_training:
  627. token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
  628. else:
  629. token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
  630. max_token_num = torch.max(token_num).item()
  631. alphas_cumsum = torch.cumsum(alphas, dim=1)
  632. alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
  633. alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
  634. index = torch.ones([batch_size, max_token_num], dtype=int_type)
  635. index = torch.cumsum(index, dim=1)
  636. index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
  637. index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
  638. index_div_bool_zeros = index_div.eq(0)
  639. index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
  640. index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
  641. token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
  642. index_div_bool_zeros_count *= token_num_mask
  643. index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
  644. ones = torch.ones_like(index_div_bool_zeros_count_tile)
  645. zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
  646. ones = torch.cumsum(ones, dim=2)
  647. cond = index_div_bool_zeros_count_tile == ones
  648. index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
  649. index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
  650. index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
  651. index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
  652. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
  653. predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
  654. int_type).to(encoder_sequence_length.device)
  655. index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
  656. predictor_alignments = index_div_bool_zeros_count_tile_out
  657. predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
  658. return predictor_alignments.detach(), predictor_alignments_length.detach()
  659. class BATPredictor(nn.Module):
  660. def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
  661. super(BATPredictor, self).__init__()
  662. self.pad = nn.ConstantPad1d((l_order, r_order), 0)
  663. self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
  664. self.cif_output = nn.Linear(idim, 1)
  665. self.dropout = torch.nn.Dropout(p=dropout)
  666. self.threshold = threshold
  667. self.smooth_factor = smooth_factor
  668. self.noise_threshold = noise_threshold
  669. self.return_accum = return_accum
  670. def cif(
  671. self,
  672. input: Tensor,
  673. alpha: Tensor,
  674. beta: float = 1.0,
  675. return_accum: bool = False,
  676. ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  677. B, S, C = input.size()
  678. assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
  679. dtype = alpha.dtype
  680. alpha = alpha.float()
  681. alpha_sum = alpha.sum(1)
  682. feat_lengths = (alpha_sum / beta).floor().long()
  683. T = feat_lengths.max()
  684. # aggregate and integrate
  685. csum = alpha.cumsum(-1)
  686. with torch.no_grad():
  687. # indices used for scattering
  688. right_idx = (csum / beta).floor().long().clip(max=T)
  689. left_idx = right_idx.roll(1, dims=1)
  690. left_idx[:, 0] = 0
  691. # count # of fires from each source
  692. fire_num = right_idx - left_idx
  693. extra_weights = (fire_num - 1).clip(min=0)
  694. # The extra entry in last dim is for
  695. output = input.new_zeros((B, T + 1, C))
  696. source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
  697. zero = alpha.new_zeros((1,))
  698. # right scatter
  699. fire_mask = fire_num > 0
  700. right_weight = torch.where(
  701. fire_mask,
  702. csum - right_idx.type_as(alpha) * beta,
  703. zero
  704. ).type_as(input)
  705. # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
  706. output.scatter_add_(
  707. 1,
  708. right_idx.unsqueeze(-1).expand(-1, -1, C),
  709. right_weight.unsqueeze(-1) * input
  710. )
  711. # left scatter
  712. left_weight = (
  713. alpha - right_weight - extra_weights.type_as(alpha) * beta
  714. ).type_as(input)
  715. output.scatter_add_(
  716. 1,
  717. left_idx.unsqueeze(-1).expand(-1, -1, C),
  718. left_weight.unsqueeze(-1) * input
  719. )
  720. # extra scatters
  721. if extra_weights.ge(0).any():
  722. extra_steps = extra_weights.max().item()
  723. tgt_idx = left_idx
  724. src_feats = input * beta
  725. for _ in range(extra_steps):
  726. tgt_idx = (tgt_idx + 1).clip(max=T)
  727. # (B, S, 1)
  728. src_mask = (extra_weights > 0)
  729. output.scatter_add_(
  730. 1,
  731. tgt_idx.unsqueeze(-1).expand(-1, -1, C),
  732. src_feats * src_mask.unsqueeze(2)
  733. )
  734. extra_weights -= 1
  735. output = output[:, :T, :]
  736. if return_accum:
  737. return output, csum
  738. else:
  739. return output, alpha
  740. def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
  741. h = hidden
  742. context = h.transpose(1, 2)
  743. queries = self.pad(context)
  744. memory = self.cif_conv1d(queries)
  745. output = memory + context
  746. output = self.dropout(output)
  747. output = output.transpose(1, 2)
  748. output = torch.relu(output)
  749. output = self.cif_output(output)
  750. alphas = torch.sigmoid(output)
  751. alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
  752. if mask is not None:
  753. alphas = alphas * mask.transpose(-1, -2).float()
  754. if mask_chunk_predictor is not None:
  755. alphas = alphas * mask_chunk_predictor
  756. alphas = alphas.squeeze(-1)
  757. if target_label_length is not None:
  758. target_length = target_label_length
  759. elif target_label is not None:
  760. target_length = (target_label != ignore_id).float().sum(-1)
  761. # logging.info("target_length: {}".format(target_length))
  762. else:
  763. target_length = None
  764. token_num = alphas.sum(-1)
  765. if target_length is not None:
  766. # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
  767. # target_length = length_noise + target_length
  768. alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
  769. acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
  770. return acoustic_embeds, token_num, alphas, cif_peak