decoding.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. from dataclasses import dataclass, field
  2. from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. from torch.distributions import Categorical
  8. from funasr.utils.whisper_utils.audio import CHUNK_LENGTH
  9. from funasr.utils.whisper_utils.tokenizer import Tokenizer, get_tokenizer
  10. from funasr.utils.whisper_utils.utils import compression_ratio
  11. if TYPE_CHECKING:
  12. from funasr.models.whisper_models.model import Whisper
  13. @torch.no_grad()
  14. def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
  15. """
  16. Detect the spoken language in the audio, and return them as list of strings, along with the ids
  17. of the most probable language tokens and the probability distribution over all language tokens.
  18. This is performed outside the main decode loop in order to not interfere with kv-caching.
  19. Returns
  20. -------
  21. language_tokens : Tensor, shape = (n_audio,)
  22. ids of the most probable language tokens, which appears after the startoftranscript token.
  23. language_probs : List[Dict[str, float]], length = n_audio
  24. list of dictionaries containing the probability distribution over all languages.
  25. """
  26. if tokenizer is None:
  27. tokenizer = get_tokenizer(model.is_multilingual)
  28. if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
  29. raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
  30. single = mel.ndim == 2
  31. if single:
  32. mel = mel.unsqueeze(0)
  33. # skip encoder forward pass if already-encoded audio features were given
  34. if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
  35. mel = model.encoder(mel)
  36. # forward pass using a single token, startoftranscript
  37. n_audio = mel.shape[0]
  38. x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
  39. logits = model.logits(x, mel)[:, 0]
  40. # collect detected languages; suppress all non-language tokens
  41. mask = torch.ones(logits.shape[-1], dtype=torch.bool)
  42. mask[list(tokenizer.all_language_tokens)] = False
  43. logits[:, mask] = -np.inf
  44. language_tokens = logits.argmax(dim=-1)
  45. language_token_probs = logits.softmax(dim=-1).cpu()
  46. language_probs = [
  47. {
  48. c: language_token_probs[i, j].item()
  49. for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
  50. }
  51. for i in range(n_audio)
  52. ]
  53. if single:
  54. language_tokens = language_tokens[0]
  55. language_probs = language_probs[0]
  56. return language_tokens, language_probs
  57. @dataclass(frozen=True)
  58. class DecodingOptions:
  59. task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
  60. language: Optional[str] = None # language that the audio is in; uses detected language if None
  61. # sampling-related options
  62. temperature: float = 0.0
  63. sample_len: Optional[int] = None # maximum number of tokens to sample
  64. best_of: Optional[int] = None # number of independent samples to collect, when t > 0
  65. beam_size: Optional[int] = None # number of beams in beam search, when t == 0
  66. patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
  67. # options for ranking generations (either beams or best-of-N samples)
  68. length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
  69. # prompt, prefix, and token suppression
  70. prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
  71. prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
  72. suppress_blank: bool = True # this will suppress blank outputs
  73. # list of tokens ids (or comma-separated token ids) to suppress
  74. # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
  75. suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
  76. # timestamp sampling options
  77. without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
  78. max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
  79. # implementation details
  80. fp16: bool = True # use fp16 for most of the calculation
  81. @dataclass(frozen=True)
  82. class DecodingResult:
  83. audio_features: Tensor
  84. language: str
  85. language_probs: Optional[Dict[str, float]] = None
  86. tokens: List[int] = field(default_factory=list)
  87. text: str = ""
  88. avg_logprob: float = np.nan
  89. no_speech_prob: float = np.nan
  90. temperature: float = np.nan
  91. compression_ratio: float = np.nan
  92. class Inference:
  93. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  94. """Perform a forward pass on the decoder and return per-token logits"""
  95. raise NotImplementedError
  96. def rearrange_kv_cache(self, source_indices) -> None:
  97. """Update the key-value cache according to the updated beams"""
  98. raise NotImplementedError
  99. def cleanup_caching(self) -> None:
  100. """Clean up any resources or hooks after decoding is finished"""
  101. pass
  102. class PyTorchInference(Inference):
  103. def __init__(self, model: "Whisper", initial_token_length: int):
  104. self.model: "Whisper" = model
  105. self.initial_token_length = initial_token_length
  106. self.kv_cache = {}
  107. self.hooks = []
  108. def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
  109. if not self.kv_cache:
  110. self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
  111. if tokens.shape[-1] > self.initial_token_length:
  112. # only need to use the last token except in the first forward pass
  113. tokens = tokens[:, -1:]
  114. return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
  115. def cleanup_caching(self):
  116. for hook in self.hooks:
  117. hook.remove()
  118. self.kv_cache = {}
  119. self.hooks = []
  120. def rearrange_kv_cache(self, source_indices):
  121. for module, tensor in self.kv_cache.items():
  122. # update the key/value cache to contain the selected sequences
  123. self.kv_cache[module] = tensor[source_indices].detach()
  124. class SequenceRanker:
  125. def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
  126. """
  127. Given a list of groups of samples and their cumulative log probabilities,
  128. return the indices of the samples in each group to select as the final result
  129. """
  130. raise NotImplementedError
  131. class MaximumLikelihoodRanker(SequenceRanker):
  132. """
  133. Select the sample with the highest log probabilities, penalized using either
  134. a simple length normalization or Google NMT paper's length penalty
  135. """
  136. def __init__(self, length_penalty: Optional[float]):
  137. self.length_penalty = length_penalty
  138. def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
  139. def scores(logprobs, lengths):
  140. result = []
  141. for logprob, length in zip(logprobs, lengths):
  142. if self.length_penalty is None:
  143. penalty = length
  144. else:
  145. # from the Google NMT paper
  146. penalty = ((5 + length) / 6) ** self.length_penalty
  147. result.append(logprob / penalty)
  148. return result
  149. # get the sequence with the highest score
  150. lengths = [[len(t) for t in s] for s in tokens]
  151. return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
  152. class TokenDecoder:
  153. def reset(self):
  154. """Initialize any stateful variables for decoding a new sequence"""
  155. def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
  156. """Specify how to select the next token, based on the current trace and logits
  157. Parameters
  158. ----------
  159. tokens : Tensor, shape = (n_batch, current_sequence_length)
  160. all tokens in the context so far, including the prefix and sot_sequence tokens
  161. logits : Tensor, shape = (n_batch, vocab_size)
  162. per-token logits of the probability distribution at the current step
  163. sum_logprobs : Tensor, shape = (n_batch)
  164. cumulative log probabilities for each sequence
  165. Returns
  166. -------
  167. tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
  168. the tokens, appended with the selected next token
  169. completed : bool
  170. True if all sequences has reached the end of text
  171. """
  172. raise NotImplementedError
  173. def finalize(
  174. self, tokens: Tensor, sum_logprobs: Tensor
  175. ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
  176. """Finalize search and return the final candidate sequences
  177. Parameters
  178. ----------
  179. tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
  180. all tokens in the context so far, including the prefix and sot_sequence
  181. sum_logprobs : Tensor, shape = (n_audio, n_group)
  182. cumulative log probabilities for each sequence
  183. Returns
  184. -------
  185. tokens : Sequence[Sequence[Tensor]], length = n_audio
  186. sequence of Tensors containing candidate token sequences, for each audio input
  187. sum_logprobs : List[List[float]], length = n_audio
  188. sequence of cumulative log probabilities corresponding to the above
  189. """
  190. raise NotImplementedError
  191. class GreedyDecoder(TokenDecoder):
  192. def __init__(self, temperature: float, eot: int):
  193. self.temperature = temperature
  194. self.eot = eot
  195. def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
  196. temperature = self.temperature
  197. if temperature == 0:
  198. next_tokens = logits.argmax(dim=-1)
  199. else:
  200. next_tokens = Categorical(logits=logits / temperature).sample()
  201. logprobs = F.log_softmax(logits.float(), dim=-1)
  202. current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
  203. sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
  204. next_tokens[tokens[:, -1] == self.eot] = self.eot
  205. tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
  206. completed = (tokens[:, -1] == self.eot).all()
  207. return tokens, completed
  208. def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
  209. # make sure each sequence has at least one EOT token at the end
  210. tokens = F.pad(tokens, (0, 1), value=self.eot)
  211. return tokens, sum_logprobs.tolist()
  212. class BeamSearchDecoder(TokenDecoder):
  213. def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
  214. self.beam_size = beam_size
  215. self.eot = eot
  216. self.inference = inference
  217. self.patience = patience or 1.0
  218. self.max_candidates: int = round(beam_size * self.patience)
  219. self.finished_sequences = None
  220. assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
  221. def reset(self):
  222. self.finished_sequences = None
  223. def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
  224. if tokens.shape[0] % self.beam_size != 0:
  225. raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
  226. n_audio = tokens.shape[0] // self.beam_size
  227. if self.finished_sequences is None: # for the first update
  228. self.finished_sequences = [{} for _ in range(n_audio)]
  229. logprobs = F.log_softmax(logits.float(), dim=-1)
  230. next_tokens, source_indices, finished_sequences = [], [], []
  231. for i in range(n_audio):
  232. scores, sources, finished = {}, {}, {}
  233. # STEP 1: calculate the cumulative log probabilities for possible candidates
  234. for j in range(self.beam_size):
  235. idx = i * self.beam_size + j
  236. prefix = tokens[idx].tolist()
  237. for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
  238. new_logprob = (sum_logprobs[idx] + logprob).item()
  239. sequence = tuple(prefix + [token.item()])
  240. scores[sequence] = new_logprob
  241. sources[sequence] = idx
  242. # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
  243. saved = 0
  244. for sequence in sorted(scores, key=scores.get, reverse=True):
  245. if sequence[-1] == self.eot:
  246. finished[sequence] = scores[sequence]
  247. else:
  248. sum_logprobs[len(next_tokens)] = scores[sequence]
  249. next_tokens.append(sequence)
  250. source_indices.append(sources[sequence])
  251. saved += 1
  252. if saved == self.beam_size:
  253. break
  254. finished_sequences.append(finished)
  255. tokens = torch.tensor(next_tokens, device=tokens.device)
  256. self.inference.rearrange_kv_cache(source_indices)
  257. # add newly finished sequences to self.finished_sequences
  258. assert len(self.finished_sequences) == len(finished_sequences)
  259. for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
  260. for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
  261. if len(previously_finished) >= self.max_candidates:
  262. break # the candidate list is full
  263. previously_finished[seq] = newly_finished[seq]
  264. # mark as completed if all audio has enough number of samples
  265. completed = all(
  266. len(sequences) >= self.max_candidates for sequences in self.finished_sequences
  267. )
  268. return tokens, completed
  269. def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
  270. # collect all finished sequences, including patience, and add unfinished ones if not enough
  271. sum_logprobs = sum_logprobs.cpu()
  272. for i, sequences in enumerate(self.finished_sequences):
  273. if len(sequences) < self.beam_size: # when not enough sequences are finished
  274. for j in list(np.argsort(sum_logprobs[i]))[::-1]:
  275. sequence = preceding_tokens[i, j].tolist() + [self.eot]
  276. sequences[tuple(sequence)] = sum_logprobs[i][j].item()
  277. if len(sequences) >= self.beam_size:
  278. break
  279. tokens: List[List[Tensor]] = [
  280. [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
  281. ]
  282. sum_logprobs: List[List[float]] = [
  283. list(sequences.values()) for sequences in self.finished_sequences
  284. ]
  285. return tokens, sum_logprobs
  286. class LogitFilter:
  287. def apply(self, logits: Tensor, tokens: Tensor) -> None:
  288. """Apply any filtering or masking to logits in-place
  289. Parameters
  290. ----------
  291. logits : Tensor, shape = (n_batch, vocab_size)
  292. per-token logits of the probability distribution at the current step
  293. tokens : Tensor, shape = (n_batch, current_sequence_length)
  294. all tokens in the context so far, including the prefix and sot_sequence tokens
  295. """
  296. raise NotImplementedError
  297. class SuppressBlank(LogitFilter):
  298. def __init__(self, tokenizer: Tokenizer, sample_begin: int):
  299. self.tokenizer = tokenizer
  300. self.sample_begin = sample_begin
  301. def apply(self, logits: Tensor, tokens: Tensor):
  302. if tokens.shape[1] == self.sample_begin:
  303. logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
  304. class SuppressTokens(LogitFilter):
  305. def __init__(self, suppress_tokens: Sequence[int]):
  306. self.suppress_tokens = list(suppress_tokens)
  307. def apply(self, logits: Tensor, tokens: Tensor):
  308. logits[:, self.suppress_tokens] = -np.inf
  309. class ApplyTimestampRules(LogitFilter):
  310. def __init__(
  311. self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
  312. ):
  313. self.tokenizer = tokenizer
  314. self.sample_begin = sample_begin
  315. self.max_initial_timestamp_index = max_initial_timestamp_index
  316. def apply(self, logits: Tensor, tokens: Tensor):
  317. # suppress <|notimestamps|> which is handled by without_timestamps
  318. if self.tokenizer.no_timestamps is not None:
  319. logits[:, self.tokenizer.no_timestamps] = -np.inf
  320. # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
  321. for k in range(tokens.shape[0]):
  322. seq = [t for t in tokens[k, self.sample_begin :].tolist()]
  323. last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
  324. penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
  325. if last_was_timestamp:
  326. if penultimate_was_timestamp: # has to be non-timestamp
  327. logits[k, self.tokenizer.timestamp_begin :] = -np.inf
  328. else: # cannot be normal text tokens
  329. logits[k, : self.tokenizer.eot] = -np.inf
  330. if tokens.shape[1] == self.sample_begin:
  331. # suppress generating non-timestamp tokens at the beginning
  332. logits[:, : self.tokenizer.timestamp_begin] = -np.inf
  333. # apply the `max_initial_timestamp` option
  334. if self.max_initial_timestamp_index is not None:
  335. last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
  336. logits[:, last_allowed + 1 :] = -np.inf
  337. # if sum of probability over timestamps is above any other token, sample timestamp
  338. logprobs = F.log_softmax(logits.float(), dim=-1)
  339. for k in range(tokens.shape[0]):
  340. timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
  341. max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
  342. if timestamp_logprob > max_text_token_logprob:
  343. logits[k, : self.tokenizer.timestamp_begin] = -np.inf
  344. class DecodingTask:
  345. inference: Inference
  346. sequence_ranker: SequenceRanker
  347. decoder: TokenDecoder
  348. logit_filters: List[LogitFilter]
  349. def __init__(self, model: "Whisper", options: DecodingOptions):
  350. self.model = model
  351. language = options.language or "en"
  352. tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
  353. self.tokenizer: Tokenizer = tokenizer
  354. self.options: DecodingOptions = self._verify_options(options)
  355. self.n_group: int = options.beam_size or options.best_of or 1
  356. self.n_ctx: int = model.dims.n_text_ctx
  357. self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
  358. self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
  359. if self.options.without_timestamps:
  360. self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
  361. self.initial_tokens: Tuple[int] = self._get_initial_tokens()
  362. self.sample_begin: int = len(self.initial_tokens)
  363. self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
  364. # inference: implements the forward pass through the decoder, including kv caching
  365. self.inference = PyTorchInference(model, len(self.initial_tokens))
  366. # sequence ranker: implements how to rank a group of sampled sequences
  367. self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
  368. # decoder: implements how to select the next tokens, given the autoregressive distribution
  369. if options.beam_size is not None:
  370. self.decoder = BeamSearchDecoder(
  371. options.beam_size, tokenizer.eot, self.inference, options.patience
  372. )
  373. else:
  374. self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
  375. # logit filters: applies various rules to suppress or penalize certain tokens
  376. self.logit_filters = []
  377. if self.options.suppress_blank:
  378. self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
  379. if self.options.suppress_tokens:
  380. self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
  381. if not options.without_timestamps:
  382. precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
  383. max_initial_timestamp_index = None
  384. if options.max_initial_timestamp:
  385. max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
  386. self.logit_filters.append(
  387. ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
  388. )
  389. def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
  390. if options.beam_size is not None and options.best_of is not None:
  391. raise ValueError("beam_size and best_of can't be given together")
  392. if options.temperature == 0:
  393. if options.best_of is not None:
  394. raise ValueError("best_of with greedy sampling (T=0) is not compatible")
  395. if options.patience is not None and options.beam_size is None:
  396. raise ValueError("patience requires beam_size to be given")
  397. if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
  398. raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
  399. return options
  400. def _get_initial_tokens(self) -> Tuple[int]:
  401. tokens = list(self.sot_sequence)
  402. prefix = self.options.prefix
  403. prompt = self.options.prompt
  404. if prefix:
  405. prefix_tokens = (
  406. self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
  407. )
  408. if self.sample_len is not None:
  409. max_prefix_len = self.n_ctx // 2 - self.sample_len
  410. prefix_tokens = prefix_tokens[-max_prefix_len:]
  411. tokens = tokens + prefix_tokens
  412. if prompt:
  413. prompt_tokens = (
  414. self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
  415. )
  416. tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
  417. return tuple(tokens)
  418. def _get_suppress_tokens(self) -> Tuple[int]:
  419. suppress_tokens = self.options.suppress_tokens
  420. if isinstance(suppress_tokens, str):
  421. suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
  422. if -1 in suppress_tokens:
  423. suppress_tokens = [t for t in suppress_tokens if t >= 0]
  424. suppress_tokens.extend(self.tokenizer.non_speech_tokens)
  425. elif suppress_tokens is None or len(suppress_tokens) == 0:
  426. suppress_tokens = [] # interpret empty string as an empty list
  427. else:
  428. assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
  429. suppress_tokens.extend(
  430. [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
  431. )
  432. if self.tokenizer.no_speech is not None:
  433. # no-speech probability is collected separately
  434. suppress_tokens.append(self.tokenizer.no_speech)
  435. return tuple(sorted(set(suppress_tokens)))
  436. def _get_audio_features(self, mel: Tensor):
  437. if self.options.fp16:
  438. mel = mel.half()
  439. if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
  440. # encoded audio features are given; skip audio encoding
  441. audio_features = mel
  442. else:
  443. audio_features = self.model.encoder(mel)
  444. if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
  445. return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
  446. return audio_features
  447. def _detect_language(self, audio_features: Tensor, tokens: Tensor):
  448. languages = [self.options.language] * audio_features.shape[0]
  449. lang_probs = None
  450. if self.options.language is None or self.options.task == "lang_id":
  451. lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
  452. languages = [max(probs, key=probs.get) for probs in lang_probs]
  453. if self.options.language is None:
  454. tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
  455. return languages, lang_probs
  456. def _main_loop(self, audio_features: Tensor, tokens: Tensor):
  457. assert audio_features.shape[0] == tokens.shape[0]
  458. n_batch = tokens.shape[0]
  459. sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
  460. no_speech_probs = [np.nan] * n_batch
  461. try:
  462. for i in range(self.sample_len):
  463. logits = self.inference.logits(tokens, audio_features)
  464. if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
  465. probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
  466. no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
  467. # now we need to consider the logits at the last token only
  468. logits = logits[:, -1]
  469. # apply the logit filters, e.g. for suppressing or applying penalty to
  470. for logit_filter in self.logit_filters:
  471. logit_filter.apply(logits, tokens)
  472. # expand the tokens tensor with the selected next tokens
  473. tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
  474. if completed or tokens.shape[-1] > self.n_ctx:
  475. break
  476. finally:
  477. self.inference.cleanup_caching()
  478. return tokens, sum_logprobs, no_speech_probs
  479. @torch.no_grad()
  480. def run(self, mel: Tensor) -> List[DecodingResult]:
  481. self.decoder.reset()
  482. tokenizer: Tokenizer = self.tokenizer
  483. n_audio: int = mel.shape[0]
  484. audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
  485. tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
  486. # detect language if requested, overwriting the language token
  487. languages, language_probs = self._detect_language(audio_features, tokens)
  488. if self.options.task == "lang_id":
  489. return [
  490. DecodingResult(audio_features=features, language=language, language_probs=probs)
  491. for features, language, probs in zip(audio_features, languages, language_probs)
  492. ]
  493. # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
  494. audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
  495. tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
  496. # call the main sampling loop
  497. tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
  498. # reshape the tensors to have (n_audio, n_group) as the first two dimensions
  499. audio_features = audio_features[:: self.n_group]
  500. no_speech_probs = no_speech_probs[:: self.n_group]
  501. assert audio_features.shape[0] == len(no_speech_probs) == n_audio
  502. tokens = tokens.reshape(n_audio, self.n_group, -1)
  503. sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
  504. # get the final candidates for each group, and slice between the first sampled token and EOT
  505. tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
  506. tokens: List[List[Tensor]] = [
  507. [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
  508. ]
  509. # select the top-ranked sample in each group
  510. selected = self.sequence_ranker.rank(tokens, sum_logprobs)
  511. tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
  512. texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
  513. sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
  514. avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
  515. fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
  516. if len(set(map(len, fields))) != 1:
  517. raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
  518. return [
  519. DecodingResult(
  520. audio_features=features,
  521. language=language,
  522. tokens=tokens,
  523. text=text,
  524. avg_logprob=avg_logprob,
  525. no_speech_prob=no_speech_prob,
  526. temperature=self.options.temperature,
  527. compression_ratio=compression_ratio(text),
  528. )
  529. for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
  530. ]
  531. @torch.no_grad()
  532. def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
  533. """
  534. Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
  535. Parameters
  536. ----------
  537. model: Whisper
  538. the Whisper model instance
  539. mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
  540. A tensor containing the Mel spectrogram(s)
  541. options: DecodingOptions
  542. A dataclass that contains all necessary options for decoding 30-second segments
  543. Returns
  544. -------
  545. result: Union[DecodingResult, List[DecodingResult]]
  546. The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
  547. """
  548. single = mel.ndim == 2
  549. if single:
  550. mel = mel.unsqueeze(0)
  551. result = DecodingTask(model, options).run(mel)
  552. if single:
  553. result = result[0]
  554. return result