| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import multiprocessing
- from multiprocessing import Pool
- import argparse
- import os
- import tritonclient.grpc as grpcclient
- from utils import cal_cer
- from speech_client import *
- import numpy as np
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-v",
- "--verbose",
- action="store_true",
- required=False,
- default=False,
- help="Enable verbose output",
- )
- parser.add_argument(
- "-u",
- "--url",
- type=str,
- required=False,
- default="localhost:10086",
- help="Inference server URL. Default is " "localhost:8001.",
- )
- parser.add_argument(
- "--model_name",
- required=False,
- default="attention_rescoring",
- choices=["attention_rescoring", "streaming_wenet", "infer_pipeline"],
- help="the model to send request to",
- )
- parser.add_argument(
- "--wavscp",
- type=str,
- required=False,
- default=None,
- help="audio_id \t wav_path",
- )
- parser.add_argument(
- "--trans",
- type=str,
- required=False,
- default=None,
- help="audio_id \t text",
- )
- parser.add_argument(
- "--data_dir",
- type=str,
- required=False,
- default=None,
- help="path prefix for wav_path in wavscp/audio_file",
- )
- parser.add_argument(
- "--audio_file",
- type=str,
- required=False,
- default=None,
- help="single wav file path",
- )
- # below arguments are for streaming
- # Please check onnx_config.yaml and train.yaml
- parser.add_argument("--streaming", action="store_true", required=False)
- parser.add_argument(
- "--sample_rate",
- type=int,
- required=False,
- default=16000,
- help="sample rate used in training",
- )
- parser.add_argument(
- "--frame_length_ms",
- type=int,
- required=False,
- default=25,
- help="frame length",
- )
- parser.add_argument(
- "--frame_shift_ms",
- type=int,
- required=False,
- default=10,
- help="frame shift length",
- )
- parser.add_argument(
- "--chunk_size",
- type=int,
- required=False,
- default=16,
- help="chunk size default is 16",
- )
- parser.add_argument(
- "--context",
- type=int,
- required=False,
- default=7,
- help="subsampling context",
- )
- parser.add_argument(
- "--subsampling",
- type=int,
- required=False,
- default=4,
- help="subsampling rate",
- )
- FLAGS = parser.parse_args()
- print(FLAGS)
- # load data
- filenames = []
- transcripts = []
- if FLAGS.audio_file is not None:
- path = FLAGS.audio_file
- if FLAGS.data_dir:
- path = os.path.join(FLAGS.data_dir, path)
- if os.path.exists(path):
- filenames = [path]
- elif FLAGS.wavscp is not None:
- audio_data = {}
- with open(FLAGS.wavscp, "r", encoding="utf-8") as f:
- for line in f:
- aid, path = line.strip().split("\t")
- if FLAGS.data_dir:
- path = os.path.join(FLAGS.data_dir, path)
- audio_data[aid] = {"path": path}
- with open(FLAGS.trans, "r", encoding="utf-8") as f:
- for line in f:
- aid, text = line.strip().split("\t")
- audio_data[aid]["text"] = text
- for key, value in audio_data.items():
- filenames.append(value["path"])
- transcripts.append(value["text"])
- num_workers = multiprocessing.cpu_count() // 2
- if FLAGS.streaming:
- speech_client_cls = StreamingSpeechClient
- else:
- speech_client_cls = OfflineSpeechClient
- def single_job(client_files):
- with grpcclient.InferenceServerClient(
- url=FLAGS.url, verbose=FLAGS.verbose
- ) as triton_client:
- protocol_client = grpcclient
- speech_client = speech_client_cls(
- triton_client, FLAGS.model_name, protocol_client, FLAGS
- )
- idx, audio_files = client_files
- predictions = []
- for li in audio_files:
- result = speech_client.recognize(li, idx)
- print("Recognized {}:{}".format(li, result[0]))
- predictions += result
- return predictions
- # start to do inference
- # Group requests in batches
- predictions = []
- tasks = []
- splits = np.array_split(filenames, num_workers)
- for idx, per_split in enumerate(splits):
- cur_files = per_split.tolist()
- tasks.append((idx, cur_files))
- with Pool(processes=num_workers) as pool:
- predictions = pool.map(single_job, tasks)
- predictions = [item for sublist in predictions for item in sublist]
- if transcripts:
- cer = cal_cer(predictions, transcripts)
- print("CER is: {}".format(cer))
|