游雁 1 gadu atpakaļ
vecāks
revīzija
8827e26b8d

+ 1 - 0
.gitignore

@@ -25,3 +25,4 @@ outputs*
 emotion2vec*
 GPT-SoVITS*
 modelscope_models
+examples/aishell/llm_asr_nar/*

+ 4 - 3
funasr/models/llm_asr/template.yaml → examples/aishell/llm_asr_nar/conf/template.yaml

@@ -6,7 +6,7 @@
 # tables.print()
 
 # network architecture
-model: LLMASR
+model: LLMASRNAR
 model_conf:
     lsm_weight: 0.1     # label smoothing option
     length_normalized_loss: true
@@ -16,12 +16,13 @@ encoder: Paraformer
 encoder_conf:
     hub: funasr
     init_param_path: "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+    freeze: false
 
 llm: Vicuna
 llm_conf:
   hub: hf
-  init_param_path: null
-  freeze_llm: true
+  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+  freeze: true
 
 adaptor: linear
 adaptor_conf:

+ 3 - 1
funasr/bin/train.py

@@ -108,8 +108,10 @@ def main(**kwargs):
                 )
             else:
                 logging.info(f"Checkpoint does not exist, init randomly: {p}")
-    else:
+    elif kwargs.get("init", None):
         initialize(model, kwargs.get("init", "kaiming_normal"))
+    else:
+        print("No initialize method")
 
 
     # freeze_param

+ 20 - 53
funasr/models/llm_asr/model.py

@@ -7,6 +7,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.cuda.amp import autocast
 
+from funasr.models.scama.utils import sequence_mask
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.models.ctc.ctc import CTC
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
@@ -19,8 +20,8 @@ from funasr.utils.datadir_writer import DatadirWriter
 from funasr.register import tables
 
 
-@tables.register("model_classes", "LLMASR")
-class LLMASR(nn.Module):
+@tables.register("model_classes", "LLMASRNAR")
+class LLMASRNAR(nn.Module):
     """ """
     
     def __init__(
@@ -72,15 +73,13 @@ class LLMASR(nn.Module):
         hub = encoder_conf.get("hub", None)
         if hub == "funasr":
             from funasr import AutoModel
-            from funasr.models.scama.utils import sequence_mask
             init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
             model = AutoModel(model=init_param_path, model_revision="v2.0.4")
-            frontend = model.kwargs.get("frontend")
+            # frontend = model.kwargs.get("frontend")
             model.model.decoder = None
             
-            self.model = model.model
-            self.frontend = frontend
-            self.mask_fn = sequence_mask
+            self.audio_encoder = model.model
+            # self.frontend = frontend
             
         elif hub == "hf":
             pass
@@ -102,8 +101,8 @@ class LLMASR(nn.Module):
                 device_map=None,
                 use_cache=None,
             )
-            freeze_llm = llm_conf.get("freeze_llm", True)
-            if freeze_llm:
+            freeze = llm_conf.get("freeze", True)
+            if freeze:
                 for name, param in model.named_parameters():
                     param.requires_grad = False
                 model.eval()
@@ -151,9 +150,9 @@ class LLMASR(nn.Module):
         text_lengths: torch.Tensor,
         input_ids: torch.Tensor,
         attention_mask:torch.Tensor,
-        labels_ids:torch.Tensor,
+        labels_ids: torch.Tensor,
         label_mask: torch.Tensor,
-        audio_mask:torch.Tensor,
+        audio_mask: torch.Tensor,
         **kwargs,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Encoder + Decoder + Calc loss
@@ -173,7 +172,7 @@ class LLMASR(nn.Module):
         batch_size = speech.shape[0]
         
         # audio encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
         
         # adaptor
         encoder_out = self.adaptor(encoder_out)
@@ -194,18 +193,18 @@ class LLMASR(nn.Module):
                 inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
                 inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
 
-        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
         loss = model_outputs.loss
 
-        acc_att = -1
+
+        stats = {}
         if self.metric:
             with torch.no_grad():
                 preds = torch.argmax(model_outputs.logits, -1)
                 acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+                stats["acc"] = acc_att
 
-        stats = {}
-        # Collect Attn branch stats
-        stats["acc"] = acc_att.detach()
+        stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
@@ -221,47 +220,15 @@ class LLMASR(nn.Module):
         audio_token_lengths = audio_mask.sum(-1)
 
         batch = {"speech": speech, "speech_lengths": speech_lengths}
-        enc, enc_lens = self.model.encode(**batch)
-        enc_mask = self.mask_fn(enc_lens, enc.size(1), device=enc.device)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, _, _ = self.model.predictor(enc,
+        enc, enc_lens = self.audio_encoder.encode(**batch)
+        enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+        pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
                                                                            mask=enc_mask,
                                                                            target_label_length=audio_token_lengths,
                                                                            )
 
         return pre_acoustic_embeds, pre_token_length
-    
-    def _calc_att_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-        ys_in_lens = ys_pad_lens + 1
-        
-        # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
-        
-        # 2. Compute attention loss
-        loss_att = self.criterion_att(decoder_out, ys_out_pad)
-        acc_att = th_accuracy(
-            decoder_out.view(-1, self.vocab_size),
-            ys_out_pad,
-            ignore_label=self.ignore_id,
-        )
-        
-        # Compute cer/wer using attention-decoder
-        if self.training or self.error_calculator is None:
-            cer_att, wer_att = None, None
-        else:
-            ys_hat = decoder_out.argmax(dim=-1)
-            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-        
-        return loss_att, acc_att, cer_att, wer_att
-    
+
 
     def inference(self,
                   data_in,

+ 4 - 1
funasr/train_utils/trainer.py

@@ -14,6 +14,7 @@ from pathlib import Path
 from funasr.train_utils.device_funcs import to_device
 from funasr.train_utils.recursive_op import recursive_average
 from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 
 @contextmanager
 def maybe_autocast(enabled):
@@ -84,7 +85,9 @@ class Trainer:
         self.batch_total = 0
         self.use_fp16 = use_fp16
         self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
-        self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+        scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+        scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
+        self.scaler = scaler
         
     
         try: