rand_gen_dataset.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import collections
  2. from pathlib import Path
  3. from typing import Union
  4. import numpy as np
  5. from funasr.fileio.read_text import load_num_sequence_text
  6. class FloatRandomGenerateDataset(collections.abc.Mapping):
  7. """Generate float array from shape.txt.
  8. Examples:
  9. shape.txt
  10. uttA 123,83
  11. uttB 34,83
  12. >>> dataset = FloatRandomGenerateDataset("shape.txt")
  13. >>> array = dataset["uttA"]
  14. >>> assert array.shape == (123, 83)
  15. >>> array = dataset["uttB"]
  16. >>> assert array.shape == (34, 83)
  17. """
  18. def __init__(
  19. self,
  20. shape_file: Union[Path, str],
  21. dtype: Union[str, np.dtype] = "float32",
  22. loader_type: str = "csv_int",
  23. ):
  24. shape_file = Path(shape_file)
  25. self.utt2shape = load_num_sequence_text(shape_file, loader_type)
  26. self.dtype = np.dtype(dtype)
  27. def __iter__(self):
  28. return iter(self.utt2shape)
  29. def __len__(self):
  30. return len(self.utt2shape)
  31. def __getitem__(self, item) -> np.ndarray:
  32. shape = self.utt2shape[item]
  33. return np.random.randn(*shape).astype(self.dtype)
  34. class IntRandomGenerateDataset(collections.abc.Mapping):
  35. """Generate float array from shape.txt
  36. Examples:
  37. shape.txt
  38. uttA 123,83
  39. uttB 34,83
  40. >>> dataset = IntRandomGenerateDataset("shape.txt", low=0, high=10)
  41. >>> array = dataset["uttA"]
  42. >>> assert array.shape == (123, 83)
  43. >>> array = dataset["uttB"]
  44. >>> assert array.shape == (34, 83)
  45. """
  46. def __init__(
  47. self,
  48. shape_file: Union[Path, str],
  49. low: int,
  50. high: int = None,
  51. dtype: Union[str, np.dtype] = "int64",
  52. loader_type: str = "csv_int",
  53. ):
  54. shape_file = Path(shape_file)
  55. self.utt2shape = load_num_sequence_text(shape_file, loader_type)
  56. self.dtype = np.dtype(dtype)
  57. self.low = low
  58. self.high = high
  59. def __iter__(self):
  60. return iter(self.utt2shape)
  61. def __len__(self):
  62. return len(self.utt2shape)
  63. def __getitem__(self, item) -> np.ndarray:
  64. shape = self.utt2shape[item]
  65. return np.random.randint(self.low, self.high, size=shape, dtype=self.dtype)