postprocess_utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import string
  3. from typing import Any, List, Union
  4. def isChinese(ch: str):
  5. if '\u4e00' <= ch <= '\u9fff':
  6. return True
  7. return False
  8. def isAllChinese(word: Union[List[Any], str]):
  9. word_lists = []
  10. table = str.maketrans('', '', string.punctuation)
  11. for i in word:
  12. cur = i.translate(table)
  13. cur = cur.replace(' ', '')
  14. cur = cur.replace('</s>', '')
  15. cur = cur.replace('<s>', '')
  16. word_lists.append(cur)
  17. if len(word_lists) == 0:
  18. return False
  19. for ch in word_lists:
  20. if isChinese(ch) is False:
  21. return False
  22. return True
  23. def isAllAlpha(word: Union[List[Any], str]):
  24. word_lists = []
  25. table = str.maketrans('', '', string.punctuation)
  26. for i in word:
  27. cur = i.translate(table)
  28. cur = cur.replace(' ', '')
  29. cur = cur.replace('</s>', '')
  30. cur = cur.replace('<s>', '')
  31. word_lists.append(cur)
  32. if len(word_lists) == 0:
  33. return False
  34. for ch in word_lists:
  35. if ch.isalpha() is False:
  36. return False
  37. elif ch.isalpha() is True and isChinese(ch) is True:
  38. return False
  39. return True
  40. def abbr_dispose(words: List[Any]) -> List[Any]:
  41. words_size = len(words)
  42. word_lists = []
  43. abbr_begin = []
  44. abbr_end = []
  45. last_num = -1
  46. for num in range(words_size):
  47. if num <= last_num:
  48. continue
  49. if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
  50. if num + 1 < words_size and words[
  51. num + 1] == ' ' and num + 2 < words_size and len(
  52. words[num +
  53. 2]) == 1 and words[num +
  54. 2].encode('utf-8').isalpha():
  55. # found the begin of abbr
  56. abbr_begin.append(num)
  57. num += 2
  58. abbr_end.append(num)
  59. # to find the end of abbr
  60. while True:
  61. num += 1
  62. if num < words_size and words[num] == ' ':
  63. num += 1
  64. if num < words_size and len(
  65. words[num]) == 1 and words[num].encode(
  66. 'utf-8').isalpha():
  67. abbr_end.pop()
  68. abbr_end.append(num)
  69. last_num = num
  70. else:
  71. break
  72. else:
  73. break
  74. last_num = -1
  75. for num in range(words_size):
  76. if num <= last_num:
  77. continue
  78. if num in abbr_begin:
  79. word_lists.append(words[num].upper())
  80. num += 1
  81. while num < words_size:
  82. if num in abbr_end:
  83. word_lists.append(words[num].upper())
  84. last_num = num
  85. break
  86. else:
  87. if words[num].encode('utf-8').isalpha():
  88. word_lists.append(words[num].upper())
  89. num += 1
  90. else:
  91. word_lists.append(words[num])
  92. return word_lists
  93. def sentence_postprocess(words: List[Any]):
  94. middle_lists = []
  95. word_lists = []
  96. word_item = ''
  97. # wash words lists
  98. for i in words:
  99. word = ''
  100. if isinstance(i, str):
  101. word = i
  102. else:
  103. word = i.decode('utf-8')
  104. if word in ['<s>', '</s>', '<unk>']:
  105. continue
  106. else:
  107. middle_lists.append(word)
  108. # all chinese characters
  109. if isAllChinese(middle_lists):
  110. for ch in middle_lists:
  111. word_lists.append(ch.replace(' ', ''))
  112. # all alpha characters
  113. elif isAllAlpha(middle_lists):
  114. for ch in middle_lists:
  115. word = ''
  116. if '@@' in ch:
  117. word = ch.replace('@@', '')
  118. word_item += word
  119. else:
  120. word_item += ch
  121. word_lists.append(word_item)
  122. word_lists.append(' ')
  123. word_item = ''
  124. # mix characters
  125. else:
  126. alpha_blank = False
  127. for ch in middle_lists:
  128. word = ''
  129. if isAllChinese(ch):
  130. if alpha_blank is True:
  131. word_lists.pop()
  132. word_lists.append(ch)
  133. alpha_blank = False
  134. elif '@@' in ch:
  135. word = ch.replace('@@', '')
  136. word_item += word
  137. alpha_blank = False
  138. elif isAllAlpha(ch):
  139. word_item += ch
  140. word_lists.append(word_item)
  141. word_lists.append(' ')
  142. word_item = ''
  143. alpha_blank = True
  144. else:
  145. raise ValueError('invalid character: {}'.format(ch))
  146. word_lists = abbr_dispose(word_lists)
  147. sentence = ''.join(word_lists).strip()
  148. return sentence