abs_espnet_model.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. from abc import ABC
  4. from abc import abstractmethod
  5. from typing import Dict
  6. from typing import Tuple
  7. import torch
  8. class AbsESPnetModel(torch.nn.Module, ABC):
  9. """The common abstract class among each tasks
  10. "ESPnetModel" is referred to a class which inherits torch.nn.Module,
  11. and makes the dnn-models forward as its member field,
  12. a.k.a delegate pattern,
  13. and defines "loss", "stats", and "weight" for the task.
  14. If you intend to implement new task in ESPNet,
  15. the model must inherit this class.
  16. In other words, the "mediator" objects between
  17. our training system and the your task class are
  18. just only these three values, loss, stats, and weight.
  19. Example:
  20. >>> from funasr.tasks.abs_task import AbsTask
  21. >>> class YourESPnetModel(AbsESPnetModel):
  22. ... def forward(self, input, input_lengths):
  23. ... ...
  24. ... return loss, stats, weight
  25. >>> class YourTask(AbsTask):
  26. ... @classmethod
  27. ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
  28. """
  29. def __init__(self):
  30. super().__init__()
  31. self.num_updates = 0
  32. @abstractmethod
  33. def forward(
  34. self, **batch: torch.Tensor
  35. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  36. raise NotImplementedError
  37. @abstractmethod
  38. def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
  39. raise NotImplementedError
  40. def set_num_updates(self, num_updates):
  41. self.num_updates = num_updates
  42. def get_num_updates(self):
  43. return self.num_updates