|
|
@@ -76,9 +76,8 @@ class TargetDelayTransformer():
|
|
|
try:
|
|
|
outputs = self.infer(data['text'], data['text_lengths'])
|
|
|
y = outputs[0]
|
|
|
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
|
|
|
- punctuations = indices
|
|
|
- assert punctuations.size()[0] == len(mini_sentence)
|
|
|
+ punctuations = np.argmax(y,axis=-1)[0]
|
|
|
+ assert punctuations.size == len(mini_sentence)
|
|
|
except ONNXRuntimeError:
|
|
|
logging.warning("error")
|
|
|
|