model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
  6. import time
  7. import torch
  8. import numpy as np
  9. from collections import OrderedDict
  10. from contextlib import contextmanager
  11. from distutils.version import LooseVersion
  12. from funasr.register import tables
  13. from funasr.models.campplus.utils import extract_feature
  14. from funasr.utils.load_utils import load_audio_text_image_video
  15. from funasr.models.campplus.components import DenseLayer, StatsPool, \
  16. TDNNLayer, CAMDenseTDNNBlock, TransitLayer, get_nonlinear, FCM
  17. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  18. from torch.cuda.amp import autocast
  19. else:
  20. # Nothing to do if torch<1.6.0
  21. @contextmanager
  22. def autocast(enabled=True):
  23. yield
  24. @tables.register("model_classes", "CAMPPlus")
  25. class CAMPPlus(torch.nn.Module):
  26. def __init__(self,
  27. feat_dim=80,
  28. embedding_size=192,
  29. growth_rate=32,
  30. bn_size=4,
  31. init_channels=128,
  32. config_str='batchnorm-relu',
  33. memory_efficient=True,
  34. output_level='segment',
  35. **kwargs,):
  36. super().__init__()
  37. self.head = FCM(feat_dim=feat_dim)
  38. channels = self.head.out_channels
  39. self.output_level = output_level
  40. self.xvector = torch.nn.Sequential(
  41. OrderedDict([
  42. ('tdnn',
  43. TDNNLayer(channels,
  44. init_channels,
  45. 5,
  46. stride=2,
  47. dilation=1,
  48. padding=-1,
  49. config_str=config_str)),
  50. ]))
  51. channels = init_channels
  52. for i, (num_layers, kernel_size,
  53. dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
  54. block = CAMDenseTDNNBlock(num_layers=num_layers,
  55. in_channels=channels,
  56. out_channels=growth_rate,
  57. bn_channels=bn_size * growth_rate,
  58. kernel_size=kernel_size,
  59. dilation=dilation,
  60. config_str=config_str,
  61. memory_efficient=memory_efficient)
  62. self.xvector.add_module('block%d' % (i + 1), block)
  63. channels = channels + num_layers * growth_rate
  64. self.xvector.add_module(
  65. 'transit%d' % (i + 1),
  66. TransitLayer(channels,
  67. channels // 2,
  68. bias=False,
  69. config_str=config_str))
  70. channels //= 2
  71. self.xvector.add_module(
  72. 'out_nonlinear', get_nonlinear(config_str, channels))
  73. if self.output_level == 'segment':
  74. self.xvector.add_module('stats', StatsPool())
  75. self.xvector.add_module(
  76. 'dense',
  77. DenseLayer(
  78. channels * 2, embedding_size, config_str='batchnorm_'))
  79. else:
  80. assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
  81. for m in self.modules():
  82. if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
  83. torch.nn.init.kaiming_normal_(m.weight.data)
  84. if m.bias is not None:
  85. torch.nn.init.zeros_(m.bias)
  86. def forward(self, x):
  87. x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
  88. x = self.head(x)
  89. x = self.xvector(x)
  90. if self.output_level == 'frame':
  91. x = x.transpose(1, 2)
  92. return x
  93. def inference(self,
  94. data_in,
  95. data_lengths=None,
  96. key: list=None,
  97. tokenizer=None,
  98. frontend=None,
  99. **kwargs,
  100. ):
  101. # extract fbank feats
  102. meta_data = {}
  103. time1 = time.perf_counter()
  104. audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
  105. time2 = time.perf_counter()
  106. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  107. speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
  108. speech = speech.to(device=kwargs["device"])
  109. time3 = time.perf_counter()
  110. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  111. meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
  112. results = [{"spk_embedding": self.forward(speech.to(torch.float32))}]
  113. return results, meta_data