utils.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. import torch
  3. from torch.nn import functional as F
  4. import yaml
  5. import numpy as np
  6. def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
  7. if maxlen is None:
  8. maxlen = lengths.max()
  9. row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
  10. matrix = torch.unsqueeze(lengths, dim=-1)
  11. mask = row_vector < matrix
  12. mask = mask.detach()
  13. return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
  14. def apply_cmvn(inputs, mvn):
  15. device = inputs.device
  16. dtype = inputs.dtype
  17. frame, dim = inputs.shape
  18. meams = np.tile(mvn[0:1, :dim], (frame, 1))
  19. vars = np.tile(mvn[1:2, :dim], (frame, 1))
  20. inputs -= torch.from_numpy(meams).type(dtype).to(device)
  21. inputs *= torch.from_numpy(vars).type(dtype).to(device)
  22. return inputs.type(torch.float32)
  23. def drop_and_add(inputs: torch.Tensor,
  24. outputs: torch.Tensor,
  25. training: bool,
  26. dropout_rate: float = 0.1,
  27. stoch_layer_coeff: float = 1.0):
  28. outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
  29. outputs *= stoch_layer_coeff
  30. input_dim = inputs.size(-1)
  31. output_dim = outputs.size(-1)
  32. if input_dim == output_dim:
  33. outputs += inputs
  34. return outputs
  35. def proc_tf_vocab(vocab_path):
  36. with open(vocab_path, encoding="utf-8") as f:
  37. token_list = [line.rstrip() for line in f]
  38. if '<unk>' not in token_list:
  39. token_list.append('<unk>')
  40. return token_list
  41. def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
  42. token_list = proc_tf_vocab(vocab_path)
  43. with open(config_path, encoding="utf-8") as f:
  44. config = yaml.safe_load(f)
  45. config['token_list'] = token_list
  46. if not os.path.exists(output_dir):
  47. os.makedirs(output_dir)
  48. with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
  49. yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
  50. class NoAliasSafeDumper(yaml.SafeDumper):
  51. # Disable anchor/alias in yaml because looks ugly
  52. def ignore_aliases(self, data):
  53. return True
  54. def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
  55. """Safe-dump in yaml with no anchor/alias"""
  56. return yaml.dump(
  57. data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
  58. )
  59. if __name__ == '__main__':
  60. import sys
  61. config_path = sys.argv[1]
  62. vocab_path = sys.argv[2]
  63. output_dir = sys.argv[3]
  64. gen_config_for_tfmodel(config_path, vocab_path, output_dir)