adaptor.py 955 B

1234567891011121314151617181920212223242526272829
  1. import torch
  2. import torch.nn as nn
  3. from funasr.register import tables
  4. @tables.register("adaptor_classes", "Linear")
  5. class Linear(nn.Module):
  6. def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
  7. super().__init__()
  8. self.k = downsample_rate
  9. self.encoder_dim = encoder_dim
  10. self.llm_dim = llm_dim
  11. self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
  12. self.relu = nn.ReLU()
  13. self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
  14. def forward(self, x):
  15. batch_size, seq_len, dim = x.size()
  16. num_frames_to_discard = seq_len % self.k
  17. if num_frames_to_discard > 0:
  18. x = x[:, :-num_frames_to_discard, :]
  19. seq_len = x.size(1)
  20. x = x.contiguous()
  21. x = x.view(batch_size, seq_len // self.k, dim * self.k)
  22. x = self.linear1(x)
  23. x = self.relu(x)
  24. x = self.linear2(x)
  25. return x