فهرست منبع

add speaker-attributed ASR task for alimeeting

smohan-speech 2 سال پیش
والد
کامیت
af6740a220

+ 79 - 0
egs/alimeeting/sa-asr/README.md

@@ -0,0 +1,79 @@
+# Get Started
+Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).  
+To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html))  
+There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MET2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge.  
+Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory:
+```shell
+dataset
+|—— Eval_Ali_far
+|—— Eval_Ali_near
+|—— Test_Ali_far
+|—— Test_Ali_near
+|—— Train_Ali_far
+|—— Train_Ali_near
+```
+There are 18 stages in `run.sh`:
+```shell
+stage 1 - 5: Data preparation and processing.
+stage 6: Generate speaker profiles (Stage 6 takes a lot of time).
+stage 7 - 9: Language model training (Optional).
+stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model).
+stage 12: SA-ASR training.
+stage 13 - 18: Inference and evaluation.
+```
+Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory.  
+```shell
+data/Test_2023_Ali_far
+|—— wav.scp
+|—— wav_raw.scp
+|—— segments
+|—— utt2spk
+|—— spk2utt
+```
+There are 4 stages in `run_m2met_2023_infer.sh`:
+```shell
+stage 1: Data preparation and processing.
+stage 2: Generate speaker profiles for inference.
+stage 3: Inference.
+stage 4: Generation of SA-ASR results required for final submission.
+```
+# Format of Final Submission
+Finally, you need to submit a file called `text_spk_merge` with the following format:
+```shell
+Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
+Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
+...
+```
+Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end.  For more information, please refer to the results generated after executing the baseline code.
+# Baseline Results
+The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy.  
+<table>
+    <tr >
+	    <td rowspan="2"></td>
+        <td colspan="2">SI-CER(%)</td>
+	    <td colspan="2">cpCER(%)</td>
+	</tr>
+    <tr>
+        <td>Eval</td>
+	    <td>Test</td>
+	    <td>Eval</td>
+	    <td>Test</td>
+	</tr>
+    <tr>
+	    <td>oracle profile</td>
+        <td>31.93</td>
+        <td>32.75</td>
+	    <td>48.56</td>
+        <td>53.33</td>
+	</tr>
+    <tr>
+	    <td>cluster profile</td>
+        <td>31.94</td>
+        <td>32.77</td>
+	    <td>55.49</td>
+        <td>58.17</td>
+	</tr>
+</table>
+
+# Reference
+N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417.

+ 11 - 4
egs/alimeeting/sa-asr/asr_local.sh

@@ -475,7 +475,9 @@ if ! "${skip_data_prep}"; then
                 fi
                 local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
                 
-                cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
+                if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
+                    cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
+                fi
 
                 rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
                 _opts=
@@ -568,8 +570,11 @@ if ! "${skip_data_prep}"; then
 
             # generate uttid
             cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
-            # filter utt2spk_all_fifo
-            python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
+            
+            if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
+                # filter utt2spk_all_fifo
+                python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
+            fi
         done
 
         # shellcheck disable=SC2002
@@ -585,7 +590,7 @@ if ! "${skip_data_prep}"; then
         echo "<blank>" > ${token_list}
         echo "<s>" >> ${token_list}
         echo "</s>" >> ${token_list}
-        local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
+        utils/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
             | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
         num_token=$(cat ${token_list} | wc -l)
         echo "<unk>" >> ${token_list}
@@ -603,6 +608,7 @@ if ! "${skip_data_prep}"; then
             python local/process_text_id.py ${data_feats}/${dset}
             log "Successfully generate ${data_feats}/${dset}/text_id_train"
             # generate oracle_embedding from single-speaker audio segment
+            log "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${dset}.log"
             python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
             log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
             # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
@@ -615,6 +621,7 @@ if ! "${skip_data_prep}"; then
             fi
             # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
             if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
+                log "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${dset}.log"
                 python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
                 log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
             fi

+ 2 - 2
egs/alimeeting/sa-asr/asr_local_infer.sh → egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh

@@ -449,7 +449,7 @@ if ! "${skip_data_prep}"; then
                     _opts+="--segments data/${dset}/segments "
                 fi
                 # shellcheck disable=SC2086
-                scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+                local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
                     --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
                     "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
 
@@ -467,7 +467,7 @@ if ! "${skip_data_prep}"; then
         mkdir -p "profile_log"
         for dset in "${test_sets}"; do
             # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
-            python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+            python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
             log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
             done
     fi

+ 0 - 157
egs/alimeeting/sa-asr/local/compute_wer.py

