Преглед изворни кода

Dev lyh (#645)

* update

* update

* fix bug

* fix bug
yhliang пре 2 година
родитељ
комит
e8528b8f62
63 измењених фајлова са 1417 додато и 120 уклоњено
  1. 0 29
      egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
  2. 86 0
      egs/alimeeting/sa_asr/README.md
  3. 0 0
      egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
  4. 102 0
      egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
  5. 131 0
      egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
  6. 24 26
      egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
  7. 0 0
      egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
  8. 0 0
      egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
  9. 0 0
      egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
  10. 0 0
      egs/alimeeting/sa_asr/local/apply_map.pl
  11. 0 0
      egs/alimeeting/sa_asr/local/combine_data.sh
  12. 134 0
      egs/alimeeting/sa_asr/local/compute_cmvn.py
  13. 39 0
      egs/alimeeting/sa_asr/local/compute_cmvn.sh
  14. 0 0
      egs/alimeeting/sa_asr/local/compute_cpcer.py
  15. 29 0
      egs/alimeeting/sa_asr/local/convert_model.py
  16. 0 0
      egs/alimeeting/sa_asr/local/copy_data_dir.sh
  17. 0 0
      egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
  18. 0 0
      egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
  19. 0 0
      egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
  20. 0 0
      egs/alimeeting/sa_asr/local/data/split_data.sh
  21. 105 0
      egs/alimeeting/sa_asr/local/download_and_untar.sh
  22. 0 0
      egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
  23. 0 0
      egs/alimeeting/sa_asr/local/download_xvector_model.py
  24. 0 0
      egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
  25. 0 0
      egs/alimeeting/sa_asr/local/fix_data_dir.sh
  26. 0 0
      egs/alimeeting/sa_asr/local/format_wav_scp.py
  27. 0 0
      egs/alimeeting/sa_asr/local/format_wav_scp.sh
  28. 3 3
      egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
  29. 3 3
      egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
  30. 0 0
      egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
  31. 2 2
      egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
  32. 0 0
      egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
  33. 1 2
      egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
  34. 0 0
      egs/alimeeting/sa_asr/local/process_text_id.py
  35. 0 0
      egs/alimeeting/sa_asr/local/process_text_spk_merge.py
  36. 0 0
      egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
  37. 0 0
      egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
  38. 0 0
      egs/alimeeting/sa_asr/local/text_format.pl
  39. 0 0
      egs/alimeeting/sa_asr/local/text_normalize.pl
  40. 0 0
      egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
  41. 0 0
      egs/alimeeting/sa_asr/local/validate_data_dir.sh
  42. 0 0
      egs/alimeeting/sa_asr/local/validate_text.pl
  43. 6 0
      egs/alimeeting/sa_asr/path.sh
  44. 435 0
      egs/alimeeting/sa_asr/run.sh
  45. 0 0
      egs/alimeeting/sa_asr/utils
  46. 0 0
      egs/alimeeting/sa_asr_deprecated/README.md
  47. 0 0
      egs/alimeeting/sa_asr_deprecated/asr_local.sh
  48. 0 0
      egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh
  49. 6 0
      egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml
  50. 0 0
      egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml
  51. 0 0
      egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml
  52. 1 0
      egs/alimeeting/sa_asr_deprecated/local
  53. 0 0
      egs/alimeeting/sa_asr_deprecated/path.sh
  54. 0 0
      egs/alimeeting/sa_asr_deprecated/run.sh
  55. 0 0
      egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh
  56. 1 0
      egs/alimeeting/sa_asr_deprecated/utils
  57. 6 4
      funasr/bin/asr_infer.py
  58. 1 1
      funasr/bin/train.py
  59. 59 1
      funasr/build_utils/build_asr_model.py
  60. 2 15
      funasr/models/e2e_sa_asr.py
  61. 86 31
      funasr/models/frontend/default.py
  62. 152 0
      funasr/tasks/asr.py
  63. 3 3
      funasr/tasks/sa_asr.py

+ 0 - 29
egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml

@@ -1,29 +0,0 @@
-lm: transformer
-lm_conf:
-    pos_enc: null
-    embed_unit: 128
-    att_unit: 512
-    head: 8
-    unit: 2048
-    layer: 16
-    dropout_rate: 0.1
-
-# optimization related
-grad_clip: 5.0
-batch_type: numel
-batch_bins: 500000 # 4gpus * 500000
-accum_grad: 1
-max_epoch: 15  # 15epoch is enougth
-
-optim: adam
-optim_conf:
-   lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
-   warmup_steps: 25000
-
-best_model_criterion:
--   - valid
-    - loss
-    - min
-keep_nbest_models: 10  # 10 is good.

+ 86 - 0
egs/alimeeting/sa_asr/README.md

@@ -0,0 +1,86 @@
+# Get Started
+Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).  
+# Train
+First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
+After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow:
+```shell
+dataset
+|—— Eval_Ali_far
+|—— Eval_Ali_near
+|—— Test_Ali_far
+|—— Test_Ali_near
+|—— Train_Ali_far
+|—— Train_Ali_near
+```
+Then you can run this receipe by running:
+```shell
+bash run.sh --stage 0 --stop-stage 6
+```
+There are 8 stages in `run.sh`:
+```shell
+stage 0: Data preparation and remove the audio which is too long or too short.
+stage 1: Speaker profile and CMVN Generation.
+stage 2: Dictionary preparation.
+stage 3: LM training (not supported).
+stage 4: ASR Training.
+stage 5: SA-ASR Training.
+stage 6: Inference
+stage 7: Inference with Test_2023_Ali_far
+```
+<!-- The baseline model is available on [ModelScope](https://www.modelscope.cn/models/damo/speech_saasr_asr-zh-cn-16k-alimeeting/summary). -->
+# Infer
+1. Download the final test set and extracted
+2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`.
+3. Set the `test_2023` in `run.sh` should be  to `Test_2023_Ali_far`.
+4. Run the `run.sh` as follow.
+```shell
+# Prepare test_2023 set
+bash run.sh --stage 0 --stop-stage 1
+# Decode test_2023 set
+bash run.sh --stage 7 --stop-stage 7
+```
+# Format of Final Submission
+Finally, you need to submit a file called `text_spk_merge` with the following format:
+```shell
+Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
+Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
+...
+```
+Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end.  For more information, please refer to the results generated after executing the baseline code.
+# Baseline Results
+The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy.  
+<!-- <table>
+    <tr >
+	    <td rowspan="2"></td>
+        <td colspan="2">SI-CER(%)</td>
+	    <td colspan="2">cpCER(%)</td>
+	</tr>
+    <tr>
+        <td>Eval</td>
+	    <td>Test</td>
+	    <td>Eval</td>
+	    <td>Test</td>
+	</tr>
+    <tr>
+	    <td>oracle profile</td>
+        <td>32.05</td>
+        <td>32.72</td>
+	    <td>47.40</td>
+        <td>42.92</td>
+	</tr>
+    <tr>
+	    <td>cluster profile</td>
+        <td>32.05</td>
+        <td>32.73</td>
+	    <td>53.76</td>
+        <td>49.37</td>
+	</tr>
+</table> -->
+|                |SI-CER(%)     |cp-CER(%)  |
+|:---------------|:------------:|----------:|
+|oracle profile  |32.72         |42.92      |
+|cluster  profile|32.73         |49.37      |
+
+
+# Reference
+N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417.

+ 0 - 0
egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml → egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml


+ 102 - 0
egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml

@@ -0,0 +1,102 @@
+# network architecture
+frontend: multichannelfrontend
+frontend_conf:
+    fs: 16000
+    window: hann
+    n_fft: 400
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+    use_channel: 0
+    
+# encoder related
+encoder: conformer
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    rel_pos_type: latest
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+# decoder related
+decoder: transformer
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+# ctc related
+ctc_conf:
+    ignore_nan_grad: true
+
+# hybrid CTC/attention
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+
+dataset_conf:
+    data_names: speech,text
+    data_types: sound,text
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 7000
+    num_workers: 8
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+    - valid
+    - acc
+best_model_criterion:
+-   - valid
+    - acc
+    - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2

+ 131 - 0
egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml

@@ -0,0 +1,131 @@
+# network architecture
+frontend: multichannelfrontend
+frontend_conf:
+    fs: 16000
+    window: hann
+    n_fft: 400
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+    use_channel: 0
+
+# encoder related
+asr_encoder: conformer
+asr_encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+spk_encoder: resnet34_diar
+spk_encoder_conf:
+  use_head_conv: true
+  batchnorm_momentum: 0.5
+  use_head_maxpool: false
+  num_nodes_pooling_layer: 256
+  layers_in_block:
+    - 3
+    - 4
+    - 6
+    - 3
+  filters_in_block:
+    - 32
+    - 64
+    - 128
+    - 256
+  pooling_type: statistic
+  num_nodes_resnet1: 256
+  num_nodes_last_layer: 256
+  batchnorm_momentum: 0.5
+
+# decoder related
+decoder: sa_decoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    asr_num_blocks: 6
+    spk_num_blocks: 3
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+# hybrid CTC/attention
+model_conf:
+    spk_weight: 0.5
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+    max_spk_num: 4
+
+ctc_conf:
+    ignore_nan_grad: true
+
+# minibatch related
+dataset_conf:
+    data_names: speech,text,profile,text_id
+    data_types: sound,text,npy,text_int
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 7000
+    num_workers: 8
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 60
+val_scheduler_criterion:
+    - valid
+    - loss
+best_model_criterion:
+-   - valid
+    - acc
+    - max
+-   - valid
+    - acc_spk
+    - max
+-   - valid
+    - loss
+    - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 8000
+
+specaug: specaug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+

+ 24 - 26
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh → egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh

@@ -21,6 +21,8 @@ EOF
 
 SECONDS=0
 tgt=Train #Train or Eval
+min_wav_duration=0.1
+max_wav_duration=20
 
 
 log "$0 $*"
@@ -57,27 +59,24 @@ stage=1
 stop_stage=4
 mkdir -p $far_dir
 mkdir -p $near_dir
+mkdir -p data/org
 
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 
     log "stage 1:process alimeeting near dir"
     
     find -L $near_raw_dir/audio_dir -iname "*.wav" | sort >  $near_dir/wavlist
-    awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid   
-    find -L $near_raw_dir/textgrid_dir  -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
+    awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid   
+    find -L $near_raw_dir/textgrid_dir  -iname "*.TextGrid" > $near_dir/textgrid.flist
     n1_wav=$(wc -l < $near_dir/wavlist)
     n2_text=$(wc -l < $near_dir/textgrid.flist)
     log  near file found $n1_wav wav and $n2_text text.
 
-    paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
-
-    # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -c 1 -t wav  - |\n", $1, $2)}'  > $near_dir/wav.scp
-    cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -t wav  - |\n", $1, $2)}'  > $near_dir/wav.scp
+    paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp
     
     python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
     cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
     utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
-    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $near_dir/utt2spk_old >$near_dir/tmp1
-    #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
+
     local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
     utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
     sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
@@ -97,9 +96,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     n2_text=$(wc -l < $far_dir/textgrid.flist)
     log  far file found $n1_wav wav and $n2_text text.
 
-    paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
-
-    cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -t wav  - |\n", $1, $2)}'  > $far_dir/wav.scp
+    paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp
 
     python local/alimeeting_process_overlap_force.py  --path $far_dir \
         --no-overlap false --mars True \
@@ -119,28 +116,28 @@ fi
 
 
 if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
-    log "stage 3: finali data process"
+    log "stage 3: final data process"
     local/fix_data_dir.sh $near_dir
     local/fix_data_dir.sh $far_dir
-    local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
-    local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+    local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near
+    local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far
 
-    sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
-    sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+    sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo
+    sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo
 
     # remove space in text
     for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
-        cp data/${x}/text data/${x}/text.org
-        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
-        > data/${x}/text
-        rm data/${x}/text.org
+        cp data/org/${x}/text data/org/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+        > data/org/${x}/text
+        rm data/org/${x}/text.org
     done
 
     log "Successfully finished. [elapsed=${SECONDS}s]"
 fi
 
 if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
-    log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
+    log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)"
     cp -r $far_dir/* $far_single_speaker_dir 
     mv $far_single_speaker_dir/textgrid.flist  $far_single_speaker_dir/textgrid_oldpath
     paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
@@ -150,14 +147,15 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
     local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
 
     ./local/fix_data_dir.sh $far_single_speaker_dir 
-    local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+    local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker
 
     # remove space in text
     for x in ${tgt}_Ali_far_single_speaker; do
-        cp data/${x}/text data/${x}/text.org
-        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
-        > data/${x}/text
-        rm data/${x}/text.org
+        cp data/org/${x}/text data/org/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+        > data/org/${x}/text
+        rm data/org/${x}/text.org
     done
+    rm -rf data/local
     log "Successfully finished. [elapsed=${SECONDS}s]"
 fi

+ 0 - 0
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh → egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh


+ 0 - 0
egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py → egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py


+ 0 - 0
egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py → egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py


+ 0 - 0
egs/alimeeting/sa-asr/local/apply_map.pl → egs/alimeeting/sa_asr/local/apply_map.pl


+ 0 - 0
egs/alimeeting/sa-asr/local/combine_data.sh → egs/alimeeting/sa_asr/local/combine_data.sh


+ 134 - 0
egs/alimeeting/sa_asr/local/compute_cmvn.py

@@ -0,0 +1,134 @@
+import argparse
+import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+import yaml
+from funasr.models.frontend.default import DefaultFrontend
+import torch
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="computer global cmvn",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--dim",
+        default=80,
+        type=int,
+        help="feature dimension",
+    )
+    parser.add_argument(
+        "--wav_path",
+        default=False,
+        required=True,
+        type=str,
+        help="the path of wav scps",
+    )
+    parser.add_argument(
+        "--config_file",
+        type=str,
+        help="the config file for computing cmvn",
+    )
+    parser.add_argument(
+        "--idx",
+        default=1,
+        required=True,
+        type=int,
+        help="index",
+    )
+    return parser
+
+
+def compute_fbank(wav_file,
+                  num_mel_bins=80,
+                  frame_length=25,
+                  frame_shift=10,
+                  dither=0.0,
+                  resample_rate=16000,
+                  speed=1.0,
+                  window_type="hamming"):
+    waveform, sample_rate = torchaudio.load(wav_file)
+    if resample_rate != sample_rate:
+        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+                                                  new_freq=resample_rate)(waveform)
+    if speed != 1.0:
+        waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+            waveform, resample_rate,
+            [['speed', str(speed)], ['rate', str(resample_rate)]]
+        )
+
+    waveform = waveform * (1 << 15)
+    mat = kaldi.fbank(waveform,
+                      num_mel_bins=num_mel_bins,
+                      frame_length=frame_length,
+                      frame_shift=frame_shift,
+                      dither=dither,
+                      energy_floor=0.0,
+                      window_type=window_type,
+                      sample_frequency=resample_rate)
+
+    return mat.numpy()
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+    cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
+
+    mean_stats = np.zeros(args.dim)
+    var_stats = np.zeros(args.dim)
+    total_frames = 0
+
+    # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+    #     for key, mat in ark_reader:
+    #         mean_stats += np.sum(mat, axis=0)
+    #         var_stats += np.sum(np.square(mat), axis=0)
+    #         total_frames += mat.shape[0]
+
+    with open(args.config_file) as f:
+        configs = yaml.safe_load(f)
+        frontend_configs = configs.get("frontend_conf", {})
+        num_mel_bins = frontend_configs.get("n_mels", 80)
+        frame_length = frontend_configs.get("frame_length", 25)
+        frame_shift = frontend_configs.get("frame_shift", 10)
+        window_type = frontend_configs.get("window", "hamming")
+        resample_rate = frontend_configs.get("fs", 16000)
+        n_fft = frontend_configs.get("n_fft", "400")
+        use_channel = frontend_configs.get("use_channel", None)
+        assert num_mel_bins == args.dim
+    frontend = DefaultFrontend(
+        fs=resample_rate,
+        n_fft=n_fft,
+        win_length=frame_length * 16,
+        hop_length=frame_shift * 16,
+        window=window_type,
+        n_mels=num_mel_bins,
+        use_channel=use_channel,
+    )
+    with open(wav_scp_file) as f:
+        lines = f.readlines()
+        for line in lines:
+            _, wav_file = line.strip().split()
+            wavform, _ = torchaudio.load(wav_file)
+            fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]]))
+            fbank = fbank.squeeze(0).numpy()
+            mean_stats += np.sum(fbank, axis=0)
+            var_stats += np.sum(np.square(fbank), axis=0)
+            total_frames += fbank.shape[0]
+
+    cmvn_info = {
+        'mean_stats': list(mean_stats.tolist()),
+        'var_stats': list(var_stats.tolist()),
+        'total_frames': total_frames
+    }
+    with open(cmvn_file, 'w') as fout:
+        fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+    main()

+ 39 - 0
egs/alimeeting/sa_asr/local/compute_cmvn.sh

@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+# Begin configuration section.
+fbankdir=
+nj=32
+cmd=./utils/run.pl
+feats_dim=80
+config_file=
+scale=1.0
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+# shellcheck disable=SC2046
+head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale
+
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+    split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1;
+
+logdir=${fbankdir}/cmvn/log
+$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
+    python local/compute_cmvn.py \
+      --dim ${feats_dim} \
+      --wav_path $split_dir \
+      --config_file $config_file \
+      --idx JOB \
+
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn
+
+echo "$0: Succeeded compute global cmvn"

+ 0 - 0
egs/alimeeting/sa-asr/local/compute_cpcer.py → egs/alimeeting/sa_asr/local/compute_cpcer.py


+ 29 - 0
egs/alimeeting/sa_asr/local/convert_model.py

@@ -0,0 +1,29 @@
+import codecs
+import pdb
+import sys
+import torch
+
+char1 = sys.argv[1]
+char2 = sys.argv[2]
+model1 = torch.load(sys.argv[3], map_location='cpu')
+model2_path = sys.argv[4]
+
+d_new = model1
+char1_list = []
+map_list = []
+
+
+with codecs.open(char1) as f:
+    for line in f.readlines():
+        char1_list.append(line.strip())
+
+with codecs.open(char2) as f:
+    for line in f.readlines():
+        map_list.append(char1_list.index(line.strip()))
+print(map_list)
+
+for k, v in d_new.items():
+    if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
+        d_new[k] = v[map_list]
+    
+torch.save(d_new, model2_path)

+ 0 - 0
egs/alimeeting/sa-asr/local/copy_data_dir.sh → egs/alimeeting/sa_asr/local/copy_data_dir.sh


+ 0 - 0
egs/alimeeting/sa-asr/local/data/get_reco2dur.sh → egs/alimeeting/sa_asr/local/data/get_reco2dur.sh


+ 0 - 0
egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh → egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh


+ 0 - 0
egs/alimeeting/sa-asr/local/data/get_utt2dur.sh → egs/alimeeting/sa_asr/local/data/get_utt2dur.sh


+ 0 - 0
egs/alimeeting/sa-asr/local/data/split_data.sh → egs/alimeeting/sa_asr/local/data/split_data.sh


+ 105 - 0
egs/alimeeting/sa_asr/local/download_and_untar.sh

@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;

+ 0 - 0
egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py → egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py


+ 0 - 0
egs/alimeeting/sa-asr/local/download_xvector_model.py → egs/alimeeting/sa_asr/local/download_xvector_model.py


+ 0 - 0
egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py → egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py


+ 0 - 0
egs/alimeeting/sa-asr/local/fix_data_dir.sh → egs/alimeeting/sa_asr/local/fix_data_dir.sh


+ 0 - 0
egs/alimeeting/sa-asr/local/format_wav_scp.py → egs/alimeeting/sa_asr/local/format_wav_scp.py


+ 0 - 0
egs/alimeeting/sa-asr/local/format_wav_scp.sh → egs/alimeeting/sa_asr/local/format_wav_scp.sh


+ 3 - 3
egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py → egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py

@@ -63,7 +63,7 @@ if __name__ == "__main__":
     wav_scp_file = open(path+'/wav.scp', 'r')
     wav_scp = wav_scp_file.readlines()
     wav_scp_file.close()
-    raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+    raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
     raw_meeting_scp = raw_meeting_scp_file.readlines()
     raw_meeting_scp_file.close()
     segments_scp_file = open(raw_path + '/segments', 'r')
@@ -92,8 +92,8 @@ if __name__ == "__main__":
     cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
     meeting_map = {}
     for line in raw_meeting_scp:
-        meeting = line.strip().split('\t')[0]
-        wav_path = line.strip().split('\t')[1]
+        meeting = line.strip().split(' ')[0]
+        wav_path = line.strip().split(' ')[1]
         wav = soundfile.read(wav_path)[0]
         # take the first channel
         if wav.ndim == 2:

+ 3 - 3
egs/alimeeting/sa-asr/local/gen_oracle_embedding.py → egs/alimeeting/sa_asr/local/gen_oracle_embedding.py

@@ -9,7 +9,7 @@ import soundfile
 if __name__=="__main__":
     path = sys.argv[1] # dump2/raw/Eval_Ali_far
     raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
-    raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+    raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
     raw_meeting_scp = raw_meeting_scp_file.readlines()
     raw_meeting_scp_file.close()
     segments_scp_file = open(raw_path + '/segments', 'r')
@@ -22,8 +22,8 @@ if __name__=="__main__":
 
     raw_wav_map = {}
     for line in raw_meeting_scp:
-        meeting = line.strip().split('\t')[0]
-        wav_path = line.strip().split('\t')[1]
+        meeting = line.strip().split(' ')[0]
+        wav_path = line.strip().split(' ')[1]
         raw_wav_map[meeting] = wav_path
     
     spk_map = {}

+ 0 - 0
egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py → egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py


+ 2 - 2
egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py → egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py

@@ -5,7 +5,7 @@ import sys
 
 
 if __name__=="__main__":
-    path = sys.argv[1] # dump2/raw/Train_Ali_far
+    path = sys.argv[1] 
     wav_scp_file = open(path+"/wav.scp", 'r')
     wav_scp = wav_scp_file.readlines()
     wav_scp_file.close()
@@ -29,7 +29,7 @@ if __name__=="__main__":
         line_list = line.strip().split(' ')
         meeting = line_list[0].split('-')[0]
         spk_id = line_list[0].split('-')[-1].split('_')[-1]
-        spk = meeting+'_' + spk_id
+        spk = meeting + '_' + spk_id
         global_spk_list.append(spk)
         if meeting in meeting_map_tmp.keys():
             meeting_map_tmp[meeting].append(spk)

+ 0 - 0
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh → egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh


+ 1 - 2
egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py → egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py

@@ -30,8 +30,7 @@ def main(args):
     meetingid_map = {}
     for line in spk2utt:
         spkid = line.strip().split(" ")[0]
-        meeting_id_list = spkid.split("_")[:3]
-        meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
+        meeting_id = spkid.split("-")[0]
         if meeting_id not in meetingid_map:
             meetingid_map[meeting_id] = 1     
         else:

+ 0 - 0
egs/alimeeting/sa-asr/local/process_text_id.py → egs/alimeeting/sa_asr/local/process_text_id.py


+ 0 - 0
egs/alimeeting/sa-asr/local/process_text_spk_merge.py → egs/alimeeting/sa_asr/local/process_text_spk_merge.py


+ 0 - 0
egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py → egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py


+ 0 - 0
egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl → egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl


+ 0 - 0
egs/alimeeting/sa-asr/local/text_format.pl → egs/alimeeting/sa_asr/local/text_format.pl


+ 0 - 0
egs/alimeeting/sa-asr/local/text_normalize.pl → egs/alimeeting/sa_asr/local/text_normalize.pl


+ 0 - 0
egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl → egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl


+ 0 - 0
egs/alimeeting/sa-asr/local/validate_data_dir.sh → egs/alimeeting/sa_asr/local/validate_data_dir.sh


+ 0 - 0
egs/alimeeting/sa-asr/local/validate_text.pl → egs/alimeeting/sa_asr/local/validate_text.pl


+ 6 - 0
egs/alimeeting/sa_asr/path.sh

@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
+export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH

+ 435 - 0
egs/alimeeting/sa_asr/run.sh

@@ -0,0 +1,435 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="6,7"
+gpu_num=2
+count=1
+gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=8
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="data" #feature output dictionary
+exp_dir="exp"
+lang=zh
+token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="1.0"
+min_wav_duration=0.1
+max_wav_duration=20
+profile_modes="cluster oracle"
+stage=7
+stop_stage=7
+
+# feature configuration
+feats_dim=80
+nj=32
+
+# data
+raw_data=
+data_url=
+
+# exp tag
+tag=""
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_Ali_far Eval_Ali_far"
+test_2023="Test_2023_Ali_far_release"
+
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+inference_config=conf/decode_asr_rnn.yaml
+inference_sa_asr_model=valid.acc_spk.ave.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+    inference_nj=$[${ngpu}*${njob}]
+    _ngpu=1
+else
+    inference_nj=$njob
+    _ngpu=0
+fi
+
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # Data preparation
+    ./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+    ./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+    ./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+    # remove long/short data
+    for x in ${train_set} ${valid_set} ${test_sets}; do
+        cp -r ${feats_dir}/org/${x} ${feats_dir}/${x}
+        rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments
+        local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+            --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+            "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+        _min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))")
+        _max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))")
+        <"${feats_dir}/${x}/utt2num_samples" \
+        awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \
+            >"${feats_dir}/${x}/utt2num_samples_rmls"
+        mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples
+        <"${feats_dir}/${x}/wav.scp" \
+            utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples"  \
+            >"${feats_dir}/${x}/wav.scp_rmls"
+        mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp
+        <"${feats_dir}/${x}/text" \
+            awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank"
+        mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text
+        local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}"
+        <"${feats_dir}/${x}/utt2spk_all_fifo" \
+            utils/filter_scp.pl "${feats_dir}/${x}/text"  \
+            >"${feats_dir}/${x}/utt2spk_all_fifo_rmls"
+        mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo"
+    done
+    for x in ${test_2023}; do
+        local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+            --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+            "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+        cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid
+        paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk
+        cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt
+    done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Speaker profile and CMVN Generation"
+    
+    mkdir -p "profile_log"
+    for x in "${train_set}" "${valid_set}" "${test_sets}"; do
+        # generate text_id spk2id
+        python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x}
+        echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id"
+        # generate text_id_train for sot
+        python local/process_text_id.py ${feats_dir}/${x}
+        echo "Successfully generate ${feats_dir}/${x}/text_id_train"
+        # generate oracle_embedding from single-speaker audio segment
+        echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log"
+        python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log"
+        echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)"
+        # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
+        if [ "${x}" = "${train_set}" ]; then
+            python local/gen_oracle_profile_padding.py ${feats_dir}/${x}
+            echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)"
+        else
+            python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x}
+            echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)"
+        fi
+        # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+        if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then
+            echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log"
+            python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+            echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+        fi
+        # compute CMVN
+        if [ "${x}" = "${train_set}" ]; then
+            local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
+        fi
+    done
+
+    for x in "${test_2023}"; do
+        # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+        python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+        echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+    done
+fi
+
+token_list=${feats_dir}/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    echo "stage 2: Dictionary Preparation"
+    mkdir -p ${feats_dir}/${lang}_token_list/char/
+
+    echo "make a dictionary"
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+    echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    echo "Stage 4: ASR Training"
+    asr_exp=${exp_dir}/${asr_model_dir}
+    mkdir -p ${asr_exp}
+    mkdir -p ${asr_exp}/log
+    INIT_FILE=${asr_exp}/ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi 
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $ngpu; ++i)); do
+        {
+            # i=0
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            train.py \
+                --task_name asr \
+                --model asr \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --split_with_space false \
+                --token_type char \
+                --token_list $token_list \
+                --data_dir ${feats_dir} \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --data_file_names "wav.scp,text" \
+                --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --resume true \
+                --output_dir ${exp_dir}/${asr_model_dir} \
+                --config $asr_config \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1
+        } &
+    done
+    wait
+
+fi
+
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+    echo "SA-ASR training"
+    asr_exp=${exp_dir}/${asr_model_dir}
+    sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+    mkdir -p ${sa_asr_exp}
+    mkdir -p ${sa_asr_exp}/log
+    INIT_FILE=${sa_asr_exp}/ddp_init
+    if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then
+        ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp
+        ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp
+    fi
+    
+    if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+        # download xvector extractor model file
+        python local/download_xvector_model.py ${exp_dir}
+        echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+    fi
+    
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi 
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $ngpu; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+	        train.py \
+                --task_name asr \
+                --model sa_asr \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --split_with_space false \
+                --unused_parameters true \
+                --token_type char \
+                --resume true \
+                --token_list $token_list \
+                --data_dir ${feats_dir} \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --data_file_names "wav.scp,text,profile.scp,text_id_train" \
+                --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder"   \
+                --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc"   \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
+                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
+                --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder"   \
+                --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense"   \
+                --output_dir ${exp_dir}/${sa_asr_model_dir} \
+                --config $sa_asr_config \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1
+        } &
+    done
+    wait
+fi
+                
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+    echo "stage 6: Inference test sets"
+    for x in ${test_sets}; do
+        for profile_mode in ${profile_modes}; do
+            echo "decoding ${x} with ${profile_mode} profile"
+            sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+            inference_tag="$(basename "${inference_config}" .yaml)"
+            _dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}"
+            _logdir="${_dir}/logdir"
+            if [ -d ${_dir} ]; then
+                echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+                exit 0
+            fi
+            mkdir -p "${_logdir}"
+            _data="${feats_dir}/${x}"
+            key_file=${_data}/${scp}
+            num_scp_file="$(<${key_file} wc -l)"
+            _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+            split_scps=
+            for n in $(seq "${_nj}"); do
+                split_scps+=" ${_logdir}/keys.${n}.scp"
+            done
+            # shellcheck disable=SC2086
+            utils/split_scp.pl "${key_file}" ${split_scps}
+            _opts=
+            if [ -n "${inference_config}" ]; then
+                _opts+="--config ${inference_config} "
+            fi
+            if [ $profile_mode = "oracle" ]; then
+                profile_scp=${profile_mode}_profile_nopadding.scp
+            else
+                profile_scp=${profile_mode}_profile_infer.scp
+            fi
+            ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+                python -m funasr.bin.asr_inference_launch \
+                    --batch_size 1 \
+                    --mc True \
+                    --ngpu "${_ngpu}" \
+                    --njob ${njob} \
+                    --nbest 1 \
+                    --gpuid_list ${gpuid_list} \
+                    --allow_variable_data_keys true \
+                    --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+                    --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                    --data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \
+                    --key_file "${_logdir}"/keys.JOB.scp \
+                    --asr_train_config "${sa_asr_exp}"/config.yaml \
+                    --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+                    --output_dir "${_logdir}"/output.JOB \
+                    --mode sa_asr \
+                    ${_opts}
+
+            for f in token token_int score text text_id; do
+                if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+                    for i in $(seq "${_nj}"); do
+                        cat "${_logdir}/output.${i}/1best_recog/${f}"
+                    done | sort -k1 >"${_dir}/${f}"
+                fi
+            done
+            sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
+            sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
+            python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
+            python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
+
+            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+            cat ${_dir}/text.cer.txt
+
+            python local/process_text_spk_merge.py ${_dir}
+            python local/process_text_spk_merge.py ${_data}
+            
+            python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+            tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+            cat ${_dir}/text.cpcer.txt
+        done
+    done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+    echo "stage 7: Inference test 2023"
+    for x in ${test_2023}; do
+        sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+        inference_tag="$(basename "${inference_config}" .yaml)"
+        _dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}"
+        _logdir="${_dir}/logdir"
+        if [ -d ${_dir} ]; then
+            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+            exit 0
+        fi
+        mkdir -p "${_logdir}"
+        _data="${feats_dir}/${x}"
+        key_file=${_data}/${scp}
+        num_scp_file="$(<${key_file} wc -l)"
+        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+        split_scps=
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/keys.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+            python -m funasr.bin.asr_inference_launch \
+                --batch_size 1 \
+                --mc True \
+                --ngpu "${_ngpu}" \
+                --njob ${njob} \
+                --nbest 1 \
+                --gpuid_list ${gpuid_list} \
+                --allow_variable_data_keys true \
+                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+                --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+                --key_file "${_logdir}"/keys.JOB.scp \
+                --asr_train_config "${sa_asr_exp}"/config.yaml \
+                --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+                --output_dir "${_logdir}"/output.JOB \
+                --mode sa_asr \
+                ${_opts}
+
+        for f in token token_int score text text_id; do
+            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | sort -k1 >"${_dir}/${f}"
+            fi
+        done
+
+        python local/process_text_spk_merge.py ${_dir}
+
+    done
+fi
+
+

+ 0 - 0
egs/alimeeting/sa-asr/utils → egs/alimeeting/sa_asr/utils


+ 0 - 0
egs/alimeeting/sa-asr/README.md → egs/alimeeting/sa_asr_deprecated/README.md


+ 0 - 0
egs/alimeeting/sa-asr/asr_local.sh → egs/alimeeting/sa_asr_deprecated/asr_local.sh


+ 0 - 0
egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh → egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh


+ 6 - 0
egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml

@@ -0,0 +1,6 @@
+beam_size: 20
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.6
+lm_weight: 0.3

+ 0 - 0
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml → egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml


+ 0 - 0
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml → egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml


+ 1 - 0
egs/alimeeting/sa_asr_deprecated/local

@@ -0,0 +1 @@
+../sa_asr/local/

+ 0 - 0
egs/alimeeting/sa-asr/path.sh → egs/alimeeting/sa_asr_deprecated/path.sh


+ 0 - 0
egs/alimeeting/sa-asr/run.sh → egs/alimeeting/sa_asr_deprecated/run.sh


+ 0 - 0
egs/alimeeting/sa-asr/run_m2met_2023_infer.sh → egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh


+ 1 - 0
egs/alimeeting/sa_asr_deprecated/utils

@@ -0,0 +1 @@
+../../aishell/transformer/utils

+ 6 - 4
funasr/bin/asr_infer.py

@@ -1651,15 +1651,17 @@ class Speech2TextSAASR:
         assert check_argument_types()
         
         # 1. Build ASR model
-        from funasr.tasks.sa_asr import ASRTask
+        from funasr.tasks.asr import ASRTaskSAASR
         scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
+        asr_model, asr_train_args = ASRTaskSAASR.build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            if asr_train_args.frontend == 'wav_frontend':
-                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            from funasr.tasks.sa_asr import frontend_choices
+            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
             else:
                 frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()

+ 1 - 1
funasr/bin/train.py

@@ -299,7 +299,7 @@ def get_parser():
         "--freeze_param",
         type=str,
         default=[],
-        nargs="*",
+        action="append",
         help="Freeze parameters",
     )
 

+ 59 - 1
funasr/build_utils/build_asr_model.py

@@ -21,8 +21,10 @@ from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
 from funasr.models.e2e_asr import ASRModel
 from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_uni_asr import UniASR
@@ -30,6 +32,7 @@ from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerM
 from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
 from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
 from funasr.models.encoder.rnn_encoder import RNNEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
 from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -90,6 +93,8 @@ model_choices = ClassChoices(
         timestamp_prediction=TimestampPredictor,
         rnnt=TransducerModel,
         rnnt_unified=UnifiedTransducerModel,
+        sa_asr=SAASRModel,
+
     ),
     default="asr",
 )
@@ -107,6 +112,27 @@ encoder_choices = ClassChoices(
     ),
     default="rnn",
 )
+asr_encoder_choices = ClassChoices(
+    "asr_encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+        mfcca_enc=MFCCAEncoder,
+    ),
+    default="rnn",
+)
+
+spk_encoder_choices = ClassChoices(
+    "spk_encoder",
+    classes=dict(
+        resnet34_diar=ResNet34Diar,
+    ),
+    default="resnet34_diar",
+)
 encoder_choices2 = ClassChoices(
     "encoder2",
     classes=dict(
@@ -131,6 +157,7 @@ decoder_choices = ClassChoices(
         paraformer_decoder_sanm=ParaformerSANMDecoder,
         paraformer_decoder_san=ParaformerDecoderSAN,
         contextual_paraformer_decoder=ContextualParaformerDecoder,
+        sa_decoder=SAAsrTransformerDecoder,
     ),
     default="rnn",
 )
@@ -222,6 +249,10 @@ class_choices_list = [
     rnnt_decoder_choices,
     # --joint_network and --joint_network_conf
     joint_network_choices,
+    # --asr_encoder and --asr_encoder_conf
+    asr_encoder_choices,
+    # --spk_encoder and --spk_encoder_conf
+    spk_encoder_choices,
 ]
 
 
@@ -239,7 +270,7 @@ def build_asr_model(args):
     # frontend
     if args.input_size is None:
         frontend_class = frontend_choices.get_class(args.frontend)
-        if args.frontend == 'wav_frontend':
+        if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
             frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
         else:
             frontend = frontend_class(**args.frontend_conf)
@@ -413,6 +444,33 @@ def build_asr_model(args):
             joint_network=joint_network,
             **args.model_conf,
         )
+    elif args.model == "sa_asr":
+        asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+        asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+        spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+        spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=asr_encoder.output_size(),
+            **args.decoder_conf,
+        )
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+        )
+
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            asr_encoder=asr_encoder,
+            spk_encoder=spk_encoder,
+            decoder=decoder,
+            ctc=ctc,
+            token_list=token_list,
+            **args.model_conf,
+        )
 
     else:
         raise NotImplementedError("Not supported model: {}".format(args.model))

+ 2 - 15
funasr/models/e2e_sa_asr.py

@@ -40,7 +40,7 @@ else:
         yield
 
 
-class ESPnetASRModel(FunASRModel):
+class SAASRModel(FunASRModel):
     """CTC-attention hybrid Encoder-Decoder model"""
 
     def __init__(
@@ -51,10 +51,8 @@ class ESPnetASRModel(FunASRModel):
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             asr_encoder: AbsEncoder,
             spk_encoder: torch.nn.Module,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             spk_weight: float = 0.5,
@@ -89,8 +87,6 @@ class ESPnetASRModel(FunASRModel):
         self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
-        self.preencoder = preencoder
-        self.postencoder = postencoder
         self.asr_encoder = asr_encoder
         self.spk_encoder = spk_encoder
 
@@ -293,10 +289,6 @@ class ESPnetASRModel(FunASRModel):
             if self.normalize is not None:
                 feats, feats_lengths = self.normalize(feats, feats_lengths)
 
-        # Pre-encoder, e.g. used for raw input data
-        if self.preencoder is not None:
-            feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
         # 4. Forward encoder
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +309,6 @@ class ESPnetASRModel(FunASRModel):
             encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
         else:
             encoder_out_spk=encoder_out_spk_ori
-        # Post-encoder, e.g. NLU
-        if self.postencoder is not None:
-            encoder_out, encoder_out_lens = self.postencoder(
-                encoder_out, encoder_out_lens
-            )
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
@@ -337,7 +324,7 @@ class ESPnetASRModel(FunASRModel):
         )
 
         if intermediate_outs is not None:
-            return (encoder_out, intermediate_outs), encoder_out_lens
+            return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
 
         return encoder_out, encoder_out_lens, encoder_out_spk
 

+ 86 - 31
funasr/models/frontend/default.py

@@ -2,7 +2,7 @@ import copy
 from typing import Optional
 from typing import Tuple
 from typing import Union
-
+import logging
 import humanfriendly
 import numpy as np
 import torch
@@ -14,6 +14,7 @@ from funasr.layers.stft import Stft
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.modules.frontends.frontend import Frontend
 from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.modules.nets_utils import make_pad_mask
 
 
 class DefaultFrontend(AbsFrontend):
@@ -137,8 +138,6 @@ class DefaultFrontend(AbsFrontend):
         return input_stft, feats_lens
 
 
-
-
 class MultiChannelFrontend(AbsFrontend):
     """Conventional frontend structure for ASR.
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -147,9 +146,9 @@ class MultiChannelFrontend(AbsFrontend):
     def __init__(
             self,
             fs: Union[int, str] = 16000,
-            n_fft: int = 512,
-            win_length: int = None,
-            hop_length: int = 128,
+            n_fft: int = 400,
+            frame_length: int = 25,
+            frame_shift: int = 10,
             window: Optional[str] = "hann",
             center: bool = True,
             normalized: bool = False,
@@ -160,10 +159,10 @@ class MultiChannelFrontend(AbsFrontend):
             htk: bool = False,
             frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
             apply_stft: bool = True,
-            frame_length: int = None,
-            frame_shift: int = None,
-            lfr_m: int = None,
-            lfr_n: int = None,
+            use_channel: int = None,
+            lfr_m: int = 1,
+            lfr_n: int = 1,
+            cmvn_file: str = None
     ):
         assert check_argument_types()
         super().__init__()
@@ -172,13 +171,14 @@ class MultiChannelFrontend(AbsFrontend):
 
         # Deepcopy (In general, dict shouldn't be used as default arg)
         frontend_conf = copy.deepcopy(frontend_conf)
-        self.hop_length = hop_length
+        self.win_length = frame_length * 16
+        self.hop_length = frame_shift * 16
 
         if apply_stft:
             self.stft = Stft(
                 n_fft=n_fft,
-                win_length=win_length,
-                hop_length=hop_length,
+                win_length=self.win_length,
+                hop_length=self.hop_length,
                 center=center,
                 window=window,
                 normalized=normalized,
@@ -202,7 +202,17 @@ class MultiChannelFrontend(AbsFrontend):
             htk=htk,
         )
         self.n_mels = n_mels
-        self.frontend_type = "multichannelfrontend"
+        self.frontend_type = "default"
+        self.use_channel = use_channel
+        if self.use_channel is not None:
+            logging.info("use the channel %d" % (self.use_channel))
+        else:
+            logging.info("random select channel")
+        self.cmvn_file = cmvn_file
+        if self.cmvn_file is not None:
+            mean, std = self._load_cmvn(self.cmvn_file)
+            self.register_buffer("mean", torch.from_numpy(mean))
+            self.register_buffer("std", torch.from_numpy(std))
 
     def output_size(self) -> int:
         return self.n_mels
@@ -215,16 +225,29 @@ class MultiChannelFrontend(AbsFrontend):
         if self.stft is not None:
             input_stft, feats_lens = self._compute_stft(input, input_lengths)
         else:
-            if isinstance(input, ComplexTensor):
-                input_stft = input
-            else:
-                input_stft = ComplexTensor(input[..., 0], input[..., 1])
+            input_stft = ComplexTensor(input[..., 0], input[..., 1])
             feats_lens = input_lengths
         # 2. [Option] Speech enhancement
         if self.frontend is not None:
             assert isinstance(input_stft, ComplexTensor), type(input_stft)
             # input_stft: (Batch, Length, [Channel], Freq)
             input_stft, _, mask = self.frontend(input_stft, feats_lens)
+
+        # 3. [Multi channel case]: Select a channel
+        if input_stft.dim() == 4:
+            # h: (B, T, C, F) -> h: (B, T, F)
+            if self.training:
+                if self.use_channel is not None:
+                    input_stft = input_stft[:, :, self.use_channel, :]
+                    
+                else:
+                    # Select 1ch randomly
+                    ch = np.random.randint(input_stft.size(2))
+                    input_stft = input_stft[:, :, ch, :]
+            else:
+                # Use the first channel
+                input_stft = input_stft[:, :, 0, :]
+
         # 4. STFT -> Power spectrum
         # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
         input_power = input_stft.real ** 2 + input_stft.imag ** 2
@@ -233,18 +256,27 @@ class MultiChannelFrontend(AbsFrontend):
         # input_power: (Batch, [Channel,] Length, Freq)
         #       -> input_feats: (Batch, Length, Dim)
         input_feats, _ = self.logmel(input_power, feats_lens)
-        bt = input_feats.size(0)
-        if input_feats.dim() ==4:
-            channel_size = input_feats.size(2)
-            # batch * channel * T * D
-            #pdb.set_trace()
-            input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
-            # input_feats = input_feats.transpose(1,2)
-            # batch * channel
-            feats_lens = feats_lens.repeat(1,channel_size).squeeze()
-        else:
-            channel_size = 1
-        return input_feats, feats_lens, channel_size
+        
+        # 6. Apply CMVN
+        if self.cmvn_file is not None:
+            if feats_lens is None:
+                feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
+            self.mean = self.mean.to(input_feats.device, input_feats.dtype)
+            self.std = self.std.to(input_feats.device, input_feats.dtype)
+            mask = make_pad_mask(feats_lens, input_feats, 1)
+
+            if input_feats.requires_grad:
+                input_feats = input_feats + self.mean
+            else:
+                input_feats += self.mean
+            if input_feats.requires_grad:
+                input_feats = input_feats.masked_fill(mask, 0.0)
+            else:
+                input_feats.masked_fill_(mask, 0.0)
+
+            input_feats *= self.std
+
+        return input_feats, feats_lens
 
     def _compute_stft(
             self, input: torch.Tensor, input_lengths: torch.Tensor
@@ -258,4 +290,27 @@ class MultiChannelFrontend(AbsFrontend):
         # Change torch.Tensor to ComplexTensor
         # input_stft: (..., F, 2) -> (..., F)
         input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
-        return input_stft, feats_lens
+        return input_stft, feats_lens
+
+    def _load_cmvn(self, cmvn_file):
+        with open(cmvn_file, 'r', encoding='utf-8') as f:
+            lines = f.readlines()
+        means_list = []
+        vars_list = []
+        for i in range(len(lines)):
+            line_item = lines[i].split()
+            if line_item[0] == '<AddShift>':
+                line_item = lines[i + 1].split()
+                if line_item[0] == '<LearnRateCoef>':
+                    add_shift_line = line_item[3:(len(line_item) - 1)]
+                    means_list = list(add_shift_line)
+                    continue
+            elif line_item[0] == '<Rescale>':
+                line_item = lines[i + 1].split()
+                if line_item[0] == '<LearnRateCoef>':
+                    rescale_line = line_item[3:(len(line_item) - 1)]
+                    vars_list = list(rescale_line)
+                    continue
+        means = np.array(means_list).astype(np.float)
+        vars = np.array(vars_list).astype(np.float)
+        return means, vars

+ 152 - 0
funasr/tasks/asr.py

@@ -38,6 +38,7 @@ from funasr.models.decoder.transformer_decoder import (
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
 from funasr.models.e2e_asr import ASRModel
 from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.joint_net.joint_network import JointNetwork
@@ -45,6 +46,7 @@ from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, Paraf
 from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_uni_asr import UniASR
 from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
 from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -54,6 +56,7 @@ from funasr.models.encoder.rnn_encoder import RNNEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
 from funasr.models.encoder.transformer_encoder import TransformerEncoder
 from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.frontend.default import DefaultFrontend
 from funasr.models.frontend.default import MultiChannelFrontend
@@ -134,6 +137,7 @@ model_choices = ClassChoices(
         timestamp_prediction=TimestampPredictor,
         rnnt=TransducerModel,
         rnnt_unified=UnifiedTransducerModel,
+        sa_asr=SAASRModel,
     ),
     type_check=FunASRModel,
     default="asr",
@@ -175,6 +179,27 @@ encoder_choices2 = ClassChoices(
     type_check=AbsEncoder,
     default="rnn",
 )
+asr_encoder_choices = ClassChoices(
+    "asr_encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+        mfcca_enc=MFCCAEncoder,
+    ),
+    type_check=AbsEncoder,
+    default="rnn",
+)
+spk_encoder_choices = ClassChoices(
+    "spk_encoder",
+    classes=dict(
+        resnet34_diar=ResNet34Diar,
+    ),
+    default="resnet34_diar",
+)
 postencoder_choices = ClassChoices(
     name="postencoder",
     classes=dict(
@@ -197,6 +222,7 @@ decoder_choices = ClassChoices(
         paraformer_decoder_sanm=ParaformerSANMDecoder,
         paraformer_decoder_san=ParaformerDecoderSAN,
         contextual_paraformer_decoder=ContextualParaformerDecoder,
+        sa_decoder=SAAsrTransformerDecoder,
     ),
     type_check=AbsDecoder,
     default="rnn",
@@ -329,6 +355,12 @@ class ASRTask(AbsTask):
             default=True,
             help="whether to split text using <space>",
         )
+        group.add_argument(
+            "--max_spk_num",
+            type=int_or_none,
+            default=None,
+            help="A text mapping int-id to token",
+        )
         group.add_argument(
             "--seg_dict_file",
             type=str,
@@ -1495,3 +1527,123 @@ class ASRTransducerTask(ASRTask):
         #assert check_return_type(model)
 
         return model
+
+
+class ASRTaskSAASR(ASRTask):
+    # If you need more than one optimizers, change this value
+    num_optimizers: int = 1
+
+    # Add variable objects configurations
+    class_choices_list = [
+        # --frontend and --frontend_conf
+        frontend_choices,
+        # --specaug and --specaug_conf
+        specaug_choices,
+        # --normalize and --normalize_conf
+        normalize_choices,
+        # --model and --model_conf
+        model_choices,
+        # --preencoder and --preencoder_conf
+        preencoder_choices,
+        # --encoder and --encoder_conf
+        # --asr_encoder and --asr_encoder_conf
+        asr_encoder_choices,
+        # --spk_encoder and --spk_encoder_conf
+        spk_encoder_choices,
+        # --decoder and --decoder_conf
+        decoder_choices,
+    ]
+
+    # If you need to modify train() or eval() procedures, change Trainer class here
+    trainer = Trainer
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace):
+        assert check_argument_types()
+        if isinstance(args.token_list, str):
+            with open(args.token_list, encoding="utf-8") as f:
+                token_list = [line.rstrip() for line in f]
+
+            # Overwriting token_list to keep it as "portable".
+            args.token_list = list(token_list)
+        elif isinstance(args.token_list, (tuple, list)):
+            token_list = list(args.token_list)
+        else:
+            raise RuntimeError("token_list must be str or list")
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend":
+                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+            else:
+                frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 5. Encoder
+        asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+        asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+        spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+        spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+
+        # 7. Decoder
+        decoder_class = decoder_choices.get_class(args.decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=asr_encoder.output_size(),
+            **args.decoder_conf,
+        )
+
+        # 8. CTC
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+        )
+
+        # import ipdb;ipdb.set_trace()
+        # 9. Build model
+        try:
+            model_class = model_choices.get_class(args.model)
+        except AttributeError:
+            model_class = model_choices.get_class("asr")
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            asr_encoder=asr_encoder,
+            spk_encoder=spk_encoder,
+            decoder=decoder,
+            ctc=ctc,
+            token_list=token_list,
+            **args.model_conf,
+        )
+
+        # 10. Initialize
+        if args.init is not None:
+            initialize(model, args.init)
+
+        assert check_return_type(model)
+        return model

+ 3 - 3
funasr/tasks/sa_asr.py

@@ -39,7 +39,7 @@ from funasr.models.decoder.transformer_decoder import (
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.e2e_sa_asr import ESPnetASRModel
+from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_asr_mfcca import MFCCA
@@ -120,7 +120,7 @@ normalize_choices = ClassChoices(
 model_choices = ClassChoices(
     "model",
     classes=dict(
-        asr=ESPnetASRModel,
+        asr=SAASRModel,
         uniasr=UniASR,
         paraformer=Paraformer,
         paraformer_bert=ParaformerBert,
@@ -620,4 +620,4 @@ class ASRTask(AbsTask):
             initialize(model, args.init)
 
         assert check_return_type(model)
-        return model
+        return model