|
|
@@ -240,14 +240,12 @@ class LLMASRNAR(nn.Module):
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
+ prompt = kwargs.get("prompt", "Transcribe speech to text.")
|
|
|
+
|
|
|
if kwargs.get("batch_size", 1) > 1:
|
|
|
raise NotImplementedError("batch decoding is not implemented")
|
|
|
-
|
|
|
- # init beamsearch
|
|
|
- if self.beam_search is None:
|
|
|
- logging.info("enable beam_search")
|
|
|
- self.init_beam_search(**kwargs)
|
|
|
- self.nbest = kwargs.get("nbest", 1)
|
|
|
+
|
|
|
+
|
|
|
|
|
|
meta_data = {}
|
|
|
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
|
|
|
@@ -272,50 +270,64 @@ class LLMASRNAR(nn.Module):
|
|
|
|
|
|
speech = speech.to(device=kwargs["device"])
|
|
|
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
|
|
+
|
|
|
# Encoder
|
|
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
|
|
- if isinstance(encoder_out, tuple):
|
|
|
- encoder_out = encoder_out[0]
|
|
|
+
|
|
|
+ # adaptor
|
|
|
+ encoder_out = self.adaptor(encoder_out)
|
|
|
|
|
|
- # c. Passed the encoder result and the beam search
|
|
|
- nbest_hyps = self.beam_search(
|
|
|
- x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
|
|
|
- )
|
|
|
+
|
|
|
+ prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
|
|
|
+ prompt_ids = self.tokenizer.encode(prompt_pre)
|
|
|
+ prompt_length = len(prompt_ids)
|
|
|
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
|
|
|
+
|
|
|
+
|
|
|
+ if hasattr(self.llm.model, "embed_tokens"):
|
|
|
+ inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
|
|
|
+ elif hasattr(self.llm.model.model, "embed_tokens"):
|
|
|
+ inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
|
|
|
+ else:
|
|
|
+ inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
|
|
|
+
|
|
|
+ inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
|
|
|
+ attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
|
|
|
|
|
|
- nbest_hyps = nbest_hyps[: self.nbest]
|
|
|
+ model_outputs = self.llm.generate(
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
+ max_length=kwargs.get("max_length", 200),
|
|
|
+ max_new_tokens=kwargs.get("max_new_tokens", 200),
|
|
|
+ num_beams=kwargs.get("num_beams", 4),
|
|
|
+ do_sample=kwargs.get("do_sample", False),
|
|
|
+ min_length=kwargs.get("min_length", 1),
|
|
|
+ top_p=kwargs.get("top_p", 1.0),
|
|
|
+ repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
|
|
+ length_penalty=kwargs.get("length_penalty", 1.0),
|
|
|
+ temperature=kwargs.get("temperature", 1.0),
|
|
|
+ attention_mask=attention_mask,
|
|
|
+ bos_token_id=tokenizer.bos_token_id,
|
|
|
+ eos_token_id=tokenizer.eos_token_id,
|
|
|
+ pad_token_id=tokenizer.pad_token_id
|
|
|
+ )
|
|
|
+
|
|
|
+ text = tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
|
|
|
|
|
|
+ ibest_writer = None
|
|
|
+ if kwargs.get("output_dir") is not None:
|
|
|
+ if not hasattr(self, "writer"):
|
|
|
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
|
|
|
+ ibest_writer = self.writer[f"{0 + 1}best_recog"]
|
|
|
+
|
|
|
results = []
|
|
|
- b, n, d = encoder_out.size()
|
|
|
- for i in range(b):
|
|
|
-
|
|
|
- for nbest_idx, hyp in enumerate(nbest_hyps):
|
|
|
- ibest_writer = None
|
|
|
- if kwargs.get("output_dir") is not None:
|
|
|
- if not hasattr(self, "writer"):
|
|
|
- self.writer = DatadirWriter(kwargs.get("output_dir"))
|
|
|
- ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
|
|
|
-
|
|
|
- # remove sos/eos and get results
|
|
|
- last_pos = -1
|
|
|
- if isinstance(hyp.yseq, list):
|
|
|
- token_int = hyp.yseq[1:last_pos]
|
|
|
- else:
|
|
|
- token_int = hyp.yseq[1:last_pos].tolist()
|
|
|
-
|
|
|
- # remove blank symbol id, which is assumed to be 0
|
|
|
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
|
|
-
|
|
|
- # Change integer-ids to tokens
|
|
|
- token = tokenizer.ids2tokens(token_int)
|
|
|
- text = tokenizer.tokens2text(token)
|
|
|
-
|
|
|
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
|
|
- result_i = {"key": key[i], "token": token, "text": text_postprocessed}
|
|
|
- results.append(result_i)
|
|
|
-
|
|
|
- if ibest_writer is not None:
|
|
|
- ibest_writer["token"][key[i]] = " ".join(token)
|
|
|
- ibest_writer["text"][key[i]] = text_postprocessed
|
|
|
+ result_i = {"key": key[0], "text": text}
|
|
|
+ results.append(result_i)
|
|
|
+
|
|
|
+ if ibest_writer is not None:
|
|
|
+ ibest_writer["text"][key[0]] = text
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
return results, meta_data
|
|
|
|