hugging_face_transformers_postencoder.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #!/usr/bin/env python3
  2. # 2021, University of Stuttgart; Pavel Denisov
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. """Hugging Face Transformers PostEncoder."""
  5. from funasr.modules.nets_utils import make_pad_mask
  6. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  7. from typing import Tuple
  8. import copy
  9. import logging
  10. import torch
  11. try:
  12. from transformers import AutoModel
  13. is_transformers_available = True
  14. except ImportError:
  15. is_transformers_available = False
  16. class HuggingFaceTransformersPostEncoder(AbsPostEncoder):
  17. """Hugging Face Transformers PostEncoder."""
  18. def __init__(
  19. self,
  20. input_size: int,
  21. model_name_or_path: str,
  22. ):
  23. """Initialize the module."""
  24. super().__init__()
  25. if not is_transformers_available:
  26. raise ImportError(
  27. "`transformers` is not available. Please install it via `pip install"
  28. " transformers` or `cd /path/to/espnet/tools && . ./activate_python.sh"
  29. " && ./installers/install_transformers.sh`."
  30. )
  31. model = AutoModel.from_pretrained(model_name_or_path)
  32. if hasattr(model, "encoder"):
  33. self.transformer = model.encoder
  34. else:
  35. self.transformer = model
  36. if hasattr(self.transformer, "embed_tokens"):
  37. del self.transformer.embed_tokens
  38. if hasattr(self.transformer, "wte"):
  39. del self.transformer.wte
  40. if hasattr(self.transformer, "word_embedding"):
  41. del self.transformer.word_embedding
  42. self.pretrained_params = copy.deepcopy(self.transformer.state_dict())
  43. if (
  44. self.transformer.config.is_encoder_decoder
  45. or self.transformer.config.model_type in ["xlnet", "t5"]
  46. ):
  47. self.use_inputs_embeds = True
  48. self.extend_attention_mask = False
  49. elif self.transformer.config.model_type == "gpt2":
  50. self.use_inputs_embeds = True
  51. self.extend_attention_mask = True
  52. else:
  53. self.use_inputs_embeds = False
  54. self.extend_attention_mask = True
  55. self.linear_in = torch.nn.Linear(
  56. input_size, self.transformer.config.hidden_size
  57. )
  58. def forward(
  59. self, input: torch.Tensor, input_lengths: torch.Tensor
  60. ) -> Tuple[torch.Tensor, torch.Tensor]:
  61. """Forward."""
  62. input = self.linear_in(input)
  63. args = {"return_dict": True}
  64. mask = (~make_pad_mask(input_lengths)).to(input.device).float()
  65. if self.extend_attention_mask:
  66. args["attention_mask"] = _extend_attention_mask(mask)
  67. else:
  68. args["attention_mask"] = mask
  69. if self.use_inputs_embeds:
  70. args["inputs_embeds"] = input
  71. else:
  72. args["hidden_states"] = input
  73. if self.transformer.config.model_type == "mpnet":
  74. args["head_mask"] = [None for _ in self.transformer.layer]
  75. output = self.transformer(**args).last_hidden_state
  76. return output, input_lengths
  77. def reload_pretrained_parameters(self):
  78. self.transformer.load_state_dict(self.pretrained_params)
  79. logging.info("Pretrained Transformers model parameters reloaded!")
  80. def output_size(self) -> int:
  81. """Get the output size."""
  82. return self.transformer.config.hidden_size
  83. def _extend_attention_mask(mask: torch.Tensor) -> torch.Tensor:
  84. mask = mask[:, None, None, :]
  85. mask = (1.0 - mask) * -10000.0
  86. return mask