| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- # -*- encoding: utf-8 -*-
- #!/usr/bin/env python3
- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- # MIT License (https://opensource.org/licenses/MIT)
- import argparse
- import logging
- import os
- import sys
- import json
- from pathlib import Path
- from typing import Any
- from typing import List
- from typing import Optional
- from typing import Sequence
- from typing import Tuple
- from typing import Union
- from typing import Dict
- import math
- import numpy as np
- import torch
- from typeguard import check_argument_types
- from typeguard import check_return_type
- from funasr.fileio.datadir_writer import DatadirWriter
- from funasr.modules.scorers.scorer_interface import BatchScorerInterface
- from funasr.modules.subsampling import TooShortUttError
- from funasr.tasks.vad import VADTask
- from funasr.torch_utils.device_funcs import to_device
- from funasr.torch_utils.set_all_random_seed import set_all_random_seed
- from funasr.utils import config_argparse
- from funasr.utils.cli_utils import get_commandline_args
- from funasr.utils.types import str2bool
- from funasr.utils.types import str2triple_str
- from funasr.utils.types import str_or_none
- from funasr.utils import asr_utils, wav_utils, postprocess_utils
- from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
- class Speech2VadSegment:
- """Speech2VadSegment class
- Examples:
- >>> import soundfile
- >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2segment(audio)
- [[10, 230], [245, 450], ...]
- """
- def __init__(
- self,
- vad_infer_config: Union[Path, str] = None,
- vad_model_file: Union[Path, str] = None,
- vad_cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- **kwargs,
- ):
- assert check_argument_types()
- # 1. Build vad model
- vad_model, vad_infer_args = VADTask.build_model_from_file(
- vad_infer_config, vad_model_file, device
- )
- frontend = None
- if vad_infer_args.frontend is not None:
- frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
- logging.info("vad_model: {}".format(vad_model))
- logging.info("vad_infer_args: {}".format(vad_infer_args))
- vad_model.to(dtype=getattr(torch, dtype)).eval()
- self.vad_model = vad_model
- self.vad_infer_args = vad_infer_args
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
- self.batch_size = batch_size
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict()
- ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
- """Inference
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
- """
- assert check_argument_types()
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if self.frontend is not None:
- self.frontend.filter_length_max = math.inf
- fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
- feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
- fbanks = to_device(fbanks, device=self.device)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- else:
- raise Exception("Need to extract feats first, please configure frontend configuration")
- # b. Forward Encoder streaming
- t_offset = 0
- step = min(feats_len.max(), 6000)
- segments = [[]] * self.batch_size
- for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
- if t_offset + step >= feats_len - 1:
- step = feats_len - t_offset
- is_final = True
- else:
- is_final = False
- batch = {
- "feats": feats[:, t_offset:t_offset + step, :],
- "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
- "is_final": is_final,
- "in_cache": in_cache
- }
- # a. To device
- #batch = to_device(batch, device=self.device)
- segments_part, in_cache = self.vad_model(**batch)
- if segments_part:
- for batch_num in range(0, self.batch_size):
- segments[batch_num] += segments_part[batch_num]
- return fbanks, segments
- class Speech2VadSegmentOnline(Speech2VadSegment):
- """Speech2VadSegmentOnline class
- Examples:
- >>> import soundfile
- >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2segment(audio)
- [[10, 230], [245, 450], ...]
- """
- def __init__(self, **kwargs):
- super(Speech2VadSegmentOnline, self).__init__(**kwargs)
- vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
- self.frontend = None
- if self.vad_infer_args.frontend is not None:
- self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
- """Inference
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
- """
- assert check_argument_types()
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- batch_size = speech.shape[0]
- segments = [[]] * batch_size
- if self.frontend is not None:
- reset = in_cache == dict()
- feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final, reset)
- fbanks, _ = self.frontend.get_fbank()
- else:
- raise Exception("Need to extract feats first, please configure frontend configuration")
- if feats.shape[0]:
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- waveforms = self.frontend.get_waveforms()
- batch = {
- "feats": feats,
- "waveform": waveforms,
- "in_cache": in_cache,
- "is_final": is_final,
- "max_end_sil": max_end_sil
- }
- # a. To device
- batch = to_device(batch, device=self.device)
- segments, in_cache = self.vad_model.forward_online(**batch)
- # in_cache.update(batch['in_cache'])
- # in_cache = {key: value for key, value in batch['in_cache'].items()}
- return fbanks, segments, in_cache
|