speech_client.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from tritonclient.utils import np_to_triton_dtype
  15. import numpy as np
  16. import math
  17. import soundfile as sf
  18. class OfflineSpeechClient(object):
  19. def __init__(self, triton_client, model_name, protocol_client, args):
  20. self.triton_client = triton_client
  21. self.protocol_client = protocol_client
  22. self.model_name = model_name
  23. def recognize(self, wav_file, idx=0):
  24. waveform, sample_rate = sf.read(wav_file)
  25. samples = np.array([waveform], dtype=np.float32)
  26. lengths = np.array([[len(waveform)]], dtype=np.int32)
  27. # better pad waveform to nearest length here
  28. # target_seconds = math.cel(len(waveform) / sample_rate)
  29. # target_samples = np.zeros([1, target_seconds * sample_rate])
  30. # target_samples[0][0: len(waveform)] = waveform
  31. # samples = target_samples
  32. sequence_id = 10086 + idx
  33. result = ""
  34. inputs = [
  35. self.protocol_client.InferInput(
  36. "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
  37. ),
  38. self.protocol_client.InferInput(
  39. "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
  40. ),
  41. ]
  42. inputs[0].set_data_from_numpy(samples)
  43. inputs[1].set_data_from_numpy(lengths)
  44. outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
  45. response = self.triton_client.infer(
  46. self.model_name,
  47. inputs,
  48. request_id=str(sequence_id),
  49. outputs=outputs,
  50. )
  51. result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
  52. return [result]
  53. class StreamingSpeechClient(object):
  54. def __init__(self, triton_client, model_name, protocol_client, args):
  55. self.triton_client = triton_client
  56. self.protocol_client = protocol_client
  57. self.model_name = model_name
  58. chunk_size = args.chunk_size
  59. subsampling = args.subsampling
  60. context = args.context
  61. frame_shift_ms = args.frame_shift_ms
  62. frame_length_ms = args.frame_length_ms
  63. # for the first chunk
  64. # we need additional frames to generate
  65. # the exact first chunk length frames
  66. # since the subsampling will look ahead several frames
  67. first_chunk_length = (chunk_size - 1) * subsampling + context
  68. add_frames = math.ceil(
  69. (frame_length_ms - frame_shift_ms) / frame_shift_ms
  70. )
  71. first_chunk_ms = (first_chunk_length + add_frames) * frame_shift_ms
  72. other_chunk_ms = chunk_size * subsampling * frame_shift_ms
  73. self.first_chunk_in_secs = first_chunk_ms / 1000
  74. self.other_chunk_in_secs = other_chunk_ms / 1000
  75. def recognize(self, wav_file, idx=0):
  76. waveform, sample_rate = sf.read(wav_file)
  77. wav_segs = []
  78. i = 0
  79. while i < len(waveform):
  80. if i == 0:
  81. stride = int(self.first_chunk_in_secs * sample_rate)
  82. wav_segs.append(waveform[i : i + stride])
  83. else:
  84. stride = int(self.other_chunk_in_secs * sample_rate)
  85. wav_segs.append(waveform[i : i + stride])
  86. i += len(wav_segs[-1])
  87. sequence_id = idx + 10086
  88. # simulate streaming
  89. for idx, seg in enumerate(wav_segs):
  90. chunk_len = len(seg)
  91. if idx == 0:
  92. chunk_samples = int(self.first_chunk_in_secs * sample_rate)
  93. expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
  94. else:
  95. chunk_samples = int(self.other_chunk_in_secs * sample_rate)
  96. expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
  97. expect_input[0][0:chunk_len] = seg
  98. input0_data = expect_input
  99. input1_data = np.array([[chunk_len]], dtype=np.int32)
  100. inputs = [
  101. self.protocol_client.InferInput(
  102. "WAV",
  103. input0_data.shape,
  104. np_to_triton_dtype(input0_data.dtype),
  105. ),
  106. self.protocol_client.InferInput(
  107. "WAV_LENS",
  108. input1_data.shape,
  109. np_to_triton_dtype(input1_data.dtype),
  110. ),
  111. ]
  112. inputs[0].set_data_from_numpy(input0_data)
  113. inputs[1].set_data_from_numpy(input1_data)
  114. outputs = [self.protocol_client.InferRequestedOutput("TRANSCRIPTS")]
  115. end = False
  116. if idx == len(wav_segs) - 1:
  117. end = True
  118. response = self.triton_client.infer(
  119. self.model_name,
  120. inputs,
  121. outputs=outputs,
  122. sequence_id=sequence_id,
  123. sequence_start=idx == 0,
  124. sequence_end=end,
  125. )
  126. idx += 1
  127. result = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
  128. print("Get response from {}th chunk: {}".format(idx, result))
  129. return [result]