collate_fn.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from typing import Collection
  2. from typing import Dict
  3. from typing import List
  4. from typing import Tuple
  5. from typing import Union
  6. import numpy as np
  7. import torch
  8. from typeguard import check_argument_types
  9. from typeguard import check_return_type
  10. from funasr.modules.nets_utils import pad_list
  11. class CommonCollateFn:
  12. """Functor class of common_collate_fn()"""
  13. def __init__(
  14. self,
  15. float_pad_value: Union[float, int] = 0.0,
  16. int_pad_value: int = -32768,
  17. not_sequence: Collection[str] = (),
  18. max_sample_size=None
  19. ):
  20. assert check_argument_types()
  21. self.float_pad_value = float_pad_value
  22. self.int_pad_value = int_pad_value
  23. self.not_sequence = set(not_sequence)
  24. self.max_sample_size = max_sample_size
  25. def __repr__(self):
  26. return (
  27. f"{self.__class__}(float_pad_value={self.float_pad_value}, "
  28. f"int_pad_value={self.float_pad_value})"
  29. )
  30. def __call__(
  31. self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
  32. ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
  33. return common_collate_fn(
  34. data,
  35. float_pad_value=self.float_pad_value,
  36. int_pad_value=self.int_pad_value,
  37. not_sequence=self.not_sequence,
  38. )
  39. def common_collate_fn(
  40. data: Collection[Tuple[str, Dict[str, np.ndarray]]],
  41. float_pad_value: Union[float, int] = 0.0,
  42. int_pad_value: int = -32768,
  43. not_sequence: Collection[str] = (),
  44. ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
  45. """Concatenate ndarray-list to an array and convert to torch.Tensor.
  46. """
  47. assert check_argument_types()
  48. uttids = [u for u, _ in data]
  49. data = [d for _, d in data]
  50. assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
  51. assert all(
  52. not k.endswith("_lengths") for k in data[0]
  53. ), f"*_lengths is reserved: {list(data[0])}"
  54. output = {}
  55. for key in data[0]:
  56. if data[0][key].dtype.kind == "i":
  57. pad_value = int_pad_value
  58. else:
  59. pad_value = float_pad_value
  60. array_list = [d[key] for d in data]
  61. tensor_list = [torch.from_numpy(a) for a in array_list]
  62. tensor = pad_list(tensor_list, pad_value)
  63. output[key] = tensor
  64. if key not in not_sequence:
  65. lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
  66. output[key + "_lengths"] = lens
  67. output = (uttids, output)
  68. assert check_return_type(output)
  69. return output