build_pretrain_model.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from funasr.layers.global_mvn import GlobalMVN
  2. from funasr.layers.utterance_mvn import UtteranceMVN
  3. from funasr.models.data2vec import Data2VecPretrainModel
  4. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  5. from funasr.models.frontend.default import DefaultFrontend
  6. from funasr.models.frontend.windowing import SlidingWindow
  7. from funasr.models.frontend.wav_frontend import WavFrontend
  8. from funasr.models.specaug.specaug import SpecAug
  9. from funasr.torch_utils.initialize import initialize
  10. from funasr.train.class_choices import ClassChoices
  11. frontend_choices = ClassChoices(
  12. name="frontend",
  13. classes=dict(
  14. default=DefaultFrontend,
  15. sliding_window=SlidingWindow,
  16. wav_frontend=WavFrontend,
  17. ),
  18. default="default",
  19. )
  20. specaug_choices = ClassChoices(
  21. name="specaug",
  22. classes=dict(specaug=SpecAug),
  23. default=None,
  24. optional=True,
  25. )
  26. normalize_choices = ClassChoices(
  27. "normalize",
  28. classes=dict(
  29. global_mvn=GlobalMVN,
  30. utterance_mvn=UtteranceMVN,
  31. ),
  32. default=None,
  33. optional=True,
  34. )
  35. encoder_choices = ClassChoices(
  36. "encoder",
  37. classes=dict(
  38. data2vec_encoder=Data2VecEncoder,
  39. ),
  40. default="data2vec_encoder",
  41. )
  42. model_choices = ClassChoices(
  43. "model",
  44. classes=dict(
  45. data2vec=Data2VecPretrainModel,
  46. ),
  47. default="data2vec",
  48. )
  49. class_choices_list = [
  50. # --frontend and --frontend_conf
  51. frontend_choices,
  52. # --specaug and --specaug_conf
  53. specaug_choices,
  54. # --normalize and --normalize_conf
  55. normalize_choices,
  56. # --encoder and --encoder_conf
  57. encoder_choices,
  58. # --model and --model_conf
  59. model_choices,
  60. ]
  61. def build_pretrain_model(args):
  62. # frontend
  63. if args.input_size is None:
  64. frontend_class = frontend_choices.get_class(args.frontend)
  65. frontend = frontend_class(**args.frontend_conf)
  66. input_size = frontend.output_size()
  67. else:
  68. args.frontend = None
  69. args.frontend_conf = {}
  70. frontend = None
  71. input_size = args.input_size
  72. # data augmentation for spectrogram
  73. if args.specaug is not None:
  74. specaug_class = specaug_choices.get_class(args.specaug)
  75. specaug = specaug_class(**args.specaug_conf)
  76. else:
  77. specaug = None
  78. # normalization layer
  79. if args.normalize is not None:
  80. normalize_class = normalize_choices.get_class(args.normalize)
  81. normalize = normalize_class(**args.normalize_conf)
  82. else:
  83. normalize = None
  84. # encoder
  85. encoder_class = encoder_choices.get_class(args.encoder)
  86. encoder = encoder_class(
  87. input_size=input_size,
  88. **args.encoder_conf,
  89. )
  90. if args.model == "data2vec":
  91. model_class = model_choices.get_class("data2vec")
  92. model = model_class(
  93. frontend=frontend,
  94. specaug=specaug,
  95. normalize=normalize,
  96. encoder=encoder,
  97. )
  98. else:
  99. raise NotImplementedError("Not supported model: {}".format(args.model))
  100. # initialize
  101. if args.init is not None:
  102. initialize(model, args.init)
  103. return model