asr_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import struct
  4. from typing import Any, Dict, List, Union
  5. import torchaudio
  6. import numpy as np
  7. import pkg_resources
  8. from modelscope.utils.logger import get_logger
  9. logger = get_logger()
  10. green_color = '\033[1;32m'
  11. red_color = '\033[0;31;40m'
  12. yellow_color = '\033[0;33;40m'
  13. end_color = '\033[0m'
  14. global_asr_language = 'zh-cn'
  15. SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
  16. def get_version():
  17. return float(pkg_resources.get_distribution('easyasr').version)
  18. def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
  19. r_audio_fs = None
  20. if audio_format == 'wav' or audio_format == 'scp':
  21. r_audio_fs = get_sr_from_wav(audio_in)
  22. elif audio_format == 'pcm' and isinstance(audio_in, bytes):
  23. r_audio_fs = get_sr_from_bytes(audio_in)
  24. return r_audio_fs
  25. def type_checking(audio_in: Union[str, bytes],
  26. audio_fs: int = None,
  27. recog_type: str = None,
  28. audio_format: str = None):
  29. r_recog_type = recog_type
  30. r_audio_format = audio_format
  31. r_wav_path = audio_in
  32. if isinstance(audio_in, str):
  33. assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
  34. elif isinstance(audio_in, bytes):
  35. assert len(audio_in) > 0, 'audio in is empty'
  36. r_audio_format = 'pcm'
  37. r_recog_type = 'wav'
  38. if audio_in is None:
  39. # for raw_inputs
  40. r_recog_type = 'wav'
  41. r_audio_format = 'pcm'
  42. if r_recog_type is None and audio_in is not None:
  43. # audio_in is wav, recog_type is wav_file
  44. if os.path.isfile(audio_in):
  45. audio_type = os.path.basename(audio_in).split(".")[-1].lower()
  46. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  47. r_recog_type = 'wav'
  48. r_audio_format = 'wav'
  49. elif audio_type == "scp":
  50. r_recog_type = 'wav'
  51. r_audio_format = 'scp'
  52. else:
  53. raise NotImplementedError(
  54. f'Not supported audio type: {audio_type}')
  55. # recog_type is datasets_file
  56. elif os.path.isdir(audio_in):
  57. dir_name = os.path.basename(audio_in)
  58. if 'test' in dir_name:
  59. r_recog_type = 'test'
  60. elif 'dev' in dir_name:
  61. r_recog_type = 'dev'
  62. elif 'train' in dir_name:
  63. r_recog_type = 'train'
  64. if r_audio_format is None:
  65. if find_file_by_ends(audio_in, '.ark'):
  66. r_audio_format = 'kaldi_ark'
  67. elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
  68. audio_in, '.WAV'):
  69. r_audio_format = 'wav'
  70. elif find_file_by_ends(audio_in, '.records'):
  71. r_audio_format = 'tfrecord'
  72. if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
  73. # datasets with kaldi_ark file
  74. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  75. elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
  76. # datasets with tensorflow records file
  77. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
  78. elif r_audio_format == 'wav' and r_recog_type != 'wav':
  79. # datasets with waveform files
  80. r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
  81. return r_recog_type, r_audio_format, r_wav_path
  82. def get_sr_from_bytes(wav: bytes):
  83. sr = None
  84. data = wav
  85. if len(data) > 44:
  86. try:
  87. header_fields = {}
  88. header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
  89. header_fields['Format'] = str(data[8:12], 'UTF-8')
  90. header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
  91. if header_fields['ChunkID'] == 'RIFF' and header_fields[
  92. 'Format'] == 'WAVE' and header_fields[
  93. 'Subchunk1ID'] == 'fmt ':
  94. header_fields['SampleRate'] = struct.unpack('<I',
  95. data[24:28])[0]
  96. sr = header_fields['SampleRate']
  97. except Exception:
  98. # no treatment
  99. pass
  100. else:
  101. logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
  102. return sr
  103. def get_sr_from_wav(fname: str):
  104. fs = None
  105. if os.path.isfile(fname):
  106. audio_type = os.path.basename(fname).split(".")[-1].lower()
  107. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  108. if audio_type == "pcm":
  109. fs = None
  110. else:
  111. audio, fs = torchaudio.load(fname)
  112. elif audio_type == "scp":
  113. with open(fname, encoding="utf-8") as f:
  114. for line in f:
  115. wav_path = line.split()[1]
  116. fs = get_sr_from_wav(wav_path)
  117. if fs is not None:
  118. break
  119. return fs
  120. elif os.path.isdir(fname):
  121. dir_files = os.listdir(fname)
  122. for file in dir_files:
  123. file_path = os.path.join(fname, file)
  124. if os.path.isfile(file_path):
  125. audio_type = os.path.basename(file_path).split(".")[-1].lower()
  126. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  127. fs = get_sr_from_wav(file_path)
  128. elif os.path.isdir(file_path):
  129. fs = get_sr_from_wav(file_path)
  130. if fs is not None:
  131. break
  132. return fs
  133. def find_file_by_ends(dir_path: str, ends: str):
  134. dir_files = os.listdir(dir_path)
  135. for file in dir_files:
  136. file_path = os.path.join(dir_path, file)
  137. if os.path.isfile(file_path):
  138. if ends == ".wav" or ends == ".WAV":
  139. audio_type = os.path.basename(file_path).split(".")[-1].lower()
  140. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  141. return True
  142. else:
  143. raise NotImplementedError(
  144. f'Not supported audio type: {audio_type}')
  145. elif file_path.endswith(ends):
  146. return True
  147. elif os.path.isdir(file_path):
  148. if find_file_by_ends(file_path, ends):
  149. return True
  150. return False
  151. def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
  152. dir_files = os.listdir(dir_path)
  153. for file in dir_files:
  154. file_path = os.path.join(dir_path, file)
  155. if os.path.isfile(file_path):
  156. audio_type = os.path.basename(file_path).split(".")[-1].lower()
  157. if audio_type in SUPPORT_AUDIO_TYPE_SETS:
  158. wav_list.append(file_path)
  159. elif os.path.isdir(file_path):
  160. recursion_dir_all_wav(wav_list, file_path)
  161. return wav_list
  162. def compute_wer(hyp_list: List[Any],
  163. ref_list: List[Any],
  164. lang: str = None) -> Dict[str, Any]:
  165. assert len(hyp_list) > 0, 'hyp list is empty'
  166. assert len(ref_list) > 0, 'ref list is empty'
  167. rst = {
  168. 'Wrd': 0,
  169. 'Corr': 0,
  170. 'Ins': 0,
  171. 'Del': 0,
  172. 'Sub': 0,
  173. 'Snt': 0,
  174. 'Err': 0.0,
  175. 'S.Err': 0.0,
  176. 'wrong_words': 0,
  177. 'wrong_sentences': 0
  178. }
  179. if lang is None:
  180. lang = global_asr_language
  181. for h_item in hyp_list:
  182. for r_item in ref_list:
  183. if h_item['key'] == r_item['key']:
  184. out_item = compute_wer_by_line(h_item['value'],
  185. r_item['value'],
  186. lang)
  187. rst['Wrd'] += out_item['nwords']
  188. rst['Corr'] += out_item['cor']
  189. rst['wrong_words'] += out_item['wrong']
  190. rst['Ins'] += out_item['ins']
  191. rst['Del'] += out_item['del']
  192. rst['Sub'] += out_item['sub']
  193. rst['Snt'] += 1
  194. if out_item['wrong'] > 0:
  195. rst['wrong_sentences'] += 1
  196. print_wrong_sentence(key=h_item['key'],
  197. hyp=h_item['value'],
  198. ref=r_item['value'])
  199. else:
  200. print_correct_sentence(key=h_item['key'],
  201. hyp=h_item['value'],
  202. ref=r_item['value'])
  203. break
  204. if rst['Wrd'] > 0:
  205. rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
  206. if rst['Snt'] > 0:
  207. rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
  208. return rst
  209. def compute_wer_by_line(hyp: List[str],
  210. ref: List[str],
  211. lang: str = 'zh-cn') -> Dict[str, Any]:
  212. if lang != 'zh-cn':
  213. hyp = hyp.split()
  214. ref = ref.split()
  215. hyp = list(map(lambda x: x.lower(), hyp))
  216. ref = list(map(lambda x: x.lower(), ref))
  217. len_hyp = len(hyp)
  218. len_ref = len(ref)
  219. cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
  220. ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
  221. for i in range(len_hyp + 1):
  222. cost_matrix[i][0] = i
  223. for j in range(len_ref + 1):
  224. cost_matrix[0][j] = j
  225. for i in range(1, len_hyp + 1):
  226. for j in range(1, len_ref + 1):
  227. if hyp[i - 1] == ref[j - 1]:
  228. cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
  229. else:
  230. substitution = cost_matrix[i - 1][j - 1] + 1
  231. insertion = cost_matrix[i - 1][j] + 1
  232. deletion = cost_matrix[i][j - 1] + 1
  233. compare_val = [substitution, insertion, deletion]
  234. min_val = min(compare_val)
  235. operation_idx = compare_val.index(min_val) + 1
  236. cost_matrix[i][j] = min_val
  237. ops_matrix[i][j] = operation_idx
  238. match_idx = []
  239. i = len_hyp
  240. j = len_ref
  241. rst = {
  242. 'nwords': len_ref,
  243. 'cor': 0,
  244. 'wrong': 0,
  245. 'ins': 0,
  246. 'del': 0,
  247. 'sub': 0
  248. }
  249. while i >= 0 or j >= 0:
  250. i_idx = max(0, i)
  251. j_idx = max(0, j)
  252. if ops_matrix[i_idx][j_idx] == 0: # correct
  253. if i - 1 >= 0 and j - 1 >= 0:
  254. match_idx.append((j - 1, i - 1))
  255. rst['cor'] += 1
  256. i -= 1
  257. j -= 1
  258. elif ops_matrix[i_idx][j_idx] == 2: # insert
  259. i -= 1
  260. rst['ins'] += 1
  261. elif ops_matrix[i_idx][j_idx] == 3: # delete
  262. j -= 1
  263. rst['del'] += 1
  264. elif ops_matrix[i_idx][j_idx] == 1: # substitute
  265. i -= 1
  266. j -= 1
  267. rst['sub'] += 1
  268. if i < 0 and j >= 0:
  269. rst['del'] += 1
  270. elif j < 0 and i >= 0:
  271. rst['ins'] += 1
  272. match_idx.reverse()
  273. wrong_cnt = cost_matrix[len_hyp][len_ref]
  274. rst['wrong'] = wrong_cnt
  275. return rst
  276. def print_wrong_sentence(key: str, hyp: str, ref: str):
  277. space = len(key)
  278. print(key + yellow_color + ' ref: ' + ref)
  279. print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
  280. def print_correct_sentence(key: str, hyp: str, ref: str):
  281. space = len(key)
  282. print(key + yellow_color + ' ref: ' + ref)
  283. print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
  284. def print_progress(percent):
  285. if percent > 1:
  286. percent = 1
  287. res = int(50 * percent) * '#'
  288. print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')