Просмотр исходного кода

add speaker-attributed ASR task for alimeeting

smohan-speech 2 лет назад
Родитель
Сommit
d76aea23d9
35 измененных файлов с 1090 добавлено и 516 удалено
  1. 18 15
      egs/alimeeting/sa-asr/asr_local.sh
  2. 2 1
      egs/alimeeting/sa-asr/asr_local_infer.sh
  3. 0 1
      egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
  4. 0 1
      egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
  5. 7 7
      egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
  6. 5 5
      egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
  7. 0 0
      egs/alimeeting/sa-asr/local/apply_map.pl
  8. 3 3
      egs/alimeeting/sa-asr/local/combine_data.sh
  9. 14 14
      egs/alimeeting/sa-asr/local/copy_data_dir.sh
  10. 0 0
      egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
  11. 1 1
      egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
  12. 1 1
      egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
  13. 3 3
      egs/alimeeting/sa-asr/local/data/split_data.sh
  14. 3 3
      egs/alimeeting/sa-asr/local/fix_data_dir.sh
  15. 243 0
      egs/alimeeting/sa-asr/local/format_wav_scp.py
  16. 142 0
      egs/alimeeting/sa-asr/local/format_wav_scp.sh
  17. 116 0
      egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
  18. 0 0
      egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
  19. 0 0
      egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
  20. 2 2
      egs/alimeeting/sa-asr/local/validate_data_dir.sh
  21. 0 0
      egs/alimeeting/sa-asr/local/validate_text.pl
  22. 1 2
      egs/alimeeting/sa-asr/path.sh
  23. 1 0
      egs/alimeeting/sa-asr/utils
  24. 0 87
      egs/alimeeting/sa-asr/utils/filter_scp.pl
  25. 0 97
      egs/alimeeting/sa-asr/utils/parse_options.sh
  26. 0 246
      egs/alimeeting/sa-asr/utils/split_scp.pl
  27. 24 4
      funasr/bin/asr_inference.py
  28. 7 1
      funasr/bin/asr_inference_launch.py
  29. 0 8
      funasr/bin/asr_train.py
  30. 22 2
      funasr/bin/sa_asr_inference.py
  31. 0 8
      funasr/bin/sa_asr_train.py
  32. 46 0
      funasr/losses/label_smoothing_loss.py
  33. 427 1
      funasr/models/decoder/transformer_decoder.py
  34. 1 2
      funasr/models/e2e_sa_asr.py
  35. 1 1
      funasr/tasks/sa_asr.py

+ 18 - 15
egs/alimeeting/sa-asr/asr_local.sh

@@ -434,14 +434,14 @@ if ! "${skip_data_prep}"; then
            log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
            for factor in ${speed_perturb_factors}; do
                if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
-                   scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
+                   local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
                    _dirs+="data/${train_set}_sp${factor} "
                else
                    # If speed factor is 1, same as the original
                    _dirs+="data/${train_set} "
                fi
            done
-           utils/combine_data.sh "data/${train_set}_sp" ${_dirs}
+           local/combine_data.sh "data/${train_set}_sp" ${_dirs}
         else
            log "Skip stage 2: Speed perturbation"
         fi
@@ -473,7 +473,7 @@ if ! "${skip_data_prep}"; then
                         _suf=""
                     fi
                 fi
-                utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+                local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
                 
                 cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
 
@@ -488,7 +488,7 @@ if ! "${skip_data_prep}"; then
                     _opts+="--segments data/${dset}/segments "
                 fi
                 # shellcheck disable=SC2086
-                scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+                local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
                     --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
                     "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
 
@@ -515,7 +515,7 @@ if ! "${skip_data_prep}"; then
         for dset in $rm_dset; do
 
             # Copy data dir
-            utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
+            local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
             cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
 
             # Remove short utterances
@@ -564,7 +564,7 @@ if ! "${skip_data_prep}"; then
                 awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
 
             # fix_data_dir.sh leaves only utts which exist in all files
-            utils/fix_data_dir.sh "${data_feats}/${dset}"
+            local/fix_data_dir.sh "${data_feats}/${dset}"
 
             # generate uttid
             cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
@@ -1283,6 +1283,7 @@ if ! "${skip_eval}"; then
             ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                 python -m funasr.bin.asr_inference_launch \
                     --batch_size 1 \
+                    --mc True   \
                     --nbest 1   \
                     --ngpu "${_ngpu}" \
                     --njob ${njob_infer} \
@@ -1312,10 +1313,10 @@ if ! "${skip_eval}"; then
             _data="${data_feats}/${dset}"
             _dir="${asr_exp}/${inference_tag}/${dset}"
 
-            python local/proce_text.py ${_data}/text ${_data}/text.proc
-            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+            python utils/proce_text.py ${_data}/text ${_data}/text.proc
+            python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
 
-            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
             tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
             cat ${_dir}/text.cer.txt
             
@@ -1390,6 +1391,7 @@ if ! "${skip_eval}"; then
             ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                 python -m funasr.bin.asr_inference_launch \
                     --batch_size 1 \
+                    --mc True   \
                     --nbest 1   \
                     --ngpu "${_ngpu}" \
                     --njob ${njob_infer} \
@@ -1421,10 +1423,10 @@ if ! "${skip_eval}"; then
             _data="${data_feats}/${dset}"
             _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
 
-            python local/proce_text.py ${_data}/text ${_data}/text.proc
-            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+            python utils/proce_text.py ${_data}/text ${_data}/text.proc
+            python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
 
-            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
             tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
             cat ${_dir}/text.cer.txt
 
@@ -1506,6 +1508,7 @@ if ! "${skip_eval}"; then
             ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                 python -m funasr.bin.asr_inference_launch \
                     --batch_size 1 \
+                    --mc True   \
                     --nbest 1   \
                     --ngpu "${_ngpu}" \
                     --njob ${njob_infer} \
@@ -1536,10 +1539,10 @@ if ! "${skip_eval}"; then
             _data="${data_feats}/${dset}"
             _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
 
-            python local/proce_text.py ${_data}/text ${_data}/text.proc
-            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+            python utils/proce_text.py ${_data}/text ${_data}/text.proc
+            python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
 
-            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
             tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
             cat ${_dir}/text.cer.txt
 

+ 2 - 1
egs/alimeeting/sa-asr/asr_local_infer.sh

@@ -436,7 +436,7 @@ if ! "${skip_data_prep}"; then
             
                 _suf=""
 
-                utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+                local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
                 
                 rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
                 _opts=
@@ -548,6 +548,7 @@ if ! "${skip_eval}"; then
             ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                 python -m funasr.bin.asr_inference_launch \
                     --batch_size 1 \
+                    --mc True   \
                     --nbest 1   \
                     --ngpu "${_ngpu}" \
                     --njob ${njob_infer} \

+ 0 - 1
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml

@@ -4,7 +4,6 @@ frontend_conf:
     n_fft: 400
     win_length: 400
     hop_length: 160
-    use_channel: 0
     
 # encoder related
 encoder: conformer

+ 0 - 1
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml

@@ -4,7 +4,6 @@ frontend_conf:
     n_fft: 400
     win_length: 400
     hop_length: 160
