auto_frontend.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import json
  6. import time
  7. import torch
  8. import hydra
  9. import random
  10. import string
  11. import logging
  12. import os.path
  13. from tqdm import tqdm
  14. from omegaconf import DictConfig, OmegaConf, ListConfig
  15. from funasr.register import tables
  16. from funasr.utils.load_utils import load_bytes
  17. from funasr.download.file import download_from_url
  18. from funasr.auto.auto_model import prepare_data_iterator
  19. from funasr.utils.timestamp_tools import timestamp_sentence
  20. from funasr.download.download_from_hub import download_model
  21. from funasr.utils.vad_utils import slice_padding_audio_samples
  22. from funasr.train_utils.set_all_random_seed import set_all_random_seed
  23. from funasr.train_utils.load_pretrained_model import load_pretrained_model
  24. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  25. from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
  26. class AutoFrontend:
  27. def __init__(self, **kwargs):
  28. assert "model" in kwargs
  29. if "model_conf" not in kwargs:
  30. logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
  31. kwargs = download_model(**kwargs)
  32. # build frontend
  33. frontend = kwargs.get("frontend", None)
  34. if frontend is not None:
  35. frontend_class = tables.frontend_classes.get(frontend)
  36. frontend = frontend_class(**kwargs["frontend_conf"])
  37. self.frontend = frontend
  38. if "frontend" in kwargs:
  39. del kwargs["frontend"]
  40. self.kwargs = kwargs
  41. def __call__(self, input, input_len=None, kwargs=None, **cfg):
  42. kwargs = self.kwargs if kwargs is None else kwargs
  43. kwargs.update(cfg)
  44. key_list, data_list = prepare_data_iterator(input, input_len=input_len)
  45. batch_size = kwargs.get("batch_size", 1)
  46. device = kwargs.get("device", "cpu")
  47. if device == "cpu":
  48. batch_size = 1
  49. meta_data = {}
  50. result_list = []
  51. num_samples = len(data_list)
  52. pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
  53. time0 = time.perf_counter()
  54. for beg_idx in range(0, num_samples, batch_size):
  55. end_idx = min(num_samples, beg_idx + batch_size)
  56. data_batch = data_list[beg_idx:end_idx]
  57. key_batch = key_list[beg_idx:end_idx]
  58. # extract fbank feats
  59. time1 = time.perf_counter()
  60. audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
  61. time2 = time.perf_counter()
  62. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  63. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  64. frontend=self.frontend, **kwargs)
  65. time3 = time.perf_counter()
  66. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  67. meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
  68. speech.to(device=device), speech_lengths.to(device=device)
  69. batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
  70. result_list.append(batch)
  71. pbar.update(1)
  72. description = (
  73. f"{meta_data}, "
  74. )
  75. pbar.set_description(description)
  76. time_end = time.perf_counter()
  77. pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
  78. return result_list