build_streaming_iterator.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import numpy as np
  2. from torch.utils.data import DataLoader
  3. from typeguard import check_argument_types
  4. from funasr.datasets.iterable_dataset import IterableESPnetDataset
  5. from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
  6. from funasr.datasets.small_datasets.preprocessor import build_preprocess
  7. def build_streaming_iterator(
  8. task_name,
  9. preprocess_args,
  10. data_path_and_name_and_type,
  11. key_file: str = None,
  12. batch_size: int = 1,
  13. fs: dict = None,
  14. mc: bool = False,
  15. dtype: str = np.float32,
  16. num_workers: int = 1,
  17. use_collate_fn: bool = True,
  18. preprocess_fn=None,
  19. ngpu: int = 0,
  20. train: bool = False,
  21. ) -> DataLoader:
  22. """Build DataLoader using iterable dataset"""
  23. assert check_argument_types()
  24. # preprocess
  25. if preprocess_fn is not None:
  26. preprocess_fn = preprocess_fn
  27. elif preprocess_args is not None:
  28. preprocess_args.task_name = task_name
  29. preprocess_fn = build_preprocess(preprocess_args, train)
  30. else:
  31. preprocess_fn = None
  32. # collate
  33. if not use_collate_fn:
  34. collate_fn = None
  35. elif task_name in ["punc", "lm"]:
  36. collate_fn = CommonCollateFn(int_pad_value=0)
  37. else:
  38. collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  39. if collate_fn is not None:
  40. kwargs = dict(collate_fn=collate_fn)
  41. else:
  42. kwargs = {}
  43. dataset = IterableESPnetDataset(
  44. data_path_and_name_and_type,
  45. float_dtype=dtype,
  46. fs=fs,
  47. mc=mc,
  48. preprocess=preprocess_fn,
  49. key_file=key_file,
  50. )
  51. if dataset.apply_utt2category:
  52. kwargs.update(batch_size=1)
  53. else:
  54. kwargs.update(batch_size=batch_size)
  55. return DataLoader(
  56. dataset=dataset,
  57. pin_memory=ngpu > 0,
  58. num_workers=num_workers,
  59. **kwargs,
  60. )