| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- # Copyright 2019 Shigeki Karita
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Multi-Head Attention layer definition."""
- import math
- import numpy
- import torch
- from torch import nn
- from typing import Optional, Tuple
- import torch.nn.functional as F
- from funasr.models.transformer.utils.nets_utils import make_pad_mask
- import funasr.models.lora.layers as lora
- class MultiHeadedAttention(nn.Module):
- """Multi-Head Attention layer.
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, n_head, n_feat, dropout_rate):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttention, self).__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- self.linear_q = nn.Linear(n_feat, n_feat)
- self.linear_k = nn.Linear(n_feat, n_feat)
- self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
- def forward_qkv(self, query, key, value):
- """Transform query, key and value.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
- """
- n_batch = query.size(0)
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
- return q, k, v
- def forward_attention(self, value, scores, mask):
- """Compute attention context vector.
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
- """
- n_batch = value.size(0)
- if mask is not None:
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
- )
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
- return self.linear_out(x) # (batch, time1, d_model)
- def forward(self, query, key, value, mask):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q, k, v = self.forward_qkv(query, key, value)
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
- return self.forward_attention(v, scores, mask)
- class MultiHeadedAttentionSANM(nn.Module):
- """Multi-Head Attention layer.
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
- """Construct an MultiHeadedAttention object."""
- super().__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- # self.linear_q = nn.Linear(n_feat, n_feat)
- # self.linear_k = nn.Linear(n_feat, n_feat)
- # self.linear_v = nn.Linear(n_feat, n_feat)
- if lora_list is not None:
- if "o" in lora_list:
- self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
- else:
- self.linear_out = nn.Linear(n_feat, n_feat)
- lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
- if lora_qkv_list == [False, False, False]:
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- else:
- self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
- else:
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
- self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
- # padding
- left_padding = (kernel_size - 1) // 2
- if sanm_shfit > 0:
- left_padding = left_padding + sanm_shfit
- right_padding = kernel_size - 1 - left_padding
- self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
- def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
- b, t, d = inputs.size()
- if mask is not None:
- mask = torch.reshape(mask, (b, -1, 1))
- if mask_shfit_chunk is not None:
- mask = mask * mask_shfit_chunk
- inputs = inputs * mask
- x = inputs.transpose(1, 2)
- x = self.pad_fn(x)
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
- x += inputs
- x = self.dropout(x)
- if mask is not None:
- x = x * mask
- return x
- def forward_qkv(self, x):
- """Transform query, key and value.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
- """
- b, t, d = x.size()
- q_k_v = self.linear_q_k_v(x)
- q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
- q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
- k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- return q_h, k_h, v_h, v
- def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
- """Compute attention context vector.
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
- """
- n_batch = value.size(0)
- if mask is not None:
- if mask_att_chunk_encoder is not None:
- mask = mask * mask_att_chunk_encoder
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
- )
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
- return self.linear_out(x) # (batch, time1, d_model)
- def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
- return att_outs + fsmn_memory
- def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- if chunk_size is not None and look_back > 0 or look_back == -1:
- if cache is not None:
- k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
- v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
- k_h = torch.cat((cache["k"], k_h), dim=2)
- v_h = torch.cat((cache["v"], v_h), dim=2)
- cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
- cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
- if look_back != -1:
- cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
- cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
- else:
- cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
- "v": v_h[:, :, :-(chunk_size[2]), :]}
- cache = cache_tmp
- fsmn_memory = self.forward_fsmn(v, None)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, None)
- return att_outs + fsmn_memory, cache
- class MultiHeadedAttentionSANMDecoder(nn.Module):
- """Multi-Head Attention layer.
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionSANMDecoder, self).__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
- self.fsmn_block = nn.Conv1d(n_feat, n_feat,
- kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
- # padding
- # padding
- left_padding = (kernel_size - 1) // 2
- if sanm_shfit > 0:
- left_padding = left_padding + sanm_shfit
- right_padding = kernel_size - 1 - left_padding
- self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
- self.kernel_size = kernel_size
- def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
- '''
- :param x: (#batch, time1, size).
- :param mask: Mask tensor (#batch, 1, time)
- :return:
- '''
- # print("in fsmn, inputs", inputs.size())
- b, t, d = inputs.size()
- # logging.info(
- # "mask: {}".format(mask.size()))
- if mask is not None:
- mask = torch.reshape(mask, (b ,-1, 1))
- # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
- if mask_shfit_chunk is not None:
- # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
- mask = mask * mask_shfit_chunk
- # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
- # print("in fsmn, mask", mask.size())
- # print("in fsmn, inputs", inputs.size())
- inputs = inputs * mask
- x = inputs.transpose(1, 2)
- b, d, t = x.size()
- if cache is None:
- # print("in fsmn, cache is None, x", x.size())
- x = self.pad_fn(x)
- if not self.training:
- cache = x
- else:
- # print("in fsmn, cache is not None, x", x.size())
- # x = torch.cat((x, cache), dim=2)[:, :, :-1]
- # if t < self.kernel_size:
- # x = self.pad_fn(x)
- x = torch.cat((cache[:, :, 1:], x), dim=2)
- x = x[:, :, -(self.kernel_size+t-1):]
- # print("in fsmn, cache is not None, x_cat", x.size())
- cache = x
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
- # print("in fsmn, fsmn_out", x.size())
- if x.size(1) != inputs.size(1):
- inputs = inputs[:, -1, :]
- x = x + inputs
- x = self.dropout(x)
- if mask is not None:
- x = x * mask
- return x, cache
- class MultiHeadedAttentionCrossAtt(nn.Module):
- """Multi-Head Attention layer.
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionCrossAtt, self).__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- if lora_list is not None:
- if "q" in lora_list:
- self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
- else:
- self.linear_q = nn.Linear(n_feat, n_feat)
- lora_kv_list = ["k" in lora_list, "v" in lora_list]
- if lora_kv_list == [False, False]:
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
- else:
- self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
- r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
- if "o" in lora_list:
- self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
- else:
- self.linear_out = nn.Linear(n_feat, n_feat)
- else:
- self.linear_q = nn.Linear(n_feat, n_feat)
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
- def forward_qkv(self, x, memory):
- """Transform query, key and value.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
- """
- # print("in forward_qkv, x", x.size())
- b = x.size(0)
- q = self.linear_q(x)
- q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
- k_v = self.linear_k_v(memory)
- k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
- k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- return q_h, k_h, v_h
- def forward_attention(self, value, scores, mask, ret_attn=False):
- """Compute attention context vector.
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
- """
- n_batch = value.size(0)
- if mask is not None:
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
- )
- # logging.info(
- # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
- if ret_attn:
- return self.linear_out(x), self.attn # (batch, time1, d_model)
- return self.linear_out(x) # (batch, time1, d_model)
- def forward(self, x, memory, memory_mask, ret_attn=False):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q_h, k_h, v_h = self.forward_qkv(x, memory)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn)
- def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q_h, k_h, v_h = self.forward_qkv(x, memory)
- if chunk_size is not None and look_back > 0:
- if cache is not None:
- k_h = torch.cat((cache["k"], k_h), dim=2)
- v_h = torch.cat((cache["v"], v_h), dim=2)
- cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
- cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
- else:
- cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
- "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
- cache = cache_tmp
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- return self.forward_attention(v_h, scores, None), cache
- class MultiHeadSelfAttention(nn.Module):
- """Multi-Head Attention layer.
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, n_head, in_feat, n_feat, dropout_rate):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadSelfAttention, self).__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
- def forward_qkv(self, x):
- """Transform query, key and value.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
- """
- b, t, d = x.size()
- q_k_v = self.linear_q_k_v(x)
- q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
- q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
- k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- return q_h, k_h, v_h, v
- def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
- """Compute attention context vector.
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
- """
- n_batch = value.size(0)
- if mask is not None:
- if mask_att_chunk_encoder is not None:
- mask = mask * mask_att_chunk_encoder
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
- )
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
- return self.linear_out(x) # (batch, time1, d_model)
- def forward(self, x, mask, mask_att_chunk_encoder=None):
- """Compute scaled dot product attention.
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
- return att_outs
|