| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273 |
- #!/usr/bin/env python3
- # -*- encoding: utf-8 -*-
- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- # MIT License (https://opensource.org/licenses/MIT)
- import torch
- from typing import List, Optional, Tuple
- from funasr.register import tables
- from funasr.models.specaug.specaug import SpecAug
- from funasr.models.transducer.beam_search_transducer import Hypothesis
- @tables.register("decoder_classes", "rnnt_decoder")
- class RNNTDecoder(torch.nn.Module):
- """RNN decoder module.
- Args:
- vocab_size: Vocabulary size.
- embed_size: Embedding size.
- hidden_size: Hidden size..
- rnn_type: Decoder layers type.
- num_layers: Number of decoder layers.
- dropout_rate: Dropout rate for decoder layers.
- embed_dropout_rate: Dropout rate for embedding layer.
- embed_pad: Embedding padding symbol ID.
- """
- def __init__(
- self,
- vocab_size: int,
- embed_size: int = 256,
- hidden_size: int = 256,
- rnn_type: str = "lstm",
- num_layers: int = 1,
- dropout_rate: float = 0.0,
- embed_dropout_rate: float = 0.0,
- embed_pad: int = 0,
- use_embed_mask: bool = False,
- ) -> None:
- """Construct a RNNDecoder object."""
- super().__init__()
- if rnn_type not in ("lstm", "gru"):
- raise ValueError(f"Not supported: rnn_type={rnn_type}")
- self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
- self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
- rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
- self.rnn = torch.nn.ModuleList(
- [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
- )
- for _ in range(1, num_layers):
- self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
- self.dropout_rnn = torch.nn.ModuleList(
- [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
- )
- self.dlayers = num_layers
- self.dtype = rnn_type
- self.output_size = hidden_size
- self.vocab_size = vocab_size
- self.device = next(self.parameters()).device
- self.score_cache = {}
- self.use_embed_mask = use_embed_mask
- if self.use_embed_mask:
- self._embed_mask = SpecAug(
- time_mask_width_range=3,
- num_time_mask=4,
- apply_freq_mask=False,
- apply_time_warp=False
- )
-
- def forward(
- self,
- labels: torch.Tensor,
- label_lens: torch.Tensor,
- states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
- ) -> torch.Tensor:
- """Encode source label sequences.
- Args:
- labels: Label ID sequences. (B, L)
- states: Decoder hidden states.
- ((N, B, D_dec), (N, B, D_dec) or None) or None
- Returns:
- dec_out: Decoder output sequences. (B, U, D_dec)
- """
- if states is None:
- states = self.init_state(labels.size(0))
- dec_embed = self.dropout_embed(self.embed(labels))
- if self.use_embed_mask and self.training:
- dec_embed = self._embed_mask(dec_embed, label_lens)[0]
- dec_out, states = self.rnn_forward(dec_embed, states)
- return dec_out
- def rnn_forward(
- self,
- x: torch.Tensor,
- state: Tuple[torch.Tensor, Optional[torch.Tensor]],
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
- """Encode source label sequences.
- Args:
- x: RNN input sequences. (B, D_emb)
- state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
- Returns:
- x: RNN output sequences. (B, D_dec)
- (h_next, c_next): Decoder hidden states.
- (N, B, D_dec), (N, B, D_dec) or None)
- """
- h_prev, c_prev = state
- h_next, c_next = self.init_state(x.size(0))
- for layer in range(self.dlayers):
- if self.dtype == "lstm":
- x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
- layer
- ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
- else:
- x, h_next[layer : layer + 1] = self.rnn[layer](
- x, hx=h_prev[layer : layer + 1]
- )
- x = self.dropout_rnn[layer](x)
- return x, (h_next, c_next)
- def score(
- self,
- label: torch.Tensor,
- label_sequence: List[int],
- dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
- """One-step forward hypothesis.
- Args:
- label: Previous label. (1, 1)
- label_sequence: Current label sequence.
- dec_state: Previous decoder hidden states.
- ((N, 1, D_dec), (N, 1, D_dec) or None)
- Returns:
- dec_out: Decoder output sequence. (1, D_dec)
- dec_state: Decoder hidden states.
- ((N, 1, D_dec), (N, 1, D_dec) or None)
- """
- str_labels = "_".join(map(str, label_sequence))
- if str_labels in self.score_cache:
- dec_out, dec_state = self.score_cache[str_labels]
- else:
- dec_embed = self.embed(label)
- dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
- self.score_cache[str_labels] = (dec_out, dec_state)
- return dec_out[0], dec_state
- def batch_score(
- self,
- hyps: List[Hypothesis],
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
- """One-step forward hypotheses.
- Args:
- hyps: Hypotheses.
- Returns:
- dec_out: Decoder output sequences. (B, D_dec)
- states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
- """
- labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
- dec_embed = self.embed(labels)
- states = self.create_batch_states([h.dec_state for h in hyps])
- dec_out, states = self.rnn_forward(dec_embed, states)
- return dec_out.squeeze(1), states
- def set_device(self, device: torch.device) -> None:
- """Set GPU device to use.
- Args:
- device: Device ID.
- """
- self.device = device
- def init_state(
- self, batch_size: int
- ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
- """Initialize decoder states.
- Args:
- batch_size: Batch size.
- Returns:
- : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
- """
- h_n = torch.zeros(
- self.dlayers,
- batch_size,
- self.output_size,
- device=self.device,
- )
- if self.dtype == "lstm":
- c_n = torch.zeros(
- self.dlayers,
- batch_size,
- self.output_size,
- device=self.device,
- )
- return (h_n, c_n)
- return (h_n, None)
- def select_state(
- self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- """Get specified ID state from decoder hidden states.
- Args:
- states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
- idx: State ID to extract.
- Returns:
- : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
- """
- return (
- states[0][:, idx : idx + 1, :],
- states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
- )
- def create_batch_states(
- self,
- new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- """Create decoder hidden states.
- Args:
- new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
- Returns:
- states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
- """
- return (
- torch.cat([s[0] for s in new_states], dim=1),
- torch.cat([s[1] for s in new_states], dim=1)
- if self.dtype == "lstm"
- else None,
- )
|