@@ -1,157 +0,0 @@
-import os
-import numpy as np
-import sys
-
-def compute_wer(ref_file,
-                hyp_file,
-                cer_detail_file):
-    rst = {
-        'Wrd': 0,
-        'Corr': 0,
-        'Ins': 0,
-        'Del': 0,
-        'Sub': 0,
-        'Snt': 0,
-        'Err': 0.0,
-        'S.Err': 0.0,
-        'wrong_words': 0,
-        'wrong_sentences': 0
-    }
-
-    hyp_dict = {}
-    ref_dict = {}
-    with open(hyp_file, 'r') as hyp_reader:
-        for line in hyp_reader:
-            key = line.strip().split()[0]
-            value = line.strip().split()[1:]
-            hyp_dict[key] = value
-    with open(ref_file, 'r') as ref_reader:
-        for line in ref_reader:
-            key = line.strip().split()[0]
-            value = line.strip().split()[1:]
-            ref_dict[key] = value
-
-    cer_detail_writer = open(cer_detail_file, 'w')
-    for hyp_key in hyp_dict:
-        if hyp_key in ref_dict:
-           out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
-           rst['Wrd'] += out_item['nwords']
-           rst['Corr'] += out_item['cor']
-           rst['wrong_words'] += out_item['wrong']
-           rst['Ins'] += out_item['ins']
-           rst['Del'] += out_item['del']
-           rst['Sub'] += out_item['sub']
-           rst['Snt'] += 1
-           if out_item['wrong'] > 0:
-               rst['wrong_sentences'] += 1
-           cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
-           cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
-           cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
-
-    if rst['Wrd'] > 0:
-        rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
-    if rst['Snt'] > 0:
-        rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
-
-    cer_detail_writer.write('\n')
-    cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
-                            ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
-    cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
-    cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
-
-     
-def compute_wer_by_line(hyp,
-                        ref):
-    hyp = list(map(lambda x: x.lower(), hyp))
-    ref = list(map(lambda x: x.lower(), ref))
-
-    len_hyp = len(hyp)
-    len_ref = len(ref)
-
-    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
-
-    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
-
-    for i in range(len_hyp + 1):
-        cost_matrix[i][0] = i
-    for j in range(len_ref + 1):
-        cost_matrix[0][j] = j
-
-    for i in range(1, len_hyp + 1):
-        for j in range(1, len_ref + 1):
-            if hyp[i - 1] == ref[j - 1]:
-                cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
-            else:
-                substitution = cost_matrix[i - 1][j - 1] + 1
-                insertion = cost_matrix[i - 1][j] + 1
-                deletion = cost_matrix[i][j - 1] + 1
-
-                compare_val = [substitution, insertion, deletion]
-
-                min_val = min(compare_val)
-                operation_idx = compare_val.index(min_val) + 1
-                cost_matrix[i][j] = min_val
-                ops_matrix[i][j] = operation_idx
-
-    match_idx = []
-    i = len_hyp
-    j = len_ref
-    rst = {
-        'nwords': len_ref,
-        'cor': 0,
-        'wrong': 0,
-        'ins': 0,
-        'del': 0,
-        'sub': 0
-    }
-    while i >= 0 or j >= 0:
-        i_idx = max(0, i)
-        j_idx = max(0, j)
-
-        if ops_matrix[i_idx][j_idx] == 0:  # correct
-            if i - 1 >= 0 and j - 1 >= 0:
-                match_idx.append((j - 1, i - 1))
-                rst['cor'] += 1
-
-            i -= 1
-            j -= 1
-
-        elif ops_matrix[i_idx][j_idx] == 2:  # insert
-            i -= 1
-            rst['ins'] += 1
-
-        elif ops_matrix[i_idx][j_idx] == 3:  # delete
-            j -= 1
-            rst['del'] += 1
-
-        elif ops_matrix[i_idx][j_idx] == 1:  # substitute
-            i -= 1
-            j -= 1
-            rst['sub'] += 1
-
-        if i < 0 and j >= 0:
-            rst['del'] += 1
-        elif j < 0 and i >= 0:
-            rst['ins'] += 1
-
-    match_idx.reverse()
-    wrong_cnt = cost_matrix[len_hyp][len_ref]
-    rst['wrong'] = wrong_cnt
-
-    return rst
-
-def print_cer_detail(rst):
-    return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
-            + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
-            + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
-            + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
-
-if __name__ == '__main__':
-    if len(sys.argv) != 4:
-        print("usage : python compute-wer.py test.ref test.hyp test.wer")
-        sys.exit(0)
-
-    ref_file = sys.argv[1]
-    hyp_file = sys.argv[2]
-    cer_detail_file = sys.argv[3]
-    compute_wer(ref_file, hyp_file, cer_detail_file)

+ 12 - 12
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh

@@ -63,20 +63,20 @@ else
 fi
 
 
-<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
-  utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
+<"${srcdir}"/utt2spk local/apply_map.pl -f 1 "${destdir}"/utt_map | \
+  local/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
 
-utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
+local/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
 
 if [[ -f ${srcdir}/segments ]]; then
 
-  utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
-      utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
+  local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
+      local/apply_map.pl -f 2 "${destdir}"/reco_map | \
           awk -v factor="${factor}" \
             '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
             >"${destdir}"/segments
 
