| 1234567891011121314151617181920212223242526272829303132333435363738 |
- #!/usr/bin/env python3
- # 2021, Carnegie Mellon University; Xuankai Chang
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Linear Projection."""
- from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
- from typeguard import check_argument_types
- from typing import Tuple
- import torch
- class LinearProjection(AbsPreEncoder):
- """Linear Projection Preencoder."""
- def __init__(
- self,
- input_size: int,
- output_size: int,
- ):
- """Initialize the module."""
- assert check_argument_types()
- super().__init__()
- self.output_dim = output_size
- self.linear_out = torch.nn.Linear(input_size, output_size)
- def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward."""
- output = self.linear_out(input)
- return output, input_lengths # no state in this layer
- def output_size(self) -> int:
- """Get the output size."""
- return self.output_dim
|