|
|
@@ -1918,6 +1918,8 @@ class Speech2TextWhisper:
|
|
|
nbest: int = 1,
|
|
|
streaming: bool = False,
|
|
|
frontend_conf: dict = None,
|
|
|
+ language: str = None,
|
|
|
+ task: str = "transcribe",
|
|
|
**kwargs,
|
|
|
):
|
|
|
|
|
|
@@ -1960,6 +1962,8 @@ class Speech2TextWhisper:
|
|
|
self.device = device
|
|
|
self.dtype = dtype
|
|
|
self.frontend = frontend
|
|
|
+ self.language = language
|
|
|
+ self.task = task
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def __call__(
|
|
|
@@ -1986,10 +1990,10 @@ class Speech2TextWhisper:
|
|
|
mel = log_mel_spectrogram(speech).to(self.device)
|
|
|
|
|
|
if self.asr_model.is_multilingual:
|
|
|
- options = DecodingOptions(fp16=False)
|
|
|
+ options = DecodingOptions(fp16=False, language=self.language, task=self.task)
|
|
|
asr_res = decode(self.asr_model, mel, options)
|
|
|
text = asr_res.text
|
|
|
- language = asr_res.language
|
|
|
+ language = self.language if self.language else asr_res.language
|
|
|
else:
|
|
|
asr_res = transcribe(self.asr_model, speech, fp16=False)
|
|
|
text = asr_res["text"]
|