build_streaming_iterator.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import numpy as np
  2. from torch.utils.data import DataLoader
  3. from funasr.datasets.iterable_dataset import IterableESPnetDataset
  4. from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
  5. from funasr.datasets.small_datasets.preprocessor import build_preprocess
  6. def build_streaming_iterator(
  7. task_name,
  8. preprocess_args,
  9. data_path_and_name_and_type,
  10. key_file: str = None,
  11. batch_size: int = 1,
  12. fs: dict = None,
  13. mc: bool = False,
  14. dtype: str = np.float32,
  15. num_workers: int = 1,
  16. use_collate_fn: bool = True,
  17. preprocess_fn=None,
  18. ngpu: int = 0,
  19. train: bool = False,
  20. ) -> DataLoader:
  21. """Build DataLoader using iterable dataset"""
  22. # preprocess
  23. if preprocess_fn is not None:
  24. preprocess_fn = preprocess_fn
  25. elif preprocess_args is not None:
  26. preprocess_args.task_name = task_name
  27. preprocess_fn = build_preprocess(preprocess_args, train)
  28. else:
  29. preprocess_fn = None
  30. # collate
  31. if not use_collate_fn:
  32. collate_fn = None
  33. elif task_name in ["punc", "lm"]:
  34. collate_fn = CommonCollateFn(int_pad_value=0)
  35. else:
  36. collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  37. if collate_fn is not None:
  38. kwargs = dict(collate_fn=collate_fn)
  39. else:
  40. kwargs = {}
  41. dataset = IterableESPnetDataset(
  42. data_path_and_name_and_type,
  43. float_dtype=dtype,
  44. fs=fs,
  45. mc=mc,
  46. preprocess=preprocess_fn,
  47. key_file=key_file,
  48. )
  49. if dataset.apply_utt2category:
  50. kwargs.update(batch_size=1)
  51. else:
  52. kwargs.update(batch_size=batch_size)
  53. return DataLoader(
  54. dataset=dataset,
  55. pin_memory=ngpu > 0,
  56. num_workers=num_workers,
  57. **kwargs,
  58. )