postprocess_utils.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # -*- encoding: utf-8 -*-
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import string
  5. import logging
  6. from typing import Any, List, Union
  7. def isChinese(ch: str):
  8. if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
  9. return True
  10. return False
  11. def isAllChinese(word: Union[List[Any], str]):
  12. word_lists = []
  13. for i in word:
  14. cur = i.replace(' ', '')
  15. cur = cur.replace('</s>', '')
  16. cur = cur.replace('<s>', '')
  17. word_lists.append(cur)
  18. if len(word_lists) == 0:
  19. return False
  20. for ch in word_lists:
  21. if isChinese(ch) is False:
  22. return False
  23. return True
  24. def isAllAlpha(word: Union[List[Any], str]):
  25. word_lists = []
  26. for i in word:
  27. cur = i.replace(' ', '')
  28. cur = cur.replace('</s>', '')
  29. cur = cur.replace('<s>', '')
  30. word_lists.append(cur)
  31. if len(word_lists) == 0:
  32. return False
  33. for ch in word_lists:
  34. if ch.isalpha() is False and ch != "'":
  35. return False
  36. elif ch.isalpha() is True and isChinese(ch) is True:
  37. return False
  38. return True
  39. # def abbr_dispose(words: List[Any]) -> List[Any]:
  40. def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
  41. words_size = len(words)
  42. word_lists = []
  43. abbr_begin = []
  44. abbr_end = []
  45. last_num = -1
  46. ts_lists = []
  47. ts_nums = []
  48. ts_index = 0
  49. for num in range(words_size):
  50. if num <= last_num:
  51. continue
  52. if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
  53. if num + 1 < words_size and words[
  54. num + 1] == ' ' and num + 2 < words_size and len(
  55. words[num +
  56. 2]) == 1 and words[num +
  57. 2].encode('utf-8').isalpha():
  58. # found the begin of abbr
  59. abbr_begin.append(num)
  60. num += 2
  61. abbr_end.append(num)
  62. # to find the end of abbr
  63. while True:
  64. num += 1
  65. if num < words_size and words[num] == ' ':
  66. num += 1
  67. if num < words_size and len(
  68. words[num]) == 1 and words[num].encode(
  69. 'utf-8').isalpha():
  70. abbr_end.pop()
  71. abbr_end.append(num)
  72. last_num = num
  73. else:
  74. break
  75. else:
  76. break
  77. for num in range(words_size):
  78. if words[num] == ' ':
  79. ts_nums.append(ts_index)
  80. else:
  81. ts_nums.append(ts_index)
  82. ts_index += 1
  83. last_num = -1
  84. for num in range(words_size):
  85. if num <= last_num:
  86. continue
  87. if num in abbr_begin:
  88. if time_stamp is not None:
  89. begin = time_stamp[ts_nums[num]][0]
  90. word_lists.append(words[num].upper())
  91. num += 1
  92. while num < words_size:
  93. if num in abbr_end:
  94. word_lists.append(words[num].upper())
  95. last_num = num
  96. break
  97. else:
  98. if words[num].encode('utf-8').isalpha():
  99. word_lists.append(words[num].upper())
  100. num += 1
  101. if time_stamp is not None:
  102. end = time_stamp[ts_nums[num]][1]
  103. ts_lists.append([begin, end])
  104. else:
  105. word_lists.append(words[num])
  106. if time_stamp is not None and words[num] != ' ':
  107. begin = time_stamp[ts_nums[num]][0]
  108. end = time_stamp[ts_nums[num]][1]
  109. ts_lists.append([begin, end])
  110. begin = end
  111. if time_stamp is not None:
  112. return word_lists, ts_lists
  113. else:
  114. return word_lists
  115. def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
  116. middle_lists = []
  117. word_lists = []
  118. word_item = ''
  119. ts_lists = []
  120. # wash words lists
  121. for i in words:
  122. word = ''
  123. if isinstance(i, str):
  124. word = i
  125. else:
  126. word = i.decode('utf-8')
  127. if word in ['<s>', '</s>', '<unk>']:
  128. continue
  129. else:
  130. middle_lists.append(word)
  131. # all chinese characters
  132. if isAllChinese(middle_lists):
  133. for i, ch in enumerate(middle_lists):
  134. word_lists.append(ch.replace(' ', ''))
  135. if time_stamp is not None:
  136. ts_lists = time_stamp
  137. # all alpha characters
  138. elif isAllAlpha(middle_lists):
  139. ts_flag = True
  140. for i, ch in enumerate(middle_lists):
  141. if ts_flag and time_stamp is not None:
  142. begin = time_stamp[i][0]
  143. end = time_stamp[i][1]
  144. word = ''
  145. if '@@' in ch:
  146. word = ch.replace('@@', '')
  147. word_item += word
  148. if time_stamp is not None:
  149. ts_flag = False
  150. end = time_stamp[i][1]
  151. else:
  152. word_item += ch
  153. word_lists.append(word_item)
  154. word_lists.append(' ')
  155. word_item = ''
  156. if time_stamp is not None:
  157. ts_flag = True
  158. end = time_stamp[i][1]
  159. ts_lists.append([begin, end])
  160. begin = end
  161. # mix characters
  162. else:
  163. alpha_blank = False
  164. ts_flag = True
  165. begin = -1
  166. end = -1
  167. for i, ch in enumerate(middle_lists):
  168. if ts_flag and time_stamp is not None:
  169. begin = time_stamp[i][0]
  170. end = time_stamp[i][1]
  171. word = ''
  172. if isAllChinese(ch):
  173. if alpha_blank is True:
  174. word_lists.pop()
  175. word_lists.append(ch)
  176. alpha_blank = False
  177. if time_stamp is not None:
  178. ts_flag = True
  179. ts_lists.append([begin, end])
  180. begin = end
  181. elif '@@' in ch:
  182. word = ch.replace('@@', '')
  183. word_item += word
  184. alpha_blank = False
  185. if time_stamp is not None:
  186. ts_flag = False
  187. end = time_stamp[i][1]
  188. elif isAllAlpha(ch):
  189. word_item += ch
  190. word_lists.append(word_item)
  191. word_lists.append(' ')
  192. word_item = ''
  193. alpha_blank = True
  194. if time_stamp is not None:
  195. ts_flag = True
  196. end = time_stamp[i][1]
  197. ts_lists.append([begin, end])
  198. begin = end
  199. else:
  200. raise ValueError('invalid character: {}'.format(ch))
  201. if time_stamp is not None:
  202. word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
  203. real_word_lists = []
  204. for ch in word_lists:
  205. if ch != ' ':
  206. real_word_lists.append(ch)
  207. sentence = ' '.join(real_word_lists).strip()
  208. return sentence, ts_lists, real_word_lists
  209. else:
  210. word_lists = abbr_dispose(word_lists)
  211. real_word_lists = []
  212. for ch in word_lists:
  213. if ch != ' ':
  214. real_word_lists.append(ch)
  215. sentence = ''.join(word_lists).strip()
  216. return sentence, real_word_lists
  217. def sentence_postprocess_sentencepiece(words):
  218. middle_lists = []
  219. word_lists = []
  220. word_item = ''
  221. # wash words lists
  222. for i in words:
  223. word = ''
  224. if isinstance(i, str):
  225. word = i
  226. else:
  227. word = i.decode('utf-8')
  228. if word in ['<s>', '</s>', '<unk>', '<OOV>']:
  229. continue
  230. else:
  231. middle_lists.append(word)
  232. # all alpha characters
  233. for i, ch in enumerate(middle_lists):
  234. word = ''
  235. if '\u2581' in ch and i == 0:
  236. word_item = ''
  237. word = ch.replace('\u2581', '')
  238. word_item += word
  239. elif '\u2581' in ch and i != 0:
  240. word_lists.append(word_item)
  241. word_lists.append(' ')
  242. word_item = ''
  243. word = ch.replace('\u2581', '')
  244. word_item += word
  245. else:
  246. word_item += ch
  247. if word_item is not None:
  248. word_lists.append(word_item)
  249. #word_lists = abbr_dispose(word_lists)
  250. real_word_lists = []
  251. for ch in word_lists:
  252. if ch != ' ':
  253. if ch == "i":
  254. ch = ch.replace("i", "I")
  255. elif ch == "i'm":
  256. ch = ch.replace("i'm", "I'm")
  257. elif ch == "i've":
  258. ch = ch.replace("i've", "I've")
  259. elif ch == "i'll":
  260. ch = ch.replace("i'll", "I'll")
  261. real_word_lists.append(ch)
  262. sentence = ''.join(word_lists)
  263. return sentence, real_word_lists