| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541 |
- from typing import List
- from typing import Tuple
- import logging
- import torch
- import torch.nn as nn
- import numpy as np
- from funasr.modules.streaming_utils import utils as myutils
- from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
- from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
- from funasr.modules.embedding import PositionalEncoding
- from funasr.modules.layer_norm import LayerNorm
- from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
- from funasr.modules.repeat import repeat
- class DecoderLayerSANM(nn.Module):
- """Single decoder layer module.
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- src_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied. i.e. x -> x + att(x)
- """
- def __init__(
- self,
- size,
- self_attn,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(DecoderLayerSANM, self).__init__()
- self.size = size
- self.self_attn = self_attn
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- if self_attn is not None:
- self.norm2 = LayerNorm(size)
- if src_attn is not None:
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
- """Compute decoded features.
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
- """
- # tgt = self.dropout(tgt)
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
- tgt = self.feed_forward(tgt)
- x = tgt
- if self.self_attn:
- if self.normalize_before:
- tgt = self.norm2(tgt)
- x, _ = self.self_attn(tgt, tgt_mask)
- x = residual + self.dropout(x)
- if self.src_attn is not None:
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
- x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
- return x, tgt_mask, memory, memory_mask, cache
- def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
- """Compute decoded features.
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
- """
- # tgt = self.dropout(tgt)
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
- tgt = self.feed_forward(tgt)
- x = tgt
- if self.self_attn:
- if self.normalize_before:
- tgt = self.norm2(tgt)
- if self.training:
- cache = None
- x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
- x = residual + self.dropout(x)
- if self.src_attn is not None:
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
- x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
- return x, tgt_mask, memory, memory_mask, cache
- def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
- """Compute decoded features.
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
- """
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
- tgt = self.feed_forward(tgt)
- x = tgt
- if self.self_attn:
- if self.normalize_before:
- tgt = self.norm2(tgt)
- x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
- x = residual + self.dropout(x)
- if self.src_attn is not None:
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
- x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
- x = residual + x
- return x, memory, fsmn_cache, opt_cache
- class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
- """
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- att_layer_num: int = 6,
- kernel_size: int = 21,
- sanm_shfit: int = None,
- concat_embeds: bool = False,
- attention_dim: int = None,
- tf2torch_tensor_name_prefix_torch: str = "decoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
- embed_tensor_name_prefix_tf: str = None,
- ):
- super().__init__(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- dropout_rate=dropout_rate,
- positional_dropout_rate=positional_dropout_rate,
- input_layer=input_layer,
- use_output_layer=use_output_layer,
- pos_enc_class=pos_enc_class,
- normalize_before=normalize_before,
- )
- if attention_dim is None:
- attention_dim = encoder_output_size
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
- self.normalize_before = normalize_before
- if self.normalize_before:
- self.after_norm = LayerNorm(attention_dim)
- if use_output_layer:
- self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
- else:
- self.output_layer = None
- self.att_layer_num = att_layer_num
- self.num_blocks = num_blocks
- if sanm_shfit is None:
- sanm_shfit = (kernel_size - 1) // 2
- self.decoders = repeat(
- att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
- ),
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if num_blocks - att_layer_num <= 0:
- self.decoders2 = None
- else:
- self.decoders2 = repeat(
- num_blocks - att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- self.decoders3 = repeat(
- 1,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- None,
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if concat_embeds:
- self.embed_concat_ffn = repeat(
- 1,
- lambda lnum: DecoderLayerSANM(
- attention_dim + encoder_output_size,
- None,
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
- adim=attention_dim),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- else:
- self.embed_concat_ffn = None
- self.concat_embeds = concat_embeds
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
- self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- chunk_mask: torch.Tensor = None,
- pre_acoustic_embeds: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward decoder.
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- tgt = ys_in_pad
- tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
- memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
- if chunk_mask is not None:
- memory_mask = memory_mask * chunk_mask
- if tgt_mask.size(1) != memory_mask.size(1):
- memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
- x = self.embed(tgt)
- if pre_acoustic_embeds is not None and self.concat_embeds:
- x = torch.cat((x, pre_acoustic_embeds), dim=-1)
- x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
- if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.output_layer is not None:
- x = self.output_layer(x)
- olens = tgt_mask.sum(1)
- return x, olens
- def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
- """Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
- cache=state
- )
- return logp.squeeze(0), state
- def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- memory_mask: torch.Tensor = None,
- pre_acoustic_embeds: torch.Tensor = None,
- cache: List[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
- """Forward one step.
- Args:
- tgt: input token ids, int64 (batch, maxlen_out)
- tgt_mask: input token mask, (batch, maxlen_out)
- dtype=torch.uint8 in PyTorch 1.2-
- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
- memory: encoded memory, float32 (batch, maxlen_in, feat)
- cache: cached output list of (batch, max_time_out-1, size)
- Returns:
- y, cache: NN output value and cache per `self.decoders`.
- y.shape` is (batch, maxlen_out, token)
- """
- x = tgt[:, -1:]
- tgt_mask = None
- x = self.embed(x)
- if pre_acoustic_embeds is not None and self.concat_embeds:
- x = torch.cat((x, pre_acoustic_embeds), dim=-1)
- x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
- if cache is None:
- cache_layer_num = len(self.decoders)
- if self.decoders2 is not None:
- cache_layer_num += len(self.decoders2)
- cache = [None] * cache_layer_num
- new_cache = []
- # for c, decoder in zip(cache, self.decoders):
- for i in range(self.att_layer_num):
- decoder = self.decoders[i]
- c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, memory_mask, cache=c
- )
- new_cache.append(c_ret)
- if self.num_blocks - self.att_layer_num >= 1:
- for i in range(self.num_blocks - self.att_layer_num):
- j = i + self.att_layer_num
- decoder = self.decoders2[i]
- c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, memory_mask, cache=c
- )
- new_cache.append(c_ret)
- for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=None
- )
- if self.normalize_before:
- y = self.after_norm(x[:, -1])
- else:
- y = x[:, -1]
- if self.output_layer is not None:
- y = self.output_layer(y)
- y = torch.log_softmax(y, dim=-1)
- return y, new_cache
- def gen_tf2torch_map_dict(self):
-
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
- map_dict_local = {
-
- ## decoder
- # ffn
- "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # fsmn
- "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
- tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- # src att
- "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # embed_concat_ffn
- "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- # in embed
- "{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
-
- # out layer
- "{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
- "{}/w_embs".format(embed_tensor_name_prefix_tf)],
- "squeeze": [None, None],
- "transpose": [(1, 0), None],
- }, # (4235,256),(256,4235)
- "{}.output_layer.bias".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
- "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
- "squeeze": [None, None],
- "transpose": [None, None],
- }, # (4235,),(4235,)
-
- }
- return map_dict_local
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- decoder_layeridx_sets = set()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "decoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders2":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- name_q = name_q.replace("decoders2", "decoders")
- layeridx_bias = len(decoder_layeridx_sets)
-
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders3":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed" or names[1] == "output_layer":
- name_tf = map_dict[name]["name"]
- if isinstance(name_tf, list):
- idx_list = 0
- if name_tf[idx_list] in var_dict_tf.keys():
- pass
- else:
- idx_list = 1
- data_tf = var_dict_tf[name_tf[idx_list]]
- if map_dict[name]["squeeze"][idx_list] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
- if map_dict[name]["transpose"][idx_list] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
- name_tf[idx_list],
- var_dict_tf[name_tf[
- idx_list]].shape))
-
- else:
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed_concat_ffn":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
- class ParaformerSANMDecoder(BaseTransformerDecoder):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2006.01713
- """
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- att_layer_num: int = 6,
- kernel_size: int = 21,
- sanm_shfit: int = 0,
- lora_list: List[str] = None,
- lora_rank: int = 8,
- lora_alpha: int = 16,
- lora_dropout: float = 0.1,
- chunk_multiply_factor: tuple = (1,),
- tf2torch_tensor_name_prefix_torch: str = "decoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
- ):
- super().__init__(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- dropout_rate=dropout_rate,
- positional_dropout_rate=positional_dropout_rate,
- input_layer=input_layer,
- use_output_layer=use_output_layer,
- pos_enc_class=pos_enc_class,
- normalize_before=normalize_before,
- )
- attention_dim = encoder_output_size
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- # pos_enc_class(attention_dim, positional_dropout_rate),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
- self.normalize_before = normalize_before
- if self.normalize_before:
- self.after_norm = LayerNorm(attention_dim)
- if use_output_layer:
- self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
- else:
- self.output_layer = None
- self.att_layer_num = att_layer_num
- self.num_blocks = num_blocks
- if sanm_shfit is None:
- sanm_shfit = (kernel_size - 1) // 2
- self.decoders = repeat(
- att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
- ),
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if num_blocks - att_layer_num <= 0:
- self.decoders2 = None
- else:
- self.decoders2 = repeat(
- num_blocks - att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
- ),
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- self.decoders3 = repeat(
- 1,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- None,
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
- self.chunk_multiply_factor = chunk_multiply_factor
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- chunk_mask: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward decoder.
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- tgt = ys_in_pad
- tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
- memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
- if chunk_mask is not None:
- memory_mask = memory_mask * chunk_mask
- if tgt_mask.size(1) != memory_mask.size(1):
- memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
- x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
- if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.output_layer is not None:
- x = self.output_layer(x)
- olens = tgt_mask.sum(1)
- return x, olens
- def score(self, ys, state, x):
- """Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
- )
- return logp.squeeze(0), state
- def forward_chunk(
- self,
- memory: torch.Tensor,
- tgt: torch.Tensor,
- cache: dict = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward decoder.
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- x = tgt
- if cache["decode_fsmn"] is None:
- cache_layer_num = len(self.decoders)
- if self.decoders2 is not None:
- cache_layer_num += len(self.decoders2)
- fsmn_cache = [None] * cache_layer_num
- else:
- fsmn_cache = cache["decode_fsmn"]
- if cache["opt"] is None:
- cache_layer_num = len(self.decoders)
- opt_cache = [None] * cache_layer_num
- else:
- opt_cache = cache["opt"]
- for i in range(self.att_layer_num):
- decoder = self.decoders[i]
- x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
- x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
- chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
- )
- if self.num_blocks - self.att_layer_num > 1:
- for i in range(self.num_blocks - self.att_layer_num):
- j = i + self.att_layer_num
- decoder = self.decoders2[i]
- x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
- x, memory, fsmn_cache=fsmn_cache[j]
- )
- for decoder in self.decoders3:
- x, memory, _, _ = decoder.forward_chunk(
- x, memory
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.output_layer is not None:
- x = self.output_layer(x)
- cache["decode_fsmn"] = fsmn_cache
- if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
- cache["opt"] = opt_cache
- return x
- def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- cache: List[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
- """Forward one step.
- Args:
- tgt: input token ids, int64 (batch, maxlen_out)
- tgt_mask: input token mask, (batch, maxlen_out)
- dtype=torch.uint8 in PyTorch 1.2-
- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
- memory: encoded memory, float32 (batch, maxlen_in, feat)
- cache: cached output list of (batch, max_time_out-1, size)
- Returns:
- y, cache: NN output value and cache per `self.decoders`.
- y.shape` is (batch, maxlen_out, token)
- """
- x = self.embed(tgt)
- if cache is None:
- cache_layer_num = len(self.decoders)
- if self.decoders2 is not None:
- cache_layer_num += len(self.decoders2)
- cache = [None] * cache_layer_num
- new_cache = []
- # for c, decoder in zip(cache, self.decoders):
- for i in range(self.att_layer_num):
- decoder = self.decoders[i]
- c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=c
- )
- new_cache.append(c_ret)
- if self.num_blocks - self.att_layer_num > 1:
- for i in range(self.num_blocks - self.att_layer_num):
- j = i + self.att_layer_num
- decoder = self.decoders2[i]
- c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=c
- )
- new_cache.append(c_ret)
- for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=None
- )
- if self.normalize_before:
- y = self.after_norm(x[:, -1])
- else:
- y = x[:, -1]
- if self.output_layer is not None:
- y = torch.log_softmax(self.output_layer(y), dim=-1)
- return y, new_cache
- def gen_tf2torch_map_dict(self):
-
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
-
- ## decoder
- # ffn
- "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # fsmn
- "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
- tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- # src att
- "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # embed_concat_ffn
- "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- # in embed
- "{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
-
- # out layer
- "{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
- "squeeze": [None, None],
- "transpose": [(1, 0), None],
- }, # (4235,256),(256,4235)
- "{}.output_layer.bias".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
- "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
- "squeeze": [None, None],
- "transpose": [None, None],
- }, # (4235,),(4235,)
-
- }
- return map_dict_local
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- decoder_layeridx_sets = set()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "decoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders2":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- name_q = name_q.replace("decoders2", "decoders")
- layeridx_bias = len(decoder_layeridx_sets)
-
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders3":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed" or names[1] == "output_layer":
- name_tf = map_dict[name]["name"]
- if isinstance(name_tf, list):
- idx_list = 0
- if name_tf[idx_list] in var_dict_tf.keys():
- pass
- else:
- idx_list = 1
- data_tf = var_dict_tf[name_tf[idx_list]]
- if map_dict[name]["squeeze"][idx_list] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
- if map_dict[name]["transpose"][idx_list] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
- name_tf[idx_list],
- var_dict_tf[name_tf[
- idx_list]].shape))
-
- else:
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed_concat_ffn":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
|