游雁 2 ani în urmă
părinte
comite
65525a1af5
1 a modificat fișierele cu 56 adăugiri și 44 ștergeri
  1. 56 44
      funasr/models/llm_asr/model.py

+ 56 - 44
funasr/models/llm_asr/model.py

@@ -240,14 +240,12 @@ class LLMASRNAR(nn.Module):
                   **kwargs,
                   ):
         
+        prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
-        
-        # init beamsearch
-        if self.beam_search is None:
-            logging.info("enable beam_search")
-            self.init_beam_search(**kwargs)
-            self.nbest = kwargs.get("nbest", 1)
+
+
         
         meta_data = {}
         if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
@@ -272,50 +270,64 @@ class LLMASRNAR(nn.Module):
         
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
+        
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        if isinstance(encoder_out, tuple):
-            encoder_out = encoder_out[0]
+
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
         
-        # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
-        )
+    
+        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+        prompt_ids = self.tokenizer.encode(prompt_pre)
+        prompt_length = len(prompt_ids)
+        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+        if hasattr(self.llm.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+        elif hasattr(self.llm.model.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+        else:
+            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
+        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
         
-        nbest_hyps = nbest_hyps[: self.nbest]
+        model_outputs = self.llm.generate(
+            inputs_embeds=inputs_embeds,
+            max_length=kwargs.get("max_length", 200),
+            max_new_tokens=kwargs.get("max_new_tokens", 200),
+            num_beams=kwargs.get("num_beams", 4),
+            do_sample=kwargs.get("do_sample", False),
+            min_length=kwargs.get("min_length", 1),
+            top_p=kwargs.get("top_p", 1.0),
+            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+            length_penalty=kwargs.get("length_penalty", 1.0),
+            temperature=kwargs.get("temperature", 1.0),
+            attention_mask=attention_mask,
+            bos_token_id=tokenizer.bos_token_id,
+            eos_token_id=tokenizer.eos_token_id,
+            pad_token_id=tokenizer.pad_token_id
+        )
+
+        text = tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
         
+        ibest_writer = None
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
         results = []
-        b, n, d = encoder_out.size()
-        for i in range(b):
-            
-            for nbest_idx, hyp in enumerate(nbest_hyps):
-                ibest_writer = None
-                if kwargs.get("output_dir") is not None:
-                    if not hasattr(self, "writer"):
-                        self.writer = DatadirWriter(kwargs.get("output_dir"))
-                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-                
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-                
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-                
-                # Change integer-ids to tokens
-                token = tokenizer.ids2tokens(token_int)
-                text = tokenizer.tokens2text(token)
-                
-                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
-                results.append(result_i)
-                
-                if ibest_writer is not None:
-                    ibest_writer["token"][key[i]] = " ".join(token)
-                    ibest_writer["text"][key[i]] = text_postprocessed
+        result_i = {"key": key[0], "text": text}
+        results.append(result_i)
+
+        if ibest_writer is not None:
+            ibest_writer["text"][key[0]] = text
+        
+        
+        
         
         return results, meta_data