build_pretrain_model.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from funasr.layers.global_mvn import GlobalMVN
  2. from funasr.layers.utterance_mvn import UtteranceMVN
  3. from funasr.models.data2vec import Data2VecPretrainModel
  4. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  5. from funasr.models.frontend.default import DefaultFrontend
  6. from funasr.models.frontend.windowing import SlidingWindow
  7. from funasr.models.specaug.specaug import SpecAug
  8. from funasr.torch_utils.initialize import initialize
  9. from funasr.train.class_choices import ClassChoices
  10. frontend_choices = ClassChoices(
  11. name="frontend",
  12. classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
  13. default="default",
  14. )
  15. specaug_choices = ClassChoices(
  16. name="specaug",
  17. classes=dict(specaug=SpecAug),
  18. default=None,
  19. optional=True,
  20. )
  21. normalize_choices = ClassChoices(
  22. "normalize",
  23. classes=dict(
  24. global_mvn=GlobalMVN,
  25. utterance_mvn=UtteranceMVN,
  26. ),
  27. default=None,
  28. optional=True,
  29. )
  30. encoder_choices = ClassChoices(
  31. "encoder",
  32. classes=dict(
  33. data2vec_encoder=Data2VecEncoder,
  34. ),
  35. default="data2vec_encoder",
  36. )
  37. model_choices = ClassChoices(
  38. "model",
  39. classes=dict(
  40. data2vec=Data2VecPretrainModel,
  41. ),
  42. default="data2vec",
  43. )
  44. class_choices_list = [
  45. # --frontend and --frontend_conf
  46. frontend_choices,
  47. # --specaug and --specaug_conf
  48. specaug_choices,
  49. # --normalize and --normalize_conf
  50. normalize_choices,
  51. # --encoder and --encoder_conf
  52. encoder_choices,
  53. # --model and --model_conf
  54. model_choices,
  55. ]
  56. def build_pretrain_model(args):
  57. # frontend
  58. if args.input_size is None:
  59. frontend_class = frontend_choices.get_class(args.frontend)
  60. frontend = frontend_class(**args.frontend_conf)
  61. input_size = frontend.output_size()
  62. else:
  63. args.frontend = None
  64. args.frontend_conf = {}
  65. frontend = None
  66. input_size = args.input_size
  67. # data augmentation for spectrogram
  68. if args.specaug is not None:
  69. specaug_class = specaug_choices.get_class(args.specaug)
  70. specaug = specaug_class(**args.specaug_conf)
  71. else:
  72. specaug = None
  73. # normalization layer
  74. if args.normalize is not None:
  75. normalize_class = normalize_choices.get_class(args.normalize)
  76. normalize = normalize_class(**args.normalize_conf)
  77. else:
  78. normalize = None
  79. # encoder
  80. encoder_class = encoder_choices.get_class(args.encoder)
  81. encoder = encoder_class(
  82. input_size=input_size,
  83. **args.encoder_conf,
  84. )
  85. if args.model == "data2vec":
  86. model_class = model_choices.get_class("data2vec")
  87. model = model_class(
  88. frontend=frontend,
  89. specaug=specaug,
  90. normalize=normalize,
  91. encoder=encoder,
  92. )
  93. else:
  94. raise NotImplementedError("Not supported model: {}".format(args.model))
  95. # initialize
  96. if args.init is not None:
  97. initialize(model, args.init)
  98. return model