hugging_face_transformers_postencoder.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #!/usr/bin/env python3
  2. # 2021, University of Stuttgart; Pavel Denisov
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. """Hugging Face Transformers PostEncoder."""
  5. from funasr.modules.nets_utils import make_pad_mask
  6. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  7. from typeguard import check_argument_types
  8. from typing import Tuple
  9. import copy
  10. import logging
  11. import torch
  12. try:
  13. from transformers import AutoModel
  14. is_transformers_available = True
  15. except ImportError:
  16. is_transformers_available = False
  17. class HuggingFaceTransformersPostEncoder(AbsPostEncoder):
  18. """Hugging Face Transformers PostEncoder."""
  19. def __init__(
  20. self,
  21. input_size: int,
  22. model_name_or_path: str,
  23. ):
  24. """Initialize the module."""
  25. assert check_argument_types()
  26. super().__init__()
  27. if not is_transformers_available:
  28. raise ImportError(
  29. "`transformers` is not available. Please install it via `pip install"
  30. " transformers` or `cd /path/to/espnet/tools && . ./activate_python.sh"
  31. " && ./installers/install_transformers.sh`."
  32. )
  33. model = AutoModel.from_pretrained(model_name_or_path)
  34. if hasattr(model, "encoder"):
  35. self.transformer = model.encoder
  36. else:
  37. self.transformer = model
  38. if hasattr(self.transformer, "embed_tokens"):
  39. del self.transformer.embed_tokens
  40. if hasattr(self.transformer, "wte"):
  41. del self.transformer.wte
  42. if hasattr(self.transformer, "word_embedding"):
  43. del self.transformer.word_embedding
  44. self.pretrained_params = copy.deepcopy(self.transformer.state_dict())
  45. if (
  46. self.transformer.config.is_encoder_decoder
  47. or self.transformer.config.model_type in ["xlnet", "t5"]
  48. ):
  49. self.use_inputs_embeds = True
  50. self.extend_attention_mask = False
  51. elif self.transformer.config.model_type == "gpt2":
  52. self.use_inputs_embeds = True
  53. self.extend_attention_mask = True
  54. else:
  55. self.use_inputs_embeds = False
  56. self.extend_attention_mask = True
  57. self.linear_in = torch.nn.Linear(
  58. input_size, self.transformer.config.hidden_size
  59. )
  60. def forward(
  61. self, input: torch.Tensor, input_lengths: torch.Tensor
  62. ) -> Tuple[torch.Tensor, torch.Tensor]:
  63. """Forward."""
  64. input = self.linear_in(input)
  65. args = {"return_dict": True}
  66. mask = (~make_pad_mask(input_lengths)).to(input.device).float()
  67. if self.extend_attention_mask:
  68. args["attention_mask"] = _extend_attention_mask(mask)
  69. else:
  70. args["attention_mask"] = mask
  71. if self.use_inputs_embeds:
  72. args["inputs_embeds"] = input
  73. else:
  74. args["hidden_states"] = input
  75. if self.transformer.config.model_type == "mpnet":
  76. args["head_mask"] = [None for _ in self.transformer.layer]
  77. output = self.transformer(**args).last_hidden_state
  78. return output, input_lengths
  79. def reload_pretrained_parameters(self):
  80. self.transformer.load_state_dict(self.pretrained_params)
  81. logging.info("Pretrained Transformers model parameters reloaded!")
  82. def output_size(self) -> int:
  83. """Get the output size."""
  84. return self.transformer.config.hidden_size
  85. def _extend_attention_mask(mask: torch.Tensor) -> torch.Tensor:
  86. mask = mask[:, None, None, :]
  87. mask = (1.0 - mask) * -10000.0
  88. return mask