build_lm_model.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import logging
  2. from funasr.train.abs_model import AbsLM
  3. from funasr.train.abs_model import LanguageModel
  4. from funasr.models.seq_rnn_lm import SequentialRNNLM
  5. from funasr.models.transformer_lm import TransformerLM
  6. from funasr.torch_utils.initialize import initialize
  7. from funasr.train.class_choices import ClassChoices
  8. lm_choices = ClassChoices(
  9. "lm",
  10. classes=dict(
  11. seq_rnn=SequentialRNNLM,
  12. transformer=TransformerLM,
  13. ),
  14. type_check=AbsLM,
  15. default="seq_rnn",
  16. )
  17. model_choices = ClassChoices(
  18. "model",
  19. classes=dict(
  20. lm=LanguageModel,
  21. ),
  22. default="lm",
  23. )
  24. class_choices_list = [
  25. # --lm and --lm_conf
  26. lm_choices,
  27. # --model and --model_conf
  28. model_choices
  29. ]
  30. def build_lm_model(args):
  31. # token_list
  32. if isinstance(args.token_list, str):
  33. with open(args.token_list, encoding="utf-8") as f:
  34. token_list = [line.rstrip() for line in f]
  35. args.token_list = list(token_list)
  36. vocab_size = len(token_list)
  37. logging.info(f"Vocabulary size: {vocab_size}")
  38. elif isinstance(args.token_list, (tuple, list)):
  39. token_list = list(args.token_list)
  40. vocab_size = len(token_list)
  41. logging.info(f"Vocabulary size: {vocab_size}")
  42. else:
  43. vocab_size = None
  44. # lm
  45. lm_class = lm_choices.get_class(args.lm)
  46. lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
  47. args.model = args.model if hasattr(args, "model") else "lm"
  48. model_class = model_choices.get_class(args.model)
  49. model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
  50. # initialize
  51. if args.init is not None:
  52. initialize(model, args.init)
  53. return model