-    use_channel: 0
 
 # encoder related
 asr_encoder: conformer

+ 7 - 7
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh

@@ -78,7 +78,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
     utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
     #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $near_dir/utt2spk_old >$near_dir/tmp1
     #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
-    utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
+    local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
     utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
     sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
     sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2
@@ -109,7 +109,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
     #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $far_dir/utt2spk_old >$far_dir/utt2spk
     
-    utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+    local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
     utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
     sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
     sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@@ -121,8 +121,8 @@ fi
 if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
     log "stage 3: finali data process"
 
-    utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
-    utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+    local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
+    local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
 
     sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
     sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@@ -146,10 +146,10 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
     python local/process_textgrid_to_single_speaker_wav.py  --path $far_single_speaker_dir
     
     cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text    
-    utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+    local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
 
-    ./utils/fix_data_dir.sh $far_single_speaker_dir 
-    utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+    ./local/fix_data_dir.sh $far_single_speaker_dir 
+    local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
 
     # remove space in text
     for x in ${tgt}_Ali_far_single_speaker; do

+ 5 - 5
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh

@@ -77,7 +77,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
     utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
     #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $far_dir/utt2spk_old >$far_dir/utt2spk
     
-    utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+    local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
     utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
     sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
     sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@@ -89,7 +89,7 @@ fi
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     log "stage 2: finali data process"
 
-    utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+    local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
 
     sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
     sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@@ -113,10 +113,10 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
     python local/process_textgrid_to_single_speaker_wav.py  --path $far_single_speaker_dir
     
     cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text    
-    utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+    local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
 
-    ./utils/fix_data_dir.sh $far_single_speaker_dir 
-    utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+    ./local/fix_data_dir.sh $far_single_speaker_dir 
+    local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
 
     # remove space in text
     for x in ${tgt}_Ali_far_single_speaker; do

+ 0 - 0
egs/alimeeting/sa-asr/utils/apply_map.pl → egs/alimeeting/sa-asr/local/apply_map.pl


+ 3 - 3
egs/alimeeting/sa-asr/utils/combine_data.sh → egs/alimeeting/sa-asr/local/combine_data.sh

@@ -98,7 +98,7 @@ if $has_segments; then
   for in_dir in $*; do
     if [ ! -f $in_dir/segments ]; then
       echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
-      utils/data/get_segments_for_data.sh $in_dir
+      local/data/get_segments_for_data.sh $in_dir
     else
       cat $in_dir/segments
     fi
@@ -133,14 +133,14 @@ for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn
   fi
 done
 
-utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
+local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
 
 if [[ $dir_with_frame_shift ]]; then
   cp $dir_with_frame_shift/frame_shift $dest
 fi
 
 if ! $skip_fix ; then
-  utils/fix_data_dir.sh $dest || exit 1;
+  local/fix_data_dir.sh $dest || exit 1;
 fi
 
 exit 0

+ 14 - 14
egs/alimeeting/sa-asr/utils/copy_data_dir.sh → egs/alimeeting/sa-asr/local/copy_data_dir.sh

@@ -71,25 +71,25 @@ else
   cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
 fi
 
-cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map  | \
-  utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
+cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map  | \
+  local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
 
-utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
+local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
 
 if [ -f $srcdir/feats.scp ]; then
-  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
+  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
 fi
 
 if [ -f $srcdir/vad.scp ]; then
-  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
+  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
 fi
 
 if [ -f $srcdir/segments ]; then
-  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
+  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
   cp $srcdir/wav.scp $destdir
 else # no segments->wav indexed by utt.
   if [ -f $srcdir/wav.scp ]; then
-    utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
+    local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
   fi
 fi
 
@@ -98,26 +98,26 @@ if [ -f $srcdir/reco2file_and_channel ]; then
 fi
 
 if [ -f $srcdir/text ]; then
-  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
+  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
 fi
 if [ -f $srcdir/utt2dur ]; then
-  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
+  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
 fi
 if [ -f $srcdir/utt2num_frames ]; then
-  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
+  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
 fi
 if [ -f $srcdir/reco2dur ]; then
   if [ -f $srcdir/segments ]; then
     cp $srcdir/reco2dur $destdir/reco2dur
   else
-    utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
+    local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
   fi
 fi
 if [ -f $srcdir/spk2gender ]; then
-  utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
+  local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
 fi
 if [ -f $srcdir/cmvn.scp ]; then
-  utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
+  local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
 fi
 for f in frame_shift stm glm ctm; do
   if [ -f $srcdir/$f ]; then
