From 11586f7ebdd353659059a6fbebbd3e2ecbba7fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 2 Aug 2024 11:20:07 +0800 Subject: [PATCH] update --- funasr/models/sense_voice/model_small.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/funasr/models/sense_voice/model_small.py b/funasr/models/sense_voice/model_small.py index 2a40f4ccf..1f572444e 100644 --- a/funasr/models/sense_voice/model_small.py +++ b/funasr/models/sense_voice/model_small.py @@ -1,3 +1,4 @@ +import logging import time import torch from torch import Tensor, nn @@ -1321,6 +1322,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module): left_padding = left_padding + kwargs.get("sanm_shfit", 0) right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + self.dropout = torch.nn.Dropout(kwargs.get("dropout_rate", 0.0)) def fsmn(self, inputs, mask): b, t, d = inputs.size() @@ -1332,7 +1334,7 @@ class MultiHeadAttentionFSMNRoPE(nn.Module): x = self.pad_fn(x) x = self.fsmn_block(x) x = x.transpose(1, 2) + inputs - # x = self.dropout(x) + x = self.dropout(x) if mask is not None: x = x * mask return x @@ -1417,18 +1419,19 @@ class MultiHeadAttentionFSMNSdpaRoPE(nn.Module): left_padding = left_padding + kwargs.get("sanm_shfit", 0) right_padding = kwargs.get("kernel_size", 15) - 1 - left_padding self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + self.dropout = torch.nn.Dropout(kwargs.get("dropout_rate", 0.0)) def fsmn(self, inputs, mask): - b, t, d = inputs.size() + b, t, d = inputs.size() # b, t, d if mask is not None: - mask = torch.reshape(mask, (b, -1, 1)) + mask = torch.reshape(mask, (b, -1, 1)) # b, t, 1 inputs = inputs * mask x = inputs.transpose(1, 2) x = self.pad_fn(x) x = self.fsmn_block(x) x = x.transpose(1, 2) + inputs - # x = self.dropout(x) + x = self.dropout(x) if mask is not None: x = x * mask return x @@ -1615,6 +1618,7 @@ from funasr.train_utils.device_funcs import force_gatherable from . import whisper_lib as whisper from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank from funasr.utils.datadir_writer import DatadirWriter +import logging @tables.register("model_classes", "SenseVoiceL") @@ -1627,6 +1631,13 @@ class SenseVoiceL(nn.Module): encoder_class = tables.encoder_classes.get(encoder) encoder = encoder_class(**encoder_conf) + if encoder_conf.get("freeze", False): + freeze_exclude_key = encoder_conf.get("freeze_exclude_key", "fsmn_block") + for name, param in encoder.named_parameters(): + if not freeze_exclude_key in name: + logging.info(f"name: {name} is freeze") + param.requires_grad = False + dims = kwargs.get("dims", {}) dims = whisper.model.ModelDimensions(**dims) model = whisper.model.Whisper(dims=dims)