mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * bugfix * update with main (#1631) * update seaco finetune * v1.0.24 --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * update with main (#1638) * update seaco finetune * v1.0.24 * update rwkv template --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * whisper * whisper * update style * update style --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
"""Torch utility module."""
|
|
|
|
import torch
|
|
|
|
if torch.distributed.is_available():
|
|
from torch.distributed import ReduceOp
|
|
|
|
|
|
def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
|
|
assert weight.dim() == 1, weight.size()
|
|
if isinstance(obj, (tuple, list)):
|
|
return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
|
|
elif isinstance(obj, dict):
|
|
return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
|
|
elif isinstance(obj, torch.Tensor):
|
|
assert obj.size() == weight.size(), (obj.size(), weight.size())
|
|
obj = (obj * weight.type(obj.dtype)).sum()
|
|
if distributed:
|
|
torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
|
|
return obj
|
|
elif obj is None:
|
|
return None
|
|
else:
|
|
raise ValueError(type(obj))
|
|
|
|
|
|
def recursive_divide(a, b: torch.Tensor):
|
|
if isinstance(a, (tuple, list)):
|
|
return type(a)(recursive_divide(v, b) for v in a)
|
|
elif isinstance(a, dict):
|
|
return {k: recursive_divide(v, b) for k, v in a.items()}
|
|
elif isinstance(a, torch.Tensor):
|
|
assert a.size() == b.size(), (a.size(), b.size())
|
|
return a / b.type(a.dtype)
|
|
elif a is None:
|
|
return None
|
|
else:
|
|
raise ValueError(type(a))
|
|
|
|
|
|
def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
|
|
obj = recursive_sum(obj, weight, distributed)
|
|
weight = weight.sum()
|
|
if distributed:
|
|
torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
|
|
# Normalize weight to be sum-to-1
|
|
obj = recursive_divide(obj, weight)
|
|
return obj, weight
|