power.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import numpy as np
  2. import torch
  3. import torch.multiprocessing
  4. import torch.nn.functional as F
  5. from itertools import combinations
  6. from itertools import permutations
  7. def generate_mapping_dict(max_speaker_num=6, max_olp_speaker_num=3):
  8. all_kinds = []
  9. all_kinds.append(0)
  10. for i in range(max_olp_speaker_num):
  11. selected_num = i + 1
  12. coms = np.array(list(combinations(np.arange(max_speaker_num), selected_num)))
  13. for com in coms:
  14. tmp = np.zeros(max_speaker_num)
  15. tmp[com] = 1
  16. item = int(raw_dec_trans(tmp.reshape(1, -1), max_speaker_num)[0])
  17. all_kinds.append(item)
  18. all_kinds_order = sorted(all_kinds)
  19. mapping_dict = {}
  20. mapping_dict['dec2label'] = {}
  21. mapping_dict['label2dec'] = {}
  22. for i in range(len(all_kinds_order)):
  23. dec = all_kinds_order[i]
  24. mapping_dict['dec2label'][dec] = i
  25. mapping_dict['label2dec'][i] = dec
  26. oov_id = len(all_kinds_order)
  27. mapping_dict['oov'] = oov_id
  28. return mapping_dict
  29. def raw_dec_trans(x, max_speaker_num):
  30. num_list = []
  31. for i in range(max_speaker_num):
  32. num_list.append(x[:, i])
  33. base = 1
  34. T = x.shape[0]
  35. res = np.zeros((T))
  36. for num in num_list:
  37. res += num * base
  38. base = base * 2
  39. return res
  40. def mapping_func(num, mapping_dict):
  41. if num in mapping_dict['dec2label'].keys():
  42. label = mapping_dict['dec2label'][num]
  43. else:
  44. label = mapping_dict['oov']
  45. return label
  46. def dec_trans(x, max_speaker_num, mapping_dict):
  47. num_list = []
  48. for i in range(max_speaker_num):
  49. num_list.append(x[:, i])
  50. base = 1
  51. T = x.shape[0]
  52. res = np.zeros((T))
  53. for num in num_list:
  54. res += num * base
  55. base = base * 2
  56. res = np.array([mapping_func(i, mapping_dict) for i in res])
  57. return res
  58. def create_powerlabel(label, mapping_dict, max_speaker_num=6, max_olp_speaker_num=3):
  59. T, C = label.shape
  60. padding_label = np.zeros((T, max_speaker_num))
  61. padding_label[:, :C] = label
  62. out_label = dec_trans(padding_label, max_speaker_num, mapping_dict)
  63. out_label = torch.from_numpy(out_label)
  64. return out_label
  65. def generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=3):
  66. perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
  67. perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
  68. perm_labels = [label[:, perm] for perm in perms]
  69. perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).
  70. to(perm_label.device, non_blocking=True) for perm_label in perm_labels]
  71. return perm_labels, perm_pse_labels
  72. def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3):
  73. perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num,
  74. max_olp_speaker_num=max_olp_speaker_num)
  75. losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
  76. for perm_pse_label in perm_pse_labels]
  77. loss = torch.stack(losses)
  78. min_index = torch.argmin(loss)
  79. selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]
  80. return selected_perm_label, selected_pse_label