build_punc_model.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import logging
  2. from funasr.models.target_delay_transformer import TargetDelayTransformer
  3. from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
  4. from funasr.torch_utils.initialize import initialize
  5. from funasr.train.abs_model import PunctuationModel
  6. from funasr.train.class_choices import ClassChoices
  7. punc_choices = ClassChoices(
  8. "punctuation",
  9. classes=dict(
  10. target_delay=TargetDelayTransformer,
  11. vad_realtime=VadRealtimeTransformer
  12. ),
  13. default="target_delay",
  14. )
  15. model_choices = ClassChoices(
  16. "model",
  17. classes=dict(
  18. punc=PunctuationModel,
  19. ),
  20. default="punc",
  21. )
  22. class_choices_list = [
  23. # --punc and --punc_conf
  24. punc_choices,
  25. # --model and --model_conf
  26. model_choices
  27. ]
  28. def build_punc_model(args):
  29. # token_list and punc list
  30. if isinstance(args.token_list, str):
  31. with open(args.token_list, encoding="utf-8") as f:
  32. token_list = [line.rstrip() for line in f]
  33. args.token_list = token_list.copy()
  34. if isinstance(args.punc_list, str):
  35. with open(args.punc_list, encoding="utf-8") as f2:
  36. pairs = [line.rstrip().split(":") for line in f2]
  37. punc_list = [pair[0] for pair in pairs]
  38. punc_weight_list = [float(pair[1]) for pair in pairs]
  39. args.punc_list = punc_list.copy()
  40. elif isinstance(args.punc_list, list):
  41. punc_list = args.punc_list.copy()
  42. punc_weight_list = [1] * len(punc_list)
  43. if isinstance(args.token_list, (tuple, list)):
  44. token_list = args.token_list.copy()
  45. else:
  46. raise RuntimeError("token_list must be str or dict")
  47. vocab_size = len(token_list)
  48. punc_size = len(punc_list)
  49. logging.info(f"Vocabulary size: {vocab_size}")
  50. # punc
  51. punc_class = punc_choices.get_class(args.punctuation)
  52. punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
  53. if "punc_weight" in args.model_conf:
  54. args.model_conf.pop("punc_weight")
  55. model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
  56. # initialize
  57. if args.init is not None:
  58. initialize(model, args.init)
  59. return model