collate_fn.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from typing import Collection
  2. from typing import Dict
  3. from typing import List
  4. from typing import Tuple
  5. from typing import Union
  6. import numpy as np
  7. import torch
  8. from typeguard import check_argument_types
  9. from typeguard import check_return_type
  10. from funasr.modules.nets_utils import pad_list
  11. class CommonCollateFn:
  12. """Functor class of common_collate_fn()"""
  13. def __init__(
  14. self,
  15. float_pad_value: Union[float, int] = 0.0,
  16. int_pad_value: int = -32768,
  17. not_sequence: Collection[str] = (),
  18. max_sample_size=None
  19. ):
  20. assert check_argument_types()
  21. self.float_pad_value = float_pad_value
  22. self.int_pad_value = int_pad_value
  23. self.not_sequence = set(not_sequence)
  24. self.max_sample_size = max_sample_size
  25. def __repr__(self):
  26. return (
  27. f"{self.__class__}(float_pad_value={self.float_pad_value}, "
  28. f"int_pad_value={self.float_pad_value})"
  29. )
  30. def __call__(
  31. self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
  32. ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
  33. return common_collate_fn(
  34. data,
  35. float_pad_value=self.float_pad_value,
  36. int_pad_value=self.int_pad_value,
  37. not_sequence=self.not_sequence,
  38. )
  39. def common_collate_fn(
  40. data: Collection[Tuple[str, Dict[str, np.ndarray]]],
  41. float_pad_value: Union[float, int] = 0.0,
  42. int_pad_value: int = -32768,
  43. not_sequence: Collection[str] = (),
  44. ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
  45. """Concatenate ndarray-list to an array and convert to torch.Tensor.
  46. """
  47. assert check_argument_types()
  48. uttids = [u for u, _ in data]
  49. data = [d for _, d in data]
  50. assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
  51. assert all(
  52. not k.endswith("_lengths") for k in data[0]
  53. ), f"*_lengths is reserved: {list(data[0])}"
  54. output = {}
  55. for key in data[0]:
  56. if data[0][key].dtype.kind == "i":
  57. pad_value = int_pad_value
  58. else:
  59. pad_value = float_pad_value
  60. array_list = [d[key] for d in data]
  61. tensor_list = [torch.from_numpy(a) for a in array_list]
  62. tensor = pad_list(tensor_list, pad_value)
  63. output[key] = tensor
  64. if key not in not_sequence:
  65. lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
  66. output[key + "_lengths"] = lens
  67. output = (uttids, output)
  68. assert check_return_type(output)
  69. return output
  70. def crop_to_max_size(feature, target_size):
  71. size = len(feature)
  72. diff = size - target_size
  73. if diff <= 0:
  74. return feature
  75. start = np.random.randint(0, diff + 1)
  76. end = size - diff + start
  77. return feature[start:end]
  78. def clipping_collate_fn(
  79. data: Collection[Tuple[str, Dict[str, np.ndarray]]],
  80. max_sample_size=None,
  81. not_sequence: Collection[str] = (),
  82. ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
  83. # mainly for pre-training
  84. assert check_argument_types()
  85. uttids = [u for u, _ in data]
  86. data = [d for _, d in data]
  87. assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
  88. assert all(
  89. not k.endswith("_lengths") for k in data[0]
  90. ), f"*_lengths is reserved: {list(data[0])}"
  91. output = {}
  92. for key in data[0]:
  93. array_list = [d[key] for d in data]
  94. tensor_list = [torch.from_numpy(a) for a in array_list]
  95. sizes = [len(s) for s in tensor_list]
  96. if max_sample_size is None:
  97. target_size = min(sizes)
  98. else:
  99. target_size = min(min(sizes), max_sample_size)
  100. tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
  101. for i, (source, size) in enumerate(zip(tensor_list, sizes)):
  102. diff = size - target_size
  103. if diff == 0:
  104. tensor[i] = source
  105. else:
  106. tensor[i] = crop_to_max_size(source, target_size)
  107. output[key] = tensor
  108. if key not in not_sequence:
  109. lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
  110. output[key + "_lengths"] = lens
  111. output = (uttids, output)
  112. assert check_return_type(output)
  113. return output