|
|
@@ -6,6 +6,8 @@ import time
|
|
|
import copy
|
|
|
import os
|
|
|
import codecs
|
|
|
+import tempfile
|
|
|
+import requests
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
from typing import Sequence
|
|
|
@@ -175,10 +177,24 @@ class Speech2Text:
|
|
|
self.converter = converter
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
- # 6. [Optional] Build hotword list from file or str
|
|
|
+ # 6. [Optional] Build hotword list from str, local file or url
|
|
|
+ # for None
|
|
|
if hotword_list_or_file is None:
|
|
|
self.hotword_list = None
|
|
|
+ # for text str input
|
|
|
+ elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
|
|
|
+ logging.info("Attempting to parse hotwords as str...")
|
|
|
+ self.hotword_list = []
|
|
|
+ hotword_str_list = []
|
|
|
+ for hw in hotword_list_or_file.strip().split():
|
|
|
+ hotword_str_list.append(hw)
|
|
|
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
+ self.hotword_list.append([self.asr_model.sos])
|
|
|
+ hotword_str_list.append('<s>')
|
|
|
+ logging.info("Hotword list: {}.".format(hotword_str_list))
|
|
|
+ # for local txt inputs
|
|
|
elif os.path.exists(hotword_list_or_file):
|
|
|
+ logging.info("Attempting to parse hotwords from local txt...")
|
|
|
self.hotword_list = []
|
|
|
hotword_str_list = []
|
|
|
with codecs.open(hotword_list_or_file, 'r') as fin:
|
|
|
@@ -186,20 +202,31 @@ class Speech2Text:
|
|
|
hw = line.strip()
|
|
|
hotword_str_list.append(hw)
|
|
|
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
- self.hotword_list.append([1])
|
|
|
+ self.hotword_list.append([self.asr_model.sos])
|
|
|
hotword_str_list.append('<s>')
|
|
|
logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
|
|
.format(hotword_list_or_file, hotword_str_list))
|
|
|
+ # for url, download and generate txt
|
|
|
else:
|
|
|
- logging.info("Attempting to parse hotwords as str...")
|
|
|
+ logging.info("Attempting to parse hotwords from url...")
|
|
|
+ work_dir = tempfile.TemporaryDirectory().name
|
|
|
+ if not os.path.exists(work_dir):
|
|
|
+ os.makedirs(work_dir)
|
|
|
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
|
|
|
+ local_file = requests.get(hotword_list_or_file)
|
|
|
+ open(text_file_path, "wb").write(local_file.content)
|
|
|
+ hotword_list_or_file = text_file_path
|
|
|
self.hotword_list = []
|
|
|
hotword_str_list = []
|
|
|
- for hw in hotword_list_or_file.strip().split():
|
|
|
- hotword_str_list.append(hw)
|
|
|
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
- self.hotword_list.append([1])
|
|
|
- hotword_str_list.append('<s>')
|
|
|
- logging.info("Hotword list: {}.".format(hotword_str_list))
|
|
|
+ with codecs.open(hotword_list_or_file, 'r') as fin:
|
|
|
+ for line in fin.readlines():
|
|
|
+ hw = line.strip()
|
|
|
+ hotword_str_list.append(hw)
|
|
|
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
|
|
|
+ self.hotword_list.append([self.asr_model.sos])
|
|
|
+ hotword_str_list.append('<s>')
|
|
|
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
|
|
|
+ .format(hotword_list_or_file, hotword_str_list))
|
|
|
|
|
|
|
|
|
is_use_lm = lm_weight != 0.0 and lm_file is not None
|