mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf exp (#1624)
* sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune
This commit is contained in:
parent
149063ced4
commit
eaf9dda9e4
@ -5,13 +5,13 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoice",
|
||||
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
)
|
||||
|
||||
|
||||
input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/SenseVoice/aed_ser/asr_bgm.wav"
|
||||
input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
|
||||
DecodingOptions = {
|
||||
"task": ("ASR", "AED", "SER"),
|
||||
|
||||
69
examples/industrial_data_pretraining/sense_voice/finetune.sh
Normal file
69
examples/industrial_data_pretraining/sense_voice/finetune.sh
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
# which gpu to train or finetune
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
|
||||
# model_name from model_hub, or model_dir in local path
|
||||
|
||||
## option 1, download model automatically
|
||||
model_name_or_model_dir="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_name_or_model_dir="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope"
|
||||
## option 2, download model by git
|
||||
#local_path_root=${workspace}/modelscope_models
|
||||
#mkdir -p ${local_path_root}/${model_name_or_model_dir}
|
||||
#git clone https://www.modelscope.cn/${model_name_or_model_dir}.git ${local_path_root}/${model_name_or_model_dir}
|
||||
#model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
|
||||
|
||||
|
||||
# data dir, which contains: train.json, val.json
|
||||
data_dir="../../../data/list"
|
||||
|
||||
train_data="${data_dir}/train.jsonl"
|
||||
val_data="${data_dir}/val.jsonl"
|
||||
|
||||
# generate train.jsonl and val.jsonl from wav.scp and text.txt
|
||||
scp2jsonl \
|
||||
++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \
|
||||
++data_type_list='["source", "target"]' \
|
||||
++jsonl_file_out="${train_data}"
|
||||
|
||||
scp2jsonl \
|
||||
++scp_file_list='["../../../data/list/val_wav.scp", "../../../data/list/val_text.txt"]' \
|
||||
++data_type_list='["source", "target"]' \
|
||||
++jsonl_file_out="${val_data}"
|
||||
|
||||
|
||||
# exp output dir
|
||||
output_dir="./outputs"
|
||||
log_file="${output_dir}/log.txt"
|
||||
|
||||
|
||||
mkdir -p ${output_dir}
|
||||
echo "log_file: ${log_file}"
|
||||
|
||||
#torchrun \
|
||||
#--nnodes 1 \
|
||||
#--node_rank 0 \
|
||||
#--nproc_per_node ${gpu_num} \
|
||||
python \
|
||||
../../../funasr/bin/train.py \
|
||||
++model="${model_name_or_model_dir}" \
|
||||
++train_data_set_list="${train_data}" \
|
||||
++valid_data_set_list="${val_data}" \
|
||||
++dataset_conf.batch_size=500 \
|
||||
++dataset_conf.batch_type="token" \
|
||||
++dataset_conf.num_workers=0 \
|
||||
++train_conf.max_epoch=50 \
|
||||
++train_conf.log_interval=1 \
|
||||
++train_conf.resume=false \
|
||||
++train_conf.validate_interval=2000 \
|
||||
++train_conf.save_checkpoint_interval=2000 \
|
||||
++train_conf.keep_nbest_models=20 \
|
||||
++train_conf.avg_nbest_model=10 \
|
||||
++optim_conf.lr=0.0002 \
|
||||
++debug=true \
|
||||
++device="cpu" \
|
||||
++output_dir="${output_dir}" #&> ${log_file}
|
||||
@ -175,6 +175,8 @@ class AutoModel:
|
||||
kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
|
||||
kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
|
||||
vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
|
||||
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
else:
|
||||
vocab_size = -1
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
|
||||
@ -102,7 +102,7 @@ def main(**kwargs):
|
||||
if use_ddp:
|
||||
model = model.cuda(local_rank)
|
||||
model = DDP(model, device_ids=[local_rank],
|
||||
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
|
||||
find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
|
||||
elif use_fsdp:
|
||||
# model = FSDP(model).cuda(local_rank)
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
|
||||
for line in fin:
|
||||
data = json.loads(line.strip())
|
||||
if "text" in data: # for sft
|
||||
self.contents.append(data['text'])
|
||||
contents.append(data['text'])
|
||||
if "source" in data: # for speech lab pretrain
|
||||
prompt = data.get("prompt", "<ASR>")
|
||||
source = data["source"]
|
||||
@ -101,13 +101,20 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
|
||||
target_len = data.get("target_len", 0)
|
||||
if "aishell" in source:
|
||||
target = target.replace(" ", "")
|
||||
contents.append({"source": source,
|
||||
"prompt": prompt,
|
||||
"target": target,
|
||||
"source_len": source_len,
|
||||
"target_len": target_len,
|
||||
}
|
||||
)
|
||||
|
||||
contents_i = {"source": source,
|
||||
"prompt": prompt,
|
||||
"target": target,
|
||||
"source_len": source_len,
|
||||
"target_len": target_len,
|
||||
}
|
||||
text_language = data.get("text_language", None)
|
||||
if text_language is not None:
|
||||
contents_i["text_language"] = text_language
|
||||
audio_language = data.get("audio_language", None)
|
||||
if audio_language is not None:
|
||||
contents_i["audio_language"] = audio_language
|
||||
contents.append(contents_i)
|
||||
|
||||
self.contents = contents
|
||||
|
||||
|
||||
0
funasr/datasets/sense_voice_datasets/__init__.py
Normal file
0
funasr/datasets/sense_voice_datasets/__init__.py
Normal file
118
funasr/datasets/sense_voice_datasets/datasets.py
Normal file
118
funasr/datasets/sense_voice_datasets/datasets.py
Normal file
@ -0,0 +1,118 @@
|
||||
import torch
|
||||
import random
|
||||
|
||||
from funasr.register import tables
|
||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "SenseVoiceDataset")
|
||||
class SenseVoiceDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
SenseVoiceDataset
|
||||
"""
|
||||
def __init__(self,
|
||||
path,
|
||||
index_ds: str = None,
|
||||
frontend=None,
|
||||
tokenizer=None,
|
||||
int_pad_value: int = -1,
|
||||
float_pad_value: float = 0.0,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
index_ds_class = tables.index_ds_classes.get(index_ds)
|
||||
self.index_ds = index_ds_class(path, **kwargs)
|
||||
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
||||
if preprocessor_speech:
|
||||
preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
|
||||
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
|
||||
self.preprocessor_speech = preprocessor_speech
|
||||
preprocessor_text = kwargs.get("preprocessor_text", None)
|
||||
if preprocessor_text:
|
||||
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
|
||||
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
|
||||
self.preprocessor_text = preprocessor_text
|
||||
|
||||
self.frontend = frontend
|
||||
self.fs = 16000 if frontend is None else frontend.fs
|
||||
self.data_type = "sound"
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.int_pad_value = int_pad_value
|
||||
self.float_pad_value = float_pad_value
|
||||
self.sos = kwargs.get("sos", "<|startoftranscript|>")
|
||||
self.eos = kwargs.get("eos", "<|endoftext|>")
|
||||
|
||||
def get_source_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_source_len(item)
|
||||
|
||||
def get_target_len(self, index):
|
||||
item = self.index_ds[index]
|
||||
return self.index_ds.get_target_len(item)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index_ds)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.index_ds[index]
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
source = item["source"]
|
||||
data_src = load_audio_text_image_video(source, fs=self.fs)
|
||||
if self.preprocessor_speech:
|
||||
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
||||
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
|
||||
speech = speech.permute(0, 2, 1)
|
||||
target = item["target"]
|
||||
if self.preprocessor_text:
|
||||
target = self.preprocessor_text(target)
|
||||
|
||||
task = item.get("prompt", "<|ASR|>")
|
||||
text_language = item.get("text_language", "<|zh|>")
|
||||
|
||||
prompt = f"{self.sos}{task}{text_language}"
|
||||
prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
|
||||
prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
|
||||
|
||||
target_ids = self.tokenizer.encode(target, allowed_special="all")
|
||||
target_ids_len = len(target_ids) + 1 # [lid, text]
|
||||
|
||||
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
|
||||
|
||||
ids = prompt_ids + target_ids + eos
|
||||
ids_lengths = len(ids)
|
||||
|
||||
text = torch.tensor(ids, dtype=torch.int64)
|
||||
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
|
||||
|
||||
target_mask = [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1] # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
|
||||
target_mask = torch.tensor(target_mask, dtype=torch.float32)
|
||||
|
||||
return {"speech": speech[0, :, :],
|
||||
"speech_lengths": speech_lengths,
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
"target_mask": target_mask,
|
||||
}
|
||||
|
||||
|
||||
def collator(self, samples: list=None):
|
||||
outputs = {}
|
||||
for sample in samples:
|
||||
for key in sample.keys():
|
||||
if key not in outputs:
|
||||
outputs[key] = []
|
||||
outputs[key].append(sample[key])
|
||||
|
||||
for key, data_list in outputs.items():
|
||||
if isinstance(data_list[0], torch.Tensor):
|
||||
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
|
||||
|
||||
pad_value = self.int_pad_value
|
||||
else:
|
||||
pad_value = self.float_pad_value
|
||||
|
||||
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -50,8 +50,8 @@ class LabelSmoothingLoss(nn.Module):
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
x = x.contiguous().view(-1, self.size)
|
||||
target = target.contiguous().view(-1)
|
||||
with torch.no_grad():
|
||||
true_dist = x.clone()
|
||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
||||
|
||||
66
funasr/models/sense_voice/decoder.py
Normal file
66
funasr/models/sense_voice/decoder.py
Normal file
@ -0,0 +1,66 @@
|
||||
import copy
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
def sense_voice_decode_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
xa: torch.Tensor,
|
||||
kv_cache: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
# import pdb;pdb.set_trace()
|
||||
use_padmask = self.use_padmask
|
||||
hlens = kwargs.get("hlens", None)
|
||||
|
||||
ys_in_lens = kwargs.get("ys_in_lens", None)
|
||||
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
tgt, memory = x, xa
|
||||
tgt[tgt==-1] = 0
|
||||
tgt = (
|
||||
self.token_embedding(tgt)
|
||||
+ self.positional_embedding[offset : offset + tgt.size(1)]
|
||||
)
|
||||
# tgt = self.dropout(tgt)
|
||||
|
||||
x = tgt.to(memory.dtype)
|
||||
|
||||
if use_padmask and hlens is not None:
|
||||
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
|
||||
else:
|
||||
memory_mask = None
|
||||
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
|
||||
|
||||
|
||||
x = self.ln(x)
|
||||
x = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
|
||||
return x
|
||||
|
||||
67
funasr/models/sense_voice/encoder.py
Normal file
67
funasr/models/sense_voice/encoder.py
Normal file
@ -0,0 +1,67 @@
|
||||
import copy
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
def sense_voice_encode_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
**kwargs,
|
||||
):
|
||||
use_padmask = self.use_padmask
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
n_frames = x.size(1)
|
||||
max_pos = self.positional_embedding.size(0)
|
||||
max_pos = n_frames if n_frames < max_pos else max_pos
|
||||
x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
|
||||
|
||||
|
||||
if ilens is not None:
|
||||
if self.downsample_rate == 4:
|
||||
olens = (
|
||||
1
|
||||
+ (
|
||||
ilens
|
||||
- self.conv1.kernel_size[0]
|
||||
+ 2 * self.conv1.padding[0]
|
||||
)
|
||||
// self.conv1.stride[0]
|
||||
)
|
||||
else:
|
||||
olens = ilens
|
||||
olens = (
|
||||
1
|
||||
+ (
|
||||
olens
|
||||
- self.conv2.kernel_size[0]
|
||||
+ 2 * self.conv2.padding[0]
|
||||
)
|
||||
// self.conv2.stride[0]
|
||||
)
|
||||
olens = torch.clamp(olens, max=max_pos)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
if use_padmask and olens is not None:
|
||||
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x = block(x, mask=padding_mask, is_pad_mask=True)
|
||||
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if ilens is None:
|
||||
return x
|
||||
else:
|
||||
return x, olens
|
||||
@ -1,35 +1,158 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Iterable, Optional
|
||||
import types
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
from funasr.metrics.compute_acc import compute_accuracy
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from . import whisper_lib as whisper
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
|
||||
|
||||
@tables.register("model_classes", "SenseVoice")
|
||||
class SenseVoice(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
hub = kwargs.get("hub", "funasr")
|
||||
|
||||
|
||||
dims = kwargs.get("dims", {})
|
||||
dims = whisper.model.ModelDimensions(**dims)
|
||||
model = whisper.model.Whisper(dims=dims)
|
||||
|
||||
# encoder
|
||||
model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
|
||||
model.encoder.use_padmask = kwargs.get("use_padmask", True)
|
||||
from .encoder import sense_voice_encode_forward
|
||||
model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
|
||||
|
||||
# decoder
|
||||
model.decoder.use_padmask = kwargs.get("use_padmask", True)
|
||||
from .decoder import sense_voice_decode_forward
|
||||
model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder)
|
||||
|
||||
self.model = model
|
||||
|
||||
self.encoder_output_size = self.model.dims.n_audio_state
|
||||
|
||||
def forward(self, ):
|
||||
pass
|
||||
self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
|
||||
self.ignore_id = kwargs.get("ignore_id", -1)
|
||||
self.vocab_size = kwargs.get("vocab_size", -1)
|
||||
self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=self.vocab_size,
|
||||
padding_idx=self.ignore_id,
|
||||
smoothing=kwargs.get("lsm_weight", 0.0),
|
||||
normalize_length=self.length_normalized_loss,
|
||||
)
|
||||
|
||||
specaug = kwargs.get("specaug", None)
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**kwargs.get("specaug_conf", {}))
|
||||
self.specaug = specaug
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
target_mask = kwargs.get("target_mask", None)
|
||||
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if len(text_lengths.size()) > 1:
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
if self.activation_checkpoint:
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
|
||||
)
|
||||
loss = loss_att
|
||||
stats = {}
|
||||
stats["acc"] = acc_att
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
stats["batch_size"] = batch_size
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
|
||||
) :
|
||||
"""Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
ind: int
|
||||
"""
|
||||
with autocast(False):
|
||||
|
||||
# Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
||||
|
||||
|
||||
# Forward encoder
|
||||
encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
target_mask = kwargs.get("target_mask", None)
|
||||
stats = {}
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out = self.model.decoder(
|
||||
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
mask = torch.ones_like(ys_pad) * (-1)
|
||||
ys_pad_mask = (ys_pad * target_mask + mask * (1-target_mask)).to(torch.int64)
|
||||
ys_pad_mask[ys_pad_mask == 0] = -1
|
||||
loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
|
||||
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(decoder_out, -1)
|
||||
acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id)
|
||||
|
||||
return loss_att, acc_att, None, None
|
||||
|
||||
|
||||
def inference(self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
|
||||
@ -74,7 +74,10 @@ class MultiHeadAttention(nn.Module):
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
@ -87,12 +90,13 @@ class MultiHeadAttention(nn.Module):
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
|
||||
):
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
@ -101,10 +105,20 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
if not is_pad_mask:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
else:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
|
||||
)
|
||||
qk = qk.masked_fill(mask, min_value)
|
||||
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
if mask is not None and is_pad_mask:
|
||||
w = w.masked_fill(mask, 0.0)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
@ -132,10 +146,13 @@ class ResidualAttentionBlock(nn.Module):
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
is_pad_mask = kwargs.get("is_pad_mask", False)
|
||||
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
@ -22,3 +22,25 @@ def WhisperTokenizer(**kwargs):
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@tables.register("tokenizer_classes", "SenseVoiceTokenizer")
|
||||
def SenseVoiceTokenizer(**kwargs):
|
||||
try:
|
||||
from funasr.models.sense_voice.whisper_lib.tokenizer import get_tokenizer
|
||||
except:
|
||||
print("Notice: If you want to use whisper, please `pip install -U openai-whisper`")
|
||||
|
||||
language = kwargs.get("language", None)
|
||||
task = kwargs.get("task", None)
|
||||
is_multilingual = kwargs.get("is_multilingual", True)
|
||||
num_languages = kwargs.get("num_languages", 8749)
|
||||
vocab_path = kwargs.get("vocab_path", None)
|
||||
tokenizer = get_tokenizer(
|
||||
multilingual=is_multilingual,
|
||||
num_languages=num_languages,
|
||||
language=language,
|
||||
task=task,
|
||||
vocab_path=vocab_path,
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user