preprocessor.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import os
  2. import json
  3. import torch
  4. import logging
  5. import concurrent.futures
  6. import librosa
  7. import torch.distributed as dist
  8. from typing import Collection
  9. import torch
  10. import torchaudio
  11. from torch import nn
  12. import random
  13. import re
  14. import string
  15. from funasr.tokenizer.cleaner import TextCleaner
  16. from funasr.register import tables
  17. @tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
  18. class TextPreprocessSegDict(nn.Module):
  19. def __init__(self,
  20. **kwargs):
  21. super().__init__()
  22. def forward(self, text, **kwargs):
  23. # 定义英文标点符号
  24. en_punct = string.punctuation
  25. # 定义中文标点符号(部分常用的)
  26. cn_punct = '。?!,、;:“”‘’()《》【】…—~·'
  27. # 合并英文和中文标点符号
  28. all_punct = en_punct + cn_punct
  29. # 创建正则表达式模式,匹配任何在all_punct中的字符
  30. punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
  31. # 使用正则表达式的sub方法替换掉这些字符
  32. return punct_pattern.sub('', text)