model.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. import re
  7. import time
  8. import copy
  9. import torch
  10. import codecs
  11. import logging
  12. import tempfile
  13. import requests
  14. import numpy as np
  15. from typing import Dict, Tuple
  16. from contextlib import contextmanager
  17. from distutils.version import LooseVersion
  18. from funasr.register import tables
  19. from funasr.utils import postprocess_utils
  20. from funasr.models.paraformer.model import Paraformer
  21. from funasr.utils.datadir_writer import DatadirWriter
  22. from funasr.models.paraformer.search import Hypothesis
  23. from funasr.train_utils.device_funcs import force_gatherable
  24. from funasr.models.bicif_paraformer.model import BiCifParaformer
  25. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  26. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  27. from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
  28. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  29. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  30. import pdb
  31. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  32. from torch.cuda.amp import autocast
  33. else:
  34. # Nothing to do if torch<1.6.0
  35. @contextmanager
  36. def autocast(enabled=True):
  37. yield
  38. @tables.register("model_classes", "SeacoParaformer")
  39. class SeacoParaformer(BiCifParaformer, Paraformer):
  40. """
  41. Author: Speech Lab of DAMO Academy, Alibaba Group
  42. SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability
  43. https://arxiv.org/abs/2308.03266
  44. """
  45. def __init__(
  46. self,
  47. *args,
  48. **kwargs,
  49. ):
  50. super().__init__(*args, **kwargs)
  51. self.inner_dim = kwargs.get("inner_dim", 256)
  52. self.bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
  53. bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
  54. bias_encoder_bid = kwargs.get("bias_encoder_bid", False)
  55. seaco_lsm_weight = kwargs.get("seaco_lsm_weight", 0.0)
  56. seaco_length_normalized_loss = kwargs.get("seaco_length_normalized_loss", True)
  57. # bias encoder
  58. if self.bias_encoder_type == 'lstm':
  59. self.bias_encoder = torch.nn.LSTM(self.inner_dim,
  60. self.inner_dim,
  61. 2,
  62. batch_first=True,
  63. dropout=bias_encoder_dropout_rate,
  64. bidirectional=bias_encoder_bid)
  65. if bias_encoder_bid:
  66. self.lstm_proj = torch.nn.Linear(self.inner_dim*2, self.inner_dim)
  67. else:
  68. self.lstm_proj = None
  69. # self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim)
  70. elif self.bias_encoder_type == 'mean':
  71. self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim)
  72. else:
  73. logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type))
  74. # seaco decoder
  75. seaco_decoder = kwargs.get("seaco_decoder", None)
  76. if seaco_decoder is not None:
  77. seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
  78. seaco_decoder_class = tables.decoder_classes.get(seaco_decoder)
  79. self.seaco_decoder = seaco_decoder_class(
  80. vocab_size=self.vocab_size,
  81. encoder_output_size=self.inner_dim,
  82. **seaco_decoder_conf,
  83. )
  84. self.hotword_output_layer = torch.nn.Linear(self.inner_dim, self.vocab_size)
  85. self.criterion_seaco = LabelSmoothingLoss(
  86. size=self.vocab_size,
  87. padding_idx=self.ignore_id,
  88. smoothing=seaco_lsm_weight,
  89. normalize_length=seaco_length_normalized_loss,
  90. )
  91. self.train_decoder = kwargs.get("train_decoder", False)
  92. self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
  93. def forward(
  94. self,
  95. speech: torch.Tensor,
  96. speech_lengths: torch.Tensor,
  97. text: torch.Tensor,
  98. text_lengths: torch.Tensor,
  99. **kwargs,
  100. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  101. """Frontend + Encoder + Decoder + Calc loss
  102. Args:
  103. speech: (Batch, Length, ...)
  104. speech_lengths: (Batch, )
  105. text: (Batch, Length)
  106. text_lengths: (Batch,)
  107. """
  108. assert text_lengths.dim() == 1, text_lengths.shape
  109. # Check that batch_size is unified
  110. assert (
  111. speech.shape[0]
  112. == speech_lengths.shape[0]
  113. == text.shape[0]
  114. == text_lengths.shape[0]
  115. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  116. hotword_pad = kwargs.get("hotword_pad")
  117. hotword_lengths = kwargs.get("hotword_lengths")
  118. dha_pad = kwargs.get("dha_pad")
  119. batch_size = speech.shape[0]
  120. # for data-parallel
  121. text = text[:, : text_lengths.max()]
  122. speech = speech[:, :speech_lengths.max()]
  123. # 1. Encoder
  124. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  125. if self.predictor_bias == 1:
  126. _, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id)
  127. ys_lengths = text_lengths + self.predictor_bias
  128. stats = dict()
  129. loss_seaco = self._calc_seaco_loss(encoder_out,
  130. encoder_out_lens,
  131. ys_pad,
  132. ys_lengths,
  133. hotword_pad,
  134. hotword_lengths,
  135. dha_pad,
  136. )
  137. if self.train_decoder:
  138. loss_att, acc_att = self._calc_att_loss(
  139. encoder_out, encoder_out_lens, text, text_lengths
  140. )
  141. loss = loss_seaco + loss_att
  142. stats["loss_att"] = torch.clone(loss_att.detach())
  143. stats["acc_att"] = acc_att
  144. else:
  145. loss = loss_seaco
  146. stats["loss_seaco"] = torch.clone(loss_seaco.detach())
  147. stats["loss"] = torch.clone(loss.detach())
  148. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  149. if self.length_normalized_loss:
  150. batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
  151. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  152. return loss, stats, weight
  153. def _merge(self, cif_attended, dec_attended):
  154. return cif_attended + dec_attended
  155. def _calc_seaco_loss(
  156. self,
  157. encoder_out: torch.Tensor,
  158. encoder_out_lens: torch.Tensor,
  159. ys_pad: torch.Tensor,
  160. ys_lengths: torch.Tensor,
  161. hotword_pad: torch.Tensor,
  162. hotword_lengths: torch.Tensor,
  163. dha_pad: torch.Tensor,
  164. ):
  165. # predictor forward
  166. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  167. encoder_out.device)
  168. pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  169. ignore_id=self.ignore_id)
  170. # decoder forward
  171. decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
  172. selected = self._hotword_representation(hotword_pad,
  173. hotword_lengths)
  174. contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
  175. num_hot_word = contextual_info.shape[1]
  176. _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
  177. # dha core
  178. cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, pre_acoustic_embeds, ys_lengths)
  179. dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths)
  180. merged = self._merge(cif_attended, dec_attended)
  181. dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation
  182. loss_att = self.criterion_seaco(dha_output, dha_pad)
  183. return loss_att
  184. def _seaco_decode_with_ASF(self,
  185. encoder_out,
  186. encoder_out_lens,
  187. sematic_embeds,
  188. ys_pad_lens,
  189. hw_list,
  190. nfilter=50,
  191. seaco_weight=1.0):
  192. # decoder forward
  193. decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
  194. decoder_pred = torch.log_softmax(decoder_out, dim=-1)
  195. if hw_list is not None:
  196. hw_lengths = [len(i) for i in hw_list]
  197. hw_list_ = [torch.Tensor(i).long() for i in hw_list]
  198. hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
  199. selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
  200. contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
  201. num_hot_word = contextual_info.shape[1]
  202. _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
  203. # ASF Core
  204. if nfilter > 0 and nfilter < num_hot_word:
  205. hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
  206. hotword_scores = hotword_scores[0].sum(0).sum(0)
  207. # hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device)
  208. dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist()
  209. add_filter = dec_filter
  210. add_filter.append(len(hw_list_pad)-1)
  211. # filter hotword embedding
  212. selected = selected[add_filter]
  213. # again
  214. contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
  215. num_hot_word = contextual_info.shape[1]
  216. _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
  217. # SeACo Core
  218. cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
  219. dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
  220. merged = self._merge(cif_attended, dec_attended)
  221. dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation
  222. dha_pred = torch.log_softmax(dha_output, dim=-1)
  223. def _merge_res(dec_output, dha_output):
  224. lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
  225. dha_ids = dha_output.max(-1)[-1]# [0]
  226. dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
  227. a = (1 - lmbd) / lmbd
  228. b = 1 / lmbd
  229. a, b = a.to(dec_output.device), b.to(dec_output.device)
  230. dha_mask = (dha_mask + a.reshape(-1, 1, 1)) / b.reshape(-1, 1, 1)
  231. # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
  232. logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
  233. return logits
  234. merged_pred = _merge_res(decoder_pred, dha_pred)
  235. return merged_pred
  236. else:
  237. return decoder_pred
  238. def _hotword_representation(self,
  239. hotword_pad,
  240. hotword_lengths):
  241. if self.bias_encoder_type != 'lstm':
  242. logging.error("Unsupported bias encoder type")
  243. '''
  244. hw_embed = self.decoder.embed(hotword_pad)
  245. hw_embed, (_, _) = self.bias_encoder(hw_embed)
  246. if self.lstm_proj is not None:
  247. hw_embed = self.lstm_proj(hw_embed)
  248. _ind = np.arange(0, hw_embed.shape[0]).tolist()
  249. selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]]
  250. return selected
  251. '''
  252. # hw_embed = self.sac_embedding(hotword_pad)
  253. hw_embed = self.decoder.embed(hotword_pad)
  254. hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hotword_lengths.cpu().type(torch.int64), batch_first=True, enforce_sorted=False)
  255. packed_rnn_output, _ = self.bias_encoder(hw_embed)
  256. rnn_output = torch.nn.utils.rnn.pad_packed_sequence(packed_rnn_output, batch_first=True)[0]
  257. if self.lstm_proj is not None:
  258. hw_hidden = self.lstm_proj(rnn_output)
  259. else:
  260. hw_hidden = rnn_output
  261. _ind = np.arange(0, hw_hidden.shape[0]).tolist()
  262. selected = hw_hidden[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]]
  263. return selected
  264. def inference(self,
  265. data_in,
  266. data_lengths=None,
  267. key: list = None,
  268. tokenizer=None,
  269. frontend=None,
  270. **kwargs,
  271. ):
  272. # init beamsearch
  273. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  274. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  275. if self.beam_search is None and (is_use_lm or is_use_ctc):
  276. logging.info("enable beam_search")
  277. self.init_beam_search(**kwargs)
  278. self.nbest = kwargs.get("nbest", 1)
  279. meta_data = {}
  280. # extract fbank feats
  281. time1 = time.perf_counter()
  282. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
  283. time2 = time.perf_counter()
  284. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  285. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  286. frontend=frontend)
  287. time3 = time.perf_counter()
  288. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  289. meta_data[
  290. "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  291. speech = speech.to(device=kwargs["device"])
  292. speech_lengths = speech_lengths.to(device=kwargs["device"])
  293. # hotword
  294. self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
  295. # Encoder
  296. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  297. if isinstance(encoder_out, tuple):
  298. encoder_out = encoder_out[0]
  299. # predictor
  300. predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
  301. pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
  302. predictor_outs[2], predictor_outs[3]
  303. pre_token_length = pre_token_length.round().long()
  304. if torch.max(pre_token_length) < 1:
  305. return []
  306. decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
  307. pre_acoustic_embeds,
  308. pre_token_length,
  309. hw_list=self.hotword_list)
  310. # decoder_out, _ = decoder_outs[0], decoder_outs[1]
  311. _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
  312. pre_token_length)
  313. results = []
  314. b, n, d = decoder_out.size()
  315. for i in range(b):
  316. x = encoder_out[i, :encoder_out_lens[i], :]
  317. am_scores = decoder_out[i, :pre_token_length[i], :]
  318. if self.beam_search is not None:
  319. nbest_hyps = self.beam_search(
  320. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  321. minlenratio=kwargs.get("minlenratio", 0.0)
  322. )
  323. nbest_hyps = nbest_hyps[: self.nbest]
  324. else:
  325. yseq = am_scores.argmax(dim=-1)
  326. score = am_scores.max(dim=-1)[0]
  327. score = torch.sum(score, dim=-1)
  328. # pad with mask tokens to ensure compatibility with sos/eos tokens
  329. yseq = torch.tensor(
  330. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  331. )
  332. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  333. for nbest_idx, hyp in enumerate(nbest_hyps):
  334. ibest_writer = None
  335. if kwargs.get("output_dir") is not None:
  336. if not hasattr(self, "writer"):
  337. self.writer = DatadirWriter(kwargs.get("output_dir"))
  338. ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
  339. # remove sos/eos and get results
  340. last_pos = -1
  341. if isinstance(hyp.yseq, list):
  342. token_int = hyp.yseq[1:last_pos]
  343. else:
  344. token_int = hyp.yseq[1:last_pos].tolist()
  345. # remove blank symbol id, which is assumed to be 0
  346. token_int = list(
  347. filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  348. if tokenizer is not None:
  349. # Change integer-ids to tokens
  350. token = tokenizer.ids2tokens(token_int)
  351. text = tokenizer.tokens2text(token)
  352. _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
  353. us_peaks[i][:encoder_out_lens[i] * 3],
  354. copy.copy(token),
  355. vad_offset=kwargs.get("begin_time", 0))
  356. text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
  357. token, timestamp)
  358. result_i = {"key": key[i], "text": text_postprocessed,
  359. "timestamp": time_stamp_postprocessed
  360. }
  361. if ibest_writer is not None:
  362. ibest_writer["token"][key[i]] = " ".join(token)
  363. ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
  364. ibest_writer["text"][key[i]] = text_postprocessed
  365. else:
  366. result_i = {"key": key[i], "token_int": token_int}
  367. results.append(result_i)
  368. return results, meta_data
  369. def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
  370. def load_seg_dict(seg_dict_file):
  371. seg_dict = {}
  372. assert isinstance(seg_dict_file, str)
  373. with open(seg_dict_file, "r", encoding="utf8") as f:
  374. lines = f.readlines()
  375. for line in lines:
  376. s = line.strip().split()
  377. key = s[0]
  378. value = s[1:]
  379. seg_dict[key] = " ".join(value)
  380. return seg_dict
  381. def seg_tokenize(txt, seg_dict):
  382. pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
  383. out_txt = ""
  384. for word in txt:
  385. word = word.lower()
  386. if word in seg_dict:
  387. out_txt += seg_dict[word] + " "
  388. else:
  389. if pattern.match(word):
  390. for char in word:
  391. if char in seg_dict:
  392. out_txt += seg_dict[char] + " "
  393. else:
  394. out_txt += "<unk>" + " "
  395. else:
  396. out_txt += "<unk>" + " "
  397. return out_txt.strip().split()
  398. seg_dict = None
  399. if frontend.cmvn_file is not None:
  400. model_dir = os.path.dirname(frontend.cmvn_file)
  401. seg_dict_file = os.path.join(model_dir, 'seg_dict')
  402. if os.path.exists(seg_dict_file):
  403. seg_dict = load_seg_dict(seg_dict_file)
  404. else:
  405. seg_dict = None
  406. # for None
  407. if hotword_list_or_file is None:
  408. hotword_list = None
  409. # for local txt inputs
  410. elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
  411. logging.info("Attempting to parse hotwords from local txt...")
  412. hotword_list = []
  413. hotword_str_list = []
  414. with codecs.open(hotword_list_or_file, 'r') as fin:
  415. for line in fin.readlines():
  416. hw = line.strip()
  417. hw_list = hw.split()
  418. if seg_dict is not None:
  419. hw_list = seg_tokenize(hw_list, seg_dict)
  420. hotword_str_list.append(hw)
  421. hotword_list.append(tokenizer.tokens2ids(hw_list))
  422. hotword_list.append([self.sos])
  423. hotword_str_list.append('<s>')
  424. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  425. .format(hotword_list_or_file, hotword_str_list))
  426. # for url, download and generate txt
  427. elif hotword_list_or_file.startswith('http'):
  428. logging.info("Attempting to parse hotwords from url...")
  429. work_dir = tempfile.TemporaryDirectory().name
  430. if not os.path.exists(work_dir):
  431. os.makedirs(work_dir)
  432. text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
  433. local_file = requests.get(hotword_list_or_file)
  434. open(text_file_path, "wb").write(local_file.content)
  435. hotword_list_or_file = text_file_path
  436. hotword_list = []
  437. hotword_str_list = []
  438. with codecs.open(hotword_list_or_file, 'r') as fin:
  439. for line in fin.readlines():
  440. hw = line.strip()
  441. hw_list = hw.split()
  442. if seg_dict is not None:
  443. hw_list = seg_tokenize(hw_list, seg_dict)
  444. hotword_str_list.append(hw)
  445. hotword_list.append(tokenizer.tokens2ids(hw_list))
  446. hotword_list.append([self.sos])
  447. hotword_str_list.append('<s>')
  448. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  449. .format(hotword_list_or_file, hotword_str_list))
  450. # for text str input
  451. elif not hotword_list_or_file.endswith('.txt'):
  452. logging.info("Attempting to parse hotwords as str...")
  453. hotword_list = []
  454. hotword_str_list = []
  455. for hw in hotword_list_or_file.strip().split():
  456. hotword_str_list.append(hw)
  457. hw_list = hw.strip().split()
  458. if seg_dict is not None:
  459. hw_list = seg_tokenize(hw_list, seg_dict)
  460. hotword_list.append(tokenizer.tokens2ids(hw_list))
  461. hotword_list.append([self.sos])
  462. hotword_str_list.append('<s>')
  463. logging.info("Hotword list: {}.".format(hotword_str_list))
  464. else:
  465. hotword_list = None
  466. return hotword_list