asr_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import struct
  4. from typing import Any, Dict, List, Union
  5. import torchaudio
  6. import librosa
  7. import numpy as np
  8. import pkg_resources
  9. from modelscope.utils.logger import get_logger
  10. logger = get_logger()
  11. green_color = '\033[1;32m'
  12. red_color = '\033[0;31;40m'
  13. yellow_color = '\033[0;33;40m'
  14. end_color = '\033[0m'
  15. global_asr_language = 'zh-cn'
  16. SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
  17. def get_version():
  18. return float(pkg_resources.get_distribution('easyasr').version)
  19. def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
  20. r_audio_fs = None
  21. if audio_format == 'wav' or audio_format == 'scp':
  22. r_audio_fs = get_sr_from_wav(audio_in)
  23. elif audio_format == 'pcm' and isinstance(audio_in, bytes):
  24. r_audio_fs = get_sr_from_bytes(audio_in)
  25. return r_audio_fs
  26. def type_checking(audio_in: Union[str, bytes],
  27. audio_fs: int = None,
  28. recog_type: str = None,
  29. audio_format: str = None):
  30. r_recog_type = recog_type
  31. r_audio_format = audio_format
  32. r_wav_path = audio_in
  33. if isinstance(audio_in, str):
  34. assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
  35. elif isinstance(audio_in, bytes):
  36. assert len(audio_in) > 0, 'audio in is empty'
  37. r_audio_format = 'pcm'
  38. r_recog_type = 'wav'
  39. if audio_in is None:
  40. # for raw_inputs
  41. r_recog_type = 'wav'
  42. r_audio_format = 'pcm'
  43. if r_recog_type is None and audio_in is not None:
  44. # audio_in is wav, recog_type is wav_file
  45. if os.path.isfile(audio_in):
  46. audio_type = os.path.basename(audio_in).lower()
  47. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  48. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  49. r_recog_type = 'wav'
  50. r_audio_format = 'wav'
  51. if audio_type.rfind(".scp") >= 0:
  52. r_recog_type = 'wav'
  53. r_audio_format = 'scp'
  54. if r_recog_type is None:
  55. raise NotImplementedError(
  56. f'Not supported audio type: {audio_type}')
  57. # recog_type is datasets_file
  58. elif os.path.isdir(audio_in):
  59. dir_name = os.path.basename(audio_in)
  60. if 'test' in dir_name:
  61. r_recog_type = 'test'
  62. elif 'dev' in dir_name:
  63. r_recog_type = 'dev'
  64. elif 'train' in dir_name:
  65. r_recog_type = 'train'
  66. if r_audio_format is None:
  67. if find_file_by_ends(audio_in, '.ark'):
  68. r_audio_format = 'kaldi_ark'
  69. elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
  70. audio_in, '.WAV'):
  71. r_audio_format = 'wav'
  72. elif find_file_by_ends(audio_in, '.records'):
  73. r_audio_format = 'tfrecord'
  74. if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
  75. # datasets with kaldi_ark file
  76. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  77. elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
  78. # datasets with tensorflow records file
  79. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  80. elif r_audio_format == 'wav' and r_recog_type != 'wav':
  81. # datasets with waveform files
  82. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
  83. return r_recog_type, r_audio_format, r_wav_path
  84. def get_sr_from_bytes(wav: bytes):
  85. sr = None
  86. data = wav
  87. if len(data) > 44:
  88. try:
  89. header_fields = {}
  90. header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
  91. header_fields['Format'] = str(data[8:12], 'UTF-8')
  92. header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
  93. if header_fields['ChunkID'] == 'RIFF' and header_fields[
  94. 'Format'] == 'WAVE' and header_fields[
  95. 'Subchunk1ID'] == 'fmt ':
  96. header_fields['SampleRate'] = struct.unpack('<I',
  97. data[24:28])[0]
  98. sr = header_fields['SampleRate']
  99. except Exception:
  100. # no treatment
  101. pass
  102. else:
  103. logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
  104. return sr
  105. def get_sr_from_wav(fname: str):
  106. fs = None
  107. if os.path.isfile(fname):
  108. audio_type = os.path.basename(fname).lower()
  109. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  110. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  111. if support_audio_type == "pcm":
  112. fs = None
  113. else:
  114. try:
  115. audio, fs = torchaudio.load(fname)
  116. except:
  117. audio, fs = librosa.load(fname)
  118. break
  119. if audio_type.rfind(".scp") >= 0:
  120. with open(fname, encoding="utf-8") as f:
  121. for line in f:
  122. wav_path = line.split()[1]
  123. fs = get_sr_from_wav(wav_path)
  124. if fs is not None:
  125. break
  126. return fs
  127. elif os.path.isdir(fname):
  128. dir_files = os.listdir(fname)
  129. for file in dir_files:
  130. file_path = os.path.join(fname, file)
  131. if os.path.isfile(file_path):
  132. fs = get_sr_from_wav(file_path)
  133. elif os.path.isdir(file_path):
  134. fs = get_sr_from_wav(file_path)
  135. if fs is not None:
  136. break
  137. return fs
  138. def find_file_by_ends(dir_path: str, ends: str):
  139. dir_files = os.listdir(dir_path)
  140. for file in dir_files:
  141. file_path = os.path.join(dir_path, file)
  142. if os.path.isfile(file_path):
  143. if ends == ".wav" or ends == ".WAV":
  144. audio_type = os.path.basename(file_path).lower()
  145. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  146. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  147. return True
  148. raise NotImplementedError(
  149. f'Not supported audio type: {audio_type}')
  150. elif file_path.endswith(ends):
  151. return True
  152. elif os.path.isdir(file_path):
  153. if find_file_by_ends(file_path, ends):
  154. return True
  155. return False
  156. def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
  157. dir_files = os.listdir(dir_path)
  158. for file in dir_files:
  159. file_path = os.path.join(dir_path, file)
  160. if os.path.isfile(file_path):
  161. audio_type = os.path.basename(file_path).lower()
  162. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  163. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  164. wav_list.append(file_path)
  165. elif os.path.isdir(file_path):
  166. recursion_dir_all_wav(wav_list, file_path)
  167. return wav_list
  168. def compute_wer(hyp_list: List[Any],
  169. ref_list: List[Any],
  170. lang: str = None) -> Dict[str, Any]:
  171. assert len(hyp_list) > 0, 'hyp list is empty'
  172. assert len(ref_list) > 0, 'ref list is empty'
  173. rst = {
  174. 'Wrd': 0,
  175. 'Corr': 0,
  176. 'Ins': 0,
  177. 'Del': 0,
  178. 'Sub': 0,
  179. 'Snt': 0,
  180. 'Err': 0.0,
  181. 'S.Err': 0.0,
  182. 'wrong_words': 0,
  183. 'wrong_sentences': 0
  184. }
  185. if lang is None:
  186. lang = global_asr_language
  187. for h_item in hyp_list:
  188. for r_item in ref_list:
  189. if h_item['key'] == r_item['key']:
  190. out_item = compute_wer_by_line(h_item['value'],
  191. r_item['value'],
  192. lang)
  193. rst['Wrd'] += out_item['nwords']
  194. rst['Corr'] += out_item['cor']
  195. rst['wrong_words'] += out_item['wrong']
  196. rst['Ins'] += out_item['ins']
  197. rst['Del'] += out_item['del']
  198. rst['Sub'] += out_item['sub']
  199. rst['Snt'] += 1
  200. if out_item['wrong'] > 0:
  201. rst['wrong_sentences'] += 1
  202. print_wrong_sentence(key=h_item['key'],
  203. hyp=h_item['value'],
  204. ref=r_item['value'])
  205. else:
  206. print_correct_sentence(key=h_item['key'],
  207. hyp=h_item['value'],
  208. ref=r_item['value'])
  209. break
  210. if rst['Wrd'] > 0:
  211. rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
  212. if rst['Snt'] > 0:
  213. rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
  214. return rst
  215. def compute_wer_by_line(hyp: List[str],
  216. ref: List[str],
  217. lang: str = 'zh-cn') -> Dict[str, Any]:
  218. if lang != 'zh-cn':
  219. hyp = hyp.split()
  220. ref = ref.split()
  221. hyp = list(map(lambda x: x.lower(), hyp))
  222. ref = list(map(lambda x: x.lower(), ref))
  223. len_hyp = len(hyp)
  224. len_ref = len(ref)
  225. cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
  226. ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
  227. for i in range(len_hyp + 1):
  228. cost_matrix[i][0] = i
  229. for j in range(len_ref + 1):
  230. cost_matrix[0][j] = j
  231. for i in range(1, len_hyp + 1):
  232. for j in range(1, len_ref + 1):
  233. if hyp[i - 1] == ref[j - 1]:
  234. cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
  235. else:
  236. substitution = cost_matrix[i - 1][j - 1] + 1
  237. insertion = cost_matrix[i - 1][j] + 1
  238. deletion = cost_matrix[i][j - 1] + 1
  239. compare_val = [substitution, insertion, deletion]
  240. min_val = min(compare_val)
  241. operation_idx = compare_val.index(min_val) + 1
  242. cost_matrix[i][j] = min_val
  243. ops_matrix[i][j] = operation_idx
  244. match_idx = []
  245. i = len_hyp
  246. j = len_ref
  247. rst = {
  248. 'nwords': len_ref,
  249. 'cor': 0,
  250. 'wrong': 0,
  251. 'ins': 0,
  252. 'del': 0,
  253. 'sub': 0
  254. }
  255. while i >= 0 or j >= 0:
  256. i_idx = max(0, i)
  257. j_idx = max(0, j)
  258. if ops_matrix[i_idx][j_idx] == 0: # correct
  259. if i - 1 >= 0 and j - 1 >= 0:
  260. match_idx.append((j - 1, i - 1))
  261. rst['cor'] += 1
  262. i -= 1
  263. j -= 1
  264. elif ops_matrix[i_idx][j_idx] == 2: # insert
  265. i -= 1
  266. rst['ins'] += 1
  267. elif ops_matrix[i_idx][j_idx] == 3: # delete
  268. j -= 1
  269. rst['del'] += 1
  270. elif ops_matrix[i_idx][j_idx] == 1: # substitute
  271. i -= 1
  272. j -= 1
  273. rst['sub'] += 1
  274. if i < 0 and j >= 0:
  275. rst['del'] += 1
  276. elif j < 0 and i >= 0:
  277. rst['ins'] += 1
  278. match_idx.reverse()
  279. wrong_cnt = cost_matrix[len_hyp][len_ref]
  280. rst['wrong'] = wrong_cnt
  281. return rst
  282. def print_wrong_sentence(key: str, hyp: str, ref: str):
  283. space = len(key)
  284. print(key + yellow_color + ' ref: ' + ref)
  285. print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
  286. def print_correct_sentence(key: str, hyp: str, ref: str):
  287. space = len(key)
  288. print(key + yellow_color + ' ref: ' + ref)
  289. print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
  290. def print_progress(percent):
  291. if percent > 1:
  292. percent = 1
  293. res = int(50 * percent) * '#'
  294. print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')