@@ -142,4 +142,4 @@ done
 [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
 [ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
 
-utils/validate_data_dir.sh $validate_opts $destdir
+local/validate_data_dir.sh $validate_opts $destdir

+ 0 - 0
egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh → egs/alimeeting/sa-asr/local/data/get_reco2dur.sh


+ 1 - 1
egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh → egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh

@@ -20,7 +20,7 @@ fi
 data=$1
 
 if [ ! -s $data/utt2dur ]; then
-  utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
+  local/data/get_utt2dur.sh $data 1>&2 || exit 1;
 fi
 
 # <utt-id> <utt-id> 0 <utt-dur>

+ 1 - 1
egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh → egs/alimeeting/sa-asr/local/data/get_utt2dur.sh

@@ -94,7 +94,7 @@ elif [ -f $data/wav.scp ]; then
       nj=$num_utts
     fi
 
-    utils/data/split_data.sh --per-utt $data $nj
+    local/data/split_data.sh --per-utt $data $nj
     sdata=$data/split${nj}utt
 
     $cmd JOB=1:$nj $data/log/get_durations.JOB.log \

+ 3 - 3
egs/alimeeting/sa-asr/utils/data/split_data.sh → egs/alimeeting/sa-asr/local/data/split_data.sh

@@ -60,11 +60,11 @@ nf=`cat $data/feats.scp 2>/dev/null | wc -l`
 nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
 if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
   echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
-  echo "**  use utils/fix_data_dir.sh $data to fix this."
+  echo "**  use local/fix_data_dir.sh $data to fix this."
 fi
 if [ -f $data/text ] && [ $nu -ne $nt ]; then
   echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
-  echo "** use utils/fix_data_dir.sh to fix this."
+  echo "** use local/fix_data_dir.sh to fix this."
 fi
 
 
@@ -112,7 +112,7 @@ utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
 
 for n in `seq $numsplit`; do
   dsn=$data/split${numsplit}${utt}/$n
-  utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
+  local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
 done
 
 maybe_wav_scp=

+ 3 - 3
egs/alimeeting/sa-asr/utils/fix_data_dir.sh → egs/alimeeting/sa-asr/local/fix_data_dir.sh

@@ -112,7 +112,7 @@ function filter_recordings {
 
 function filter_speakers {
   # throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
-  utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+  local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
 
   cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
   for s in cmvn.scp spk2gender; do
@@ -123,7 +123,7 @@ function filter_speakers {
   done
 
   filter_file $tmpdir/speakers $data/spk2utt
-  utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
+  local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
 
   for s in cmvn.scp spk2gender $spk_extra_files; do
     f=$data/$s
@@ -210,6 +210,6 @@ filter_utts
 filter_speakers
 filter_recordings
 
-utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
 
 echo "fix_data_dir.sh: old files are kept in $data/.backup"

+ 243 - 0
egs/alimeeting/sa-asr/local/format_wav_scp.py

@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from io import BytesIO
+from pathlib import Path
+from typing import Tuple, Optional
+
+import kaldiio
+import humanfriendly
+import numpy as np
+import resampy
+import soundfile
+from tqdm import tqdm
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.read_text import read_2column_text
+from funasr.fileio.sound_scp import SoundScpWriter
+
+
+def humanfriendly_or_none(value: str):
+    if value in ("none", "None", "NONE"):
+        return None
+    return humanfriendly.parse_size(value)
+
+
+def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
+    """
+
+    >>> str2int_tuple('3,4,5')
+    (3, 4, 5)
+
+    """
+    assert check_argument_types()
+    if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
+        return None
+    return tuple(map(int, integers.strip().split(",")))
+
+
+def main():
+    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+    logging.basicConfig(level=logging.INFO, format=logfmt)
+    logging.info(get_commandline_args())
+
+    parser = argparse.ArgumentParser(
+        description='Create waves list from "wav.scp"',
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument("scp")
+    parser.add_argument("outdir")
+    parser.add_argument(
+        "--name",
+        default="wav",
+        help="Specify the prefix word of output file name " 'such as "wav.scp"',
+    )
+    parser.add_argument("--segments", default=None)
+    parser.add_argument(
+        "--fs",
+        type=humanfriendly_or_none,
+        default=None,
+        help="If the sampling rate specified, " "Change the sampling rate.",
+    )
+    parser.add_argument("--audio-format", default="wav")
+    group = parser.add_mutually_exclusive_group()
+    group.add_argument("--ref-channels", default=None, type=str2int_tuple)
+    group.add_argument("--utt2ref-channels", default=None, type=str)
+    args = parser.parse_args()
+
+    out_num_samples = Path(args.outdir) / f"utt2num_samples"
+
+    if args.ref_channels is not None:
+
+        def utt2ref_channels(x) -> Tuple[int, ...]:
+            return args.ref_channels
+
+    elif args.utt2ref_channels is not None:
+        utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
+
+        def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
+            chs_str = d[x]
+            return tuple(map(int, chs_str.split()))
+
+    else:
+        utt2ref_channels = None
+
+    Path(args.outdir).mkdir(parents=True, exist_ok=True)
+    out_wavscp = Path(args.outdir) / f"{args.name}.scp"
+    if args.segments is not None:
+        # Note: kaldiio supports only wav-pcm-int16le file.
+        loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
+        if args.audio_format.endswith("ark"):
+            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+            fscp = out_wavscp.open("w")
+        else:
+            writer = SoundScpWriter(
+                args.outdir,
+                out_wavscp,
+                format=args.audio_format,
+            )
+
+        with out_num_samples.open("w") as fnum_samples:
+            for uttid, (rate, wave) in tqdm(loader):
+                # wave: (Time,) or (Time, Nmic)
+                if wave.ndim == 2 and utt2ref_channels is not None:
+                    wave = wave[:, utt2ref_channels(uttid)]
+
+                if args.fs is not None and args.fs != rate:
+                    # FIXME(kamo): To use sox?
+                    wave = resampy.resample(
+                        wave.astype(np.float64), rate, args.fs, axis=0
+                    )
+                    wave = wave.astype(np.int16)
+                    rate = args.fs
+                if args.audio_format.endswith("ark"):
+                    if "flac" in args.audio_format:
+                        suf = "flac"
+                    elif "wav" in args.audio_format:
+                        suf = "wav"
+                    else:
+                        raise RuntimeError("wav.ark or flac")
+
+                    # NOTE(kamo): Using extended ark format style here.
+                    # This format is incompatible with Kaldi
+                    kaldiio.save_ark(
+                        fark,
+                        {uttid: (wave, rate)},
+                        scp=fscp,
+                        append=True,
+                        write_function=f"soundfile_{suf}",
+                    )
+
+                else:
+                    writer[uttid] = rate, wave
+                fnum_samples.write(f"{uttid} {len(wave)}\n")
+    else:
+        if args.audio_format.endswith("ark"):
+            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+        else:
+            wavdir = Path(args.outdir) / f"data_{args.name}"
+            wavdir.mkdir(parents=True, exist_ok=True)
+
+        with Path(args.scp).open("r") as fscp, out_wavscp.open(
+            "w"
+        ) as fout, out_num_samples.open("w") as fnum_samples:
+            for line in tqdm(fscp):
+                uttid, wavpath = line.strip().split(None, 1)
+
+                if wavpath.endswith("|"):
+                    # Streaming input e.g. cat a.wav |
+                    with kaldiio.open_like_kaldi(wavpath, "rb") as f:
+                        with BytesIO(f.read()) as g:
+                            wave, rate = soundfile.read(g, dtype=np.int16)
+                            if wave.ndim == 2 and utt2ref_channels is not None:
+                                wave = wave[:, utt2ref_channels(uttid)]
+
+                        if args.fs is not None and args.fs != rate:
+                            # FIXME(kamo): To use sox?
+                            wave = resampy.resample(
+                                wave.astype(np.float64), rate, args.fs, axis=0
+                            )
+                            wave = wave.astype(np.int16)
+                            rate = args.fs
+
+                        if args.audio_format.endswith("ark"):
+                            if "flac" in args.audio_format:
+                                suf = "flac"
+                            elif "wav" in args.audio_format:
+                                suf = "wav"
+                            else:
+                                raise RuntimeError("wav.ark or flac")
+
+                            # NOTE(kamo): Using extended ark format style here.
+                            # This format is incompatible with Kaldi
+                            kaldiio.save_ark(
+                                fark,
+                                {uttid: (wave, rate)},
+                                scp=fout,
+                                append=True,
+                                write_function=f"soundfile_{suf}",
+                            )
+                        else:
+                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+                            soundfile.write(owavpath, wave, rate)
+                            fout.write(f"{uttid} {owavpath}\n")
+                else:
+                    wave, rate = soundfile.read(wavpath, dtype=np.int16)
+                    if wave.ndim == 2 and utt2ref_channels is not None:
+                        wave = wave[:, utt2ref_channels(uttid)]
+                        save_asis = False
+
+                    elif args.audio_format.endswith("ark"):
+                        save_asis = False
+
+                    elif Path(wavpath).suffix == "." + args.audio_format and (
+                        args.fs is None or args.fs == rate
+                    ):
+                        save_asis = True
+
+                    else:
+                        save_asis = False
+
+                    if save_asis:
+                        # Neither --segments nor --fs are specified and
+                        # the line doesn't end with "|",
+                        # i.e. not using unix-pipe,
+                        # only in this case,
+                        # just using the original file as is.
+                        fout.write(f"{uttid} {wavpath}\n")
+                    else:
+                        if args.fs is not None and args.fs != rate:
+                            # FIXME(kamo): To use sox?
+                            wave = resampy.resample(
+                                wave.astype(np.float64), rate, args.fs, axis=0
+                            )
+                            wave = wave.astype(np.int16)
+                            rate = args.fs
+
+                        if args.audio_format.endswith("ark"):
+                            if "flac" in args.audio_format:
+                                suf = "flac"
+                            elif "wav" in args.audio_format:
+                                suf = "wav"
+                            else:
+                                raise RuntimeError("wav.ark or flac")
+
+                            # NOTE(kamo): Using extended ark format style here.
+                            # This format is not supported in Kaldi.
+                            kaldiio.save_ark(
+                                fark,
+                                {uttid: (wave, rate)},
+                                scp=fout,
+                                append=True,
+                                write_function=f"soundfile_{suf}",
+                            )
+                        else:
+                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+                            soundfile.write(owavpath, wave, rate)
+                            fout.write(f"{uttid} {owavpath}\n")
+                fnum_samples.write(f"{uttid} {len(wave)}\n")
+
+
+if __name__ == "__main__":
+    main()

+ 142 - 0
egs/alimeeting/sa-asr/local/format_wav_scp.sh

@@ -0,0 +1,142 @@
+#!/usr/bin/env bash
+set -euo pipefail
+SECONDS=0
+log() {
+    local fname=${BASH_SOURCE[1]##*/}
+    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+help_message=$(cat << EOF
+Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
+e.g.
+$0 data/test/wav.scp data/test_format/
+
+Format 'wav.scp': In short words,
+changing "kaldi-datadir" to "modified-kaldi-datadir"
+
+The 'wav.scp' format in kaldi is very flexible,
+e.g. It can use unix-pipe as describing that wav file,
+but it sometime looks confusing and make scripts more complex.
+This tools creates actual wav files from 'wav.scp'
+and also segments wav files using 'segments'.
+
+Options
+  --fs <fs>
+  --segments <segments>
+  --nj <nj>
+  --cmd <cmd>
+EOF
+)
+
+out_filename=wav.scp
+cmd=utils/run.pl
+nj=30
+fs=none
+segments=
+
+ref_channels=
+utt2ref_channels=
+
+audio_format=wav
+write_utt2num_samples=true
+
+log "$0 $*"
+. utils/parse_options.sh
+
+if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
+    log "${help_message}"
+    log "Error: invalid command line arguments"
+    exit 1
+fi
+
+. ./path.sh  # Setup the environment
+
+scp=$1
+if [ ! -f "${scp}" ]; then
+    log "${help_message}"
+    echo "$0: Error: No such file: ${scp}"
+    exit 1
+fi
+dir=$2
+
+
+if [ $# -eq 2 ]; then
+    logdir=${dir}/logs
+    outdir=${dir}/data
+
+elif [ $# -eq 3 ]; then
+    logdir=$3
+    outdir=${dir}/data
+
+elif [ $# -eq 4 ]; then
+    logdir=$3
+    outdir=$4
+fi
+
+
+mkdir -p ${logdir}
+
+rm -f "${dir}/${out_filename}"
+
+
+opts=
+if [ -n "${utt2ref_channels}" ]; then
+    opts="--utt2ref-channels ${utt2ref_channels} "
+elif [ -n "${ref_channels}" ]; then
+    opts="--ref-channels ${ref_channels} "
+fi
+
+
+if [ -n "${segments}" ]; then
+    log "[info]: using ${segments}"
+    nutt=$(<${segments} wc -l)
+    nj=$((nj<nutt?nj:nutt))
+
+    split_segments=""
+    for n in $(seq ${nj}); do
+        split_segments="${split_segments} ${logdir}/segments.${n}"
+    done
+
+    utils/split_scp.pl "${segments}" ${split_segments}
+
+    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+        local/format_wav_scp.py \
+            ${opts} \
+            --fs ${fs} \
+            --audio-format "${audio_format}" \
+            "--segment=${logdir}/segments.JOB" \
+            "${scp}" "${outdir}/format.JOB"
+
+else
+    log "[info]: without segments"
+    nutt=$(<${scp} wc -l)
+    nj=$((nj<nutt?nj:nutt))
+
+    split_scps=""
+    for n in $(seq ${nj}); do
+        split_scps="${split_scps} ${logdir}/wav.${n}.scp"
+    done
+
+    utils/split_scp.pl "${scp}" ${split_scps}
+    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+        local/format_wav_scp.py \
+        ${opts} \
+        --fs "${fs}" \
+        --audio-format "${audio_format}" \
+        "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
+fi
+
+# Workaround for the NFS problem
+ls ${outdir}/format.* > /dev/null
+
+# concatenate the .scp files together.
+for n in $(seq ${nj}); do
+    cat "${outdir}/format.${n}/wav.scp" || exit 1;
+done > "${dir}/${out_filename}" || exit 1
+
+if "${write_utt2num_samples}"; then
+    for n in $(seq ${nj}); do
+        cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
+    done > "${dir}/utt2num_samples"  || exit 1
+fi
+
+log "Successfully finished. [elapsed=${SECONDS}s]"

+ 116 - 0
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh

@@ -0,0 +1,116 @@
+#!/usr/bin/env bash
+
+# 2020 @kamo-naoyuki
+# This file was copied from Kaldi and 
+# I deleted parts related to wav duration 
+# because we shouldn't use kaldi's command here
+# and we don't need the files actually.
+
+# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
+#           2014  Tom Ko
+#           2018  Emotech LTD (author: Pawel Swietojanski)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+#  wav.scp
+#  spk2utt
+#  utt2spk
+#  text
+#
+# It generates the files which are used for perturbing the speed of the original data.
+
+export LC_ALL=C
+set -euo pipefail
+
+if [[ $# != 3 ]]; then
+    echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
+    echo "e.g.:"
+    echo " $0 0.9 data/train_si284 data/train_si284p"
+    exit 1
+fi
+
+factor=$1
+srcdir=$2
+destdir=$3
+label="sp"
+spk_prefix="${label}${factor}-"
+utt_prefix="${label}${factor}-"
+
+#check is sox on the path
+
+! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
+
+if [[ ! -f ${srcdir}/utt2spk ]]; then
+  echo "$0: no such file ${srcdir}/utt2spk"
+  exit 1;
+fi
+
+if [[ ${destdir} == "${srcdir}" ]]; then
+  echo "$0: this script requires <srcdir> and <destdir> to be different."
+  exit 1
+fi
+
+mkdir -p "${destdir}"
+
+<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
+<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
+<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
+if [[ ! -f ${srcdir}/utt2uniq ]]; then
+    <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
+else
+    <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
+fi
+
+
+<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
+  utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
+
+utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
+
+if [[ -f ${srcdir}/segments ]]; then
+
+  utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
+      utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
+          awk -v factor="${factor}" \
+            '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
+            >"${destdir}"/segments
+
+  utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+      # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+      awk -v factor="${factor}" \
+          '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+            else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+            else  {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+             > "${destdir}"/wav.scp
+  if [[ -f ${srcdir}/reco2file_and_channel ]]; then
+      utils/apply_map.pl -f 1 "${destdir}"/reco_map \
+       <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
+  fi
+
+else # no segments->wav indexed by utterance.
+    if [[ -f ${srcdir}/wav.scp ]]; then
+        utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+         # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+         awk -v factor="${factor}" \
+           '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+             else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+             else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+                 > "${destdir}"/wav.scp
+    fi
+fi
+
+if [[ -f ${srcdir}/text ]]; then
+    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
+fi
+if [[ -f ${srcdir}/spk2gender ]]; then
+    utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
+fi
+if [[ -f ${srcdir}/utt2lang ]]; then
+    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
+fi
+
+rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
+echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
+
+utils/validate_data_dir.sh --no-feats --no-text "${destdir}"

+ 0 - 0
egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl → egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl


+ 0 - 0
egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl → egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl


+ 2 - 2
egs/alimeeting/sa-asr/utils/validate_data_dir.sh → egs/alimeeting/sa-asr/local/validate_data_dir.sh

@@ -113,7 +113,7 @@ fi
 check_sorted_and_uniq $data/spk2utt
 
 ! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
-     <(utils/spk2utt_to_utt2spk.pl $data/spk2utt)  && \
+     <(local/spk2utt_to_utt2spk.pl $data/spk2utt)  && \
    echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
 
 cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
@@ -135,7 +135,7 @@ if ! $no_text; then
     echo "$0: text contains $n_non_print lines with non-printable characters" &&\
     exit 1;
   fi
-  utils/validate_text.pl $data/text || exit 1;
+  local/validate_text.pl $data/text || exit 1;
   check_sorted_and_uniq $data/text
   text_len=`cat $data/text | wc -l`
   illegal_sym_list="<s> </s> #0"

+ 0 - 0
egs/alimeeting/sa-asr/utils/validate_text.pl → egs/alimeeting/sa-asr/local/validate_text.pl


+ 1 - 2
egs/alimeeting/sa-asr/path.sh

@@ -2,5 +2,4 @@ export FUNASR_DIR=$PWD/../../..
 
 # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
 export PYTHONIOENCODING=UTF-8
-export PATH=$FUNASR_DIR/funasr/bin:$PATH
-export PATH=$PWD/utils/:$PATH
+export PATH=$FUNASR_DIR/funasr/bin:$PATH

+ 1 - 0
egs/alimeeting/sa-asr/utils

@@ -0,0 +1 @@
+../../aishell/transformer/utils

+ 0 - 87
egs/alimeeting/sa-asr/utils/filter_scp.pl

@@ -1,87 +0,0 @@
-#!/usr/bin/env perl
-# Copyright 2010-2012 Microsoft Corporation
-#                     Johns Hopkins University (author: Daniel Povey)
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#  http://www.apache.org/licenses/LICENSE-2.0
-#
-# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
-# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
-# MERCHANTABLITY OR NON-INFRINGEMENT.
-# See the Apache 2 License for the specific language governing permissions and
-# limitations under the License.
-
-
-# This script takes a list of utterance-ids or any file whose first field
-# of each line is an utterance-id, and filters an scp
-# file (or any file whose "n-th" field is an utterance id), printing
-# out only those lines whose "n-th" field is in id_list. The index of
-# the "n-th" field is 1, by default, but can be changed by using
-# the -f <n> switch
-
-$exclude = 0;
-$field = 1;
-$shifted = 0;
-
-do {
-  $shifted=0;
-  if ($ARGV[0] eq "--exclude") {
-    $exclude = 1;
-    shift @ARGV;
-    $shifted=1;
-  }
-  if ($ARGV[0] eq "-f") {
-    $field = $ARGV[1];
-    shift @ARGV; shift @ARGV;
-    $shifted=1
-  }
-} while ($shifted);
-
-if(@ARGV < 1 || @ARGV > 2) {
-  die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
-      "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
-      "Note: only the first field of each line in id_list matters.  With --exclude, prints\n" .
-      "only the lines that were *not* in id_list.\n" .
-      "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
-      "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
-      "-f option, add 1 to the argument.\n" .
-      "See also: utils/filter_scp.pl .\n";
-}
-
-
-$idlist = shift @ARGV;
-open(F, "<$idlist") || die "Could not open id-list file $idlist";
-while(<F>) {
-  @A = split;
-  @A>=1 || die "Invalid id-list file line $_";
-  $seen{$A[0]} = 1;
-}
-
-if ($field == 1) { # Treat this as special case, since it is common.
-  while(<>) {
-    $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
-    # $1 is what we filter on.
-    if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
-      print $_;
-    }
-  }
-} else {
-  while(<>) {
-    @A = split;
-    @A > 0 || die "Invalid scp file line $_";
-    @A >= $field || die "Invalid scp file line $_";
-    if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
-      print $_;
-    }
-  }
-}
-
-# tests:
-# the following should print "foo 1"
-# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
-# the following should print "bar 2".
-# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)

+ 0 - 97
egs/alimeeting/sa-asr/utils/parse_options.sh

@@ -1,97 +0,0 @@
-#!/usr/bin/env bash
-
-# Copyright 2012  Johns Hopkins University (Author: Daniel Povey);
-#                 Arnab Ghoshal, Karel Vesely
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#  http://www.apache.org/licenses/LICENSE-2.0
-#
-# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
-# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
-# MERCHANTABLITY OR NON-INFRINGEMENT.
-# See the Apache 2 License for the specific language governing permissions and
-# limitations under the License.
-
-
-# Parse command-line options.
-# To be sourced by another script (as in ". parse_options.sh").
-# Option format is: --option-name arg
-# and shell variable "option_name" gets set to value "arg."
-# The exception is --help, which takes no arguments, but prints the
-# $help_message variable (if defined).
-
-
-###
-### The --config file options have lower priority to command line
-### options, so we need to import them first...
-###
-
-# Now import all the configs specified by command-line, in left-to-right order
-for ((argpos=1; argpos<$#; argpos++)); do
-  if [ "${!argpos}" == "--config" ]; then
-    argpos_plus1=$((argpos+1))
-    config=${!argpos_plus1}
-    [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
-    . $config  # source the config file.
-  fi
-done
-
-
-###
-### Now we process the command line options
-###
-while true; do
-  [ -z "${1:-}" ] && break;  # break if there are no arguments
-  case "$1" in
-    # If the enclosing script is called with --help option, print the help
-    # message and exit.  Scripts should put help messages in $help_message
-    --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
-      else printf "$help_message\n" 1>&2 ; fi;
-      exit 0 ;;
-    --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
-      exit 1 ;;
-    # If the first command-line argument begins with "--" (e.g. --foo-bar),
-    # then work out the variable name as $name, which will equal "foo_bar".
-    --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
-      # Next we test whether the variable in question is undefned-- if so it's
-      # an invalid option and we die.  Note: $0 evaluates to the name of the
-      # enclosing script.
-      # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
-      # is undefined.  We then have to wrap this test inside "eval" because
-      # foo_bar is itself inside a variable ($name).
-      eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
-
-      oldval="`eval echo \\$$name`";
-      # Work out whether we seem to be expecting a Boolean argument.
-      if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
-        was_bool=true;
-      else
-        was_bool=false;
-      fi
-
-      # Set the variable to the right value-- the escaped quotes make it work if
-      # the option had spaces, like --cmd "queue.pl -sync y"
-      eval $name=\"$2\";
-
-      # Check that Boolean-valued arguments are really Boolean.
-      if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
-        echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
-        exit 1;
-      fi
-      shift 2;
-      ;;
-  *) break;
-  esac
-done
-
-
-# Check for an empty argument to the --cmd option, which can easily occur as a
-# result of scripting errors.
-[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
-
-
-true; # so this script returns exit code 0.

+ 0 - 246
egs/alimeeting/sa-asr/utils/split_scp.pl

@@ -1,246 +0,0 @@
-#!/usr/bin/env perl
-
-# Copyright 2010-2011 Microsoft Corporation
-
-# See ../../COPYING for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#  http://www.apache.org/licenses/LICENSE-2.0
-#
-# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
-# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
-# MERCHANTABLITY OR NON-INFRINGEMENT.
-# See the Apache 2 License for the specific language governing permissions and
-# limitations under the License.
-
-
-# This program splits up any kind of .scp or archive-type file.
-# If there is no utt2spk option it will work on any text  file and
-# will split it up with an approximately equal number of lines in
-# each but.
-# With the --utt2spk option it will work on anything that has the
-# utterance-id as the first entry on each line; the utt2spk file is
-# of the form "utterance speaker" (on each line).
-# It splits it into equal size chunks as far as it can.  If you use the utt2spk
-# option it will make sure these chunks coincide with speaker boundaries.  In
-# this case, if there are more chunks than speakers (and in some other
-# circumstances), some of the resulting chunks will be empty and it will print
-# an error message and exit with nonzero status.
-# You will normally call this like:
-# split_scp.pl scp scp.1 scp.2 scp.3 ...
-# or
-# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
-# Note that you can use this script to split the utt2spk file itself,
-# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
-
-# You can also call the scripts like:
-# split_scp.pl -j 3 0 scp scp.0
-# [note: with this option, it assumes zero-based indexing of the split parts,
-# i.e. the second number must be 0 <= n < num-jobs.]
-
-use warnings;
-
-$num_jobs = 0;
-$job_id = 0;
-$utt2spk_file = "";
-$one_based = 0;
-
-for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
-    if ($ARGV[0] eq "-j") {
-        shift @ARGV;
-        $num_jobs = shift @ARGV;
-        $job_id = shift @ARGV;
-    }
-    if ($ARGV[0] =~ /--utt2spk=(.+)/) {
-        $utt2spk_file=$1;
-        shift;
-    }
-    if ($ARGV[0] eq '--one-based') {
-        $one_based = 1;
-        shift @ARGV;
-    }
-}
-
-if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
-                       $job_id - $one_based >= $num_jobs)) {
-  die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
-      ($one_based ? " --one-based" : "") . "'\n"
-}
-
-$one_based
-    and $job_id--;
-
-if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
-    die
-"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
-   or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
- ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
-}
-
-$error = 0;
-$inscp = shift @ARGV;
-if ($num_jobs == 0) { # without -j option
-    @OUTPUTS = @ARGV;
-} else {
-    for ($j = 0; $j < $num_jobs; $j++) {
-        if ($j == $job_id) {
-            if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
-            else { push @OUTPUTS, "-"; }
-        } else {
-            push @OUTPUTS, "/dev/null";
-        }
-    }
-}
-
-if ($utt2spk_file ne "") {  # We have the --utt2spk option...
-    open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
-    while(<$u_fh>) {
-        @A = split;
-        @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
-        ($u,$s) = @A;
-        $utt2spk{$u} = $s;
-    }
-    close $u_fh;
-    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
-    @spkrs = ();
-    while(<$i_fh>) {
-        @A = split;
-        if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
-        $u = $A[0];
-        $s = $utt2spk{$u};
-        defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
-        if(!defined $spk_count{$s}) {
-            push @spkrs, $s;
-            $spk_count{$s} = 0;
-            $spk_data{$s} = [];  # ref to new empty array.
-        }
-        $spk_count{$s}++;
-        push @{$spk_data{$s}}, $_;
-    }
-    # Now split as equally as possible ..
-    # First allocate spks to files by allocating an approximately
-    # equal number of speakers.
-    $numspks = @spkrs;  # number of speakers.
-    $numscps = @OUTPUTS; # number of output files.
-    if ($numspks < $numscps) {
-      die "$0: Refusing to split data because number of speakers $numspks " .
-          "is less than the number of output .scp files $numscps\n";
-    }
-    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
-        $scparray[$scpidx] = []; # [] is array reference.
-    }
-    for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
-        $scpidx = int(($spkidx*$numscps) / $numspks);
-        $spk = $spkrs[$spkidx];
-        push @{$scparray[$scpidx]}, $spk;
-        $scpcount[$scpidx] += $spk_count{$spk};
-    }
-
-    # Now will try to reassign beginning + ending speakers
-    # to different scp's and see if it gets more balanced.
-    # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
-    # We can show that if considering changing just 2 scp's, we minimize
-    # this by minimizing the squared difference in sizes.  This is
-    # equivalent to minimizing the absolute difference in sizes.  This
-    # shows this method is bound to converge.
-
-    $changed = 1;
-    while($changed) {
-        $changed = 0;
-        for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
-            # First try to reassign ending spk of this scp.
-            if($scpidx < $numscps-1) {
-                $sz = @{$scparray[$scpidx]};
-                if($sz > 0) {
-                    $spk = $scparray[$scpidx]->[$sz-1];
-                    $count = $spk_count{$spk};
-                    $nutt1 = $scpcount[$scpidx];
-                    $nutt2 = $scpcount[$scpidx+1];
-                    if( abs( ($nutt2+$count) - ($nutt1-$count))
-                        < abs($nutt2 - $nutt1))  { # Would decrease
-                        # size-diff by reassigning spk...
-                        $scpcount[$scpidx+1] += $count;
-                        $scpcount[$scpidx] -= $count;
-                        pop @{$scparray[$scpidx]};
-                        unshift @{$scparray[$scpidx+1]}, $spk;
-                        $changed = 1;
-                    }
-                }
-            }
-            if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
-                $spk = $scparray[$scpidx]->[0];
-                $count = $spk_count{$spk};
-                $nutt1 = $scpcount[$scpidx-1];
-                $nutt2 = $scpcount[$scpidx];
-                if( abs( ($nutt2-$count) - ($nutt1+$count))
-                    < abs($nutt2 - $nutt1))  { # Would decrease
-                    # size-diff by reassigning spk...
-                    $scpcount[$scpidx-1] += $count;
-                    $scpcount[$scpidx] -= $count;
-                    shift @{$scparray[$scpidx]};
-                    push @{$scparray[$scpidx-1]}, $spk;
-                    $changed = 1;
-                }
-            }
-        }
-    }
-    # Now print out the files...
-    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
-        $scpfile = $OUTPUTS[$scpidx];
-        ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
-                         : open($f_fh, '>&', \*STDOUT)) ||
-            die "$0: Could not open scp file $scpfile for writing: $!\n";
-        $count = 0;
-        if(@{$scparray[$scpidx]} == 0) {
-            print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
-                         "$scpfile (too many splits and too few speakers?)\n";
-            $error = 1;
-        } else {
-            foreach $spk ( @{$scparray[$scpidx]} ) {
-                print $f_fh @{$spk_data{$spk}};
-                $count += $spk_count{$spk};
-            }
-            $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
-        }
-        close($f_fh);
-    }
-} else {
-   # This block is the "normal" case where there is no --utt2spk
-   # option and we just break into equal size chunks.
-
-    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
-
-    $numscps = @OUTPUTS;  # size of array.
-    @F = ();
-    while(<$i_fh>) {
-        push @F, $_;
-    }
-    $numlines = @F;
-    if($numlines == 0) {
-        print STDERR "$0: error: empty input scp file $inscp\n";
-        $error = 1;
-    }
-    $linesperscp = int( $numlines / $numscps); # the "whole part"..
-    $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
-    $remainder = $numlines - ($linesperscp * $numscps);
-    ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
-    # [just doing int() rounds down].
-    $n = 0;
-    for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
-        $scpfile = $OUTPUTS[$scpidx];
-        ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
-                         : open($o_fh, '>&', \*STDOUT)) ||
-            die "$0: Could not open scp file $scpfile for writing: $!\n";
-        for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
-            print $o_fh $F[$n++];
-        }
-        close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
-    }
-    $n == $numlines || die "$n != $numlines [code error]";
-}
-
-exit ($error);

+ 24 - 4
funasr/bin/asr_inference.py

@@ -40,6 +40,8 @@ from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
 
 
 header_colors = '\033[95m'
@@ -90,6 +92,12 @@ class Speech2Text:
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
+        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+            if asr_train_args.frontend=='wav_frontend':
+                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+            else:
+                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
 
         logging.info("asr_model: {}".format(asr_model))
         logging.info("asr_train_args: {}".format(asr_train_args))
@@ -197,12 +205,21 @@ class Speech2Text:
 
         """
         assert check_argument_types()
-
+        
         # Input as audio signal
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
 
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        if self.frontend is not None:
+            feats, feats_len = self.frontend.forward(speech, speech_lengths)
+            feats = to_device(feats, device=self.device)
+            feats_len = feats_len.int()
+            self.asr_model.frontend = None
+        else:
+            feats = speech
+            feats_len = speech_lengths
+        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+        batch = {"speech": feats, "speech_lengths": feats_len}
 
         # a. To device
         batch = to_device(batch, device=self.device)
@@ -275,6 +292,7 @@ def inference(
         ngram_weight: float = 0.9,
         nbest: int = 1,
         num_workers: int = 1,
+        mc: bool = False,
         **kwargs,
 ):
     inference_pipeline = inference_modelscope(
@@ -305,6 +323,7 @@ def inference(
         ngram_weight=ngram_weight,
         nbest=nbest,
         num_workers=num_workers,
+        mc=mc,
         **kwargs,
     )
     return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -337,6 +356,7 @@ def inference_modelscope(
     ngram_weight: float = 0.9,
     nbest: int = 1,
     num_workers: int = 1,
+    mc: bool = False,
     param_dict: dict = None,
     **kwargs,
 ):
@@ -406,7 +426,7 @@ def inference_modelscope(
             data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
-            mc=True,
+            mc=mc,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
@@ -415,7 +435,7 @@ def inference_modelscope(
             allow_variable_data_keys=allow_variable_data_keys,
             inference=True,
         )
-        
+
         finish_count = 0
         file_count = 1
         # 7 .Start for-loop

+ 7 - 1
funasr/bin/asr_inference_launch.py

@@ -71,7 +71,13 @@ def get_parser():
     )
     group.add_argument("--key_file", type=str_or_none)
     group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
+    group.add_argument(
+            "--mc",
+            type=bool,
+            default=False,
+            help="MultiChannel input",
+        )
+        
     group = parser.add_argument_group("The model configuration related")
     group.add_argument(
         "--vad_infer_config",

+ 0 - 8
funasr/bin/asr_train.py

@@ -2,14 +2,6 @@
 
 import os
 
-import logging
-
-logging.basicConfig(
-    level='INFO',
-    format=f"[{os.uname()[1].split('.')[0]}]"
-           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
 from funasr.tasks.asr import ASRTask
 
 

+ 22 - 2
funasr/bin/sa_asr_inference.py

@@ -35,6 +35,8 @@ from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
 
 
 header_colors = '\033[95m'
@@ -85,6 +87,12 @@ class Speech2Text:
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
+        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+            if asr_train_args.frontend=='wav_frontend':
+                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+            else:
+                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
 
         logging.info("asr_model: {}".format(asr_model))
         logging.info("asr_train_args: {}".format(asr_train_args))
@@ -201,7 +209,16 @@ class Speech2Text:
         if isinstance(profile, np.ndarray):
             profile = torch.tensor(profile)
 
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        if self.frontend is not None:
+            feats, feats_len = self.frontend.forward(speech, speech_lengths)
+            feats = to_device(feats, device=self.device)
+            feats_len = feats_len.int()
+            self.asr_model.frontend = None
+        else:
+            feats = speech
+            feats_len = speech_lengths
+        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+        batch = {"speech": feats, "speech_lengths": feats_len}
 
         # a. To device
         batch = to_device(batch, device=self.device)
@@ -308,6 +325,7 @@ def inference(
         ngram_weight: float = 0.9,
         nbest: int = 1,
         num_workers: int = 1,
+        mc: bool = False,
         **kwargs,
 ):
     inference_pipeline = inference_modelscope(
@@ -338,6 +356,7 @@ def inference(
         ngram_weight=ngram_weight,
         nbest=nbest,
         num_workers=num_workers,
+        mc=mc,
         **kwargs,
     )
     return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -370,6 +389,7 @@ def inference_modelscope(
     ngram_weight: float = 0.9,
     nbest: int = 1,
     num_workers: int = 1,
+    mc: bool = False,
     param_dict: dict = None,
     **kwargs,
 ):
@@ -437,7 +457,7 @@ def inference_modelscope(
             data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
-            mc=True,
+            mc=mc,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,

+ 0 - 8
funasr/bin/sa_asr_train.py

@@ -2,14 +2,6 @@
 
 import os
 
-import logging
-
-logging.basicConfig(
-    level='INFO',
-    format=f"[{os.uname()[1].split('.')[0]}]"
-           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
 from funasr.tasks.sa_asr import ASRTask
 
 

+ 46 - 0
funasr/losses/label_smoothing_loss.py

@@ -79,3 +79,49 @@ class SequenceBinaryCrossEntropy(nn.Module):
         loss = self.criterion(pred, label)
         denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
         return loss.masked_fill(pad_mask, 0).sum() / denom
+
+
+class NllLoss(nn.Module):
+    """Nll loss.
+
+    :param int size: the number of class
+    :param int padding_idx: ignored class id
+    :param bool normalize_length: normalize loss by sequence length if True
+    :param torch.nn.Module criterion: loss function
+    """
+
+    def __init__(
+        self,
+        size,
+        padding_idx,
+        normalize_length=False,
+        criterion=nn.NLLLoss(reduction='none'),
+    ):
+        """Construct an LabelSmoothingLoss object."""
+        super(NllLoss, self).__init__()
+        self.criterion = criterion
+        self.padding_idx = padding_idx
+        self.size = size
+        self.true_dist = None
+        self.normalize_length = normalize_length
+
+    def forward(self, x, target):
+        """Compute loss between x and target.
+
+        :param torch.Tensor x: prediction (batch, seqlen, class)
+        :param torch.Tensor target:
+            target signal masked with self.padding_id (batch, seqlen)
+        :return: scalar float value
+        :rtype torch.Tensor
+        """
+        assert x.size(2) == self.size
+        batch_size = x.size(0)
+        x = x.view(-1, self.size)
+        target = target.view(-1)
+        with torch.no_grad():
+            ignore = target == self.padding_idx  # (B,)
+            total = len(target) - ignore.sum().item()
+            target = target.masked_fill(ignore, 0)  # avoid -1 index
+        kl = self.criterion(x , target)
+        denom = total if self.normalize_length else batch_size
+        return kl.masked_fill(ignore, 0).sum() / denom

+ 427 - 1
funasr/models/decoder/transformer_decoder.py

@@ -13,6 +13,7 @@ from typeguard import check_argument_types
 
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.attention import CosineDistanceAttention
 from funasr.modules.dynamic_conv import DynamicConvolution
 from funasr.modules.dynamic_conv2d import DynamicConvolution2D
 from funasr.modules.embedding import PositionalEncoding
@@ -763,4 +764,429 @@ class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
                 normalize_before,
                 concat_after,
             ),
-        )
+        )
+
+class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
+    
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        spker_embedding_dim: int = 256,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        input_layer: str = "embed",
+        use_asr_output_layer: bool = True,
+        use_spk_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        attention_dim = encoder_output_size
+
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_asr_output_layer:
+            self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.asr_output_layer = None
+
+        if use_spk_output_layer:
+            self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
+        else:
+            self.spk_output_layer = None
+
+        self.cos_distance_att = CosineDistanceAttention()
+
+        self.decoder1 = None
+        self.decoder2 = None
+        self.decoder3 = None
+        self.decoder4 = None
+
+    def forward(
+        self,
+        asr_hs_pad: torch.Tensor,
+        spk_hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        profile: torch.Tensor,
+        profile_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        
+        tgt = ys_in_pad
+        # tgt_mask: (B, 1, L)
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+        # m: (1, L, L)
+        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+        # tgt_mask: (B, L, L)
+        tgt_mask = tgt_mask & m
+
+        asr_memory = asr_hs_pad
+        spk_memory = spk_hs_pad
+        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
+        # Spk decoder
+        x = self.embed(tgt)
+
+        x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
+            x, tgt_mask, asr_memory, spk_memory, memory_mask
+        )
+        x, tgt_mask, spk_memory, memory_mask = self.decoder2(
+            x, tgt_mask, spk_memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.spk_output_layer is not None:
+            x = self.spk_output_layer(x)
+        dn, weights = self.cos_distance_att(x, profile, profile_lens)
+        # Asr decoder
+        x, tgt_mask, asr_memory, memory_mask = self.decoder3(
+            z, tgt_mask, asr_memory, memory_mask, dn
+        )
+        x, tgt_mask, asr_memory, memory_mask = self.decoder4(
+            x, tgt_mask, asr_memory, memory_mask
+        )
+
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.asr_output_layer is not None:
+            x = self.asr_output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, weights, olens
+
+
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        asr_memory: torch.Tensor,
+        spk_memory: torch.Tensor,
+        profile: torch.Tensor,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        
+        x = self.embed(tgt)
+
+        if cache is None:
+            cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
+        new_cache = []
+        x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
+                x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
+        )
+        new_cache.append(x)
+        for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
+            x, tgt_mask, spk_memory, _ = decoder(
+                x, tgt_mask, spk_memory, None, cache=c
+            )
+            new_cache.append(x)
+        if self.normalize_before:
+            x = self.after_norm(x)
+        else:
+            x = x
+        if self.spk_output_layer is not None:
+            x = self.spk_output_layer(x)
+        dn, weights = self.cos_distance_att(x, profile, None)
+
+        x, tgt_mask, asr_memory, _ = self.decoder3(
+            z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
+        )
+        new_cache.append(x)
+
+        for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
+            x, tgt_mask, asr_memory, _ = decoder(
+                x, tgt_mask, asr_memory, None, cache=c
+            )
+            new_cache.append(x)
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.asr_output_layer is not None:
+            y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
+
+        return y, weights, new_cache
+
+    def score(self, ys, state, asr_enc, spk_enc, profile):
+        """Score."""
+        ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
+        logp, weights, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
+        )
+        return logp.squeeze(0), weights.squeeze(), state
+
+class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        spker_embedding_dim: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        asr_num_blocks: int = 6,
+        spk_num_blocks: int = 3,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_asr_output_layer: bool = True,
+        use_spk_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+    ):
+        assert check_argument_types()
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            spker_embedding_dim=spker_embedding_dim,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_asr_output_layer=use_asr_output_layer,
+            use_spk_output_layer=use_spk_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+
+        self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
+            attention_dim,
+            MultiHeadedAttention(
+                attention_heads, attention_dim, self_attention_dropout_rate
+            ),
+            MultiHeadedAttention(
+                attention_heads, attention_dim, src_attention_dropout_rate
+            ),
+            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+            dropout_rate,
+            normalize_before,
+            concat_after,
+        )
+        self.decoder2 = repeat(
+            spk_num_blocks - 1,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        
+        
+        self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
+            attention_dim,
+            spker_embedding_dim,
+            MultiHeadedAttention(
+                attention_heads, attention_dim, src_attention_dropout_rate
+            ),
+            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+            dropout_rate,
+            normalize_before,
+            concat_after,
+        )
+        self.decoder4 = repeat(
+            asr_num_blocks - 1,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
+        
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            # compute only the last frame query keeping dim: max_time_out -> 1
+            assert cache.shape == (
+                tgt.shape[0],
+                tgt.shape[1] - 1,
+                self.size,
+            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        if self.concat_after:
+            tgt_concat = torch.cat(
+                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+            )
+            x = residual + self.concat_linear1(tgt_concat)
+        else:
+            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+        if not self.normalize_before:
+            x = self.norm1(x)
+        z = x
+        
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
+
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, skip), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x = residual + self.dropout(skip)
+        if not self.normalize_before:
+            x = self.norm1(x)
+        
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+            
+        return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
+
+class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
+    
+    def __init__(
+        self,
+        size,
+        d_size,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
+        self.size = size
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        self.spk_linear = nn.Linear(d_size, size, bias=False)
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
+        
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        x = tgt_q
+        if self.normalize_before:
+            x = self.norm2(x)
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+        if not self.normalize_before:
+            x = self.norm2(x)
+        residual = x
+
+        if dn!=None:
+            x = x + self.spk_linear(dn)
+        if self.normalize_before:
+            x = self.norm3(x)
+        
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm3(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, tgt_mask, memory, memory_mask

+ 1 - 2
funasr/models/e2e_sa_asr.py

@@ -16,9 +16,8 @@ from typeguard import check_argument_types
 
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss,  # noqa: H301
+    LabelSmoothingLoss, NllLoss  # noqa: H301
 )
-from funasr.losses.nll_loss import NllLoss
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder

+ 1 - 1
funasr/tasks/sa_asr.py

@@ -28,7 +28,7 @@ from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecode
 from funasr.models.decoder.transformer_decoder import (
     DynamicConvolution2DTransformerDecoder,  # noqa: H301
 )
-from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
 from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
 from funasr.models.decoder.transformer_decoder import (
     LightweightConvolution2DTransformerDecoder,  # noqa: H301