collate_fn.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from typing import Collection
  2. from typing import Dict
  3. from typing import List
  4. from typing import Tuple
  5. from typing import Union
  6. import numpy as np
  7. import torch
  8. from funasr.modules.nets_utils import pad_list
  9. class CommonCollateFn:
  10. """Functor class of common_collate_fn()"""
  11. def __init__(
  12. self,
  13. float_pad_value: Union[float, int] = 0.0,
  14. int_pad_value: int = -32768,
  15. not_sequence: Collection[str] = (),
  16. max_sample_size=None
  17. ):
  18. self.float_pad_value = float_pad_value
  19. self.int_pad_value = int_pad_value
  20. self.not_sequence = set(not_sequence)
  21. self.max_sample_size = max_sample_size
  22. def __repr__(self):
  23. return (
  24. f"{self.__class__}(float_pad_value={self.float_pad_value}, "
  25. f"int_pad_value={self.float_pad_value})"
  26. )
  27. def __call__(
  28. self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
  29. ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
  30. return common_collate_fn(
  31. data,
  32. float_pad_value=self.float_pad_value,
  33. int_pad_value=self.int_pad_value,
  34. not_sequence=self.not_sequence,
  35. )
  36. def common_collate_fn(
  37. data: Collection[Tuple[str, Dict[str, np.ndarray]]],
  38. float_pad_value: Union[float, int] = 0.0,
  39. int_pad_value: int = -32768,
  40. not_sequence: Collection[str] = (),
  41. ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
  42. """Concatenate ndarray-list to an array and convert to torch.Tensor.
  43. """
  44. uttids = [u for u, _ in data]
  45. data = [d for _, d in data]
  46. assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
  47. assert all(
  48. not k.endswith("_lengths") for k in data[0]
  49. ), f"*_lengths is reserved: {list(data[0])}"
  50. output = {}
  51. for key in data[0]:
  52. if data[0][key].dtype.kind == "i":
  53. pad_value = int_pad_value
  54. else:
  55. pad_value = float_pad_value
  56. array_list = [d[key] for d in data]
  57. tensor_list = [torch.from_numpy(a) for a in array_list]
  58. tensor = pad_list(tensor_list, pad_value)
  59. output[key] = tensor
  60. if key not in not_sequence:
  61. lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
  62. output[key + "_lengths"] = lens
  63. output = (uttids, output)
  64. return output
  65. def crop_to_max_size(feature, target_size):
  66. size = len(feature)
  67. diff = size - target_size
  68. if diff <= 0:
  69. return feature
  70. start = np.random.randint(0, diff + 1)
  71. end = size - diff + start
  72. return feature[start:end]