| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- from collections import defaultdict
- import logging
- from pathlib import Path
- from typing import Dict
- from typing import Iterable
- from typing import List
- from typing import Optional
- from typing import Tuple
- import numpy as np
- import torch
- from torch.nn.parallel import data_parallel
- from torch.utils.data import DataLoader
- from typeguard import check_argument_types
- from funasr.fileio.datadir_writer import DatadirWriter
- from funasr.fileio.npy_scp import NpyScpWriter
- from funasr.torch_utils.device_funcs import to_device
- from funasr.torch_utils.forward_adaptor import ForwardAdaptor
- from funasr.train.abs_espnet_model import AbsESPnetModel
- @torch.no_grad()
- def collect_stats(
- model: AbsESPnetModel,
- train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
- valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
- output_dir: Path,
- ngpu: Optional[int],
- log_interval: Optional[int],
- write_collected_feats: bool,
- ) -> None:
- """Perform on collect_stats mode.
- Running for deriving the shape information from data
- and gathering statistics.
- This method is used before executing train().
- """
- assert check_argument_types()
- npy_scp_writers = {}
- for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
- if log_interval is None:
- try:
- log_interval = max(len(itr) // 20, 10)
- except TypeError:
- log_interval = 100
- sum_dict = defaultdict(lambda: 0)
- sq_dict = defaultdict(lambda: 0)
- count_dict = defaultdict(lambda: 0)
- with DatadirWriter(output_dir / mode) as datadir_writer:
- for iiter, (keys, batch) in enumerate(itr, 1):
- batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- # 1. Write shape file
- for name in batch:
- if name.endswith("_lengths"):
- continue
- for i, (key, data) in enumerate(zip(keys, batch[name])):
- if f"{name}_lengths" in batch:
- lg = int(batch[f"{name}_lengths"][i])
- data = data[:lg]
- datadir_writer[f"{name}_shape"][key] = ",".join(
- map(str, data.shape)
- )
- # 2. Extract feats
- if ngpu <= 1:
- data = model.collect_feats(**batch)
- else:
- # Note that data_parallel can parallelize only "forward()"
- data = data_parallel(
- ForwardAdaptor(model, "collect_feats"),
- (),
- range(ngpu),
- module_kwargs=batch,
- )
- # 3. Calculate sum and square sum
- for key, v in data.items():
- for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
- # Truncate zero-padding region
- if f"{key}_lengths" in data:
- length = data[f"{key}_lengths"][i]
- # seq: (Length, Dim, ...)
- seq = seq[:length]
- else:
- # seq: (Dim, ...) -> (1, Dim, ...)
- seq = seq[None]
- # Accumulate value, its square, and count
- sum_dict[key] += seq.sum(0)
- sq_dict[key] += (seq**2).sum(0)
- count_dict[key] += len(seq)
- # 4. [Option] Write derived features as npy format file.
- if write_collected_feats:
- # Instantiate NpyScpWriter for the first iteration
- if (key, mode) not in npy_scp_writers:
- p = output_dir / mode / "collect_feats"
- npy_scp_writers[(key, mode)] = NpyScpWriter(
- p / f"data_{key}", p / f"{key}.scp"
- )
- # Save array as npy file
- npy_scp_writers[(key, mode)][uttid] = seq
- if iiter % log_interval == 0:
- logging.info(f"Niter: {iiter}")
- for key in sum_dict:
- np.savez(
- output_dir / mode / f"{key}_stats.npz",
- count=count_dict[key],
- sum=sum_dict[key],
- sum_square=sq_dict[key],
- )
- # batch_keys and stats_keys are used by aggregate_stats_dirs.py
- with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
- f.write(
- "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
- )
- with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
- f.write("\n".join(sum_dict) + "\n")
|