mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * bugfix * update with main (#1631) * update seaco finetune * v1.0.24 --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * update with main (#1638) * update seaco finetune * v1.0.24 * update rwkv template --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sensevoice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * sense voice * whisper * whisper * update style * update style --------- Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
305 lines
10 KiB
Python
305 lines
10 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional, Tuple, List
|
|
import numpy as np
|
|
|
|
|
|
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
|
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
|
|
|
|
|
class SamePad(nn.Module):
|
|
def __init__(self, kernel_size, causal=False):
|
|
super().__init__()
|
|
if causal:
|
|
self.remove = kernel_size - 1
|
|
else:
|
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
|
|
|
def forward(self, x):
|
|
if self.remove > 0:
|
|
x = x[:, :, : -self.remove]
|
|
return x
|
|
|
|
|
|
class TransposeLast(nn.Module):
|
|
def __init__(self, deconstruct_idx=None):
|
|
super().__init__()
|
|
self.deconstruct_idx = deconstruct_idx
|
|
|
|
def forward(self, x):
|
|
if self.deconstruct_idx is not None:
|
|
x = x[self.deconstruct_idx]
|
|
return x.transpose(-2, -1)
|
|
|
|
|
|
class Fp32LayerNorm(nn.LayerNorm):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def forward(self, input):
|
|
output = F.layer_norm(
|
|
input.float(),
|
|
self.normalized_shape,
|
|
self.weight.float() if self.weight is not None else None,
|
|
self.bias.float() if self.bias is not None else None,
|
|
self.eps,
|
|
)
|
|
return output.type_as(input)
|
|
|
|
|
|
class Fp32GroupNorm(nn.GroupNorm):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def forward(self, input):
|
|
output = F.group_norm(
|
|
input.float(),
|
|
self.num_groups,
|
|
self.weight.float() if self.weight is not None else None,
|
|
self.bias.float() if self.bias is not None else None,
|
|
self.eps,
|
|
)
|
|
return output.type_as(input)
|
|
|
|
|
|
class ConvFeatureExtractionModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
conv_layers: List[Tuple[int, int, int]],
|
|
dropout: float = 0.0,
|
|
mode: str = "default",
|
|
conv_bias: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
assert mode in {"default", "layer_norm"}
|
|
|
|
def block(
|
|
n_in,
|
|
n_out,
|
|
k,
|
|
stride,
|
|
is_layer_norm=False,
|
|
is_group_norm=False,
|
|
conv_bias=False,
|
|
):
|
|
def make_conv():
|
|
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
|
nn.init.kaiming_normal_(conv.weight)
|
|
return conv
|
|
|
|
assert (
|
|
is_layer_norm and is_group_norm
|
|
) == False, "layer norm and group norm are exclusive"
|
|
|
|
if is_layer_norm:
|
|
return nn.Sequential(
|
|
make_conv(),
|
|
nn.Dropout(p=dropout),
|
|
nn.Sequential(
|
|
TransposeLast(),
|
|
Fp32LayerNorm(dim, elementwise_affine=True),
|
|
TransposeLast(),
|
|
),
|
|
nn.GELU(),
|
|
)
|
|
elif is_group_norm:
|
|
return nn.Sequential(
|
|
make_conv(),
|
|
nn.Dropout(p=dropout),
|
|
Fp32GroupNorm(dim, dim, affine=True),
|
|
nn.GELU(),
|
|
)
|
|
else:
|
|
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
|
|
|
in_d = 1
|
|
self.conv_layers = nn.ModuleList()
|
|
for i, cl in enumerate(conv_layers):
|
|
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
|
(dim, k, stride) = cl
|
|
|
|
self.conv_layers.append(
|
|
block(
|
|
in_d,
|
|
dim,
|
|
k,
|
|
stride,
|
|
is_layer_norm=mode == "layer_norm",
|
|
is_group_norm=mode == "default" and i == 0,
|
|
conv_bias=conv_bias,
|
|
)
|
|
)
|
|
in_d = dim
|
|
|
|
def forward(self, x):
|
|
|
|
# BxT -> BxCxT
|
|
x = x.unsqueeze(1)
|
|
|
|
for conv in self.conv_layers:
|
|
x = conv(x)
|
|
|
|
return x
|
|
|
|
|
|
def compute_mask_indices(
|
|
shape: Tuple[int, int],
|
|
padding_mask: Optional[torch.Tensor],
|
|
mask_prob: float,
|
|
mask_length: int,
|
|
mask_type: str = "static",
|
|
mask_other: float = 0.0,
|
|
min_masks: int = 0,
|
|
no_overlap: bool = False,
|
|
min_space: int = 0,
|
|
require_same_masks: bool = True,
|
|
mask_dropout: float = 0.0,
|
|
) -> np.ndarray:
|
|
"""
|
|
Computes random mask spans for a given shape
|
|
|
|
Args:
|
|
shape: the the shape for which to compute masks.
|
|
should be of size 2 where first element is batch size and 2nd is timesteps
|
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
|
mask_type: how to compute mask lengths
|
|
static = fixed size
|
|
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
|
poisson = sample from possion distribution with lambda = mask length
|
|
min_masks: minimum number of masked spans
|
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
|
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
|
mask_dropout: randomly dropout this percentage of masks in each example
|
|
"""
|
|
|
|
bsz, all_sz = shape
|
|
mask = np.full((bsz, all_sz), False)
|
|
|
|
all_num_mask = int(
|
|
# add a random number for probabilistic rounding
|
|
mask_prob * all_sz / float(mask_length)
|
|
+ np.random.rand()
|
|
)
|
|
|
|
all_num_mask = max(min_masks, all_num_mask)
|
|
|
|
mask_idcs = []
|
|
for i in range(bsz):
|
|
if padding_mask is not None:
|
|
sz = all_sz - padding_mask[i].long().sum().item()
|
|
num_mask = int(
|
|
# add a random number for probabilistic rounding
|
|
mask_prob * sz / float(mask_length)
|
|
+ np.random.rand()
|
|
)
|
|
num_mask = max(min_masks, num_mask)
|
|
else:
|
|
sz = all_sz
|
|
num_mask = all_num_mask
|
|
|
|
if mask_type == "static":
|
|
lengths = np.full(num_mask, mask_length)
|
|
elif mask_type == "uniform":
|
|
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
|
elif mask_type == "normal":
|
|
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
|
lengths = [max(1, int(round(x))) for x in lengths]
|
|
elif mask_type == "poisson":
|
|
lengths = np.random.poisson(mask_length, size=num_mask)
|
|
lengths = [int(round(x)) for x in lengths]
|
|
else:
|
|
raise Exception("unknown mask selection " + mask_type)
|
|
|
|
if sum(lengths) == 0:
|
|
lengths[0] = min(mask_length, sz - 1)
|
|
|
|
if no_overlap:
|
|
mask_idc = []
|
|
|
|
def arrange(s, e, length, keep_length):
|
|
span_start = np.random.randint(s, e - length)
|
|
mask_idc.extend(span_start + i for i in range(length))
|
|
|
|
new_parts = []
|
|
if span_start - s - min_space >= keep_length:
|
|
new_parts.append((s, span_start - min_space + 1))
|
|
if e - span_start - length - min_space > keep_length:
|
|
new_parts.append((span_start + length + min_space, e))
|
|
return new_parts
|
|
|
|
parts = [(0, sz)]
|
|
min_length = min(lengths)
|
|
for length in sorted(lengths, reverse=True):
|
|
lens = np.fromiter(
|
|
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
|
np.int,
|
|
)
|
|
l_sum = np.sum(lens)
|
|
if l_sum == 0:
|
|
break
|
|
probs = lens / np.sum(lens)
|
|
c = np.random.choice(len(parts), p=probs)
|
|
s, e = parts.pop(c)
|
|
parts.extend(arrange(s, e, length, min_length))
|
|
mask_idc = np.asarray(mask_idc)
|
|
else:
|
|
min_len = min(lengths)
|
|
if sz - min_len <= num_mask:
|
|
min_len = sz - num_mask - 1
|
|
|
|
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
|
|
|
mask_idc = np.asarray(
|
|
[mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
|
|
)
|
|
|
|
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
|
|
|
min_len = min([len(m) for m in mask_idcs])
|
|
for i, mask_idc in enumerate(mask_idcs):
|
|
if len(mask_idc) > min_len and require_same_masks:
|
|
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
|
if mask_dropout > 0:
|
|
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
|
|
mask_idc = np.random.choice(mask_idc, len(mask_idc) - num_holes, replace=False)
|
|
|
|
mask[i, mask_idc] = True
|
|
|
|
return mask
|
|
|
|
|
|
class GradMultiply(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, scale):
|
|
ctx.scale = scale
|
|
res = x.new(x)
|
|
return res
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad * ctx.scale, None
|
|
|
|
|
|
def is_xla_tensor(tensor):
|
|
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
|
|
|
|
|
def index_put(tensor, indices, value):
|
|
if is_xla_tensor(tensor):
|
|
for _ in range(indices.dim(), tensor.dim()):
|
|
indices = indices.unsqueeze(-1)
|
|
if indices.size(-1) < tensor.size(-1):
|
|
indices = indices.expand_as(tensor)
|
|
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
|
else:
|
|
tensor[indices] = value
|
|
return tensor
|