dataloader_fn.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import torch
  2. from funasr.datasets.dataset_jsonl import AudioDataset
  3. from funasr.datasets.data_sampler import BatchSampler
  4. from funasr.models.frontend.wav_frontend import WavFrontend
  5. from funasr.tokenizer.build_tokenizer import build_tokenizer
  6. from funasr.tokenizer.token_id_converter import TokenIDConverter
  7. collate_fn = None
  8. # collate_fn = collate_fn,
  9. jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
  10. frontend = WavFrontend()
  11. token_type = 'char'
  12. bpemodel = None
  13. delimiter = None
  14. space_symbol = "<space>"
  15. non_linguistic_symbols = None
  16. g2p_type = None
  17. tokenizer = build_tokenizer(
  18. token_type=token_type,
  19. bpemodel=bpemodel,
  20. delimiter=delimiter,
  21. space_symbol=space_symbol,
  22. non_linguistic_symbols=non_linguistic_symbols,
  23. g2p_type=g2p_type,
  24. )
  25. token_list = ""
  26. unk_symbol = "<unk>"
  27. token_id_converter = TokenIDConverter(
  28. token_list=token_list,
  29. unk_symbol=unk_symbol,
  30. )
  31. dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
  32. batch_sampler = BatchSampler(dataset)
  33. dataloader_tr = torch.utils.data.DataLoader(dataset,
  34. collate_fn=dataset.collator,
  35. batch_sampler=batch_sampler,
  36. shuffle=False,
  37. num_workers=0,
  38. pin_memory=True)
  39. print(len(dataset))
  40. for i in range(3):
  41. print(i)
  42. for data in dataloader_tr:
  43. print(len(data), data)
  44. # data_iter = iter(dataloader_tr)
  45. # data = next(data_iter)
  46. pass