Преглед изворни кода

Merge remote-tracking branch 'origin/main'

wucong.lyb пре 2 година
родитељ
комит
b7c82bbb57

+ 3 - 3
README.md

@@ -109,13 +109,13 @@ An example with websocket:
 For the server:
 ```shell
 cd funasr/runtime/python/websocket
-python wss_srv_asr.py --port 10095
+python funasr_wss_server.py --port 10095
 ```
 
 For the client:
 ```shell
-python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
-#python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
+python funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
+#python funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
 ```
 More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/websocket_python.html#id2)
 ## Contact

+ 1 - 1
docs/model_zoo/modelscope_models.md

@@ -15,7 +15,7 @@ Here we provided several pretrained models on different datasets. The details of
 |                                                                     Model Name                                                                     | Language |          Training Data           | Vocab Size | Parameter | Offline/Online | Notes                                                                                                                           |
 |:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
 |        [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)        | CN & EN  | Alibaba Speech Data (60000hours) |    8404    |   220M    |    Offline     | Duration of input wav <= 20s                                                                                                    |
-| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN  | Alibaba Speech Data (60000hours) |    8404    |   220M    |    Offline     | Which ould deal with arbitrary length input wav                                                                                 |
+| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN  | Alibaba Speech Data (60000hours) |    8404    |   220M    |    Offline     | Which would deal with arbitrary length input wav                                                                                 |
 | [Paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN  | Alibaba Speech Data (60000hours) |    8404    |   220M    |    Offline     | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
 |              [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary)              | CN & EN  | Alibaba Speech Data (50000hours) |    8358    |    68M    |    Offline     | Duration of input wav <= 20s                                                                                                    |
 |           [Paraformer-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/summary)           | CN & EN  | Alibaba Speech Data (50000hours) |    8404    |    68M    |     Online     | Which could deal with streaming input                                                                                           |

+ 1 - 0
docs/reference/papers.md

@@ -4,6 +4,7 @@ FunASR have implemented the following paper code
 
 ### Speech Recognition
 - [FunASR: A Fundamental End-to-End Speech Recognition Toolkit](https://arxiv.org/abs/2305.11013), INTERSPEECH 2023
+- [BAT: Boundary aware transducer for memory-efficient and low-latency ASR](https://arxiv.org/abs/2305.11571), INTERSPEECH 2023
 - [Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition](https://arxiv.org/abs/2206.08317), INTERSPEECH 2022
 - [Universal ASR: Unifying Streaming and Non-Streaming ASR Using a Single Encoder-Decoder Model](https://arxiv.org/abs/2010.14099), arXiv preprint arXiv:2010.14099, 2020.
 - [San-m: Memory equipped self-attention for end-to-end speech recognition](https://arxiv.org/pdf/2006.01713), INTERSPEECH 2020

+ 16 - 0
egs/aishell/bat/README.md

@@ -0,0 +1,16 @@
+# Boundary Aware Transducer (BAT) Result
+
+## Training Config
+- 8 gpu(Tesla V100)
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train config: conf/train_conformer_bat.yaml
+- LM config: LM was not used
+- Model size: 90M
+
+## Results (CER)
+- Decode config: conf/decode_bat_conformer.yaml
+
+|   testset   |  CER(%) |
+|:-----------:|:-------:|
+|     dev     |  4.56   |
+|    test     |  4.97   |

+ 1 - 0
egs/aishell/bat/conf/decode_bat_conformer.yaml

@@ -0,0 +1 @@
+beam_size: 10

+ 108 - 0
egs/aishell/bat/conf/train_conformer_bat.yaml

@@ -0,0 +1,108 @@
+encoder: chunk_conformer
+encoder_conf:
+      activation_type: swish
+      positional_dropout_rate: 0.5
+      time_reduction_factor: 2
+      embed_vgg_like: false
+      subsampling_factor: 4
+      linear_units: 2048
+      output_size: 512
+      attention_heads: 8
+      dropout_rate: 0.5
+      positional_dropout_rate: 0.5
+      attention_dropout_rate: 0.5
+      cnn_module_kernel: 15
+      num_blocks: 12    
+
+# decoder related
+rnnt_decoder: rnnt
+rnnt_decoder_conf:
+    embed_size: 512
+    hidden_size: 512
+    embed_dropout_rate: 0.5
+    dropout_rate: 0.5
+    use_embed_mask: true
+
+predictor: bat_predictor
+predictor_conf:
+  idim: 512
+  threshold: 1.0
+  l_order: 1
+  r_order: 1
+  return_accum: true
+
+joint_network_conf:
+    joint_space_size: 512
+
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+
+# Auxiliary CTC
+model: bat
+model_conf:
+    auxiliary_ctc_weight: 0.0
+    cif_weight: 1.0
+    r_d: 3
+    r_u: 5
+
+# minibatch related
+use_amp: true
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+    - valid
+    - loss
+best_model_criterion:
+-   - valid
+    - cer_transducer
+    - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 40
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 50
+    num_time_mask: 5
+
+dataset_conf:
+    data_names: speech,text
+    data_types: sound,text
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 25000
+    num_workers: 8
+
+log_interval: 50

+ 66 - 0
egs/aishell/bat/local/aishell_data_prep.sh

@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# Copyright 2017 Xingyu Na
+# Apache 2.0
+
+#. ./path.sh || exit 1;
+
+if [ $# != 3 ]; then
+  echo "Usage: $0 <audio-path> <text-path> <output-path>"
+  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
+  exit 1;
+fi
+
+aishell_audio_dir=$1
+aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
+
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
+
+mkdir -p $train_dir
+mkdir -p $dev_dir
+mkdir -p $test_dir
+mkdir -p $tmp_dir
+
+# data directory check
+if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
+  echo "Error: $0 requires two directory arguments"
+  exit 1;
+fi
+
+# find wav audio file for train, dev and test resp.
+find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
+n=`cat $tmp_dir/wav.flist | wc -l`
+[ $n -ne 141925 ] && \
+  echo Warning: expected 141925 data data files, found $n
+
+grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
+grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
+grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
+
+rm -r $tmp_dir
+
+# Transcriptions preparation
+for dir in $train_dir $dev_dir $test_dir; do
+  echo Preparing $dir transcriptions
+  sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
+  paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
+  utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
+  awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
+  utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
+  sort -u $dir/transcripts.txt > $dir/text
+done
+
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
+
+for f in wav.scp text; do
+  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
+done
+
+echo "$0: AISHELL data preparation succeeded"
+exit 0;

+ 5 - 0
egs/aishell/bat/path.sh

@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH

+ 210 - 0
egs/aishell/bat/run.sh

@@ -0,0 +1,210 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+gpu_num=8
+count=1
+gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=zh
+token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=5
+
+# feature configuration
+feats_dim=80
+nj=64
+
+# data
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_conformer_bat.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_bat_conformer.yaml
+inference_asr_model=valid.cer_transducer.ave_10best.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+    inference_nj=$[${ngpu}*${njob}]
+    _ngpu=1
+else
+    inference_nj=$njob
+    _ngpu=0
+fi
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # Data preparation
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
+    for x in train dev test; do
+        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+            > ${feats_dir}/data/${x}/text
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+    done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    echo "stage 2: Dictionary Preparation"
+    mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
+    echo "make a dictionary"
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+    echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    echo "stage 4: ASR Training"
+    mkdir -p ${exp_dir}/exp/${model_dir}
+    mkdir -p ${exp_dir}/exp/${model_dir}/log
+    INIT_FILE=./ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi 
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $gpu_num; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            train.py \
+                --task_name asr \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --token_type char \
+                --token_list $token_list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --data_file_names "wav.scp,text" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --resume true \
+                --output_dir ${exp_dir}/exp/${model_dir} \
+                --config $asr_config \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+        } &
+        done
+        wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+    echo "stage 5: Inference"
+    for dset in ${test_sets}; do
+        asr_exp=${exp_dir}/exp/${model_dir}
+        inference_tag="$(basename "${inference_config}" .yaml)"
+        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+        _logdir="${_dir}/logdir"
+        if [ -d ${_dir} ]; then
+            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+            exit 0
+        fi
+        mkdir -p "${_logdir}"
+        _data="${feats_dir}/data/${dset}"
+        key_file=${_data}/${scp}
+        num_scp_file="$(<${key_file} wc -l)"
+        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+        split_scps=
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/keys.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+            python -m funasr.bin.asr_inference_launch \
+                --batch_size 1 \
+                --ngpu "${_ngpu}" \
+                --njob ${njob} \
+                --gpuid_list ${gpuid_list} \
+                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --key_file "${_logdir}"/keys.JOB.scp \
+                --asr_train_config "${asr_exp}"/config.yaml \
+                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+                --output_dir "${_logdir}"/output.JOB \
+                --mode bat \
+                ${_opts}
+
+        for f in token token_int score text; do
+            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | sort -k1 >"${_dir}/${f}"
+            fi
+        done
+        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+        python utils/proce_text.py ${_data}/text ${_data}/text.proc
+        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+        cat ${_dir}/text.cer.txt
+    done
+fi

+ 1 - 0
egs/aishell/bat/utils

@@ -0,0 +1 @@
+../transformer/utils

+ 4 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/demo.py

@@ -3,6 +3,10 @@ from modelscope.utils.constant import Tasks
 
 param_dict = dict()
 param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
+param_dict['clas_scale'] = 1.00  # 1.50 # set it larger if you want high recall (sacrifice general accuracy)
+# 13% relative recall raise over internal hotword test set (45%->51%)
+# CER might raise when utterance contains no hotword
+
 inference_pipeline = pipeline(
     task=Tasks.auto_speech_recognition,
     model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",

+ 10 - 4
funasr/bin/asr_infer.py

@@ -280,6 +280,7 @@ class Speech2TextParaformer:
             nbest: int = 1,
             frontend_conf: dict = None,
             hotword_list_or_file: str = None,
+            clas_scale: float = 1.0,
             decoding_ind: int = 0,
             **kwargs,
     ):
@@ -376,6 +377,7 @@ class Speech2TextParaformer:
         # 6. [Optional] Build hotword list from str, local file or url
         self.hotword_list = None
         self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
+        self.clas_scale = clas_scale
 
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -439,16 +441,20 @@ class Speech2TextParaformer:
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
-                                                                                   NeatContextualParaformer):
+        if not isinstance(self.asr_model, ContextualParaformer) and \
+            not isinstance(self.asr_model, NeatContextualParaformer):
             if self.hotword_list:
                 logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
             decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                      pre_token_length)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
         else:
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
-                                                                     pre_token_length, hw_list=self.hotword_list)
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, 
+                                                                     enc_len, 
+                                                                     pre_acoustic_embeds,
+                                                                     pre_token_length, 
+                                                                     hw_list=self.hotword_list,
+                                                                     clas_scale=self.clas_scale)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         if isinstance(self.asr_model, BiCifParaformer):

+ 19 - 4
funasr/bin/asr_inference_launch.py

@@ -257,6 +257,7 @@ def inference_paraformer(
         export_mode = param_dict.get("export_mode", False)
     else:
         hotword_list_or_file = None
+    clas_scale = param_dict.get('clas_scale', 1.0)
 
     if kwargs.get("device", None) == "cpu":
         ngpu = 0
@@ -289,6 +290,7 @@ def inference_paraformer(
         penalty=penalty,
         nbest=nbest,
         hotword_list_or_file=hotword_list_or_file,
+        clas_scale=clas_scale,
     )
 
     speech2text = Speech2TextParaformer(**speech2text_kwargs)
@@ -617,6 +619,22 @@ def inference_paraformer_vad_punc(
             sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
             results_sorted = []
             
+            if not len(sorted_data):
+                key = keys[0]
+                # no active segments after VAD
+                if writer is not None:
+                    # Write empty results
+                    ibest_writer["token"][key] = ""
+                    ibest_writer["token_int"][key] = ""
+                    ibest_writer["vad"][key] = ""
+                    ibest_writer["text"][key] = ""
+                    ibest_writer["text_with_punc"][key] = ""
+                    if use_timestamp:
+                        ibest_writer["time_stamp"][key] = ""
+
+                logging.info("decoding, utt: {}, empty speech".format(key))
+                continue
+
             batch_size_token_ms = batch_size_token*60
             if speech2text.device == "cpu":
                 batch_size_token_ms = 0
@@ -1349,10 +1367,7 @@ def inference_transducer(
         left_context=left_context,
         right_context=right_context,
     )
-    speech2text = Speech2TextTransducer.from_pretrained(
-        model_tag=model_tag,
-        **speech2text_kwargs,
-    )
+    speech2text = Speech2TextTransducer(**speech2text_kwargs)
 
     def _forward(data_path_and_name_and_type,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,

+ 3 - 1
funasr/bin/build_trainer.py

@@ -85,7 +85,9 @@ def build_trainer(modelscope_dict,
         finetune_configs = yaml.safe_load(f)
         # set data_types
         if dataset_type == "large":
-            finetune_configs["dataset_conf"]["data_types"] = "sound,text"
+            # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
+            if 'data_types' not in finetune_configs['dataset_conf']:
+                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
     finetune_configs = update_dct(configs, finetune_configs)
     for key, value in finetune_configs.items():
         if hasattr(args, key):

+ 3 - 12
funasr/bin/diar_inference_launch.py

@@ -92,10 +92,7 @@ def inference_sond(
             embedding_node="resnet1_dense"
         )
         logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
-        speech2xvector = Speech2Xvector.from_pretrained(
-            model_tag=model_tag,
-            **speech2xvector_kwargs,
-        )
+        speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
         speech2xvector.sv_model.eval()
 
     # 2b. Build speech2diar
@@ -109,10 +106,7 @@ def inference_sond(
         dur_threshold=dur_threshold,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationSOND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):
@@ -257,10 +251,7 @@ def inference_eend(
         dtype=dtype,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationEEND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):

+ 10 - 8
funasr/datasets/large_datasets/dataset.py

@@ -202,14 +202,7 @@ def Dataset(data_list_file,
     data_types = conf.get("data_types", "kaldi_ark,text")
 
     pre_hwfile = conf.get("pre_hwlist", None)
-    pre_prob = conf.get("pre_prob", 0)  # unused yet
-
-    hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
-                 "double_rate": conf.get("double_rate", 0.1),
-                 "hotword_min_length": conf.get("hotword_min_length", 2),
-                 "hotword_max_length": conf.get("hotword_max_length", 8),
-                 "pre_prob": conf.get("pre_prob", 0.0)}
-
+    # pre_prob = conf.get("pre_prob", 0)  # unused yet
     if pre_hwfile is not None:
         pre_hwlist = []
         with open(pre_hwfile, 'r') as fin:
@@ -218,6 +211,15 @@ def Dataset(data_list_file,
     else:
         pre_hwlist = None
 
+    hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
+                 "double_rate": conf.get("double_rate", 0.1),
+                 "hotword_min_length": conf.get("hotword_min_length", 2),
+                 "hotword_max_length": conf.get("hotword_max_length", 8),
+                 "pre_prob": conf.get("pre_prob", 0.0),
+                 "pre_hwlist": pre_hwlist}
+
+    
+
     dataset = AudioDataset(scp_lists, 
                            data_names, 
                            data_types, 

+ 2 - 1
funasr/datasets/large_datasets/utils/hotword_utils.py

@@ -6,7 +6,8 @@ def sample_hotword(length,
                    sample_rate,
                    double_rate,
                    pre_prob,
-                   pre_index=None):
+                   pre_index=None,
+                   pre_hwlist=None):
         if length < hotword_min_length:
             return [-1]
         if random.random() < sample_rate:

+ 11 - 1
funasr/datasets/large_datasets/utils/tokenize.py

@@ -54,7 +54,17 @@ def tokenize(data,
 
     length = len(text)
     if 'hw_tag' in data:
-        hotword_indxs = sample_hotword(length, **hw_config)
+        if hw_config['pre_hwlist'] is not None and hw_config['pre_prob'] > 0:
+            # enable preset hotword detect in sampling
+            pre_index = None
+            for hw in hw_config['pre_hwlist']:
+                hw = " ".join(seg_tokenize(hw, seg_dict))
+                _find = " ".join(text).find(hw)
+                if _find != -1:
+                    # _find = text[:_find].count(" ")  # bpe sometimes
+                    pre_index = [_find, _find + max(hw.count(" "), 1)]
+                    break
+        hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
         data['hotword_indxs'] = hotword_indxs
         del data['hw_tag']
     for i in range(length):

+ 2 - 1
funasr/models/decoder/contextual_decoder.py

@@ -244,6 +244,7 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
         ys_in_pad: torch.Tensor,
         ys_in_lens: torch.Tensor,
         contextual_info: torch.Tensor,
+        clas_scale: float = 1.0,
         return_hidden: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
@@ -283,7 +284,7 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder):
         cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
 
         if self.bias_output is not None:
-            x = torch.cat([x_src_attn, cx], dim=2)
+            x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
             x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
             x = x_self_attn + self.dropout(x)
 

+ 2 - 2
funasr/models/e2e_asr_contextual_paraformer.py

@@ -341,7 +341,7 @@ class NeatContextualParaformer(Paraformer):
             input_mask_expand_dim, 0)
         return sematic_embeds * tgt_mask, decoder_out * tgt_mask
 
-    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
         if hw_list is None:
             hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
             hw_list_pad = pad_list(hw_list, 0)
@@ -363,7 +363,7 @@ class NeatContextualParaformer(Paraformer):
             hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
         
         decoder_outs = self.decoder(
-            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
         )
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)

+ 3 - 2
funasr/runtime/docs/SDK_advanced_guide_offline_zh.md

@@ -35,9 +35,9 @@ sudo systemctl start docker
 通过下述命令拉取并启动FunASR runtime-SDK的docker镜像:
 
 ```shell
-sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.1.0
 
-sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.1.0
 ```
 
 命令参数介绍:
@@ -53,6 +53,7 @@ sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models
 
 docker启动之后,启动 funasr-wss-server服务程序:
 ```shell
+cd FunASR/funasr/runtime
 ./run_server.sh --vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
   --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx  \
   --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx

+ 1 - 1
funasr/runtime/html5/readme.md

@@ -41,7 +41,7 @@ python h5Server.py --host 0.0.0.0 --port 1337
 `Tips:` asr service and html5 service should be deployed on the same device.
 ```shell
 cd ../python/websocket
-python wss_srv_asr.py --port 10095
+python funasr_wss_server.py --port 10095
 ```
 
 

+ 1 - 1
funasr/runtime/html5/readme_cn.md

@@ -49,7 +49,7 @@ python h5Server.py --host 0.0.0.0 --port 1337
 #### wss方式
 ```shell
 cd ../python/websocket
-python wss_srv_asr.py --port 10095
+python funasr_wss_server.py --port 10095
 ```
 
 ### 浏览器打开地址

+ 9 - 9
funasr/runtime/python/websocket/README.md

@@ -24,7 +24,7 @@ pip install -r requirements_server.txt
 
 ##### API-reference
 ```shell
-python wss_srv_asr.py \
+python funasr_wss_server.py \
 --port [port id] \
 --asr_model [asr model_name] \
 --asr_model_online [asr model_name] \
@@ -36,7 +36,7 @@ python wss_srv_asr.py \
 ```
 ##### Usage examples
 ```shell
-python wss_srv_asr.py --port 10095
+python funasr_wss_server.py --port 10095
 ```
 
 ## For the client
@@ -51,7 +51,7 @@ pip install -r requirements_client.txt
 ### Start client
 #### API-reference
 ```shell
-python wss_client_asr.py \
+python funasr_wss_client.py \
 --host [ip_address] \
 --port [port id] \
 --chunk_size ["5,10,5"=600ms, "8,8,4"=480ms] \
@@ -68,36 +68,36 @@ python wss_client_asr.py \
 Recording from mircrophone
 ```shell
 # --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode offline
+python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode offline
 ```
 Loadding from wav.scp(kaldi style)
 ```shell
 # --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode offline --audio_in "./data/wav.scp" --output_dir "./results"
+python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode offline --audio_in "./data/wav.scp" --output_dir "./results"
 ```
 
 ##### ASR streaming client
 Recording from mircrophone
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5"
+python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5"
 ```
 Loadding from wav.scp(kaldi style)
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5" --audio_in "./data/wav.scp" --output_dir "./results"
+python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5" --audio_in "./data/wav.scp" --output_dir "./results"
 ```
 
 ##### ASR offline/online 2pass client
 Recording from mircrophone
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4"
+python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4"
 ```
 Loadding from wav.scp(kaldi style)
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
+python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
 ```
 ## Acknowledge
 1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).

+ 10 - 26
funasr/runtime/python/websocket/wss_client_asr.py → funasr/runtime/python/websocket/funasr_wss_client.py

@@ -100,11 +100,13 @@ async def record_microphone():
 
     message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
                           "wav_name": "microphone", "is_speaking": True})
-    voices.put(message)
+    #voices.put(message)
+    await websocket.send(message)
     while True:
         data = stream.read(CHUNK)
         message = data
-        voices.put(message)
+        #voices.put(message)
+        await websocket.send(message)
         await asyncio.sleep(0.005)
 
 async def record_from_scp(chunk_begin, chunk_size):
@@ -178,25 +180,7 @@ async def record_from_scp(chunk_begin, chunk_size):
     await websocket.close()
 
 
-async def ws_send():
-    global voices
-    global websocket
-    print("started to sending data!")
-    while True:
-        while not voices.empty():
-            data = voices.get()
-            voices.task_done()
-            try:
-                await websocket.send(data)
-            except Exception as e:
-                print('Exception occurred:', e)
-                traceback.print_exc()
-                exit(0)
-            await asyncio.sleep(0.005)
-        await asyncio.sleep(0.005)
-
- 
-             
+          
 async def message(id):
     global websocket,voices,offline_msg_done
     text_print = ""
@@ -215,12 +199,12 @@ async def message(id):
             if meg["mode"] == "online":
                 text_print += "{}".format(text)
                 text_print = text_print[-args.words_max_print:]
-                os.system('clear')
+                # os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
             elif meg["mode"] == "offline":
                 text_print += "{}".format(text)
                 text_print = text_print[-args.words_max_print:]
-                os.system('clear')
+                # os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
                 offline_msg_done=True
             else:
@@ -232,8 +216,9 @@ async def message(id):
                     text_print = text_print_2pass_offline + "{}".format(text)
                     text_print_2pass_offline += "{}".format(text)
                 text_print = text_print[-args.words_max_print:]
-                os.system('clear')
+                # os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
+                offline_msg_done=True
 
     except Exception as e:
             print("Exception:", e)
@@ -277,9 +262,8 @@ async def ws_client(id, chunk_begin, chunk_size):
             task = asyncio.create_task(record_from_scp(i, 1))
         else:
             task = asyncio.create_task(record_microphone())
-        task2 = asyncio.create_task(ws_send())
         task3 = asyncio.create_task(message(str(id)+"_"+str(i))) #processid+fileid
-        await asyncio.gather(task, task2, task3)
+        await asyncio.gather(task, task3)
   exit(0)
     
 

+ 0 - 0
funasr/runtime/python/websocket/wss_srv_asr.py → funasr/runtime/python/websocket/funasr_wss_server.py


+ 12 - 4
tests/test_asr_inference_pipeline.py

@@ -119,20 +119,28 @@ class TestParaformerInferencePipelines(unittest.TestCase):
     def test_paraformer_large_online_common(self):
         inference_pipeline = pipeline(
             task=Tasks.auto_speech_recognition,
-            model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online')
+            model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
+            model_revision='v1.0.6',
+            update_model=False,
+            mode="paraformer_fake_streaming"
+        )
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
-        assert rec_result["text"] == "欢迎大 家来 体验达 摩院推 出的 语音识 别模 型"
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
 
     def test_paraformer_online_common(self):
         inference_pipeline = pipeline(
             task=Tasks.auto_speech_recognition,
-            model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online')
+            model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
+            model_revision='v1.0.6',
+            update_model=False,
+            mode="paraformer_fake_streaming"
+        )
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
         logger.info("asr inference result: {0}".format(rec_result))
-        assert rec_result["text"] == "欢迎 大家来 体验达 摩院推 出的 语音识 别模 型"
+        assert rec_result["text"] == "欢迎大家来体验达摩院推出的语音识别模型"
 
     def test_paraformer_tiny_commandword(self):
         inference_pipeline = pipeline(