load_utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os
  2. import torch
  3. import json
  4. import torch.distributed as dist
  5. import numpy as np
  6. import kaldiio
  7. import librosa
  8. import torchaudio
  9. import time
  10. import logging
  11. from torch.nn.utils.rnn import pad_sequence
  12. try:
  13. from funasr.download.file import download_from_url
  14. except:
  15. print("urllib is not installed, if you infer from url, please install it first.")
  16. import pdb
  17. def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
  18. if isinstance(data_or_path_or_list, (list, tuple)):
  19. if data_type is not None and isinstance(data_type, (list, tuple)):
  20. data_types = [data_type] * len(data_or_path_or_list)
  21. data_or_path_or_list_ret = [[] for d in data_type]
  22. for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
  23. for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
  24. data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
  25. data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
  26. return data_or_path_or_list_ret
  27. else:
  28. return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
  29. if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
  30. data_or_path_or_list = download_from_url(data_or_path_or_list)
  31. pdb.set_trace()
  32. if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
  33. if data_type is None or data_type == "sound":
  34. data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
  35. if kwargs.get("reduce_channels", True):
  36. data_or_path_or_list = data_or_path_or_list.mean(0)
  37. elif data_type == "text" and tokenizer is not None:
  38. data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
  39. elif data_type == "image": # undo
  40. pass
  41. elif data_type == "video": # undo
  42. pass
  43. # if data_in is a file or url, set is_final=True
  44. if "cache" in kwargs:
  45. kwargs["cache"]["is_final"] = True
  46. kwargs["cache"]["is_streaming_input"] = False
  47. elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
  48. data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
  49. elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
  50. data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
  51. elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
  52. data_mat = kaldiio.load_mat(data_or_path_or_list)
  53. if isinstance(data_mat, tuple):
  54. audio_fs, mat = data_mat
  55. else:
  56. mat = data_mat
  57. if mat.dtype == 'int16' or mat.dtype == 'int32':
  58. mat = mat.astype(np.float64)
  59. mat = mat / 32768
  60. if mat.ndim ==2:
  61. mat = mat[:,0]
  62. data_or_path_or_list = mat
  63. else:
  64. pass
  65. # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
  66. if audio_fs != fs and data_type != "text":
  67. resampler = torchaudio.transforms.Resample(audio_fs, fs)
  68. data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
  69. return data_or_path_or_list
  70. def load_bytes(input):
  71. middle_data = np.frombuffer(input, dtype=np.int16)
  72. middle_data = np.asarray(middle_data)
  73. if middle_data.dtype.kind not in 'iu':
  74. raise TypeError("'middle_data' must be an array of integers")
  75. dtype = np.dtype('float32')
  76. if dtype.kind != 'f':
  77. raise TypeError("'dtype' must be a floating point type")
  78. i = np.iinfo(middle_data.dtype)
  79. abs_max = 2 ** (i.bits - 1)
  80. offset = i.min + abs_max
  81. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  82. return array
  83. def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
  84. # import pdb;
  85. # pdb.set_trace()
  86. if isinstance(data, np.ndarray):
  87. data = torch.from_numpy(data)
  88. if len(data.shape) < 2:
  89. data = data[None, :] # data: [batch, N]
  90. data_len = [data.shape[1]] if data_len is None else data_len
  91. elif isinstance(data, torch.Tensor):
  92. if len(data.shape) < 2:
  93. data = data[None, :] # data: [batch, N]
  94. data_len = [data.shape[1]] if data_len is None else data_len
  95. elif isinstance(data, (list, tuple)):
  96. data_list, data_len = [], []
  97. for data_i in data:
  98. if isinstance(data_i, np.ndarray):
  99. data_i = torch.from_numpy(data_i)
  100. data_list.append(data_i)
  101. data_len.append(data_i.shape[0])
  102. data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
  103. data, data_len = frontend(data, data_len, **kwargs)
  104. if isinstance(data_len, (list, tuple)):
  105. data_len = torch.tensor([data_len])
  106. return data.to(torch.float32), data_len.to(torch.int32)