mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * bugfix * update with main (#1631) * update seaco finetune * v1.0.24 --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * update with main (#1638) * update seaco finetune * v1.0.24 * update rwkv template --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * whisper * whisper * update style * update style --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import torch
|
|
|
|
|
|
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
|
"""Calculate accuracy.
|
|
|
|
Args:
|
|
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
|
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
|
ignore_label (int): Ignore label id.
|
|
|
|
Returns:
|
|
float: Accuracy value (0.0 - 1.0).
|
|
|
|
"""
|
|
pad_pred = pad_outputs.view(
|
|
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
|
).argmax(2)
|
|
mask = pad_targets != ignore_label
|
|
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
|
denominator = torch.sum(mask)
|
|
return float(numerator) / float(denominator)
|
|
|
|
|
|
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
|
|
"""Calculate accuracy.
|
|
|
|
Args:
|
|
pad_outputs (LongTensor): Prediction tensors (B, Lmax).
|
|
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
|
ignore_label (int): Ignore label id.
|
|
|
|
Returns:
|
|
float: Accuracy value (0.0 - 1.0).
|
|
|
|
"""
|
|
mask = pad_targets != ignore_label
|
|
numerator = torch.sum(pad_outputs.masked_select(mask) == pad_targets.masked_select(mask))
|
|
denominator = torch.sum(mask)
|
|
return numerator.float() / denominator.float() # (FIX:MZY):return torch.Tensor type
|