model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. import re
  7. import time
  8. import torch
  9. import codecs
  10. import logging
  11. import tempfile
  12. import requests
  13. import numpy as np
  14. from typing import Dict, Tuple
  15. from contextlib import contextmanager
  16. from distutils.version import LooseVersion
  17. from funasr.register import tables
  18. from funasr.losses.label_smoothing_loss import (
  19. LabelSmoothingLoss, # noqa: H301
  20. )
  21. from funasr.utils import postprocess_utils
  22. from funasr.metrics.compute_acc import th_accuracy
  23. from funasr.models.paraformer.model import Paraformer
  24. from funasr.utils.datadir_writer import DatadirWriter
  25. from funasr.models.paraformer.search import Hypothesis
  26. from funasr.train_utils.device_funcs import force_gatherable
  27. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  28. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  29. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  30. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  31. from torch.cuda.amp import autocast
  32. else:
  33. # Nothing to do if torch<1.6.0
  34. @contextmanager
  35. def autocast(enabled=True):
  36. yield
  37. @tables.register("model_classes", "ContextualParaformer")
  38. class ContextualParaformer(Paraformer):
  39. """
  40. Author: Speech Lab of DAMO Academy, Alibaba Group
  41. FunASR: A Fundamental End-to-End Speech Recognition Toolkit
  42. https://arxiv.org/abs/2305.11013
  43. """
  44. def __init__(
  45. self,
  46. *args,
  47. **kwargs,
  48. ):
  49. super().__init__(*args, **kwargs)
  50. self.target_buffer_length = kwargs.get("target_buffer_length", -1)
  51. inner_dim = kwargs.get("inner_dim", 256)
  52. bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
  53. use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
  54. crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
  55. crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
  56. bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
  57. if bias_encoder_type == 'lstm':
  58. self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
  59. self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
  60. elif bias_encoder_type == 'mean':
  61. self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
  62. else:
  63. logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
  64. if self.target_buffer_length > 0:
  65. self.hotword_buffer = None
  66. self.length_record = []
  67. self.current_buffer_length = 0
  68. self.use_decoder_embedding = use_decoder_embedding
  69. self.crit_attn_weight = crit_attn_weight
  70. if self.crit_attn_weight > 0:
  71. self.attn_loss = torch.nn.L1Loss()
  72. self.crit_attn_smooth = crit_attn_smooth
  73. def forward(
  74. self,
  75. speech: torch.Tensor,
  76. speech_lengths: torch.Tensor,
  77. text: torch.Tensor,
  78. text_lengths: torch.Tensor,
  79. **kwargs,
  80. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  81. """Frontend + Encoder + Decoder + Calc loss
  82. Args:
  83. speech: (Batch, Length, ...)
  84. speech_lengths: (Batch, )
  85. text: (Batch, Length)
  86. text_lengths: (Batch,)
  87. """
  88. if len(text_lengths.size()) > 1:
  89. text_lengths = text_lengths[:, 0]
  90. if len(speech_lengths.size()) > 1:
  91. speech_lengths = speech_lengths[:, 0]
  92. batch_size = speech.shape[0]
  93. hotword_pad = kwargs.get("hotword_pad")
  94. hotword_lengths = kwargs.get("hotword_lengths")
  95. dha_pad = kwargs.get("dha_pad")
  96. # 1. Encoder
  97. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  98. loss_ctc, cer_ctc = None, None
  99. stats = dict()
  100. # 1. CTC branch
  101. if self.ctc_weight != 0.0:
  102. loss_ctc, cer_ctc = self._calc_ctc_loss(
  103. encoder_out, encoder_out_lens, text, text_lengths
  104. )
  105. # Collect CTC branch stats
  106. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  107. stats["cer_ctc"] = cer_ctc
  108. # 2b. Attention decoder branch
  109. loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
  110. encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
  111. )
  112. # 3. CTC-Att loss definition
  113. if self.ctc_weight == 0.0:
  114. loss = loss_att + loss_pre * self.predictor_weight
  115. else:
  116. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  117. if loss_ideal is not None:
  118. loss = loss + loss_ideal * self.crit_attn_weight
  119. stats["loss_ideal"] = loss_ideal.detach().cpu()
  120. # Collect Attn branch stats
  121. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  122. stats["acc"] = acc_att
  123. stats["cer"] = cer_att
  124. stats["wer"] = wer_att
  125. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  126. stats["loss"] = torch.clone(loss.detach())
  127. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  128. if self.length_normalized_loss:
  129. batch_size = int((text_lengths + self.predictor_bias).sum())
  130. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  131. return loss, stats, weight
  132. def _calc_att_clas_loss(
  133. self,
  134. encoder_out: torch.Tensor,
  135. encoder_out_lens: torch.Tensor,
  136. ys_pad: torch.Tensor,
  137. ys_pad_lens: torch.Tensor,
  138. hotword_pad: torch.Tensor,
  139. hotword_lengths: torch.Tensor,
  140. ):
  141. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  142. encoder_out.device)
  143. if self.predictor_bias == 1:
  144. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  145. ys_pad_lens = ys_pad_lens + self.predictor_bias
  146. pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  147. ignore_id=self.ignore_id)
  148. # -1. bias encoder
  149. if self.use_decoder_embedding:
  150. hw_embed = self.decoder.embed(hotword_pad)
  151. else:
  152. hw_embed = self.bias_embed(hotword_pad)
  153. hw_embed, (_, _) = self.bias_encoder(hw_embed)
  154. _ind = np.arange(0, hotword_pad.shape[0]).tolist()
  155. selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
  156. contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
  157. # 0. sampler
  158. decoder_out_1st = None
  159. if self.sampling_ratio > 0.0:
  160. if self.step_cur < 2:
  161. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  162. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  163. pre_acoustic_embeds, contextual_info)
  164. else:
  165. if self.step_cur < 2:
  166. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  167. sematic_embeds = pre_acoustic_embeds
  168. # 1. Forward decoder
  169. decoder_outs = self.decoder(
  170. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  171. )
  172. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  173. '''
  174. if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
  175. ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
  176. attn_non_blank = attn[:,:,:,:-1]
  177. ideal_attn_non_blank = ideal_attn[:,:,:-1]
  178. loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
  179. else:
  180. loss_ideal = None
  181. '''
  182. loss_ideal = None
  183. if decoder_out_1st is None:
  184. decoder_out_1st = decoder_out
  185. # 2. Compute attention loss
  186. loss_att = self.criterion_att(decoder_out, ys_pad)
  187. acc_att = th_accuracy(
  188. decoder_out_1st.view(-1, self.vocab_size),
  189. ys_pad,
  190. ignore_label=self.ignore_id,
  191. )
  192. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  193. # Compute cer/wer using attention-decoder
  194. if self.training or self.error_calculator is None:
  195. cer_att, wer_att = None, None
  196. else:
  197. ys_hat = decoder_out_1st.argmax(dim=-1)
  198. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  199. return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
  200. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
  201. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  202. ys_pad = ys_pad * tgt_mask[:, :, 0]
  203. if self.share_embedding:
  204. ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
  205. else:
  206. ys_pad_embed = self.decoder.embed(ys_pad)
  207. with torch.no_grad():
  208. decoder_outs = self.decoder(
  209. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
  210. )
  211. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  212. pred_tokens = decoder_out.argmax(-1)
  213. nonpad_positions = ys_pad.ne(self.ignore_id)
  214. seq_lens = (nonpad_positions).sum(1)
  215. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  216. input_mask = torch.ones_like(nonpad_positions)
  217. bsz, seq_len = ys_pad.size()
  218. for li in range(bsz):
  219. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  220. if target_num > 0:
  221. input_mask[li].scatter_(dim=0,
  222. index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
  223. value=0)
  224. input_mask = input_mask.eq(1)
  225. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  226. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  227. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  228. input_mask_expand_dim, 0)
  229. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  230. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
  231. clas_scale=1.0):
  232. if hw_list is None:
  233. hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
  234. hw_list_pad = pad_list(hw_list, 0)
  235. if self.use_decoder_embedding:
  236. hw_embed = self.decoder.embed(hw_list_pad)
  237. else:
  238. hw_embed = self.bias_embed(hw_list_pad)
  239. hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
  240. hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
  241. else:
  242. hw_lengths = [len(i) for i in hw_list]
  243. hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
  244. if self.use_decoder_embedding:
  245. hw_embed = self.decoder.embed(hw_list_pad)
  246. else:
  247. hw_embed = self.bias_embed(hw_list_pad)
  248. hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
  249. enforce_sorted=False)
  250. _, (h_n, _) = self.bias_encoder(hw_embed)
  251. hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
  252. decoder_outs = self.decoder(
  253. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
  254. )
  255. decoder_out = decoder_outs[0]
  256. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  257. return decoder_out, ys_pad_lens
  258. def inference(self,
  259. data_in,
  260. data_lengths=None,
  261. key: list = None,
  262. tokenizer=None,
  263. frontend=None,
  264. **kwargs,
  265. ):
  266. # init beamsearch
  267. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  268. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  269. if self.beam_search is None and (is_use_lm or is_use_ctc):
  270. logging.info("enable beam_search")
  271. self.init_beam_search(**kwargs)
  272. self.nbest = kwargs.get("nbest", 1)
  273. meta_data = {}
  274. # extract fbank feats
  275. time1 = time.perf_counter()
  276. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
  277. time2 = time.perf_counter()
  278. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  279. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  280. frontend=frontend)
  281. time3 = time.perf_counter()
  282. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  283. meta_data[
  284. "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  285. speech = speech.to(device=kwargs["device"])
  286. speech_lengths = speech_lengths.to(device=kwargs["device"])
  287. # hotword
  288. self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
  289. # Encoder
  290. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  291. if isinstance(encoder_out, tuple):
  292. encoder_out = encoder_out[0]
  293. # predictor
  294. predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
  295. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  296. predictor_outs[2], predictor_outs[3]
  297. pre_token_length = pre_token_length.round().long()
  298. if torch.max(pre_token_length) < 1:
  299. return []
  300. decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
  301. pre_acoustic_embeds,
  302. pre_token_length,
  303. hw_list=self.hotword_list,
  304. clas_scale=kwargs.get("clas_scale", 1.0))
  305. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  306. results = []
  307. b, n, d = decoder_out.size()
  308. for i in range(b):
  309. x = encoder_out[i, :encoder_out_lens[i], :]
  310. am_scores = decoder_out[i, :pre_token_length[i], :]
  311. if self.beam_search is not None:
  312. nbest_hyps = self.beam_search(
  313. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  314. minlenratio=kwargs.get("minlenratio", 0.0)
  315. )
  316. nbest_hyps = nbest_hyps[: self.nbest]
  317. else:
  318. yseq = am_scores.argmax(dim=-1)
  319. score = am_scores.max(dim=-1)[0]
  320. score = torch.sum(score, dim=-1)
  321. # pad with mask tokens to ensure compatibility with sos/eos tokens
  322. yseq = torch.tensor(
  323. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  324. )
  325. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  326. for nbest_idx, hyp in enumerate(nbest_hyps):
  327. ibest_writer = None
  328. if kwargs.get("output_dir") is not None:
  329. if not hasattr(self, "writer"):
  330. self.writer = DatadirWriter(kwargs.get("output_dir"))
  331. ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
  332. # remove sos/eos and get results
  333. last_pos = -1
  334. if isinstance(hyp.yseq, list):
  335. token_int = hyp.yseq[1:last_pos]
  336. else:
  337. token_int = hyp.yseq[1:last_pos].tolist()
  338. # remove blank symbol id, which is assumed to be 0
  339. token_int = list(
  340. filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  341. if tokenizer is not None:
  342. # Change integer-ids to tokens
  343. token = tokenizer.ids2tokens(token_int)
  344. text = tokenizer.tokens2text(token)
  345. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  346. result_i = {"key": key[i], "text": text_postprocessed}
  347. if ibest_writer is not None:
  348. ibest_writer["token"][key[i]] = " ".join(token)
  349. ibest_writer["text"][key[i]] = text
  350. ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
  351. else:
  352. result_i = {"key": key[i], "token_int": token_int}
  353. results.append(result_i)
  354. return results, meta_data
  355. def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
  356. def load_seg_dict(seg_dict_file):
  357. seg_dict = {}
  358. assert isinstance(seg_dict_file, str)
  359. with open(seg_dict_file, "r", encoding="utf8") as f:
  360. lines = f.readlines()
  361. for line in lines:
  362. s = line.strip().split()
  363. key = s[0]
  364. value = s[1:]
  365. seg_dict[key] = " ".join(value)
  366. return seg_dict
  367. def seg_tokenize(txt, seg_dict):
  368. pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
  369. out_txt = ""
  370. for word in txt:
  371. word = word.lower()
  372. if word in seg_dict:
  373. out_txt += seg_dict[word] + " "
  374. else:
  375. if pattern.match(word):
  376. for char in word:
  377. if char in seg_dict:
  378. out_txt += seg_dict[char] + " "
  379. else:
  380. out_txt += "<unk>" + " "
  381. else:
  382. out_txt += "<unk>" + " "
  383. return out_txt.strip().split()
  384. seg_dict = None
  385. if frontend.cmvn_file is not None:
  386. model_dir = os.path.dirname(frontend.cmvn_file)
  387. seg_dict_file = os.path.join(model_dir, 'seg_dict')
  388. if os.path.exists(seg_dict_file):
  389. seg_dict = load_seg_dict(seg_dict_file)
  390. else:
  391. seg_dict = None
  392. # for None
  393. if hotword_list_or_file is None:
  394. hotword_list = None
  395. # for local txt inputs
  396. elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
  397. logging.info("Attempting to parse hotwords from local txt...")
  398. hotword_list = []
  399. hotword_str_list = []
  400. with codecs.open(hotword_list_or_file, 'r') as fin:
  401. for line in fin.readlines():
  402. hw = line.strip()
  403. hw_list = hw.split()
  404. if seg_dict is not None:
  405. hw_list = seg_tokenize(hw_list, seg_dict)
  406. hotword_str_list.append(hw)
  407. hotword_list.append(tokenizer.tokens2ids(hw_list))
  408. hotword_list.append([self.sos])
  409. hotword_str_list.append('<s>')
  410. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  411. .format(hotword_list_or_file, hotword_str_list))
  412. # for url, download and generate txt
  413. elif hotword_list_or_file.startswith('http'):
  414. logging.info("Attempting to parse hotwords from url...")
  415. work_dir = tempfile.TemporaryDirectory().name
  416. if not os.path.exists(work_dir):
  417. os.makedirs(work_dir)
  418. text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
  419. local_file = requests.get(hotword_list_or_file)
  420. open(text_file_path, "wb").write(local_file.content)
  421. hotword_list_or_file = text_file_path
  422. hotword_list = []
  423. hotword_str_list = []
  424. with codecs.open(hotword_list_or_file, 'r') as fin:
  425. for line in fin.readlines():
  426. hw = line.strip()
  427. hw_list = hw.split()
  428. if seg_dict is not None:
  429. hw_list = seg_tokenize(hw_list, seg_dict)
  430. hotword_str_list.append(hw)
  431. hotword_list.append(tokenizer.tokens2ids(hw_list))
  432. hotword_list.append([self.sos])
  433. hotword_str_list.append('<s>')
  434. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  435. .format(hotword_list_or_file, hotword_str_list))
  436. # for text str input
  437. elif not hotword_list_or_file.endswith('.txt'):
  438. logging.info("Attempting to parse hotwords as str...")
  439. hotword_list = []
  440. hotword_str_list = []
  441. for hw in hotword_list_or_file.strip().split():
  442. hotword_str_list.append(hw)
  443. hw_list = hw.strip().split()
  444. if seg_dict is not None:
  445. hw_list = seg_tokenize(hw_list, seg_dict)
  446. hotword_list.append(tokenizer.tokens2ids(hw_list))
  447. hotword_list.append([self.sos])
  448. hotword_str_list.append('<s>')
  449. logging.info("Hotword list: {}.".format(hotword_str_list))
  450. else:
  451. hotword_list = None
  452. return hotword_list