postprocess_utils.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import string
  3. import logging
  4. from typing import Any, List, Union
  5. def isChinese(ch: str):
  6. if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039' or ch == '@':
  7. return True
  8. return False
  9. def isAllChinese(word: Union[List[Any], str]):
  10. word_lists = []
  11. for i in word:
  12. cur = i.replace(' ', '')
  13. cur = cur.replace('</s>', '')
  14. cur = cur.replace('<s>', '')
  15. cur = cur.replace('<unk>', '')
  16. cur = cur.replace('<OOV>', '')
  17. word_lists.append(cur)
  18. if len(word_lists) == 0:
  19. return False
  20. for ch in word_lists:
  21. if isChinese(ch) is False:
  22. return False
  23. return True
  24. def isAllAlpha(word: Union[List[Any], str]):
  25. word_lists = []
  26. for i in word:
  27. cur = i.replace(' ', '')
  28. cur = cur.replace('</s>', '')
  29. cur = cur.replace('<s>', '')
  30. cur = cur.replace('<unk>', '')
  31. cur = cur.replace('<OOV>', '')
  32. word_lists.append(cur)
  33. if len(word_lists) == 0:
  34. return False
  35. for ch in word_lists:
  36. if ch.isalpha() is False and ch != "'":
  37. return False
  38. elif ch.isalpha() is True and isChinese(ch) is True:
  39. return False
  40. return True
  41. # def abbr_dispose(words: List[Any]) -> List[Any]:
  42. def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
  43. words_size = len(words)
  44. word_lists = []
  45. abbr_begin = []
  46. abbr_end = []
  47. last_num = -1
  48. ts_lists = []
  49. ts_nums = []
  50. ts_index = 0
  51. for num in range(words_size):
  52. if num <= last_num:
  53. continue
  54. if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
  55. if num + 1 < words_size and words[
  56. num + 1] == ' ' and num + 2 < words_size and len(
  57. words[num +
  58. 2]) == 1 and words[num +
  59. 2].encode('utf-8').isalpha():
  60. # found the begin of abbr
  61. abbr_begin.append(num)
  62. num += 2
  63. abbr_end.append(num)
  64. # to find the end of abbr
  65. while True:
  66. num += 1
  67. if num < words_size and words[num] == ' ':
  68. num += 1
  69. if num < words_size and len(
  70. words[num]) == 1 and words[num].encode(
  71. 'utf-8').isalpha():
  72. abbr_end.pop()
  73. abbr_end.append(num)
  74. last_num = num
  75. else:
  76. break
  77. else:
  78. break
  79. for num in range(words_size):
  80. if words[num] == ' ':
  81. ts_nums.append(ts_index)
  82. else:
  83. ts_nums.append(ts_index)
  84. ts_index += 1
  85. last_num = -1
  86. for num in range(words_size):
  87. if num <= last_num:
  88. continue
  89. if num in abbr_begin:
  90. if time_stamp is not None:
  91. begin = time_stamp[ts_nums[num]][0]
  92. word_lists.append(words[num].upper())
  93. num += 1
  94. while num < words_size:
  95. if num in abbr_end:
  96. word_lists.append(words[num].upper())
  97. last_num = num
  98. break
  99. else:
  100. if words[num].encode('utf-8').isalpha():
  101. word_lists.append(words[num].upper())
  102. num += 1
  103. if time_stamp is not None:
  104. end = time_stamp[ts_nums[num]][1]
  105. ts_lists.append([begin, end])
  106. else:
  107. word_lists.append(words[num])
  108. if time_stamp is not None and words[num] != ' ':
  109. begin = time_stamp[ts_nums[num]][0]
  110. end = time_stamp[ts_nums[num]][1]
  111. ts_lists.append([begin, end])
  112. begin = end
  113. if time_stamp is not None:
  114. return word_lists, ts_lists
  115. else:
  116. return word_lists
  117. def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
  118. middle_lists = []
  119. word_lists = []
  120. word_item = ''
  121. ts_lists = []
  122. # wash words lists
  123. for i in words:
  124. word = ''
  125. if isinstance(i, str):
  126. word = i
  127. else:
  128. word = i.decode('utf-8')
  129. if word in ['<s>', '</s>', '<unk>', '<OOV>']:
  130. continue
  131. else:
  132. middle_lists.append(word)
  133. # all chinese characters
  134. if isAllChinese(middle_lists):
  135. for i, ch in enumerate(middle_lists):
  136. word_lists.append(ch.replace(' ', ''))
  137. if time_stamp is not None:
  138. ts_lists = time_stamp
  139. # all alpha characters
  140. elif isAllAlpha(middle_lists):
  141. ts_flag = True
  142. for i, ch in enumerate(middle_lists):
  143. if ts_flag and time_stamp is not None:
  144. begin = time_stamp[i][0]
  145. end = time_stamp[i][1]
  146. word = ''
  147. if '@@' in ch:
  148. word = ch.replace('@@', '')
  149. word_item += word
  150. if time_stamp is not None:
  151. ts_flag = False
  152. end = time_stamp[i][1]
  153. else:
  154. word_item += ch
  155. word_lists.append(word_item)
  156. word_lists.append(' ')
  157. word_item = ''
  158. if time_stamp is not None:
  159. ts_flag = True
  160. end = time_stamp[i][1]
  161. ts_lists.append([begin, end])
  162. begin = end
  163. # mix characters
  164. else:
  165. alpha_blank = False
  166. ts_flag = True
  167. begin = -1
  168. end = -1
  169. for i, ch in enumerate(middle_lists):
  170. if ts_flag and time_stamp is not None:
  171. begin = time_stamp[i][0]
  172. end = time_stamp[i][1]
  173. word = ''
  174. if isAllChinese(ch):
  175. if alpha_blank is True:
  176. word_lists.pop()
  177. word_lists.append(ch)
  178. alpha_blank = False
  179. if time_stamp is not None:
  180. ts_flag = True
  181. ts_lists.append([begin, end])
  182. begin = end
  183. elif '@@' in ch:
  184. word = ch.replace('@@', '')
  185. word_item += word
  186. alpha_blank = False
  187. if time_stamp is not None:
  188. ts_flag = False
  189. end = time_stamp[i][1]
  190. elif isAllAlpha(ch):
  191. word_item += ch
  192. word_lists.append(word_item)
  193. word_lists.append(' ')
  194. word_item = ''
  195. alpha_blank = True
  196. if time_stamp is not None:
  197. ts_flag = True
  198. end = time_stamp[i][1]
  199. ts_lists.append([begin, end])
  200. begin = end
  201. else:
  202. raise ValueError('invalid character: {}'.format(ch))
  203. if time_stamp is not None:
  204. word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
  205. real_word_lists = []
  206. for ch in word_lists:
  207. if ch != ' ':
  208. real_word_lists.append(ch)
  209. sentence = ' '.join(real_word_lists).strip()
  210. return sentence, ts_lists, real_word_lists
  211. else:
  212. word_lists = abbr_dispose(word_lists)
  213. real_word_lists = []
  214. for ch in word_lists:
  215. if ch != ' ':
  216. real_word_lists.append(ch)
  217. sentence = ''.join(word_lists).strip()
  218. return sentence, real_word_lists