auto_frontend.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import json
  2. import time
  3. import torch
  4. import hydra
  5. import random
  6. import string
  7. import logging
  8. import os.path
  9. from tqdm import tqdm
  10. from omegaconf import DictConfig, OmegaConf, ListConfig
  11. from funasr.register import tables
  12. from funasr.utils.load_utils import load_bytes
  13. from funasr.download.file import download_from_url
  14. from funasr.download.download_from_hub import download_model
  15. from funasr.utils.vad_utils import slice_padding_audio_samples
  16. from funasr.train_utils.set_all_random_seed import set_all_random_seed
  17. from funasr.train_utils.load_pretrained_model import load_pretrained_model
  18. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  19. from funasr.utils.timestamp_tools import timestamp_sentence
  20. from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
  21. from funasr.models.campplus.cluster_backend import ClusterBackend
  22. from funasr.auto.auto_model import prepare_data_iterator
  23. class AutoFrontend:
  24. def __init__(self, **kwargs):
  25. assert "model" in kwargs
  26. if "model_conf" not in kwargs:
  27. logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
  28. kwargs = download_model(**kwargs)
  29. # build frontend
  30. frontend = kwargs.get("frontend", None)
  31. if frontend is not None:
  32. frontend_class = tables.frontend_classes.get(frontend)
  33. frontend = frontend_class(**kwargs["frontend_conf"])
  34. self.frontend = frontend
  35. if "frontend" in kwargs:
  36. del kwargs["frontend"]
  37. self.kwargs = kwargs
  38. def __call__(self, input, input_len=None, kwargs=None, **cfg):
  39. kwargs = self.kwargs if kwargs is None else kwargs
  40. kwargs.update(cfg)
  41. key_list, data_list = prepare_data_iterator(input, input_len=input_len)
  42. batch_size = kwargs.get("batch_size", 1)
  43. device = kwargs.get("device", "cpu")
  44. if device == "cpu":
  45. batch_size = 1
  46. meta_data = {}
  47. result_list = []
  48. num_samples = len(data_list)
  49. pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
  50. time0 = time.perf_counter()
  51. for beg_idx in range(0, num_samples, batch_size):
  52. end_idx = min(num_samples, beg_idx + batch_size)
  53. data_batch = data_list[beg_idx:end_idx]
  54. key_batch = key_list[beg_idx:end_idx]
  55. # extract fbank feats
  56. time1 = time.perf_counter()
  57. audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
  58. time2 = time.perf_counter()
  59. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  60. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  61. frontend=self.frontend, **kwargs)
  62. time3 = time.perf_counter()
  63. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  64. meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
  65. speech.to(device=device), speech_lengths.to(device=device)
  66. batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
  67. result_list.append(batch)
  68. pbar.update(1)
  69. description = (
  70. f"{meta_data}, "
  71. )
  72. pbar.set_description(description)
  73. time_end = time.perf_counter()
  74. pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
  75. return result_list