asr_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import struct
  4. from typing import Any, Dict, List, Union
  5. import torchaudio
  6. import numpy as np
  7. import pkg_resources
  8. from modelscope.utils.logger import get_logger
  9. logger = get_logger()
  10. green_color = '\033[1;32m'
  11. red_color = '\033[0;31;40m'
  12. yellow_color = '\033[0;33;40m'
  13. end_color = '\033[0m'
  14. global_asr_language = 'zh-cn'
  15. SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
  16. def get_version():
  17. return float(pkg_resources.get_distribution('easyasr').version)
  18. def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
  19. r_audio_fs = None
  20. if audio_format == 'wav':
  21. r_audio_fs = get_sr_from_wav(audio_in)
  22. elif audio_format == 'pcm' and isinstance(audio_in, bytes):
  23. r_audio_fs = get_sr_from_bytes(audio_in)
  24. return r_audio_fs
  25. def type_checking(audio_in: Union[str, bytes],
  26. audio_fs: int = None,
  27. recog_type: str = None,
  28. audio_format: str = None):
  29. r_recog_type = recog_type
  30. r_audio_format = audio_format
  31. r_wav_path = audio_in
  32. if isinstance(audio_in, str):
  33. assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
  34. elif isinstance(audio_in, bytes):
  35. assert len(audio_in) > 0, 'audio in is empty'
  36. r_audio_format = 'pcm'
  37. r_recog_type = 'wav'
  38. if audio_in is None:
  39. # for raw_inputs
  40. r_recog_type = 'wav'
  41. r_audio_format = 'pcm'
  42. if r_recog_type is None and audio_in is not None:
  43. # audio_in is wav, recog_type is wav_file
  44. if os.path.isfile(audio_in):
  45. audio_type = os.path.basename(audio_in).split(".")[-1].lower()
  46. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  47. r_recog_type = 'wav'
  48. r_audio_format = 'wav'
  49. elif audio_type == "scp":
  50. r_recog_type = 'wav'
  51. r_audio_format = 'scp'
  52. else:
  53. raise NotImplementedError(
  54. f'Not supported audio type: {audio_type}')
  55. # recog_type is datasets_file
  56. elif os.path.isdir(audio_in):
  57. dir_name = os.path.basename(audio_in)
  58. if 'test' in dir_name:
  59. r_recog_type = 'test'
  60. elif 'dev' in dir_name:
  61. r_recog_type = 'dev'
  62. elif 'train' in dir_name:
  63. r_recog_type = 'train'
  64. if r_audio_format is None:
  65. if find_file_by_ends(audio_in, '.ark'):
  66. r_audio_format = 'kaldi_ark'
  67. elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
  68. audio_in, '.WAV'):
  69. r_audio_format = 'wav'
  70. elif find_file_by_ends(audio_in, '.records'):
  71. r_audio_format = 'tfrecord'
  72. if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
  73. # datasets with kaldi_ark file
  74. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  75. elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
  76. # datasets with tensorflow records file
  77. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  78. elif r_audio_format == 'wav' and r_recog_type != 'wav':
  79. # datasets with waveform files
  80. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
  81. return r_recog_type, r_audio_format, r_wav_path
  82. def get_sr_from_bytes(wav: bytes):
  83. sr = None
  84. data = wav
  85. if len(data) > 44:
  86. try:
  87. header_fields = {}
  88. header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
  89. header_fields['Format'] = str(data[8:12], 'UTF-8')
  90. header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
  91. if header_fields['ChunkID'] == 'RIFF' and header_fields[
  92. 'Format'] == 'WAVE' and header_fields[
  93. 'Subchunk1ID'] == 'fmt ':
  94. header_fields['SampleRate'] = struct.unpack('<I',
  95. data[24:28])[0]
  96. sr = header_fields['SampleRate']
  97. except Exception:
  98. # no treatment
  99. pass
  100. else:
  101. logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
  102. return sr
  103. def get_sr_from_wav(fname: str):
  104. fs = None
  105. if os.path.isfile(fname):
  106. audio_type = os.path.basename(fname).split(".")[-1].lower()
  107. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  108. if audio_type == "pcm":
  109. fs = None
  110. else:
  111. audio, fs = torchaudio.load(fname)
  112. return fs
  113. elif os.path.isdir(fname):
  114. dir_files = os.listdir(fname)
  115. for file in dir_files:
  116. file_path = os.path.join(fname, file)
  117. if os.path.isfile(file_path):
  118. audio_type = os.path.basename(file_path).split(".")[-1].lower()
  119. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  120. fs = get_sr_from_wav(file_path)
  121. elif os.path.isdir(file_path):
  122. fs = get_sr_from_wav(file_path)
  123. if fs is not None:
  124. break
  125. return fs
  126. def find_file_by_ends(dir_path: str, ends: str):
  127. dir_files = os.listdir(dir_path)
  128. for file in dir_files:
  129. file_path = os.path.join(dir_path, file)
  130. if os.path.isfile(file_path):
  131. if ends == ".wav" or ends == ".WAV":
  132. audio_type = os.path.basename(file_path).split(".")[-1].lower()
  133. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  134. return True
  135. else:
  136. raise NotImplementedError(
  137. f'Not supported audio type: {audio_type}')
  138. elif file_path.endswith(ends):
  139. return True
  140. elif os.path.isdir(file_path):
  141. if find_file_by_ends(file_path, ends):
  142. return True
  143. return False
  144. def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
  145. dir_files = os.listdir(dir_path)
  146. for file in dir_files:
  147. file_path = os.path.join(dir_path, file)
  148. if os.path.isfile(file_path):
  149. audio_type = os.path.basename(file_path).split(".")[-1].lower()
  150. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  151. wav_list.append(file_path)
  152. elif os.path.isdir(file_path):
  153. recursion_dir_all_wav(wav_list, file_path)
  154. return wav_list
  155. def compute_wer(hyp_list: List[Any],
  156. ref_list: List[Any],
  157. lang: str = None) -> Dict[str, Any]:
  158. assert len(hyp_list) > 0, 'hyp list is empty'
  159. assert len(ref_list) > 0, 'ref list is empty'
  160. rst = {
  161. 'Wrd': 0,
  162. 'Corr': 0,
  163. 'Ins': 0,
  164. 'Del': 0,
  165. 'Sub': 0,
  166. 'Snt': 0,
  167. 'Err': 0.0,
  168. 'S.Err': 0.0,
  169. 'wrong_words': 0,
  170. 'wrong_sentences': 0
  171. }
  172. if lang is None:
  173. lang = global_asr_language
  174. for h_item in hyp_list:
  175. for r_item in ref_list:
  176. if h_item['key'] == r_item['key']:
  177. out_item = compute_wer_by_line(h_item['value'],
  178. r_item['value'],
  179. lang)
  180. rst['Wrd'] += out_item['nwords']
  181. rst['Corr'] += out_item['cor']
  182. rst['wrong_words'] += out_item['wrong']
  183. rst['Ins'] += out_item['ins']
  184. rst['Del'] += out_item['del']
  185. rst['Sub'] += out_item['sub']
  186. rst['Snt'] += 1
  187. if out_item['wrong'] > 0:
  188. rst['wrong_sentences'] += 1
  189. print_wrong_sentence(key=h_item['key'],
  190. hyp=h_item['value'],
  191. ref=r_item['value'])
  192. else:
  193. print_correct_sentence(key=h_item['key'],
  194. hyp=h_item['value'],
  195. ref=r_item['value'])
  196. break
  197. if rst['Wrd'] > 0:
  198. rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
  199. if rst['Snt'] > 0:
  200. rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
  201. return rst
  202. def compute_wer_by_line(hyp: List[str],
  203. ref: List[str],
  204. lang: str = 'zh-cn') -> Dict[str, Any]:
  205. if lang != 'zh-cn':
  206. hyp = hyp.split()
  207. ref = ref.split()
  208. hyp = list(map(lambda x: x.lower(), hyp))
  209. ref = list(map(lambda x: x.lower(), ref))
  210. len_hyp = len(hyp)
  211. len_ref = len(ref)
  212. cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
  213. ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
  214. for i in range(len_hyp + 1):
  215. cost_matrix[i][0] = i
  216. for j in range(len_ref + 1):
  217. cost_matrix[0][j] = j
  218. for i in range(1, len_hyp + 1):
  219. for j in range(1, len_ref + 1):
  220. if hyp[i - 1] == ref[j - 1]:
  221. cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
  222. else:
  223. substitution = cost_matrix[i - 1][j - 1] + 1
  224. insertion = cost_matrix[i - 1][j] + 1
  225. deletion = cost_matrix[i][j - 1] + 1
  226. compare_val = [substitution, insertion, deletion]
  227. min_val = min(compare_val)
  228. operation_idx = compare_val.index(min_val) + 1
  229. cost_matrix[i][j] = min_val
  230. ops_matrix[i][j] = operation_idx
  231. match_idx = []
  232. i = len_hyp
  233. j = len_ref
  234. rst = {
  235. 'nwords': len_ref,
  236. 'cor': 0,
  237. 'wrong': 0,
  238. 'ins': 0,
  239. 'del': 0,
  240. 'sub': 0
  241. }
  242. while i >= 0 or j >= 0:
  243. i_idx = max(0, i)
  244. j_idx = max(0, j)
  245. if ops_matrix[i_idx][j_idx] == 0: # correct
  246. if i - 1 >= 0 and j - 1 >= 0:
  247. match_idx.append((j - 1, i - 1))
  248. rst['cor'] += 1
  249. i -= 1
  250. j -= 1
  251. elif ops_matrix[i_idx][j_idx] == 2: # insert
  252. i -= 1
  253. rst['ins'] += 1
  254. elif ops_matrix[i_idx][j_idx] == 3: # delete
  255. j -= 1
  256. rst['del'] += 1
  257. elif ops_matrix[i_idx][j_idx] == 1: # substitute
  258. i -= 1
  259. j -= 1
  260. rst['sub'] += 1
  261. if i < 0 and j >= 0:
  262. rst['del'] += 1
  263. elif j < 0 and i >= 0:
  264. rst['ins'] += 1
  265. match_idx.reverse()
  266. wrong_cnt = cost_matrix[len_hyp][len_ref]
  267. rst['wrong'] = wrong_cnt
  268. return rst
  269. def print_wrong_sentence(key: str, hyp: str, ref: str):
  270. space = len(key)
  271. print(key + yellow_color + ' ref: ' + ref)
  272. print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
  273. def print_correct_sentence(key: str, hyp: str, ref: str):
  274. space = len(key)
  275. print(key + yellow_color + ' ref: ' + ref)
  276. print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
  277. def print_progress(percent):
  278. if percent > 1:
  279. percent = 1
  280. res = int(50 * percent) * '#'
  281. print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')