|
|
@@ -99,7 +99,8 @@ class TritonPythonModel:
|
|
|
feats_len = torch.tensor(speech_len, dtype=torch.int32).to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- logits = self.model(feats, feats_len)[0]
|
|
|
+ outputs = self.model(feats, feats_len)
|
|
|
+ logits, token_num = outputs[0], outputs[1]
|
|
|
|
|
|
def replace_space(tokens):
|
|
|
return [i if i != '<space>' else ' ' for i in tokens]
|
|
|
@@ -107,6 +108,7 @@ class TritonPythonModel:
|
|
|
yseq = logits.argmax(axis=-1).tolist()
|
|
|
token_int = [list(filter(lambda x: x not in (0, 2), y)) for y in yseq]
|
|
|
tokens = [[self.vocab_dict[i] for i in t] for t in token_int]
|
|
|
+ tokens = [t[:int(token_num[i]) - 1] for i, t in enumerate(tokens)]
|
|
|
hyps = [''.join(replace_space(t)).encode('utf-8') for t in tokens]
|
|
|
responses = []
|
|
|
for i in range(len(requests)):
|