-  utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+  local/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
       # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
       awk -v factor="${factor}" \
           '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
@@ -84,13 +84,13 @@ if [[ -f ${srcdir}/segments ]]; then
             else  {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
              > "${destdir}"/wav.scp
   if [[ -f ${srcdir}/reco2file_and_channel ]]; then
-      utils/apply_map.pl -f 1 "${destdir}"/reco_map \
+      local/apply_map.pl -f 1 "${destdir}"/reco_map \
        <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
   fi
 
 else # no segments->wav indexed by utterance.
     if [[ -f ${srcdir}/wav.scp ]]; then
-        utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+        local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
          # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
          awk -v factor="${factor}" \
            '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
@@ -101,16 +101,16 @@ else # no segments->wav indexed by utterance.
 fi
 
 if [[ -f ${srcdir}/text ]]; then
-    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
+    local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
 fi
 if [[ -f ${srcdir}/spk2gender ]]; then
-    utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
+    local/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
 fi
 if [[ -f ${srcdir}/utt2lang ]]; then
-    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
+    local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
 fi
 
 rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
 echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
 
-utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
+local/validate_data_dir.sh --no-feats --no-text "${destdir}"

+ 0 - 32
egs/alimeeting/sa-asr/local/proce_text.py

@@ -1,32 +0,0 @@
-
-import sys
-import re
-
-in_f = sys.argv[1]
-out_f = sys.argv[2]
-
-
-with open(in_f, "r", encoding="utf-8") as f:
-  lines = f.readlines()
-
-with open(out_f, "w", encoding="utf-8") as f:
-  for line in lines:
-    outs = line.strip().split(" ", 1)
-    if len(outs) == 2:
-      idx, text = outs
-      text = re.sub("</s>", "", text)
-      text = re.sub("<s>", "", text)
-      text = re.sub("@@", "", text)
-      text = re.sub("@", "", text)
-      text = re.sub("<unk>", "", text)
-      text = re.sub(" ", "", text)
-      text = re.sub("\$", "", text)
-      text = text.lower()
-    else:
-      idx = outs[0]
-      text = " "
-
-    text = [x for x in text]
-    text = " ".join(text)
-    out = "{} {}\n".format(idx, text)
-    f.write(out)

+ 0 - 1
egs/alimeeting/sa-asr/run_m2met_2023.sh → egs/alimeeting/sa-asr/run.sh

@@ -8,7 +8,6 @@ set -o pipefail
 ngpu=4
 device="0,1,2,3"
 
-#stage 1 creat both near and far
 stage=1
 stop_stage=18
 

+ 1 - 1
egs/alimeeting/sa-asr/run_m2met_2023_infer.sh

@@ -22,7 +22,7 @@ inference_config=conf/decode_asr_rnn.yaml
 lm_config=conf/train_lm_transformer.yaml
 use_lm=false
 use_wordlm=false
-./asr_local_infer.sh                                         \
+./asr_local_m2met_2023_infer.sh                                         \
     --device ${device}                                 \
     --ngpu ${ngpu}                                     \
     --stage ${stage}                                   \

+ 1 - 8
funasr/bin/asr_inference.py

@@ -94,7 +94,7 @@ class Speech2Text:
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
             if asr_train_args.frontend=='wav_frontend':
-                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
             else:
                 frontend_class=frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -147,13 +147,6 @@ class Speech2Text:
             pre_beam_score_key=None if ctc_weight == 1.0 else "full",
         )
 
-        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-        for scorer in scorers.values():
-            if isinstance(scorer, torch.nn.Module):
-                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-        logging.info(f"Beam_search: {beam_search}")
-        logging.info(f"Decoding device={device}, dtype={dtype}")
-
         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
             token_type = asr_train_args.token_type

+ 1 - 8
funasr/bin/sa_asr_inference.py

@@ -89,7 +89,7 @@ class Speech2Text:
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
             if asr_train_args.frontend=='wav_frontend':
-                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
             else:
                 frontend_class=frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -142,13 +142,6 @@ class Speech2Text:
             pre_beam_score_key=None if ctc_weight == 1.0 else "full",
         )
 
-        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-        for scorer in scorers.values():
-            if isinstance(scorer, torch.nn.Module):
-                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-        logging.info(f"Beam_search: {beam_search}")
-        logging.info(f"Decoding device={device}, dtype={dtype}")
-
         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
             token_type = asr_train_args.token_type

+ 1 - 1
funasr/losses/label_smoothing_loss.py

@@ -97,7 +97,7 @@ class NllLoss(nn.Module):
         normalize_length=False,
         criterion=nn.NLLLoss(reduction='none'),
     ):
-        """Construct an LabelSmoothingLoss object."""
+        """Construct an NllLoss object."""
         super(NllLoss, self).__init__()
         self.criterion = criterion
         self.padding_idx = padding_idx