build_lm_model.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import logging
  2. from funasr.train.abs_model import AbsLM
  3. from funasr.train.abs_model import LanguageModel
  4. from funasr.models.seq_rnn_lm import SequentialRNNLM
  5. from funasr.models.transformer_lm import TransformerLM
  6. from funasr.torch_utils.initialize import initialize
  7. from funasr.train.class_choices import ClassChoices
  8. lm_choices = ClassChoices(
  9. "lm",
  10. classes=dict(
  11. seq_rnn=SequentialRNNLM,
  12. transformer=TransformerLM,
  13. ),
  14. type_check=AbsLM,
  15. default="seq_rnn",
  16. )
  17. model_choices = ClassChoices(
  18. "model",
  19. classes=dict(
  20. lm=LanguageModel,
  21. ),
  22. default="lm",
  23. )
  24. class_choices_list = [
  25. # --lm and --lm_conf
  26. lm_choices,
  27. # --model and --model_conf
  28. model_choices
  29. ]
  30. def build_lm_model(args):
  31. # token_list
  32. if args.token_list is not None:
  33. with open(args.token_list) as f:
  34. token_list = [line.rstrip() for line in f]
  35. args.token_list = list(token_list)
  36. vocab_size = len(token_list)
  37. logging.info(f"Vocabulary size: {vocab_size}")
  38. else:
  39. vocab_size = None
  40. # lm
  41. lm_class = lm_choices.get_class(args.lm)
  42. lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
  43. model_class = model_choices.get_class(args.model)
  44. model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
  45. # initialize
  46. if args.init is not None:
  47. initialize(model, args.init)
  48. return model