tokenize.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python
  2. import re
  3. import numpy as np
  4. def forward_segment(text, seg_dict):
  5. word_list = []
  6. i = 0
  7. while i < len(text):
  8. longest_word = text[i]
  9. for j in range(i + 1, len(text) + 1):
  10. word = text[i:j]
  11. if word in seg_dict:
  12. if len(word) > len(longest_word):
  13. longest_word = word
  14. word_list.append(longest_word)
  15. i += len(longest_word)
  16. return word_list
  17. def seg_tokenize(txt, seg_dict):
  18. pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
  19. out_txt = ""
  20. for word in txt:
  21. word = word.lower()
  22. if word in seg_dict:
  23. out_txt += seg_dict[word] + " "
  24. else:
  25. if pattern.match(word):
  26. for char in word:
  27. if char in seg_dict:
  28. out_txt += seg_dict[char] + " "
  29. else:
  30. out_txt += "<unk>" + " "
  31. else:
  32. out_txt += "<unk>" + " "
  33. return out_txt.strip().split()
  34. def tokenize(data,
  35. vocab=None,
  36. seg_dict=None,
  37. punc_dict=None,
  38. bpe_tokenizer=None):
  39. assert "text" in data
  40. assert isinstance(vocab, dict)
  41. text = data["text"]
  42. token = []
  43. vad = -2
  44. if bpe_tokenizer is not None:
  45. text = bpe_tokenizer.text2tokens("".join(text))
  46. if seg_dict is not None:
  47. assert isinstance(seg_dict, dict)
  48. text = seg_tokenize(text, seg_dict)
  49. length = len(text)
  50. for i in range(length):
  51. x = text[i]
  52. if i == length-1 and "punc" in data and x.startswith("vad:"):
  53. vad = x[4:]
  54. if len(vad) == 0:
  55. vad = -1
  56. else:
  57. vad = int(vad)
  58. elif x in vocab:
  59. token.append(vocab[x])
  60. else:
  61. token.append(vocab['<unk>'])
  62. if "punc" in data and punc_dict is not None:
  63. punc_token = []
  64. for punc in data["punc"]:
  65. if punc in punc_dict:
  66. punc_token.append(punc_dict[punc])
  67. else:
  68. punc_token.append(punc_dict["_"])
  69. data["punc"] = np.array(punc_token)
  70. data["text"] = np.array(token)
  71. if vad is not -2:
  72. data["vad_indexes"]=np.array([vad], dtype=np.int64)
  73. return data