浏览代码

Dev gzf (#1421)

* fixbug

* qwenaudio
zhifu gao 1 年之前
父节点
当前提交
44a6b59468
共有 3 个文件被更改,包括 131 次插入0 次删除
  1. 0 0
      funasr/models/qwen_audio/__init__.py
  2. 85 0
      funasr/models/qwen_audio/model.py
  3. 46 0
      funasr/models/qwen_audio/template.yaml

+ 0 - 0
funasr/models/qwen_audio/__init__.py


+ 85 - 0
funasr/models/qwen_audio/model.py

@@ -0,0 +1,85 @@
+from dataclasses import dataclass
+from typing import Dict
+from typing import Iterable, Optional
+import time
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch import nn
+import whisper
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+from funasr.register import tables
+
+
+
+@tables.register("model_classes", "WhisperWarp")
+class WhisperWarp(nn.Module):
+    def __init__(self, whisper_dims: dict, **kwargs):
+        super().__init__()
+        hub = kwargs.get("hub", "funasr")
+        if hub == "openai":
+            init_param_path = kwargs.get("init_param_path", "large-v3")
+            model = whisper.load_model(init_param_path)
+        else:
+            dims = whisper.model.ModelDimensions(**whisper_dims)
+            model = whisper.model.Whisper(dims=dims)
+        
+        self.model = model
+        
+    def forward(self, ):
+        pass
+    
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  **kwargs,
+                  ):
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
+            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])[0, :, :]
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        # detect the spoken language
+        _, probs = self.model.detect_language(speech)
+        print(f"Detected language: {max(probs, key=probs.get)}")
+
+        # decode the audio
+        options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False)
+        result = whisper.decode(self.model, speech, options)
+
+        results = []
+        result_i = {"key": key[0], "text": result.text}
+
+        results.append(result_i)
+    
+        return results, meta_data
+    

+ 46 - 0
funasr/models/qwen_audio/template.yaml

@@ -0,0 +1,46 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: WhisperWarp
+model_conf:
+    lsm_weight: 0.1
+    length_normalized_loss: true
+    hub: funasr # openai
+    init_param_path: null # large-v2 or large-v3 if hub == "openai"
+
+
+
+# only use for hub == funasr,
+#  if hub == openai, whisper_dims is automaticall download
+whisper_dims:
+    'n_mels': 80
+    'n_vocab': 51865
+    'n_audio_ctx': 1500
+    'n_audio_state': 1280
+    'n_audio_head': 20
+    'n_audio_layer': 32
+    'n_text_ctx': 448
+    'n_text_state': 1280
+    'n_text_head': 20
+    'n_text_layer': 32
+
+# frontend related
+frontend: WhisperFrontend
+frontend_conf:
+    fs: 16000
+    n_mels: 80
+    do_pad_trim: true
+
+tokenizer: WhisperTokenizer
+tokenizer_conf:
+  language: null
+  task: transcribe
+  is_multilingual: true
+  num_languages: 99
+
+scope_map: ['none', "model."]