| 123456789101112131415161718192021222324252627282930313233343536373839404142 |
- import torch
- def th_accuracy(pad_outputs, pad_targets, ignore_label):
- """Calculate accuracy.
- Args:
- pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
- pad_targets (LongTensor): Target label tensors (B, Lmax, D).
- ignore_label (int): Ignore label id.
- Returns:
- float: Accuracy value (0.0 - 1.0).
- """
- pad_pred = pad_outputs.view(
- pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
- ).argmax(2)
- mask = pad_targets != ignore_label
- numerator = torch.sum(
- pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
- )
- denominator = torch.sum(mask)
- return float(numerator) / float(denominator)
- def compute_accuracy(pad_outputs, pad_targets, ignore_label):
- """Calculate accuracy.
- Args:
- pad_outputs (LongTensor): Prediction tensors (B, Lmax).
- pad_targets (LongTensor): Target label tensors (B, Lmax).
- ignore_label (int): Ignore label id.
- Returns:
- float: Accuracy value (0.0 - 1.0).
- """
- mask = pad_targets != ignore_label
- numerator = torch.sum(
- pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
- )
- denominator = torch.sum(mask)
- return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
|