| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- #!/usr/bin/env python3
- # -*- encoding: utf-8 -*-
- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- # MIT License (https://opensource.org/licenses/MIT)
- # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
- import time
- import torch
- import numpy as np
- from collections import OrderedDict
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- from funasr.register import tables
- from funasr.models.campplus.utils import extract_feature
- from funasr.utils.load_utils import load_audio_text_image_video
- from funasr.models.campplus.components import DenseLayer, StatsPool, \
- TDNNLayer, CAMDenseTDNNBlock, TransitLayer, get_nonlinear, FCM
- if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
- else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
- @tables.register("model_classes", "CAMPPlus")
- class CAMPPlus(torch.nn.Module):
- def __init__(self,
- feat_dim=80,
- embedding_size=192,
- growth_rate=32,
- bn_size=4,
- init_channels=128,
- config_str='batchnorm-relu',
- memory_efficient=True,
- output_level='segment',
- **kwargs,):
- super().__init__()
- self.head = FCM(feat_dim=feat_dim)
- channels = self.head.out_channels
- self.output_level = output_level
- self.xvector = torch.nn.Sequential(
- OrderedDict([
- ('tdnn',
- TDNNLayer(channels,
- init_channels,
- 5,
- stride=2,
- dilation=1,
- padding=-1,
- config_str=config_str)),
- ]))
- channels = init_channels
- for i, (num_layers, kernel_size,
- dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
- block = CAMDenseTDNNBlock(num_layers=num_layers,
- in_channels=channels,
- out_channels=growth_rate,
- bn_channels=bn_size * growth_rate,
- kernel_size=kernel_size,
- dilation=dilation,
- config_str=config_str,
- memory_efficient=memory_efficient)
- self.xvector.add_module('block%d' % (i + 1), block)
- channels = channels + num_layers * growth_rate
- self.xvector.add_module(
- 'transit%d' % (i + 1),
- TransitLayer(channels,
- channels // 2,
- bias=False,
- config_str=config_str))
- channels //= 2
- self.xvector.add_module(
- 'out_nonlinear', get_nonlinear(config_str, channels))
- if self.output_level == 'segment':
- self.xvector.add_module('stats', StatsPool())
- self.xvector.add_module(
- 'dense',
- DenseLayer(
- channels * 2, embedding_size, config_str='batchnorm_'))
- else:
- assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
- for m in self.modules():
- if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
- torch.nn.init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- torch.nn.init.zeros_(m.bias)
- def forward(self, x):
- x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
- x = self.head(x)
- x = self.xvector(x)
- if self.output_level == 'frame':
- x = x.transpose(1, 2)
- return x
- def inference(self,
- data_in,
- data_lengths=None,
- key: list=None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- # extract fbank feats
- meta_data = {}
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
- speech = speech.to(device=kwargs["device"])
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
- results = [{"spk_embedding": self.forward(speech.to(torch.float32))}]
- return results, meta_data
|