Эх сурвалжийг харах

Merge pull request #260 from alibaba-damo-academy/tmp

update audio type check
zhifu gao 3 жил өмнө
parent
commit
41eb0b4b5e

+ 6 - 12
funasr/datasets/iterable_dataset.py

@@ -228,13 +228,9 @@ class IterableESPnetDataset(IterableDataset):
                 name = self.path_name_type_list[i][1]
                 _type = self.path_name_type_list[i][2]
                 if _type == "sound":
-                    audio_type = os.path.basename(value).split(".")[-1].lower()
-                    if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
-                        raise NotImplementedError(
-                            f'Not supported audio type: {audio_type}')
-                    if audio_type == "pcm":
-                        _type = "pcm"
-
+                   audio_type = os.path.basename(value).lower()
+                   if audio_type.rfind(".pcm") >= 0:
+                       _type = "pcm"
                 func = DATA_TYPES[_type]
                 array = func(value)
                 if self.fs is not None and (name == "speech" or name == "ref_speech"):
@@ -336,11 +332,8 @@ class IterableESPnetDataset(IterableDataset):
                 # 2.a. Load data streamingly
                 for value, (path, name, _type) in zip(values, self.path_name_type_list):
                     if _type == "sound":
-                        audio_type = os.path.basename(value).split(".")[-1].lower()
-                        if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
-                            raise NotImplementedError(
-                                f'Not supported audio type: {audio_type}')
-                        if audio_type == "pcm":
+                        audio_type = os.path.basename(value).lower()
+                        if audio_type.rfind(".pcm") >= 0:
                             _type = "pcm"
                     func = DATA_TYPES[_type]
                     # Load entry
@@ -392,3 +385,4 @@ class IterableESPnetDataset(IterableDataset):
 
         if count == 0:
             raise RuntimeError("No iteration")
+

+ 27 - 25
funasr/utils/asr_utils.py

@@ -58,14 +58,15 @@ def type_checking(audio_in: Union[str, bytes],
     if r_recog_type is None and audio_in is not None:
         # audio_in is wav, recog_type is wav_file
         if os.path.isfile(audio_in):
-            audio_type = os.path.basename(audio_in).split(".")[-1].lower()
-            if audio_type in SUPPORT_AUDIO_TYPE_SETS:
-                r_recog_type = 'wav'
-                r_audio_format = 'wav'
-            elif audio_type == "scp":
+            audio_type = os.path.basename(audio_in).lower()
+            for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
+                if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
+                    r_recog_type = 'wav'
+                    r_audio_format = 'wav'
+            if audio_type.rfind(".scp") >= 0:
                 r_recog_type = 'wav'
                 r_audio_format = 'scp'
-            else:
+            if r_recog_type is None:
                 raise NotImplementedError(
                     f'Not supported audio type: {audio_type}')
 
@@ -128,13 +129,15 @@ def get_sr_from_bytes(wav: bytes):
 def get_sr_from_wav(fname: str):
     fs = None
     if os.path.isfile(fname):
-        audio_type = os.path.basename(fname).split(".")[-1].lower()
-        if audio_type in SUPPORT_AUDIO_TYPE_SETS:
-            if audio_type == "pcm":
-                fs = None
-            else:
-                audio, fs = torchaudio.load(fname)
-        elif audio_type == "scp":
+        audio_type = os.path.basename(fname).lower()
+        for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
+            if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
+                if support_audio_type == "pcm":
+                    fs = None
+                else:
+                    audio, fs = torchaudio.load(fname)
+                break
+        if audio_type.rfind(".scp") >= 0:
             with open(fname, encoding="utf-8") as f:
                 for line in f:
                     wav_path = line.split()[1]
@@ -147,9 +150,7 @@ def get_sr_from_wav(fname: str):
         for file in dir_files:
             file_path = os.path.join(fname, file)
             if os.path.isfile(file_path):
-                audio_type = os.path.basename(file_path).split(".")[-1].lower()
-                if audio_type in SUPPORT_AUDIO_TYPE_SETS:
-                    fs = get_sr_from_wav(file_path)
+                fs = get_sr_from_wav(file_path)
             elif os.path.isdir(file_path):
                 fs = get_sr_from_wav(file_path)
 
@@ -165,12 +166,12 @@ def find_file_by_ends(dir_path: str, ends: str):
         file_path = os.path.join(dir_path, file)
         if os.path.isfile(file_path):
             if ends == ".wav" or ends == ".WAV":
-                audio_type = os.path.basename(file_path).split(".")[-1].lower()
-                if audio_type in SUPPORT_AUDIO_TYPE_SETS:
-                    return True
-                else:
-                    raise NotImplementedError(
-                        f'Not supported audio type: {audio_type}')
+                audio_type = os.path.basename(file_path).lower()
+                for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
+                    if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
+                        return True
+                raise NotImplementedError(
+                    f'Not supported audio type: {audio_type}')
             elif file_path.endswith(ends):
                 return True
         elif os.path.isdir(file_path):
@@ -185,9 +186,10 @@ def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
     for file in dir_files:
         file_path = os.path.join(dir_path, file)
         if os.path.isfile(file_path):
-            audio_type = os.path.basename(file_path).split(".")[-1].lower()
-            if audio_type in SUPPORT_AUDIO_TYPE_SETS:
-                wav_list.append(file_path)
+            audio_type = os.path.basename(file_path).lower()
+            for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
+                if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
+                    wav_list.append(file_path)
         elif os.path.isdir(file_path):
             recursion_dir_all_wav(wav_list, file_path)