|
|
@@ -600,6 +600,9 @@ def inference_paraformer_vad_punc(
|
|
|
if 'hotword' in kwargs:
|
|
|
hotword_list_or_file = kwargs['hotword']
|
|
|
|
|
|
+ batch_size_token = kwargs.get("batch_size_token", 6000)
|
|
|
+ print("batch_size_token: ", batch_size_token)
|
|
|
+
|
|
|
if speech2text.hotword_list is None:
|
|
|
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
|
|
|
|
|
@@ -642,8 +645,10 @@ def inference_paraformer_vad_punc(
|
|
|
assert all(isinstance(s, str) for s in keys), keys
|
|
|
_bs = len(next(iter(batch.values())))
|
|
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
|
|
-
|
|
|
+ beg_vad = time.time()
|
|
|
vad_results = speech2vadsegment(**batch)
|
|
|
+ end_vad = time.time()
|
|
|
+ print("time cost vad: ", end_vad-beg_vad)
|
|
|
_, vadsegments = vad_results[0], vad_results[1][0]
|
|
|
|
|
|
speech, speech_lengths = batch["speech"], batch["speech_lengths"]
|
|
|
@@ -652,17 +657,29 @@ def inference_paraformer_vad_punc(
|
|
|
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
|
|
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
|
|
results_sorted = []
|
|
|
- for j, beg_idx in enumerate(range(0, n, batch_size)):
|
|
|
- end_idx = min(n, beg_idx + batch_size)
|
|
|
+ batch_size_token_ms = batch_size_token*60
|
|
|
+ batch_size_token_ms_cum = 0
|
|
|
+ beg_idx = 0
|
|
|
+ for j, _ in enumerate(range(0, n)):
|
|
|
+ batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
|
|
|
+ if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
|
|
|
+ continue
|
|
|
+ batch_size_token_ms_cum = 0
|
|
|
+ end_idx = j + 1
|
|
|
speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
|
|
|
-
|
|
|
+ beg_idx = end_idx
|
|
|
batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
|
|
|
batch = to_device(batch, device=device)
|
|
|
+ print("batch: ", speech_j.shape[0])
|
|
|
+ beg_asr = time.time()
|
|
|
results = speech2text(**batch)
|
|
|
+ end_asr = time.time()
|
|
|
+ print("time cost asr: ", end_asr - beg_asr)
|
|
|
|
|
|
if len(results) < 1:
|
|
|
results = [["", [], [], [], [], [], []]]
|
|
|
results_sorted.extend(results)
|
|
|
+
|
|
|
restored_data = [0] * n
|
|
|
for j in range(n):
|
|
|
index = sorted_data[j][1]
|
|
|
@@ -701,7 +718,10 @@ def inference_paraformer_vad_punc(
|
|
|
text_postprocessed_punc = text_postprocessed
|
|
|
punc_id_list = []
|
|
|
if len(word_lists) > 0 and text2punc is not None:
|
|
|
+ beg_punc = time.time()
|
|
|
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
|
|
|
+ end_punc = time.time()
|
|
|
+ print("time cost punc: ", end_punc-beg_punc)
|
|
|
|
|
|
item = {'key': key, 'value': text_postprocessed_punc}
|
|
|
if text_postprocessed != "":
|