asr_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import struct
  4. from typing import Any, Dict, List, Union
  5. import torchaudio
  6. import numpy as np
  7. import pkg_resources
  8. from modelscope.utils.logger import get_logger
  9. logger = get_logger()
  10. green_color = '\033[1;32m'
  11. red_color = '\033[0;31;40m'
  12. yellow_color = '\033[0;33;40m'
  13. end_color = '\033[0m'
  14. global_asr_language = 'zh-cn'
  15. SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
  16. def get_version():
  17. return float(pkg_resources.get_distribution('easyasr').version)
  18. def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
  19. r_audio_fs = None
  20. if audio_format == 'wav' or audio_format == 'scp':
  21. r_audio_fs = get_sr_from_wav(audio_in)
  22. elif audio_format == 'pcm' and isinstance(audio_in, bytes):
  23. r_audio_fs = get_sr_from_bytes(audio_in)
  24. return r_audio_fs
  25. def type_checking(audio_in: Union[str, bytes],
  26. audio_fs: int = None,
  27. recog_type: str = None,
  28. audio_format: str = None):
  29. r_recog_type = recog_type
  30. r_audio_format = audio_format
  31. r_wav_path = audio_in
  32. if isinstance(audio_in, str):
  33. assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
  34. elif isinstance(audio_in, bytes):
  35. assert len(audio_in) > 0, 'audio in is empty'
  36. r_audio_format = 'pcm'
  37. r_recog_type = 'wav'
  38. if audio_in is None:
  39. # for raw_inputs
  40. r_recog_type = 'wav'
  41. r_audio_format = 'pcm'
  42. if r_recog_type is None and audio_in is not None:
  43. # audio_in is wav, recog_type is wav_file
  44. if os.path.isfile(audio_in):
  45. audio_type = os.path.basename(audio_in).lower()
  46. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  47. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  48. r_recog_type = 'wav'
  49. r_audio_format = 'wav'
  50. if audio_type.rfind(".scp") >= 0:
  51. r_recog_type = 'wav'
  52. r_audio_format = 'scp'
  53. if r_recog_type is None:
  54. raise NotImplementedError(
  55. f'Not supported audio type: {audio_type}')
  56. # recog_type is datasets_file
  57. elif os.path.isdir(audio_in):
  58. dir_name = os.path.basename(audio_in)
  59. if 'test' in dir_name:
  60. r_recog_type = 'test'
  61. elif 'dev' in dir_name:
  62. r_recog_type = 'dev'
  63. elif 'train' in dir_name:
  64. r_recog_type = 'train'
  65. if r_audio_format is None:
  66. if find_file_by_ends(audio_in, '.ark'):
  67. r_audio_format = 'kaldi_ark'
  68. elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
  69. audio_in, '.WAV'):
  70. r_audio_format = 'wav'
  71. elif find_file_by_ends(audio_in, '.records'):
  72. r_audio_format = 'tfrecord'
  73. if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
  74. # datasets with kaldi_ark file
  75. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  76. elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
  77. # datasets with tensorflow records file
  78. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  79. elif r_audio_format == 'wav' and r_recog_type != 'wav':
  80. # datasets with waveform files
  81. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
  82. return r_recog_type, r_audio_format, r_wav_path
  83. def get_sr_from_bytes(wav: bytes):
  84. sr = None
  85. data = wav
  86. if len(data) > 44:
  87. try:
  88. header_fields = {}
  89. header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
  90. header_fields['Format'] = str(data[8:12], 'UTF-8')
  91. header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
  92. if header_fields['ChunkID'] == 'RIFF' and header_fields[
  93. 'Format'] == 'WAVE' and header_fields[
  94. 'Subchunk1ID'] == 'fmt ':
  95. header_fields['SampleRate'] = struct.unpack('<I',
  96. data[24:28])[0]
  97. sr = header_fields['SampleRate']
  98. except Exception:
  99. # no treatment
  100. pass
  101. else:
  102. logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
  103. return sr
  104. def get_sr_from_wav(fname: str):
  105. fs = None
  106. if os.path.isfile(fname):
  107. audio_type = os.path.basename(fname).lower()
  108. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  109. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  110. if support_audio_type == "pcm":
  111. fs = None
  112. else:
  113. audio, fs = torchaudio.load(fname)
  114. break
  115. if audio_type.rfind(".scp") >= 0:
  116. with open(fname, encoding="utf-8") as f:
  117. for line in f:
  118. wav_path = line.split()[1]
  119. fs = get_sr_from_wav(wav_path)
  120. if fs is not None:
  121. break
  122. return fs
  123. elif os.path.isdir(fname):
  124. dir_files = os.listdir(fname)
  125. for file in dir_files:
  126. file_path = os.path.join(fname, file)
  127. if os.path.isfile(file_path):
  128. fs = get_sr_from_wav(file_path)
  129. elif os.path.isdir(file_path):
  130. fs = get_sr_from_wav(file_path)
  131. if fs is not None:
  132. break
  133. return fs
  134. def find_file_by_ends(dir_path: str, ends: str):
  135. dir_files = os.listdir(dir_path)
  136. for file in dir_files:
  137. file_path = os.path.join(dir_path, file)
  138. if os.path.isfile(file_path):
  139. if ends == ".wav" or ends == ".WAV":
  140. audio_type = os.path.basename(file_path).lower()
  141. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  142. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  143. return True
  144. raise NotImplementedError(
  145. f'Not supported audio type: {audio_type}')
  146. elif file_path.endswith(ends):
  147. return True
  148. elif os.path.isdir(file_path):
  149. if find_file_by_ends(file_path, ends):
  150. return True
  151. return False
  152. def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
  153. dir_files = os.listdir(dir_path)
  154. for file in dir_files:
  155. file_path = os.path.join(dir_path, file)
  156. if os.path.isfile(file_path):
  157. audio_type = os.path.basename(file_path).lower()
  158. for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
  159. if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
  160. wav_list.append(file_path)
  161. elif os.path.isdir(file_path):
  162. recursion_dir_all_wav(wav_list, file_path)
  163. return wav_list
  164. def compute_wer(hyp_list: List[Any],
  165. ref_list: List[Any],
  166. lang: str = None) -> Dict[str, Any]:
  167. assert len(hyp_list) > 0, 'hyp list is empty'
  168. assert len(ref_list) > 0, 'ref list is empty'
  169. rst = {
  170. 'Wrd': 0,
  171. 'Corr': 0,
  172. 'Ins': 0,
  173. 'Del': 0,
  174. 'Sub': 0,
  175. 'Snt': 0,
  176. 'Err': 0.0,
  177. 'S.Err': 0.0,
  178. 'wrong_words': 0,
  179. 'wrong_sentences': 0
  180. }
  181. if lang is None:
  182. lang = global_asr_language
  183. for h_item in hyp_list:
  184. for r_item in ref_list:
  185. if h_item['key'] == r_item['key']:
  186. out_item = compute_wer_by_line(h_item['value'],
  187. r_item['value'],
  188. lang)
  189. rst['Wrd'] += out_item['nwords']
  190. rst['Corr'] += out_item['cor']
  191. rst['wrong_words'] += out_item['wrong']
  192. rst['Ins'] += out_item['ins']
  193. rst['Del'] += out_item['del']
  194. rst['Sub'] += out_item['sub']
  195. rst['Snt'] += 1
  196. if out_item['wrong'] > 0:
  197. rst['wrong_sentences'] += 1
  198. print_wrong_sentence(key=h_item['key'],
  199. hyp=h_item['value'],
  200. ref=r_item['value'])
  201. else:
  202. print_correct_sentence(key=h_item['key'],
  203. hyp=h_item['value'],
  204. ref=r_item['value'])
  205. break
  206. if rst['Wrd'] > 0:
  207. rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
  208. if rst['Snt'] > 0:
  209. rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
  210. return rst
  211. def compute_wer_by_line(hyp: List[str],
  212. ref: List[str],
  213. lang: str = 'zh-cn') -> Dict[str, Any]:
  214. if lang != 'zh-cn':
  215. hyp = hyp.split()
  216. ref = ref.split()
  217. hyp = list(map(lambda x: x.lower(), hyp))
  218. ref = list(map(lambda x: x.lower(), ref))
  219. len_hyp = len(hyp)
  220. len_ref = len(ref)
  221. cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
  222. ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
  223. for i in range(len_hyp + 1):
  224. cost_matrix[i][0] = i
  225. for j in range(len_ref + 1):
  226. cost_matrix[0][j] = j
  227. for i in range(1, len_hyp + 1):
  228. for j in range(1, len_ref + 1):
  229. if hyp[i - 1] == ref[j - 1]:
  230. cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
  231. else:
  232. substitution = cost_matrix[i - 1][j - 1] + 1
  233. insertion = cost_matrix[i - 1][j] + 1
  234. deletion = cost_matrix[i][j - 1] + 1
  235. compare_val = [substitution, insertion, deletion]
  236. min_val = min(compare_val)
  237. operation_idx = compare_val.index(min_val) + 1
  238. cost_matrix[i][j] = min_val
  239. ops_matrix[i][j] = operation_idx
  240. match_idx = []
  241. i = len_hyp
  242. j = len_ref
  243. rst = {
  244. 'nwords': len_ref,
  245. 'cor': 0,
  246. 'wrong': 0,
  247. 'ins': 0,
  248. 'del': 0,
  249. 'sub': 0
  250. }
  251. while i >= 0 or j >= 0:
  252. i_idx = max(0, i)
  253. j_idx = max(0, j)
  254. if ops_matrix[i_idx][j_idx] == 0: # correct
  255. if i - 1 >= 0 and j - 1 >= 0:
  256. match_idx.append((j - 1, i - 1))
  257. rst['cor'] += 1
  258. i -= 1
  259. j -= 1
  260. elif ops_matrix[i_idx][j_idx] == 2: # insert
  261. i -= 1
  262. rst['ins'] += 1
  263. elif ops_matrix[i_idx][j_idx] == 3: # delete
  264. j -= 1
  265. rst['del'] += 1
  266. elif ops_matrix[i_idx][j_idx] == 1: # substitute
  267. i -= 1
  268. j -= 1
  269. rst['sub'] += 1
  270. if i < 0 and j >= 0:
  271. rst['del'] += 1
  272. elif j < 0 and i >= 0:
  273. rst['ins'] += 1
  274. match_idx.reverse()
  275. wrong_cnt = cost_matrix[len_hyp][len_ref]
  276. rst['wrong'] = wrong_cnt
  277. return rst
  278. def print_wrong_sentence(key: str, hyp: str, ref: str):
  279. space = len(key)
  280. print(key + yellow_color + ' ref: ' + ref)
  281. print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
  282. def print_correct_sentence(key: str, hyp: str, ref: str):
  283. space = len(key)
  284. print(key + yellow_color + ' ref: ' + ref)
  285. print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
  286. def print_progress(percent):
  287. if percent > 1:
  288. percent = 1
  289. res = int(50 * percent) * '#'
  290. print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')