kaldi_data.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
  2. # Licensed under the MIT license.
  3. #
  4. # This library provides utilities for kaldi-style data directory.
  5. from __future__ import print_function
  6. import os
  7. import sys
  8. import numpy as np
  9. import subprocess
  10. import librosa as sf
  11. import io
  12. from functools import lru_cache
  13. def load_segments(segments_file):
  14. """ load segments file as array """
  15. if not os.path.exists(segments_file):
  16. return None
  17. return np.loadtxt(
  18. segments_file,
  19. dtype=[('utt', 'object'),
  20. ('rec', 'object'),
  21. ('st', 'f'),
  22. ('et', 'f')],
  23. ndmin=1)
  24. def load_segments_hash(segments_file):
  25. ret = {}
  26. if not os.path.exists(segments_file):
  27. return None
  28. for line in open(segments_file):
  29. utt, rec, st, et = line.strip().split()
  30. ret[utt] = (rec, float(st), float(et))
  31. return ret
  32. def load_segments_rechash(segments_file):
  33. ret = {}
  34. if not os.path.exists(segments_file):
  35. return None
  36. for line in open(segments_file):
  37. utt, rec, st, et = line.strip().split()
  38. if rec not in ret:
  39. ret[rec] = []
  40. ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
  41. return ret
  42. def load_wav_scp(wav_scp_file):
  43. """ return dictionary { rec: wav_rxfilename } """
  44. lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
  45. return {x[0]: x[1] for x in lines}
  46. @lru_cache(maxsize=1)
  47. def load_wav(wav_rxfilename, start=0, end=None):
  48. """ This function reads audio file and return data in numpy.float32 array.
  49. "lru_cache" holds recently loaded audio so that can be called
  50. many times on the same audio file.
  51. OPTIMIZE: controls lru_cache size for random access,
  52. considering memory size
  53. """
  54. if wav_rxfilename.endswith('|'):
  55. # input piped command
  56. p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
  57. stdout=subprocess.PIPE)
  58. data, samplerate = sf.load(io.BytesIO(p.stdout.read()),
  59. dtype='float32')
  60. # cannot seek
  61. data = data[start:end]
  62. elif wav_rxfilename == '-':
  63. # stdin
  64. data, samplerate = sf.load(sys.stdin, dtype='float32')
  65. # cannot seek
  66. data = data[start:end]
  67. else:
  68. # normal wav file
  69. data, samplerate = sf.load(wav_rxfilename, start=start, stop=end)
  70. return data, samplerate
  71. def load_utt2spk(utt2spk_file):
  72. """ returns dictionary { uttid: spkid } """
  73. lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
  74. return {x[0]: x[1] for x in lines}
  75. def load_spk2utt(spk2utt_file):
  76. """ returns dictionary { spkid: list of uttids } """
  77. if not os.path.exists(spk2utt_file):
  78. return None
  79. lines = [line.strip().split() for line in open(spk2utt_file)]
  80. return {x[0]: x[1:] for x in lines}
  81. def load_reco2dur(reco2dur_file):
  82. """ returns dictionary { recid: duration } """
  83. if not os.path.exists(reco2dur_file):
  84. return None
  85. lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
  86. return {x[0]: float(x[1]) for x in lines}
  87. def process_wav(wav_rxfilename, process):
  88. """ This function returns preprocessed wav_rxfilename
  89. Args:
  90. wav_rxfilename: input
  91. process: command which can be connected via pipe,
  92. use stdin and stdout
  93. Returns:
  94. wav_rxfilename: output piped command
  95. """
  96. if wav_rxfilename.endswith('|'):
  97. # input piped command
  98. return wav_rxfilename + process + "|"
  99. else:
  100. # stdin "-" or normal file
  101. return "cat {} | {} |".format(wav_rxfilename, process)
  102. def extract_segments(wavs, segments=None):
  103. """ This function returns generator of segmented audio as
  104. (utterance id, numpy.float32 array)
  105. TODO?: sampling rate is not converted.
  106. """
  107. if segments is not None:
  108. # segments should be sorted by rec-id
  109. for seg in segments:
  110. wav = wavs[seg['rec']]
  111. data, samplerate = load_wav(wav)
  112. st_sample = np.rint(seg['st'] * samplerate).astype(int)
  113. et_sample = np.rint(seg['et'] * samplerate).astype(int)
  114. yield seg['utt'], data[st_sample:et_sample]
  115. else:
  116. # segments file not found,
  117. # wav.scp is used as segmented audio list
  118. for rec in wavs:
  119. data, samplerate = load_wav(wavs[rec])
  120. yield rec, data
  121. class KaldiData:
  122. def __init__(self, data_dir):
  123. self.data_dir = data_dir
  124. self.segments = load_segments_rechash(
  125. os.path.join(self.data_dir, 'segments'))
  126. self.utt2spk = load_utt2spk(
  127. os.path.join(self.data_dir, 'utt2spk'))
  128. self.wavs = load_wav_scp(
  129. os.path.join(self.data_dir, 'wav.scp'))
  130. self.reco2dur = load_reco2dur(
  131. os.path.join(self.data_dir, 'reco2dur'))
  132. self.spk2utt = load_spk2utt(
  133. os.path.join(self.data_dir, 'spk2utt'))
  134. def load_wav(self, recid, start=0, end=None):
  135. data, rate = load_wav(
  136. self.wavs[recid], start, end)
  137. return data, rate