|
|
@@ -27,7 +27,7 @@ def get_version():
|
|
|
def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
|
|
|
r_audio_fs = None
|
|
|
|
|
|
- if audio_format == 'wav':
|
|
|
+ if audio_format == 'wav' or audio_format == 'scp':
|
|
|
r_audio_fs = get_sr_from_wav(audio_in)
|
|
|
elif audio_format == 'pcm' and isinstance(audio_in, bytes):
|
|
|
r_audio_fs = get_sr_from_bytes(audio_in)
|
|
|
@@ -134,6 +134,13 @@ def get_sr_from_wav(fname: str):
|
|
|
fs = None
|
|
|
else:
|
|
|
audio, fs = torchaudio.load(fname)
|
|
|
+ elif audio_type == "scp":
|
|
|
+ with open(fname, encoding="utf-8") as f:
|
|
|
+ for line in f:
|
|
|
+ wav_path = line.split()[1]
|
|
|
+ fs = get_sr_from_wav(wav_path)
|
|
|
+ if fs is not None:
|
|
|
+ break
|
|
|
return fs
|
|
|
elif os.path.isdir(fname):
|
|
|
dir_files = os.listdir(fname)
|