This commit is contained in:
游雁 2024-08-02 11:20:07 +08:00
parent 5dd4495406
commit 11586f7ebd

View File

@ -1,3 +1,4 @@
import logging
import time
import torch
from torch import Tensor, nn
@ -1321,6 +1322,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
left_padding = left_padding + kwargs.get("sanm_shfit", 0)
right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.dropout = torch.nn.Dropout(kwargs.get("dropout_rate", 0.0))
def fsmn(self, inputs, mask):
b, t, d = inputs.size()
@ -1332,7 +1334,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2) + inputs
# x = self.dropout(x)
x = self.dropout(x)
if mask is not None:
x = x * mask
return x
@ -1417,18 +1419,19 @@ class MultiHeadAttentionFSMNSdpaRoPE(nn.Module):
left_padding = left_padding + kwargs.get("sanm_shfit", 0)
right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.dropout = torch.nn.Dropout(kwargs.get("dropout_rate", 0.0))
def fsmn(self, inputs, mask):
b, t, d = inputs.size()
b, t, d = inputs.size() # b, t, d
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
mask = torch.reshape(mask, (b, -1, 1)) # b, t, 1
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2) + inputs
# x = self.dropout(x)
x = self.dropout(x)
if mask is not None:
x = x * mask
return x
@ -1615,6 +1618,7 @@ from funasr.train_utils.device_funcs import force_gatherable
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
import logging
@tables.register("model_classes", "SenseVoiceL")
@ -1627,6 +1631,13 @@ class SenseVoiceL(nn.Module):
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
if encoder_conf.get("freeze", False):
freeze_exclude_key = encoder_conf.get("freeze_exclude_key", "fsmn_block")
for name, param in encoder.named_parameters():
if not freeze_exclude_key in name:
logging.info(f"name: {name} is freeze")
param.requires_grad = False
dims = kwargs.get("dims", {})
dims = whisper.model.ModelDimensions(**dims)
model = whisper.model.Whisper(dims=dims)