mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
5dd4495406
commit
11586f7ebd
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -1321,6 +1322,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
|
||||
left_padding = left_padding + kwargs.get("sanm_shfit", 0)
|
||||
right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.dropout = torch.nn.Dropout(kwargs.get("dropout_rate", 0.0))
|
||||
|
||||
def fsmn(self, inputs, mask):
|
||||
b, t, d = inputs.size()
|
||||
@ -1332,7 +1334,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module):
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2) + inputs
|
||||
# x = self.dropout(x)
|
||||
x = self.dropout(x)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
return x
|
||||
@ -1417,18 +1419,19 @@ class MultiHeadAttentionFSMNSdpaRoPE(nn.Module):
|
||||
left_padding = left_padding + kwargs.get("sanm_shfit", 0)
|
||||
right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.dropout = torch.nn.Dropout(kwargs.get("dropout_rate", 0.0))
|
||||
|
||||
def fsmn(self, inputs, mask):
|
||||
b, t, d = inputs.size()
|
||||
b, t, d = inputs.size() # b, t, d
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
mask = torch.reshape(mask, (b, -1, 1)) # b, t, 1
|
||||
inputs = inputs * mask
|
||||
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2) + inputs
|
||||
# x = self.dropout(x)
|
||||
x = self.dropout(x)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
return x
|
||||
@ -1615,6 +1618,7 @@ from funasr.train_utils.device_funcs import force_gatherable
|
||||
from . import whisper_lib as whisper
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
import logging
|
||||
|
||||
|
||||
@tables.register("model_classes", "SenseVoiceL")
|
||||
@ -1627,6 +1631,13 @@ class SenseVoiceL(nn.Module):
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(**encoder_conf)
|
||||
|
||||
if encoder_conf.get("freeze", False):
|
||||
freeze_exclude_key = encoder_conf.get("freeze_exclude_key", "fsmn_block")
|
||||
for name, param in encoder.named_parameters():
|
||||
if not freeze_exclude_key in name:
|
||||
logging.info(f"name: {name} is freeze")
|
||||
param.requires_grad = False
|
||||
|
||||
dims = kwargs.get("dims", {})
|
||||
dims = whisper.model.ModelDimensions(**dims)
|
||||
model = whisper.model.Whisper(dims=dims)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user