| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from tritonclient.utils import np_to_triton_dtype
- import numpy as np
- import math
- import soundfile as sf
- class OfflineSpeechClient(object):
- def __init__(self, triton_client, model_name, protocol_client, args):
- self.triton_client = triton_client
- self.protocol_client = protocol_client
- self.model_name = model_name
- def recognize(self, wav_file, idx=0):
- waveform, sample_rate = sf.read(wav_file)
- samples = np.array([waveform], dtype=np.float32)
- lengths = np.array([[len(waveform)]], dtype=np.int32)
- # better pad waveform to nearest length here
- # target_seconds = math.cel(len(waveform) / sample_rate)
- # target_samples = np.zeros([1, target_seconds * sample_rate])
- # target_samples[0][0: len(waveform)] = waveform
- # samples = target_samples
- sequence_id = 10086 + idx
- result = ""
- inputs = [
- self.protocol_client.InferInput(
- "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
- ),
- self.protocol_client.InferInput(
- "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
- ),
- ]
- inputs[0].set_data_from_numpy(samples)
- inputs[1].set_data_from_numpy(lengths)
- outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
- response = self.triton_client.infer(
- self.model_name,
- inputs,
- request_id=str(sequence_id),
- outputs=outputs,
- )
- result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
- return [result]
- class StreamingSpeechClient(object):
- def __init__(self, triton_client, model_name, protocol_client, args):
- self.triton_client = triton_client
- self.protocol_client = protocol_client
- self.model_name = model_name
- chunk_size = args.chunk_size
- subsampling = args.subsampling
- context = args.context
- frame_shift_ms = args.frame_shift_ms
- frame_length_ms = args.frame_length_ms
- # for the first chunk
- # we need additional frames to generate
- # the exact first chunk length frames
- # since the subsampling will look ahead several frames
- first_chunk_length = (chunk_size - 1) * subsampling + context
- add_frames = math.ceil(
- (frame_length_ms - frame_shift_ms) / frame_shift_ms
- )
- first_chunk_ms = (first_chunk_length + add_frames) * frame_shift_ms
- other_chunk_ms = chunk_size * subsampling * frame_shift_ms
- self.first_chunk_in_secs = first_chunk_ms / 1000
- self.other_chunk_in_secs = other_chunk_ms / 1000
- def recognize(self, wav_file, idx=0):
- waveform, sample_rate = sf.read(wav_file)
- wav_segs = []
- i = 0
- while i < len(waveform):
- if i == 0:
- stride = int(self.first_chunk_in_secs * sample_rate)
- wav_segs.append(waveform[i : i + stride])
- else:
- stride = int(self.other_chunk_in_secs * sample_rate)
- wav_segs.append(waveform[i : i + stride])
- i += len(wav_segs[-1])
- sequence_id = idx + 10086
- # simulate streaming
- for idx, seg in enumerate(wav_segs):
- chunk_len = len(seg)
- if idx == 0:
- chunk_samples = int(self.first_chunk_in_secs * sample_rate)
- expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
- else:
- chunk_samples = int(self.other_chunk_in_secs * sample_rate)
- expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
- expect_input[0][0:chunk_len] = seg
- input0_data = expect_input
- input1_data = np.array([[chunk_len]], dtype=np.int32)
- inputs = [
- self.protocol_client.InferInput(
- "WAV",
- input0_data.shape,
- np_to_triton_dtype(input0_data.dtype),
- ),
- self.protocol_client.InferInput(
- "WAV_LENS",
- input1_data.shape,
- np_to_triton_dtype(input1_data.dtype),
- ),
- ]
- inputs[0].set_data_from_numpy(input0_data)
- inputs[1].set_data_from_numpy(input1_data)
- outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
- end = False
- if idx == len(wav_segs) - 1:
- end = True
- response = self.triton_client.infer(
- self.model_name,
- inputs,
- outputs=outputs,
- sequence_id=sequence_id,
- sequence_start=idx == 0,
- sequence_end=end,
- )
- idx += 1
- result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
- print("Get response from {}th chunk: {}".format(idx, result))
- return [result]
|