vad_infer.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. import json
  10. from pathlib import Path
  11. from typing import Any
  12. from typing import List
  13. from typing import Optional
  14. from typing import Sequence
  15. from typing import Tuple
  16. from typing import Union
  17. from typing import Dict
  18. import math
  19. import numpy as np
  20. import torch
  21. from typeguard import check_argument_types
  22. from typeguard import check_return_type
  23. from funasr.fileio.datadir_writer import DatadirWriter
  24. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  25. from funasr.modules.subsampling import TooShortUttError
  26. from funasr.tasks.vad import VADTask
  27. from funasr.torch_utils.device_funcs import to_device
  28. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  29. from funasr.utils import config_argparse
  30. from funasr.utils.cli_utils import get_commandline_args
  31. from funasr.utils.types import str2bool
  32. from funasr.utils.types import str2triple_str
  33. from funasr.utils.types import str_or_none
  34. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  35. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  36. class Speech2VadSegment:
  37. """Speech2VadSegment class
  38. Examples:
  39. >>> import soundfile
  40. >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
  41. >>> audio, rate = soundfile.read("speech.wav")
  42. >>> speech2segment(audio)
  43. [[10, 230], [245, 450], ...]
  44. """
  45. def __init__(
  46. self,
  47. vad_infer_config: Union[Path, str] = None,
  48. vad_model_file: Union[Path, str] = None,
  49. vad_cmvn_file: Union[Path, str] = None,
  50. device: str = "cpu",
  51. batch_size: int = 1,
  52. dtype: str = "float32",
  53. **kwargs,
  54. ):
  55. assert check_argument_types()
  56. # 1. Build vad model
  57. vad_model, vad_infer_args = VADTask.build_model_from_file(
  58. vad_infer_config, vad_model_file, device
  59. )
  60. frontend = None
  61. if vad_infer_args.frontend is not None:
  62. frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
  63. logging.info("vad_model: {}".format(vad_model))
  64. logging.info("vad_infer_args: {}".format(vad_infer_args))
  65. vad_model.to(dtype=getattr(torch, dtype)).eval()
  66. self.vad_model = vad_model
  67. self.vad_infer_args = vad_infer_args
  68. self.device = device
  69. self.dtype = dtype
  70. self.frontend = frontend
  71. self.batch_size = batch_size
  72. @torch.no_grad()
  73. def __call__(
  74. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
  75. in_cache: Dict[str, torch.Tensor] = dict()
  76. ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
  77. """Inference
  78. Args:
  79. speech: Input speech data
  80. Returns:
  81. text, token, token_int, hyp
  82. """
  83. assert check_argument_types()
  84. # Input as audio signal
  85. if isinstance(speech, np.ndarray):
  86. speech = torch.tensor(speech)
  87. if self.frontend is not None:
  88. self.frontend.filter_length_max = math.inf
  89. fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
  90. feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
  91. fbanks = to_device(fbanks, device=self.device)
  92. feats = to_device(feats, device=self.device)
  93. feats_len = feats_len.int()
  94. else:
  95. raise Exception("Need to extract feats first, please configure frontend configuration")
  96. # b. Forward Encoder streaming
  97. t_offset = 0
  98. step = min(feats_len.max(), 6000)
  99. segments = [[]] * self.batch_size
  100. for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
  101. if t_offset + step >= feats_len - 1:
  102. step = feats_len - t_offset
  103. is_final = True
  104. else:
  105. is_final = False
  106. batch = {
  107. "feats": feats[:, t_offset:t_offset + step, :],
  108. "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
  109. "is_final": is_final,
  110. "in_cache": in_cache
  111. }
  112. # a. To device
  113. #batch = to_device(batch, device=self.device)
  114. segments_part, in_cache = self.vad_model(**batch)
  115. if segments_part:
  116. for batch_num in range(0, self.batch_size):
  117. segments[batch_num] += segments_part[batch_num]
  118. return fbanks, segments
  119. class Speech2VadSegmentOnline(Speech2VadSegment):
  120. """Speech2VadSegmentOnline class
  121. Examples:
  122. >>> import soundfile
  123. >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
  124. >>> audio, rate = soundfile.read("speech.wav")
  125. >>> speech2segment(audio)
  126. [[10, 230], [245, 450], ...]
  127. """
  128. def __init__(self, **kwargs):
  129. super(Speech2VadSegmentOnline, self).__init__(**kwargs)
  130. vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
  131. self.frontend = None
  132. if self.vad_infer_args.frontend is not None:
  133. self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
  134. @torch.no_grad()
  135. def __call__(
  136. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
  137. in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
  138. ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
  139. """Inference
  140. Args:
  141. speech: Input speech data
  142. Returns:
  143. text, token, token_int, hyp
  144. """
  145. assert check_argument_types()
  146. # Input as audio signal
  147. if isinstance(speech, np.ndarray):
  148. speech = torch.tensor(speech)
  149. batch_size = speech.shape[0]
  150. segments = [[]] * batch_size
  151. if self.frontend is not None:
  152. reset = in_cache == dict()
  153. feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final, reset)
  154. fbanks, _ = self.frontend.get_fbank()
  155. else:
  156. raise Exception("Need to extract feats first, please configure frontend configuration")
  157. if feats.shape[0]:
  158. feats = to_device(feats, device=self.device)
  159. feats_len = feats_len.int()
  160. waveforms = self.frontend.get_waveforms()
  161. batch = {
  162. "feats": feats,
  163. "waveform": waveforms,
  164. "in_cache": in_cache,
  165. "is_final": is_final,
  166. "max_end_sil": max_end_sil
  167. }
  168. # a. To device
  169. batch = to_device(batch, device=self.device)
  170. segments, in_cache = self.vad_model.forward_online(**batch)
  171. # in_cache.update(batch['in_cache'])
  172. # in_cache = {key: value for key, value in batch['in_cache'].items()}
  173. return fbanks, segments, in_cache