whisper.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from funasr.datasets.collate_fn import CommonCollateFn
  16. from funasr.datasets.preprocessor import CommonPreprocessor
  17. from funasr.layers.abs_normalize import AbsNormalize
  18. from funasr.layers.global_mvn import GlobalMVN
  19. from funasr.layers.utterance_mvn import UtteranceMVN
  20. from funasr.models.ctc import CTC
  21. from funasr.models.decoder.abs_decoder import AbsDecoder
  22. from funasr.models.decoder.rnn_decoder import RNNDecoder
  23. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  24. from funasr.models.decoder.transformer_decoder import (
  25. DynamicConvolution2DTransformerDecoder, # noqa: H301
  26. )
  27. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  28. from funasr.models.decoder.transformer_decoder import (
  29. LightweightConvolution2DTransformerDecoder, # noqa: H301
  30. )
  31. from funasr.models.decoder.transformer_decoder import (
  32. LightweightConvolutionTransformerDecoder, # noqa: H301
  33. )
  34. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  35. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  36. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  37. from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
  38. from funasr.models.e2e_asr import ASRModel
  39. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  40. from funasr.models.joint_net.joint_network import JointNetwork
  41. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
  42. from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
  43. from funasr.models.e2e_tp import TimestampPredictor
  44. from funasr.models.e2e_asr_mfcca import MFCCA
  45. from funasr.models.e2e_sa_asr import SAASRModel
  46. from funasr.models.e2e_uni_asr import UniASR
  47. from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
  48. from funasr.models.e2e_asr_bat import BATModel
  49. from funasr.models.encoder.abs_encoder import AbsEncoder
  50. from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
  51. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  52. from funasr.models.encoder.rnn_encoder import RNNEncoder
  53. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  54. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  55. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  56. from funasr.models.encoder.resnet34_encoder import ResNet34Diar
  57. from funasr.models.frontend.abs_frontend import AbsFrontend
  58. from funasr.models.frontend.default import DefaultFrontend
  59. from funasr.models.frontend.default import MultiChannelFrontend
  60. from funasr.models.frontend.fused import FusedFrontends
  61. from funasr.models.frontend.s3prl import S3prlFrontend
  62. from funasr.models.frontend.wav_frontend import WavFrontend
  63. from funasr.models.frontend.windowing import SlidingWindow
  64. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  65. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  66. HuggingFaceTransformersPostEncoder, # noqa: H301
  67. )
  68. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
  69. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  70. from funasr.models.preencoder.linear import LinearProjection
  71. from funasr.models.preencoder.sinc import LightweightSincConvs
  72. from funasr.models.specaug.abs_specaug import AbsSpecAug
  73. from funasr.models.specaug.specaug import SpecAug
  74. from funasr.models.specaug.specaug import SpecAugLFR
  75. from funasr.modules.subsampling import Conv1dSubsampling
  76. from funasr.tasks.abs_task import AbsTask
  77. from funasr.tokenizer.phoneme_tokenizer import g2p_choices
  78. from funasr.torch_utils.initialize import initialize
  79. from funasr.models.base_model import FunASRModel
  80. from funasr.train.class_choices import ClassChoices
  81. from funasr.train.trainer import Trainer
  82. from funasr.utils.get_default_kwargs import get_default_kwargs
  83. from funasr.utils.nested_dict_action import NestedDictAction
  84. from funasr.utils.types import float_or_none
  85. from funasr.utils.types import int_or_none
  86. from funasr.utils.types import str2bool
  87. from funasr.utils.types import str_or_none
  88. from funasr.models.whisper_models.model import Whisper, AudioEncoder, TextDecoder
  89. frontend_choices = ClassChoices(
  90. name="frontend",
  91. classes=dict(
  92. default=DefaultFrontend,
  93. sliding_window=SlidingWindow,
  94. s3prl=S3prlFrontend,
  95. fused=FusedFrontends,
  96. wav_frontend=WavFrontend,
  97. multichannelfrontend=MultiChannelFrontend,
  98. ),
  99. type_check=AbsFrontend,
  100. default="default",
  101. )
  102. specaug_choices = ClassChoices(
  103. name="specaug",
  104. classes=dict(
  105. specaug=SpecAug,
  106. specaug_lfr=SpecAugLFR,
  107. ),
  108. type_check=AbsSpecAug,
  109. default=None,
  110. optional=True,
  111. )
  112. normalize_choices = ClassChoices(
  113. "normalize",
  114. classes=dict(
  115. global_mvn=GlobalMVN,
  116. utterance_mvn=UtteranceMVN,
  117. ),
  118. type_check=AbsNormalize,
  119. default=None,
  120. optional=True,
  121. )
  122. model_choices = ClassChoices(
  123. "model",
  124. classes=dict(
  125. asr=ASRModel,
  126. uniasr=UniASR,
  127. paraformer=Paraformer,
  128. paraformer_online=ParaformerOnline,
  129. paraformer_bert=ParaformerBert,
  130. bicif_paraformer=BiCifParaformer,
  131. contextual_paraformer=ContextualParaformer,
  132. neatcontextual_paraformer=NeatContextualParaformer,
  133. mfcca=MFCCA,
  134. timestamp_prediction=TimestampPredictor,
  135. rnnt=TransducerModel,
  136. rnnt_unified=UnifiedTransducerModel,
  137. bat=BATModel,
  138. sa_asr=SAASRModel,
  139. whisper=Whisper,
  140. ),
  141. type_check=FunASRModel,
  142. default="asr",
  143. )
  144. preencoder_choices = ClassChoices(
  145. name="preencoder",
  146. classes=dict(
  147. sinc=LightweightSincConvs,
  148. linear=LinearProjection,
  149. ),
  150. type_check=AbsPreEncoder,
  151. default=None,
  152. optional=True,
  153. )
  154. encoder_choices = ClassChoices(
  155. "encoder",
  156. classes=dict(
  157. conformer=ConformerEncoder,
  158. transformer=TransformerEncoder,
  159. rnn=RNNEncoder,
  160. sanm=SANMEncoder,
  161. sanm_chunk_opt=SANMEncoderChunkOpt,
  162. data2vec_encoder=Data2VecEncoder,
  163. mfcca_enc=MFCCAEncoder,
  164. chunk_conformer=ConformerChunkEncoder,
  165. ),
  166. type_check=AbsEncoder,
  167. default="rnn",
  168. )
  169. encoder_choices2 = ClassChoices(
  170. "encoder2",
  171. classes=dict(
  172. conformer=ConformerEncoder,
  173. transformer=TransformerEncoder,
  174. rnn=RNNEncoder,
  175. sanm=SANMEncoder,
  176. sanm_chunk_opt=SANMEncoderChunkOpt,
  177. ),
  178. type_check=AbsEncoder,
  179. default="rnn",
  180. )
  181. asr_encoder_choices = ClassChoices(
  182. "asr_encoder",
  183. classes=dict(
  184. conformer=ConformerEncoder,
  185. transformer=TransformerEncoder,
  186. rnn=RNNEncoder,
  187. sanm=SANMEncoder,
  188. sanm_chunk_opt=SANMEncoderChunkOpt,
  189. data2vec_encoder=Data2VecEncoder,
  190. mfcca_enc=MFCCAEncoder,
  191. ),
  192. type_check=AbsEncoder,
  193. default="rnn",
  194. )
  195. spk_encoder_choices = ClassChoices(
  196. "spk_encoder",
  197. classes=dict(
  198. resnet34_diar=ResNet34Diar,
  199. ),
  200. default="resnet34_diar",
  201. )
  202. postencoder_choices = ClassChoices(
  203. name="postencoder",
  204. classes=dict(
  205. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  206. ),
  207. type_check=AbsPostEncoder,
  208. default=None,
  209. optional=True,
  210. )
  211. decoder_choices = ClassChoices(
  212. "decoder",
  213. classes=dict(
  214. transformer=TransformerDecoder,
  215. lightweight_conv=LightweightConvolutionTransformerDecoder,
  216. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  217. dynamic_conv=DynamicConvolutionTransformerDecoder,
  218. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  219. rnn=RNNDecoder,
  220. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  221. paraformer_decoder_sanm=ParaformerSANMDecoder,
  222. paraformer_decoder_san=ParaformerDecoderSAN,
  223. contextual_paraformer_decoder=ContextualParaformerDecoder,
  224. sa_decoder=SAAsrTransformerDecoder,
  225. ),
  226. type_check=AbsDecoder,
  227. default="rnn",
  228. )
  229. decoder_choices2 = ClassChoices(
  230. "decoder2",
  231. classes=dict(
  232. transformer=TransformerDecoder,
  233. lightweight_conv=LightweightConvolutionTransformerDecoder,
  234. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  235. dynamic_conv=DynamicConvolutionTransformerDecoder,
  236. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  237. rnn=RNNDecoder,
  238. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  239. paraformer_decoder_sanm=ParaformerSANMDecoder,
  240. ),
  241. type_check=AbsDecoder,
  242. default="rnn",
  243. )
  244. rnnt_decoder_choices = ClassChoices(
  245. "rnnt_decoder",
  246. classes=dict(
  247. rnnt=RNNTDecoder,
  248. ),
  249. type_check=RNNTDecoder,
  250. default="rnnt",
  251. )
  252. joint_network_choices = ClassChoices(
  253. name="joint_network",
  254. classes=dict(
  255. joint_network=JointNetwork,
  256. ),
  257. default="joint_network",
  258. optional=True,
  259. )
  260. predictor_choices = ClassChoices(
  261. name="predictor",
  262. classes=dict(
  263. cif_predictor=CifPredictor,
  264. ctc_predictor=None,
  265. cif_predictor_v2=CifPredictorV2,
  266. cif_predictor_v3=CifPredictorV3,
  267. bat_predictor=BATPredictor,
  268. ),
  269. type_check=None,
  270. default="cif_predictor",
  271. optional=True,
  272. )
  273. predictor_choices2 = ClassChoices(
  274. name="predictor2",
  275. classes=dict(
  276. cif_predictor=CifPredictor,
  277. ctc_predictor=None,
  278. cif_predictor_v2=CifPredictorV2,
  279. ),
  280. type_check=None,
  281. default="cif_predictor",
  282. optional=True,
  283. )
  284. stride_conv_choices = ClassChoices(
  285. name="stride_conv",
  286. classes=dict(
  287. stride_conv1d=Conv1dSubsampling
  288. ),
  289. type_check=None,
  290. default="stride_conv1d",
  291. optional=True,
  292. )
  293. class ASRTask(AbsTask):
  294. # If you need more than one optimizers, change this value
  295. num_optimizers: int = 1
  296. # Add variable objects configurations
  297. class_choices_list = [
  298. # --frontend and --frontend_conf
  299. frontend_choices,
  300. # --specaug and --specaug_conf
  301. specaug_choices,
  302. # --normalize and --normalize_conf
  303. normalize_choices,
  304. # --model and --model_conf
  305. model_choices,
  306. # --preencoder and --preencoder_conf
  307. preencoder_choices,
  308. # --encoder and --encoder_conf
  309. encoder_choices,
  310. # --postencoder and --postencoder_conf
  311. postencoder_choices,
  312. # --decoder and --decoder_conf
  313. decoder_choices,
  314. # --predictor and --predictor_conf
  315. predictor_choices,
  316. # --encoder2 and --encoder2_conf
  317. encoder_choices2,
  318. # --decoder2 and --decoder2_conf
  319. decoder_choices2,
  320. # --predictor2 and --predictor2_conf
  321. predictor_choices2,
  322. # --stride_conv and --stride_conv_conf
  323. stride_conv_choices,
  324. # --rnnt_decoder and --rnnt_decoder_conf
  325. rnnt_decoder_choices,
  326. ]
  327. # If you need to modify train() or eval() procedures, change Trainer class here
  328. trainer = Trainer
  329. @classmethod
  330. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  331. group = parser.add_argument_group(description="Task related")
  332. # NOTE(kamo): add_arguments(..., required=True) can't be used
  333. # to provide --print_config mode. Instead of it, do as
  334. # required = parser.get_default("required")
  335. # required += ["token_list"]
  336. group.add_argument(
  337. "--token_list",
  338. type=str_or_none,
  339. default=None,
  340. help="A text mapping int-id to token",
  341. )
  342. group.add_argument(
  343. "--split_with_space",
  344. type=str2bool,
  345. default=True,
  346. help="whether to split text using <space>",
  347. )
  348. group.add_argument(
  349. "--max_spk_num",
  350. type=int_or_none,
  351. default=None,
  352. help="A text mapping int-id to token",
  353. )
  354. group.add_argument(
  355. "--seg_dict_file",
  356. type=str,
  357. default=None,
  358. help="seg_dict_file for text processing",
  359. )
  360. group.add_argument(
  361. "--init",
  362. type=lambda x: str_or_none(x.lower()),
  363. default=None,
  364. help="The initialization method",
  365. choices=[
  366. "chainer",
  367. "xavier_uniform",
  368. "xavier_normal",
  369. "kaiming_uniform",
  370. "kaiming_normal",
  371. None,
  372. ],
  373. )
  374. group.add_argument(
  375. "--input_size",
  376. type=int_or_none,
  377. default=None,
  378. help="The number of input dimension of the feature",
  379. )
  380. group.add_argument(
  381. "--ctc_conf",
  382. action=NestedDictAction,
  383. default=get_default_kwargs(CTC),
  384. help="The keyword arguments for CTC class.",
  385. )
  386. group = parser.add_argument_group(description="Preprocess related")
  387. group.add_argument(
  388. "--use_preprocessor",
  389. type=str2bool,
  390. default=True,
  391. help="Apply preprocessing to data or not",
  392. )
  393. group.add_argument(
  394. "--token_type",
  395. type=str,
  396. default="bpe",
  397. choices=["bpe", "char", "word", "phn"],
  398. help="The text will be tokenized " "in the specified level token",
  399. )
  400. group.add_argument(
  401. "--bpemodel",
  402. type=str_or_none,
  403. default=None,
  404. help="The model file of sentencepiece",
  405. )
  406. parser.add_argument(
  407. "--non_linguistic_symbols",
  408. type=str_or_none,
  409. default=None,
  410. help="non_linguistic_symbols file path",
  411. )
  412. parser.add_argument(
  413. "--cleaner",
  414. type=str_or_none,
  415. choices=[None, "tacotron", "jaconv", "vietnamese"],
  416. default=None,
  417. help="Apply text cleaning",
  418. )
  419. parser.add_argument(
  420. "--g2p",
  421. type=str_or_none,
  422. choices=g2p_choices,
  423. default=None,
  424. help="Specify g2p method if --token_type=phn",
  425. )
  426. parser.add_argument(
  427. "--speech_volume_normalize",
  428. type=float_or_none,
  429. default=None,
  430. help="Scale the maximum amplitude to the given value.",
  431. )
  432. parser.add_argument(
  433. "--rir_scp",
  434. type=str_or_none,
  435. default=None,
  436. help="The file path of rir scp file.",
  437. )
  438. parser.add_argument(
  439. "--rir_apply_prob",
  440. type=float,
  441. default=1.0,
  442. help="THe probability for applying RIR convolution.",
  443. )
  444. parser.add_argument(
  445. "--cmvn_file",
  446. type=str_or_none,
  447. default=None,
  448. help="The file path of noise scp file.",
  449. )
  450. parser.add_argument(
  451. "--noise_scp",
  452. type=str_or_none,
  453. default=None,
  454. help="The file path of noise scp file.",
  455. )
  456. parser.add_argument(
  457. "--noise_apply_prob",
  458. type=float,
  459. default=1.0,
  460. help="The probability applying Noise adding.",
  461. )
  462. parser.add_argument(
  463. "--noise_db_range",
  464. type=str,
  465. default="13_15",
  466. help="The range of noise decibel level.",
  467. )
  468. for class_choices in cls.class_choices_list:
  469. # Append --<name> and --<name>_conf.
  470. # e.g. --encoder and --encoder_conf
  471. class_choices.add_arguments(group)
  472. @classmethod
  473. def build_collate_fn(
  474. cls, args: argparse.Namespace, train: bool
  475. ) -> Callable[
  476. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  477. Tuple[List[str], Dict[str, torch.Tensor]],
  478. ]:
  479. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  480. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  481. @classmethod
  482. def build_preprocess_fn(
  483. cls, args: argparse.Namespace, train: bool
  484. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  485. if args.use_preprocessor:
  486. retval = CommonPreprocessor(
  487. train=train,
  488. token_type=args.token_type,
  489. token_list=args.token_list,
  490. bpemodel=args.bpemodel,
  491. non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
  492. text_cleaner=args.cleaner,
  493. g2p_type=args.g2p,
  494. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  495. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  496. # NOTE(kamo): Check attribute existence for backward compatibility
  497. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  498. rir_apply_prob=args.rir_apply_prob
  499. if hasattr(args, "rir_apply_prob")
  500. else 1.0,
  501. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  502. noise_apply_prob=args.noise_apply_prob
  503. if hasattr(args, "noise_apply_prob")
  504. else 1.0,
  505. noise_db_range=args.noise_db_range
  506. if hasattr(args, "noise_db_range")
  507. else "13_15",
  508. speech_volume_normalize=args.speech_volume_normalize
  509. if hasattr(args, "rir_scp")
  510. else None,
  511. )
  512. else:
  513. retval = None
  514. return retval
  515. @classmethod
  516. def required_data_names(
  517. cls, train: bool = True, inference: bool = False
  518. ) -> Tuple[str, ...]:
  519. if not inference:
  520. retval = ("speech", "text")
  521. else:
  522. # Recognition mode
  523. retval = ("speech",)
  524. return retval
  525. @classmethod
  526. def optional_data_names(
  527. cls, train: bool = True, inference: bool = False
  528. ) -> Tuple[str, ...]:
  529. retval = ()
  530. return retval
  531. @classmethod
  532. def build_model(cls, args: argparse.Namespace):
  533. if args.token_list is not None:
  534. if isinstance(args.token_list, str):
  535. with open(args.token_list, encoding="utf-8") as f:
  536. token_list = [line.rstrip() for line in f]
  537. # Overwriting token_list to keep it as "portable".
  538. args.token_list = list(token_list)
  539. elif isinstance(args.token_list, (tuple, list)):
  540. token_list = list(args.token_list)
  541. else:
  542. raise RuntimeError("token_list must be str or list")
  543. vocab_size = len(token_list)
  544. logging.info(f"Vocabulary size: {vocab_size}")
  545. else:
  546. vocab_size = args.vocab_size
  547. # 1. frontend
  548. if args.input_size is None:
  549. # Extract features in the model
  550. frontend_class = frontend_choices.get_class(args.frontend)
  551. if args.frontend == 'wav_frontend':
  552. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  553. else:
  554. frontend = frontend_class(**args.frontend_conf)
  555. input_size = frontend.output_size()
  556. else:
  557. # Give features from data-loader
  558. args.frontend = None
  559. args.frontend_conf = {}
  560. frontend = None
  561. input_size = args.input_size
  562. # 2. Data augmentation for spectrogram
  563. if args.specaug is not None:
  564. specaug_class = specaug_choices.get_class(args.specaug)
  565. specaug = specaug_class(**args.specaug_conf)
  566. else:
  567. specaug = None
  568. # 3. Normalization layer
  569. if args.normalize is not None:
  570. normalize_class = normalize_choices.get_class(args.normalize)
  571. normalize = normalize_class(**args.normalize_conf)
  572. else:
  573. normalize = None
  574. # 9. Build model
  575. try:
  576. model_class = model_choices.get_class(args.model)
  577. except AttributeError:
  578. model_class = model_choices.get_class("asr")
  579. model = model_class(
  580. args.whisper_dims,
  581. )
  582. # 10. Initialize
  583. if args.init is not None:
  584. initialize(model, args.init)
  585. return model
  586. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  587. @classmethod
  588. def build_model_from_file(
  589. cls,
  590. config_file: Union[Path, str] = None,
  591. model_file: Union[Path, str] = None,
  592. cmvn_file: Union[Path, str] = None,
  593. device: str = "cpu",
  594. ):
  595. """Build model from the files.
  596. This method is used for inference or fine-tuning.
  597. Args:
  598. config_file: The yaml file saved when training.
  599. model_file: The model file saved when training.
  600. device: Device type, "cpu", "cuda", or "cuda:N".
  601. """
  602. if config_file is None:
  603. assert model_file is not None, (
  604. "The argument 'model_file' must be provided "
  605. "if the argument 'config_file' is not specified."
  606. )
  607. config_file = Path(model_file).parent / "config.yaml"
  608. else:
  609. config_file = Path(config_file)
  610. with config_file.open("r", encoding="utf-8") as f:
  611. args = yaml.safe_load(f)
  612. if cmvn_file is not None:
  613. args["cmvn_file"] = cmvn_file
  614. args = argparse.Namespace(**args)
  615. if model_file is not None:
  616. model_dict = torch.load(model_file, map_location=device)
  617. args.whisper_dims = model_dict["dims"]
  618. model = cls.build_model(args)
  619. if not isinstance(model, FunASRModel):
  620. raise RuntimeError(
  621. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  622. )
  623. model.to(device)
  624. model_dict = dict()
  625. model_name_pth = None
  626. if model_file is not None:
  627. logging.info("model_file is {}".format(model_file))
  628. if device == "cuda":
  629. device = f"cuda:{torch.cuda.current_device()}"
  630. model_dir = os.path.dirname(model_file)
  631. model_name = os.path.basename(model_file)
  632. model_dict = torch.load(model_file, map_location=device)
  633. model.load_state_dict(model_dict["model_state_dict"])
  634. if model_name_pth is not None and not os.path.exists(model_name_pth):
  635. torch.save(model_dict, model_name_pth)
  636. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  637. return model, args