游雁 1 年之前
父節點
當前提交
cb8b09e085
共有 2 個文件被更改,包括 10 次插入5 次删除
  1. 4 3
      funasr/datasets/audio_datasets/preprocessor.py
  2. 6 2
      funasr/train_utils/trainer.py

+ 4 - 3
funasr/datasets/audio_datasets/preprocessor.py

@@ -26,9 +26,10 @@ class SpeechPreprocessSpeedPerturb(nn.Module):
 			return waveform
 		speed = random.choice(self.speed_perturb)
 		if speed != 1.0:
-			waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
-				torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
-			waveform = waveform.view(-1)
+			with torch.no_grad():
+				waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+					torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
+				waveform = waveform.view(-1)
 			
 		return waveform
 

+ 6 - 2
funasr/train_utils/trainer.py

@@ -273,8 +273,9 @@ class Trainer:
                 speed_stats["total_time"] = total_time
 
 
-            pbar.update(1)
+            
             if self.local_rank == 0:
+                pbar.update(1)
                 gpu_info = "GPU, memory: {:.3f} GB, " \
                            "{:.3f} GB, "\
                            "{:.3f} GB, "\
@@ -290,6 +291,7 @@ class Trainer:
                     f"(loss: {loss.detach().cpu().item():.3f}), "
                     f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                     f"{gpu_info}"
+                    f"rank: {self.local_rank}"
                 )
                 pbar.set_description(description)
                 if self.writer:
@@ -344,14 +346,16 @@ class Trainer:
                 loss = loss
                 time4 = time.perf_counter()
 
-                pbar.update(1)
+                
                 if self.local_rank == 0:
+                    pbar.update(1)
                     description = (
                         f"validation epoch: {epoch}/{self.max_epoch}, "
                         f"step {batch_idx}/{len(self.dataloader_train)}, "
                         f"{speed_stats}, "
                         f"(loss: {loss.detach().cpu().item():.3f}), "
                         f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+                        f"rank: {self.local_rank}"
                     )
                     pbar.set_description(description)
                     if self.writer: