asr_utils.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import struct
  4. from typing import Any, Dict, List, Union
  5. import librosa
  6. import numpy as np
  7. import pkg_resources
  8. from modelscope.utils.logger import get_logger
  9. logger = get_logger()
  10. green_color = '\033[1;32m'
  11. red_color = '\033[0;31;40m'
  12. yellow_color = '\033[0;33;40m'
  13. end_color = '\033[0m'
  14. global_asr_language = 'zh-cn'
  15. def get_version():
  16. return float(pkg_resources.get_distribution('easyasr').version)
  17. def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
  18. r_audio_fs = None
  19. if audio_format == 'wav':
  20. r_audio_fs = get_sr_from_wav(audio_in)
  21. elif audio_format == 'pcm' and isinstance(audio_in, bytes):
  22. r_audio_fs = get_sr_from_bytes(audio_in)
  23. return r_audio_fs
  24. def type_checking(audio_in: Union[str, bytes],
  25. audio_fs: int = None,
  26. recog_type: str = None,
  27. audio_format: str = None):
  28. r_recog_type = recog_type
  29. r_audio_format = audio_format
  30. r_wav_path = audio_in
  31. if isinstance(audio_in, str):
  32. assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
  33. elif isinstance(audio_in, bytes):
  34. assert len(audio_in) > 0, 'audio in is empty'
  35. r_audio_format = 'pcm'
  36. r_recog_type = 'wav'
  37. if r_recog_type is None:
  38. # audio_in is wav, recog_type is wav_file
  39. if os.path.isfile(audio_in):
  40. if audio_in.endswith('.wav') or audio_in.endswith('.WAV'):
  41. r_recog_type = 'wav'
  42. r_audio_format = 'wav'
  43. # recog_type is datasets_file
  44. elif os.path.isdir(audio_in):
  45. dir_name = os.path.basename(audio_in)
  46. if 'test' in dir_name:
  47. r_recog_type = 'test'
  48. elif 'dev' in dir_name:
  49. r_recog_type = 'dev'
  50. elif 'train' in dir_name:
  51. r_recog_type = 'train'
  52. if r_audio_format is None:
  53. if find_file_by_ends(audio_in, '.ark'):
  54. r_audio_format = 'kaldi_ark'
  55. elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
  56. audio_in, '.WAV'):
  57. r_audio_format = 'wav'
  58. elif find_file_by_ends(audio_in, '.records'):
  59. r_audio_format = 'tfrecord'
  60. if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
  61. # datasets with kaldi_ark file
  62. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  63. elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
  64. # datasets with tensorflow records file
  65. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  66. elif r_audio_format == 'wav' and r_recog_type != 'wav':
  67. # datasets with waveform files
  68. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
  69. return r_recog_type, r_audio_format, r_wav_path
  70. def get_sr_from_bytes(wav: bytes):
  71. sr = None
  72. data = wav
  73. if len(data) > 44:
  74. try:
  75. header_fields = {}
  76. header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
  77. header_fields['Format'] = str(data[8:12], 'UTF-8')
  78. header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
  79. if header_fields['ChunkID'] == 'RIFF' and header_fields[
  80. 'Format'] == 'WAVE' and header_fields[
  81. 'Subchunk1ID'] == 'fmt ':
  82. header_fields['SampleRate'] = struct.unpack('<I',
  83. data[24:28])[0]
  84. sr = header_fields['SampleRate']
  85. except Exception:
  86. # no treatment
  87. pass
  88. else:
  89. logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
  90. return sr
  91. def get_sr_from_wav(fname: str):
  92. fs = None
  93. if os.path.isfile(fname):
  94. audio, fs = librosa.load(fname, sr=None)
  95. return fs
  96. elif os.path.isdir(fname):
  97. dir_files = os.listdir(fname)
  98. for file in dir_files:
  99. file_path = os.path.join(fname, file)
  100. if os.path.isfile(file_path):
  101. if file_path.endswith('.wav') or file_path.endswith('.WAV'):
  102. fs = get_sr_from_wav(file_path)
  103. elif os.path.isdir(file_path):
  104. fs = get_sr_from_wav(file_path)
  105. if fs is not None:
  106. break
  107. return fs
  108. def find_file_by_ends(dir_path: str, ends: str):
  109. dir_files = os.listdir(dir_path)
  110. for file in dir_files:
  111. file_path = os.path.join(dir_path, file)
  112. if os.path.isfile(file_path):
  113. if file_path.endswith(ends):
  114. return True
  115. elif os.path.isdir(file_path):
  116. if find_file_by_ends(file_path, ends):
  117. return True
  118. return False
  119. def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
  120. dir_files = os.listdir(dir_path)
  121. for file in dir_files:
  122. file_path = os.path.join(dir_path, file)
  123. if os.path.isfile(file_path):
  124. if file_path.endswith('.wav') or file_path.endswith('.WAV'):
  125. wav_list.append(file_path)
  126. elif os.path.isdir(file_path):
  127. recursion_dir_all_wav(wav_list, file_path)
  128. return wav_list
  129. def set_parameters(language: str = None):
  130. if language is not None:
  131. global global_asr_language
  132. global_asr_language = language
  133. def compute_wer(hyp_list: List[Any],
  134. ref_list: List[Any],
  135. lang: str = None) -> Dict[str, Any]:
  136. assert len(hyp_list) > 0, 'hyp list is empty'
  137. assert len(ref_list) > 0, 'ref list is empty'
  138. if lang is not None:
  139. global global_asr_language
  140. global_asr_language = lang
  141. rst = {
  142. 'Wrd': 0,
  143. 'Corr': 0,
  144. 'Ins': 0,
  145. 'Del': 0,
  146. 'Sub': 0,
  147. 'Snt': 0,
  148. 'Err': 0.0,
  149. 'S.Err': 0.0,
  150. 'wrong_words': 0,
  151. 'wrong_sentences': 0
  152. }
  153. for h_item in hyp_list:
  154. for r_item in ref_list:
  155. if h_item['key'] == r_item['key']:
  156. out_item = compute_wer_by_line(h_item['value'],
  157. r_item['value'],
  158. global_asr_language)
  159. rst['Wrd'] += out_item['nwords']
  160. rst['Corr'] += out_item['cor']
  161. rst['wrong_words'] += out_item['wrong']
  162. rst['Ins'] += out_item['ins']
  163. rst['Del'] += out_item['del']
  164. rst['Sub'] += out_item['sub']
  165. rst['Snt'] += 1
  166. if out_item['wrong'] > 0:
  167. rst['wrong_sentences'] += 1
  168. print_wrong_sentence(key=h_item['key'],
  169. hyp=h_item['value'],
  170. ref=r_item['value'])
  171. else:
  172. print_correct_sentence(key=h_item['key'],
  173. hyp=h_item['value'],
  174. ref=r_item['value'])
  175. break
  176. if rst['Wrd'] > 0:
  177. rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
  178. if rst['Snt'] > 0:
  179. rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
  180. return rst
  181. def compute_wer_by_line(hyp: List[str],
  182. ref: List[str],
  183. lang: str = 'zh-cn') -> Dict[str, Any]:
  184. if lang != 'zh-cn':
  185. hyp = hyp.split()
  186. ref = ref.split()
  187. hyp = list(map(lambda x: x.lower(), hyp))
  188. ref = list(map(lambda x: x.lower(), ref))
  189. len_hyp = len(hyp)
  190. len_ref = len(ref)
  191. cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
  192. ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
  193. for i in range(len_hyp + 1):
  194. cost_matrix[i][0] = i
  195. for j in range(len_ref + 1):
  196. cost_matrix[0][j] = j
  197. for i in range(1, len_hyp + 1):
  198. for j in range(1, len_ref + 1):
  199. if hyp[i - 1] == ref[j - 1]:
  200. cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
  201. else:
  202. substitution = cost_matrix[i - 1][j - 1] + 1
  203. insertion = cost_matrix[i - 1][j] + 1
  204. deletion = cost_matrix[i][j - 1] + 1
  205. compare_val = [substitution, insertion, deletion]
  206. min_val = min(compare_val)
  207. operation_idx = compare_val.index(min_val) + 1
  208. cost_matrix[i][j] = min_val
  209. ops_matrix[i][j] = operation_idx
  210. match_idx = []
  211. i = len_hyp
  212. j = len_ref
  213. rst = {
  214. 'nwords': len_ref,
  215. 'cor': 0,
  216. 'wrong': 0,
  217. 'ins': 0,
  218. 'del': 0,
  219. 'sub': 0
  220. }
  221. while i >= 0 or j >= 0:
  222. i_idx = max(0, i)
  223. j_idx = max(0, j)
  224. if ops_matrix[i_idx][j_idx] == 0: # correct
  225. if i - 1 >= 0 and j - 1 >= 0:
  226. match_idx.append((j - 1, i - 1))
  227. rst['cor'] += 1
  228. i -= 1
  229. j -= 1
  230. elif ops_matrix[i_idx][j_idx] == 2: # insert
  231. i -= 1
  232. rst['ins'] += 1
  233. elif ops_matrix[i_idx][j_idx] == 3: # delete
  234. j -= 1
  235. rst['del'] += 1
  236. elif ops_matrix[i_idx][j_idx] == 1: # substitute
  237. i -= 1
  238. j -= 1
  239. rst['sub'] += 1
  240. if i < 0 and j >= 0:
  241. rst['del'] += 1
  242. elif j < 0 and i >= 0:
  243. rst['ins'] += 1
  244. match_idx.reverse()
  245. wrong_cnt = cost_matrix[len_hyp][len_ref]
  246. rst['wrong'] = wrong_cnt
  247. return rst
  248. def print_wrong_sentence(key: str, hyp: str, ref: str):
  249. space = len(key)
  250. print(key + yellow_color + ' ref: ' + ref)
  251. print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
  252. def print_correct_sentence(key: str, hyp: str, ref: str):
  253. space = len(key)
  254. print(key + yellow_color + ' ref: ' + ref)
  255. print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
  256. def print_progress(percent):
  257. if percent > 1:
  258. percent = 1
  259. res = int(50 * percent) * '#'
  260. print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')