compute_acc.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. def th_accuracy(pad_outputs, pad_targets, ignore_label):
  3. """Calculate accuracy.
  4. Args:
  5. pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
  6. pad_targets (LongTensor): Target label tensors (B, Lmax, D).
  7. ignore_label (int): Ignore label id.
  8. Returns:
  9. float: Accuracy value (0.0 - 1.0).
  10. """
  11. pad_pred = pad_outputs.view(
  12. pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
  13. ).argmax(2)
  14. mask = pad_targets != ignore_label
  15. numerator = torch.sum(
  16. pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
  17. )
  18. denominator = torch.sum(mask)
  19. return float(numerator) / float(denominator)
  20. def compute_accuracy(pad_outputs, pad_targets, ignore_label):
  21. """Calculate accuracy.
  22. Args:
  23. pad_outputs (LongTensor): Prediction tensors (B, Lmax).
  24. pad_targets (LongTensor): Target label tensors (B, Lmax).
  25. ignore_label (int): Ignore label id.
  26. Returns:
  27. float: Accuracy value (0.0 - 1.0).
  28. """
  29. mask = pad_targets != ignore_label
  30. numerator = torch.sum(
  31. pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
  32. )
  33. denominator = torch.sum(mask)
  34. return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type