auto_frontend.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import json
  2. import time
  3. import torch
  4. import hydra
  5. import random
  6. import string
  7. import logging
  8. import os.path
  9. from tqdm import tqdm
  10. from omegaconf import DictConfig, OmegaConf, ListConfig
  11. from funasr.register import tables
  12. from funasr.utils.load_utils import load_bytes
  13. from funasr.download.file import download_from_url
  14. from funasr.download.download_from_hub import download_model
  15. from funasr.utils.vad_utils import slice_padding_audio_samples
  16. from funasr.train_utils.set_all_random_seed import set_all_random_seed
  17. from funasr.train_utils.load_pretrained_model import load_pretrained_model
  18. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  19. from funasr.utils.timestamp_tools import timestamp_sentence
  20. from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
  21. from funasr.auto.auto_model import prepare_data_iterator
  22. class AutoFrontend:
  23. def __init__(self, **kwargs):
  24. assert "model" in kwargs
  25. if "model_conf" not in kwargs:
  26. logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
  27. kwargs = download_model(**kwargs)
  28. # build frontend
  29. frontend = kwargs.get("frontend", None)
  30. if frontend is not None:
  31. frontend_class = tables.frontend_classes.get(frontend)
  32. frontend = frontend_class(**kwargs["frontend_conf"])
  33. self.frontend = frontend
  34. if "frontend" in kwargs:
  35. del kwargs["frontend"]
  36. self.kwargs = kwargs
  37. def __call__(self, input, input_len=None, kwargs=None, **cfg):
  38. kwargs = self.kwargs if kwargs is None else kwargs
  39. kwargs.update(cfg)
  40. key_list, data_list = prepare_data_iterator(input, input_len=input_len)
  41. batch_size = kwargs.get("batch_size", 1)
  42. device = kwargs.get("device", "cpu")
  43. if device == "cpu":
  44. batch_size = 1
  45. meta_data = {}
  46. result_list = []
  47. num_samples = len(data_list)
  48. pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
  49. time0 = time.perf_counter()
  50. for beg_idx in range(0, num_samples, batch_size):
  51. end_idx = min(num_samples, beg_idx + batch_size)
  52. data_batch = data_list[beg_idx:end_idx]
  53. key_batch = key_list[beg_idx:end_idx]
  54. # extract fbank feats
  55. time1 = time.perf_counter()
  56. audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
  57. time2 = time.perf_counter()
  58. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  59. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  60. frontend=self.frontend, **kwargs)
  61. time3 = time.perf_counter()
  62. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  63. meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
  64. speech.to(device=device), speech_lengths.to(device=device)
  65. batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
  66. result_list.append(batch)
  67. pbar.update(1)
  68. description = (
  69. f"{meta_data}, "
  70. )
  71. pbar.set_description(description)
  72. time_end = time.perf_counter()
  73. pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
  74. return result_list