base_model.py 319 B

1234567891011121314151617
  1. import torch
  2. class FunASRModel(torch.nn.Module):
  3. """The common model class
  4. """
  5. def __init__(self):
  6. super().__init__()
  7. self.num_updates = 0
  8. def set_num_updates(self, num_updates):
  9. self.num_updates = num_updates
  10. def get_num_updates(self):
  11. return self.num_updates