| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- from abc import ABC
- from abc import abstractmethod
- from typing import Dict
- from typing import Tuple
- import torch
- class AbsESPnetModel(torch.nn.Module, ABC):
- """The common abstract class among each tasks
- "ESPnetModel" is referred to a class which inherits torch.nn.Module,
- and makes the dnn-models forward as its member field,
- a.k.a delegate pattern,
- and defines "loss", "stats", and "weight" for the task.
- If you intend to implement new task in ESPNet,
- the model must inherit this class.
- In other words, the "mediator" objects between
- our training system and the your task class are
- just only these three values, loss, stats, and weight.
- Example:
- >>> from funasr.tasks.abs_task import AbsTask
- >>> class YourESPnetModel(AbsESPnetModel):
- ... def forward(self, input, input_lengths):
- ... ...
- ... return loss, stats, weight
- >>> class YourTask(AbsTask):
- ... @classmethod
- ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
- """
- def __init__(self):
- super().__init__()
- self.num_updates = 0
- @abstractmethod
- def forward(
- self, **batch: torch.Tensor
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- raise NotImplementedError
- @abstractmethod
- def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
- raise NotImplementedError
- def set_num_updates(self, num_updates):
- self.num_updates = num_updates
- def get_num_updates(self):
- return self.num_updates
|