preprocessor.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. import json
  3. import torch
  4. import logging
  5. import concurrent.futures
  6. import librosa
  7. import torch.distributed as dist
  8. from typing import Collection
  9. import torch
  10. import torchaudio
  11. from torch import nn
  12. import random
  13. import re
  14. from funasr.tokenizer.cleaner import TextCleaner
  15. from funasr.register import tables
  16. @tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
  17. class SpeechPreprocessSpeedPerturb(nn.Module):
  18. def __init__(self, speed_perturb: list=None, **kwargs):
  19. super().__init__()
  20. self.speed_perturb = speed_perturb
  21. def forward(self, waveform, fs, **kwargs):
  22. if self.speed_perturb is None:
  23. return waveform
  24. speed = random.choice(self.speed_perturb)
  25. if speed != 1.0:
  26. if not isinstance(waveform, torch.Tensor):
  27. waveform = torch.tensor(waveform)
  28. waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
  29. waveform.view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
  30. waveform = waveform.view(-1)
  31. return waveform
  32. @tables.register("preprocessor_classes", "TextPreprocessSegDict")
  33. class TextPreprocessSegDict(nn.Module):
  34. def __init__(self, seg_dict: str = None,
  35. text_cleaner: Collection[str] = None,
  36. split_with_space: bool = False,
  37. **kwargs):
  38. super().__init__()
  39. self.text_cleaner = TextCleaner(text_cleaner)
  40. def forward(self, text, **kwargs):
  41. text = self.text_cleaner(text)
  42. return text