model.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import time
  6. import torch
  7. import logging
  8. from contextlib import contextmanager
  9. from typing import Dict, Optional, Tuple
  10. from distutils.version import LooseVersion
  11. from funasr.register import tables
  12. from funasr.utils import postprocess_utils
  13. from funasr.utils.datadir_writer import DatadirWriter
  14. from funasr.models.transducer.model import Transducer
  15. from funasr.train_utils.device_funcs import force_gatherable
  16. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  17. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  18. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  19. from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
  20. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  21. from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
  22. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  23. from torch.cuda.amp import autocast
  24. else:
  25. # Nothing to do if torch<1.6.0
  26. @contextmanager
  27. def autocast(enabled=True):
  28. yield
  29. @tables.register("model_classes", "BAT") # TODO: BAT training
  30. class BAT(Transducer):
  31. pass