| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- # -*- encoding: utf-8 -*-
- import os.path
- from pathlib import Path
- from typing import List, Union, Tuple
- import copy
- import librosa
- import numpy as np
- from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
- OrtInferSession, TokenIDConverter, get_logger,
- read_yaml)
- from .utils.postprocess_utils import sentence_postprocess
- from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
- logging = get_logger()
- class Paraformer():
- def __init__(self, model_dir: Union[str, Path] = None,
- batch_size: int = 1,
- chunk_size: List = [5, 10, 5],
- device_id: Union[str, int] = "-1",
- quantize: bool = False,
- intra_op_num_threads: int = 4,
- cache_dir: str = None
- ):
- if not Path(model_dir).exists():
- try:
- from modelscope.hub.snapshot_download import snapshot_download
- except:
- raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
- "\npip3 install -U modelscope\n" \
- "For the users in China, you could install with the command:\n" \
- "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- try:
- model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
- except:
- raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
-
- encoder_model_file = os.path.join(model_dir, 'model.onnx')
- decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
- if quantize:
- encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
- decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
- if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
- print(".onnx is not exist, begin to export onnx")
- try:
- from funasr.export.export_model import ModelExport
- except:
- raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
- "\npip3 install -U funasr\n" \
- "For the users in China, you could install with the command:\n" \
- "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- export_model = ModelExport(
- cache_dir=cache_dir,
- onnx=True,
- device="cpu",
- quant=quantize,
- )
- export_model.export(model_dir)
- config_file = os.path.join(model_dir, 'config.yaml')
- cmvn_file = os.path.join(model_dir, 'am.mvn')
- config = read_yaml(config_file)
- self.converter = TokenIDConverter(config['token_list'])
- self.tokenizer = CharTokenizer()
- self.frontend = WavFrontendOnline(
- cmvn_file=cmvn_file,
- **config['frontend_conf']
- )
- self.pe = SinusoidalPositionEncoderOnline()
- self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
- intra_op_num_threads=intra_op_num_threads)
- self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
- intra_op_num_threads=intra_op_num_threads)
- self.batch_size = batch_size
- self.chunk_size = chunk_size
- self.encoder_output_size = config["encoder_conf"]["output_size"]
- self.fsmn_layer = config["decoder_conf"]["num_blocks"]
- self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
- self.fsmn_dims = config["encoder_conf"]["output_size"]
- self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
- self.cif_threshold = config["predictor_conf"]["threshold"]
- self.tail_threshold = config["predictor_conf"]["tail_threshold"]
- def prepare_cache(self, cache: dict = {}, batch_size=1):
- if len(cache) > 0:
- return cache
- cache["start_idx"] = 0
- cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
- cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
- cache["chunk_size"] = self.chunk_size
- cache["last_chunk"] = False
- cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
- cache["decoder_fsmn"] = []
- for i in range(self.fsmn_layer):
- fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
- cache["decoder_fsmn"].append(fsmn_cache)
- return cache
- def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
- if len(cache) == 0:
- return feats
- # process last chunk
- overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
- if cache["is_final"]:
- cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
- if not cache["last_chunk"]:
- padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
- overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
- else:
- cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
- return overlap_feats
- def __call__(self, audio_in: np.ndarray, **kwargs):
- waveforms = np.expand_dims(audio_in, axis=0)
- param_dict = kwargs.get('param_dict', dict())
- is_final = param_dict.get('is_final', False)
- cache = param_dict.get('cache', dict())
- asr_res = []
-
- if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
- cache["last_chunk"] = True
- feats = cache["feats"]
- feats_len = np.array([feats.shape[1]]).astype(np.int32)
- asr_res = self.infer(feats, feats_len, cache)
- return asr_res
-
- feats, feats_len = self.extract_feat(waveforms, is_final)
- if feats.shape[1] != 0:
- feats *= self.encoder_output_size ** 0.5
- cache = self.prepare_cache(cache)
- cache["is_final"] = is_final
- # fbank -> position encoding -> overlap chunk
- feats = self.pe.forward(feats, cache["start_idx"])
- cache["start_idx"] += feats.shape[1]
- if is_final:
- if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
- cache["last_chunk"] = True
- feats = self.add_overlap_chunk(feats, cache)
- else:
- # first chunk
- feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
- feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
- asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
- # last chunk
- cache["last_chunk"] = True
- feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
- feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
- asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
-
- asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
- res = {}
- for pred in asr_res_chunk:
- for key, value in pred.items():
- if key in res:
- res[key][0] += value[0]
- res[key][1].extend(value[1])
- else:
- res[key] = [value[0], value[1]]
- return [res]
- else:
- feats = self.add_overlap_chunk(feats, cache)
- feats_len = np.array([feats.shape[1]]).astype(np.int32)
- asr_res = self.infer(feats, feats_len, cache)
- return asr_res
- def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
- # encoder forward
- enc_input = [feats, feats_len]
- enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
- # predictor forward
- acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
- # decoder forward
- asr_res = []
- if acoustic_embeds.shape[1] > 0:
- dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
- dec_input.extend(cache["decoder_fsmn"])
- dec_output = self.ort_decoder_infer(dec_input)
- logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
- cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
- preds = self.decode(logits, acoustic_embeds_len)
- for pred in preds:
- pred = sentence_postprocess(pred)
- asr_res.append({'preds': pred})
- return asr_res
- def load_data(self,
- wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
- def load_wav(path: str) -> np.ndarray:
- waveform, _ = librosa.load(path, sr=fs)
- return waveform
- if isinstance(wav_content, np.ndarray):
- return [wav_content]
- if isinstance(wav_content, str):
- return [load_wav(wav_content)]
- if isinstance(wav_content, list):
- return [load_wav(path) for path in wav_content]
- raise TypeError(
- f'The type of {wav_content} is not in [str, np.ndarray, list]')
- def extract_feat(self,
- waveforms: np.ndarray, is_final: bool = False
- ) -> Tuple[np.ndarray, np.ndarray]:
- waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
- for idx, waveform in enumerate(waveforms):
- waveforms_lens[idx] = waveform.shape[-1]
- feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
- return feats.astype(np.float32), feats_len.astype(np.int32)
- def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
- return [self.decode_one(am_score, token_num)
- for am_score, token_num in zip(am_scores, token_nums)]
- def decode_one(self,
- am_score: np.ndarray,
- valid_token_num: int) -> List[str]:
- yseq = am_score.argmax(axis=-1)
- score = am_score.max(axis=-1)
- score = np.sum(score, axis=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- # asr_model.sos:1 asr_model.eos:2
- yseq = np.array([1] + yseq.tolist() + [2])
- hyp = Hypothesis(yseq=yseq, score=score)
- # remove sos/eos and get results
- last_pos = -1
- token_int = hyp.yseq[1:last_pos].tolist()
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x not in (0, 2), token_int))
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
- token = token[:valid_token_num]
- # texts = sentence_postprocess(token)
- return token
- def cif_search(self, hidden, alphas, cache=None):
- batch_size, len_time, hidden_size = hidden.shape
- token_length = []
- list_fires = []
- list_frames = []
- cache_alphas = []
- cache_hiddens = []
- alphas[:, :self.chunk_size[0]] = 0.0
- alphas[:, sum(self.chunk_size[:2]):] = 0.0
- if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
- hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
- alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
- if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
- tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
- tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
- tail_alphas =np.tile(tail_alphas, (batch_size, 1))
- hidden = np.concatenate((hidden, tail_hidden), axis=1)
- alphas = np.concatenate((alphas, tail_alphas), axis=1)
- len_time = alphas.shape[1]
- for b in range(batch_size):
- integrate = 0.0
- frames = np.zeros(hidden_size).astype(np.float32)
- list_frame = []
- list_fire = []
- for t in range(len_time):
- alpha = alphas[b][t]
- if alpha + integrate < self.cif_threshold:
- integrate += alpha
- list_fire.append(integrate)
- frames += alpha * hidden[b][t]
- else:
- frames += (self.cif_threshold - integrate) * hidden[b][t]
- list_frame.append(frames)
- integrate += alpha
- list_fire.append(integrate)
- integrate -= self.cif_threshold
- frames = integrate * hidden[b][t]
- cache_alphas.append(integrate)
- if integrate > 0.0:
- cache_hiddens.append(frames / integrate)
- else:
- cache_hiddens.append(frames)
- token_length.append(len(list_frame))
- list_fires.append(list_fire)
- list_frames.append(list_frame)
- max_token_len = max(token_length)
- list_ls = []
- for b in range(batch_size):
- pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
- if token_length[b] == 0:
- list_ls.append(pad_frames)
- else:
- list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
- cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
- cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
- cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
- cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
- return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
|