|
|
@@ -76,9 +76,8 @@ class TargetDelayTransformer():
|
|
|
try:
|
|
|
outputs = self.infer(data['text'], data['text_lengths'])
|
|
|
y = outputs[0]
|
|
|
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
|
|
- punctuations = indices
|
|
|
- assert punctuations.size()[0] == len(mini_sentence)
|
|
|
+ punctuations = np.argmax(y,axis=-1)[0]
|
|
|
+ assert punctuations.size == len(mini_sentence)
|
|
|
except ONNXRuntimeError:
|
|
|
logging.warning("error")
|
|
|
|
|
|
@@ -102,8 +101,7 @@ class TargetDelayTransformer():
|
|
|
mini_sentence = mini_sentence[0:sentenceEnd + 1]
|
|
|
punctuations = punctuations[0:sentenceEnd + 1]
|
|
|
|
|
|
- punctuations_np = punctuations.cpu().numpy()
|
|
|
- new_mini_sentence_punc += [int(x) for x in punctuations_np]
|
|
|
+ new_mini_sentence_punc += [int(x) for x in punctuations]
|
|
|
words_with_punc = []
|
|
|
for i in range(len(mini_sentence)):
|
|
|
if i > 0:
|