From eaf9dda9e4d970af3d09db695e9e10c83ef94e25 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Wed, 17 Apr 2024 15:05:37 +0800 Subject: [PATCH] Dev gzf exp (#1624) * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune * sensevoice finetune --- .../sense_voice/demo.py | 4 +- .../sense_voice/finetune.sh | 69 +++++++++ funasr/auto/auto_model.py | 2 + funasr/bin/train.py | 2 +- funasr/datasets/audio_datasets/index_ds.py | 23 +-- .../datasets/sense_voice_datasets/__init__.py | 0 .../datasets/sense_voice_datasets/datasets.py | 118 ++++++++++++++++ funasr/losses/label_smoothing_loss.py | 4 +- funasr/models/sense_voice/decoder.py | 66 +++++++++ funasr/models/sense_voice/encoder.py | 67 +++++++++ funasr/models/sense_voice/model.py | 131 +++++++++++++++++- .../models/sense_voice/whisper_lib/model.py | 27 +++- funasr/tokenizer/whisper_tokenizer.py | 22 +++ 13 files changed, 513 insertions(+), 22 deletions(-) create mode 100644 examples/industrial_data_pretraining/sense_voice/finetune.sh create mode 100644 funasr/datasets/sense_voice_datasets/__init__.py create mode 100644 funasr/datasets/sense_voice_datasets/datasets.py create mode 100644 funasr/models/sense_voice/decoder.py create mode 100644 funasr/models/sense_voice/encoder.py diff --git a/examples/industrial_data_pretraining/sense_voice/demo.py b/examples/industrial_data_pretraining/sense_voice/demo.py index b2fca4749..0d8ef9756 100644 --- a/examples/industrial_data_pretraining/sense_voice/demo.py +++ b/examples/industrial_data_pretraining/sense_voice/demo.py @@ -5,13 +5,13 @@ from funasr import AutoModel -model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoice", +model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope", vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", vad_kwargs={"max_single_segment_time": 30000}, ) -input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/SenseVoice/aed_ser/asr_bgm.wav" +input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" DecodingOptions = { "task": ("ASR", "AED", "SER"), diff --git a/examples/industrial_data_pretraining/sense_voice/finetune.sh b/examples/industrial_data_pretraining/sense_voice/finetune.sh new file mode 100644 index 000000000..cb079014e --- /dev/null +++ b/examples/industrial_data_pretraining/sense_voice/finetune.sh @@ -0,0 +1,69 @@ +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + + +# which gpu to train or finetune +export CUDA_VISIBLE_DEVICES="0" +gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + +# model_name from model_hub, or model_dir in local path + +## option 1, download model automatically +model_name_or_model_dir="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +model_name_or_model_dir="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope" +## option 2, download model by git +#local_path_root=${workspace}/modelscope_models +#mkdir -p ${local_path_root}/${model_name_or_model_dir} +#git clone https://www.modelscope.cn/${model_name_or_model_dir}.git ${local_path_root}/${model_name_or_model_dir} +#model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir} + + +# data dir, which contains: train.json, val.json +data_dir="../../../data/list" + +train_data="${data_dir}/train.jsonl" +val_data="${data_dir}/val.jsonl" + +# generate train.jsonl and val.jsonl from wav.scp and text.txt +scp2jsonl \ +++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \ +++data_type_list='["source", "target"]' \ +++jsonl_file_out="${train_data}" + +scp2jsonl \ +++scp_file_list='["../../../data/list/val_wav.scp", "../../../data/list/val_text.txt"]' \ +++data_type_list='["source", "target"]' \ +++jsonl_file_out="${val_data}" + + +# exp output dir +output_dir="./outputs" +log_file="${output_dir}/log.txt" + + +mkdir -p ${output_dir} +echo "log_file: ${log_file}" + +#torchrun \ +#--nnodes 1 \ +#--node_rank 0 \ +#--nproc_per_node ${gpu_num} \ +python \ +../../../funasr/bin/train.py \ +++model="${model_name_or_model_dir}" \ +++train_data_set_list="${train_data}" \ +++valid_data_set_list="${val_data}" \ +++dataset_conf.batch_size=500 \ +++dataset_conf.batch_type="token" \ +++dataset_conf.num_workers=0 \ +++train_conf.max_epoch=50 \ +++train_conf.log_interval=1 \ +++train_conf.resume=false \ +++train_conf.validate_interval=2000 \ +++train_conf.save_checkpoint_interval=2000 \ +++train_conf.keep_nbest_models=20 \ +++train_conf.avg_nbest_model=10 \ +++optim_conf.lr=0.0002 \ +++debug=true \ +++device="cpu" \ +++output_dir="${output_dir}" #&> ${log_file} \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 630c390f7..d173a533c 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -175,6 +175,8 @@ class AutoModel: kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 + if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): + vocab_size = tokenizer.get_vocab_size() else: vocab_size = -1 kwargs["tokenizer"] = tokenizer diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 880bb63b3..353ce6813 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -102,7 +102,7 @@ def main(**kwargs): if use_ddp: model = model.cuda(local_rank) model = DDP(model, device_ids=[local_rank], - find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) + find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True)) elif use_fsdp: # model = FSDP(model).cuda(local_rank) diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index 34f7b4fa6..5396c8a07 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -92,7 +92,7 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset): for line in fin: data = json.loads(line.strip()) if "text" in data: # for sft - self.contents.append(data['text']) + contents.append(data['text']) if "source" in data: # for speech lab pretrain prompt = data.get("prompt", "") source = data["source"] @@ -101,13 +101,20 @@ class IndexDSJsonlRankFull(torch.utils.data.Dataset): target_len = data.get("target_len", 0) if "aishell" in source: target = target.replace(" ", "") - contents.append({"source": source, - "prompt": prompt, - "target": target, - "source_len": source_len, - "target_len": target_len, - } - ) + + contents_i = {"source": source, + "prompt": prompt, + "target": target, + "source_len": source_len, + "target_len": target_len, + } + text_language = data.get("text_language", None) + if text_language is not None: + contents_i["text_language"] = text_language + audio_language = data.get("audio_language", None) + if audio_language is not None: + contents_i["audio_language"] = audio_language + contents.append(contents_i) self.contents = contents diff --git a/funasr/datasets/sense_voice_datasets/__init__.py b/funasr/datasets/sense_voice_datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py new file mode 100644 index 000000000..956cf79ab --- /dev/null +++ b/funasr/datasets/sense_voice_datasets/datasets.py @@ -0,0 +1,118 @@ +import torch +import random + +from funasr.register import tables +from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video + + +@tables.register("dataset_classes", "SenseVoiceDataset") +class SenseVoiceDataset(torch.utils.data.Dataset): + """ + SenseVoiceDataset + """ + def __init__(self, + path, + index_ds: str = None, + frontend=None, + tokenizer=None, + int_pad_value: int = -1, + float_pad_value: float = 0.0, + **kwargs): + super().__init__() + index_ds_class = tables.index_ds_classes.get(index_ds) + self.index_ds = index_ds_class(path, **kwargs) + preprocessor_speech = kwargs.get("preprocessor_speech", None) + if preprocessor_speech: + preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) + preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) + self.preprocessor_speech = preprocessor_speech + preprocessor_text = kwargs.get("preprocessor_text", None) + if preprocessor_text: + preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) + preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) + self.preprocessor_text = preprocessor_text + + self.frontend = frontend + self.fs = 16000 if frontend is None else frontend.fs + self.data_type = "sound" + self.tokenizer = tokenizer + + self.int_pad_value = int_pad_value + self.float_pad_value = float_pad_value + self.sos = kwargs.get("sos", "<|startoftranscript|>") + self.eos = kwargs.get("eos", "<|endoftext|>") + + def get_source_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_source_len(item) + + def get_target_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_target_len(item) + + def __len__(self): + return len(self.index_ds) + + def __getitem__(self, index): + item = self.index_ds[index] + # import pdb; + # pdb.set_trace() + source = item["source"] + data_src = load_audio_text_image_video(source, fs=self.fs) + if self.preprocessor_speech: + data_src = self.preprocessor_speech(data_src, fs=self.fs) + speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d] + speech = speech.permute(0, 2, 1) + target = item["target"] + if self.preprocessor_text: + target = self.preprocessor_text(target) + + task = item.get("prompt", "<|ASR|>") + text_language = item.get("text_language", "<|zh|>") + + prompt = f"{self.sos}{task}{text_language}" + prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") + prompt_ids_len = len(prompt_ids) - 1 # [sos, task] + + target_ids = self.tokenizer.encode(target, allowed_special="all") + target_ids_len = len(target_ids) + 1 # [lid, text] + + eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] + + ids = prompt_ids + target_ids + eos + ids_lengths = len(ids) + + text = torch.tensor(ids, dtype=torch.int64) + text_lengths = torch.tensor([ids_lengths], dtype=torch.int32) + + target_mask = [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1] # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + target_mask = torch.tensor(target_mask, dtype=torch.float32) + + return {"speech": speech[0, :, :], + "speech_lengths": speech_lengths, + "text": text, + "text_lengths": text_lengths, + "target_mask": target_mask, + } + + + def collator(self, samples: list=None): + outputs = {} + for sample in samples: + for key in sample.keys(): + if key not in outputs: + outputs[key] = [] + outputs[key].append(sample[key]) + + for key, data_list in outputs.items(): + if isinstance(data_list[0], torch.Tensor): + if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: + + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + + outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) + return outputs + + diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py index 8f0809a71..385025dab 100644 --- a/funasr/losses/label_smoothing_loss.py +++ b/funasr/losses/label_smoothing_loss.py @@ -50,8 +50,8 @@ class LabelSmoothingLoss(nn.Module): """ assert x.size(2) == self.size batch_size = x.size(0) - x = x.view(-1, self.size) - target = target.view(-1) + x = x.contiguous().view(-1, self.size) + target = target.contiguous().view(-1) with torch.no_grad(): true_dist = x.clone() true_dist.fill_(self.smoothing / (self.size - 1)) diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py new file mode 100644 index 000000000..bae2832ae --- /dev/null +++ b/funasr/models/sense_voice/decoder.py @@ -0,0 +1,66 @@ +import copy +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from funasr.models.transformer.utils.nets_utils import make_pad_mask + +def sense_voice_decode_forward( + self, + x: torch.Tensor, + xa: torch.Tensor, + kv_cache: Optional[dict] = None, + **kwargs, +): + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + # import pdb;pdb.set_trace() + use_padmask = self.use_padmask + hlens = kwargs.get("hlens", None) + + ys_in_lens = kwargs.get("ys_in_lens", None) + + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + tgt, memory = x, xa + tgt[tgt==-1] = 0 + tgt = ( + self.token_embedding(tgt) + + self.positional_embedding[offset : offset + tgt.size(1)] + ) + # tgt = self.dropout(tgt) + + x = tgt.to(memory.dtype) + + if use_padmask and hlens is not None: + memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) + else: + memory_mask = None + + for layer, block in enumerate(self.blocks): + x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True) + + + x = self.ln(x) + x = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + + return x + \ No newline at end of file diff --git a/funasr/models/sense_voice/encoder.py b/funasr/models/sense_voice/encoder.py new file mode 100644 index 000000000..3870c52cf --- /dev/null +++ b/funasr/models/sense_voice/encoder.py @@ -0,0 +1,67 @@ +import copy +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from funasr.models.transformer.utils.nets_utils import make_pad_mask + + +def sense_voice_encode_forward( + self, + x: torch.Tensor, + ilens: torch.Tensor = None, + **kwargs, +): + use_padmask = self.use_padmask + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + n_frames = x.size(1) + max_pos = self.positional_embedding.size(0) + max_pos = n_frames if n_frames < max_pos else max_pos + x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype) + + + if ilens is not None: + if self.downsample_rate == 4: + olens = ( + 1 + + ( + ilens + - self.conv1.kernel_size[0] + + 2 * self.conv1.padding[0] + ) + // self.conv1.stride[0] + ) + else: + olens = ilens + olens = ( + 1 + + ( + olens + - self.conv2.kernel_size[0] + + 2 * self.conv2.padding[0] + ) + // self.conv2.stride[0] + ) + olens = torch.clamp(olens, max=max_pos) + else: + olens = None + + if use_padmask and olens is not None: + padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device) + else: + padding_mask = None + + for layer, block in enumerate(self.blocks): + x = block(x, mask=padding_mask, is_pad_mask=True) + + + x = self.ln_post(x) + + if ilens is None: + return x + else: + return x, olens diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index 4ee2fa55c..b5272a1ef 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -1,35 +1,158 @@ from dataclasses import dataclass from typing import Dict from typing import Iterable, Optional +import types import time import numpy as np import torch import torch.nn.functional as F from torch import Tensor from torch import nn +from torch.cuda.amp import autocast +from funasr.metrics.compute_acc import compute_accuracy +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.train_utils.device_funcs import force_gatherable from . import whisper_lib as whisper from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank from funasr.register import tables + + @tables.register("model_classes", "SenseVoice") class SenseVoice(nn.Module): def __init__(self, *args, **kwargs): super().__init__() - hub = kwargs.get("hub", "funasr") - + dims = kwargs.get("dims", {}) dims = whisper.model.ModelDimensions(**dims) model = whisper.model.Whisper(dims=dims) + # encoder + model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) + model.encoder.use_padmask = kwargs.get("use_padmask", True) + from .encoder import sense_voice_encode_forward + model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) + + # decoder + model.decoder.use_padmask = kwargs.get("use_padmask", True) + from .decoder import sense_voice_decode_forward + model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder) + self.model = model self.encoder_output_size = self.model.dims.n_audio_state - def forward(self, ): - pass + self.activation_checkpoint = kwargs.get("activation_checkpoint", False) + self.ignore_id = kwargs.get("ignore_id", -1) + self.vocab_size = kwargs.get("vocab_size", -1) + self.length_normalized_loss = kwargs.get("length_normalized_loss", True) + self.criterion_att = LabelSmoothingLoss( + size=self.vocab_size, + padding_idx=self.ignore_id, + smoothing=kwargs.get("lsm_weight", 0.0), + normalize_length=self.length_normalized_loss, + ) + + specaug = kwargs.get("specaug", None) + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**kwargs.get("specaug_conf", {})) + self.specaug = specaug + + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ): + target_mask = kwargs.get("target_mask", None) + # import pdb; + # pdb.set_trace() + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + if self.activation_checkpoint: + from torch.utils.checkpoint import checkpoint + encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False) + else: + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask + ) + loss = loss_att + stats = {} + stats["acc"] = acc_att + stats["loss"] = torch.clone(loss.detach()) + stats["batch_size"] = batch_size + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, + ) : + """Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + + # Forward encoder + encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) + + return encoder_out, encoder_out_lens + + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + **kwargs, + ): + target_mask = kwargs.get("target_mask", None) + stats = {} + + # 1. Forward decoder + decoder_out = self.model.decoder( + x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens + ) + + # 2. Compute attention loss + mask = torch.ones_like(ys_pad) * (-1) + ys_pad_mask = (ys_pad * target_mask + mask * (1-target_mask)).to(torch.int64) + ys_pad_mask[ys_pad_mask == 0] = -1 + loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) + + with torch.no_grad(): + preds = torch.argmax(decoder_out, -1) + acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id) + + return loss_att, acc_att, None, None + + def inference(self, data_in, data_lengths=None, diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py index 0e8f09b1b..ca960f197 100644 --- a/funasr/models/sense_voice/whisper_lib/model.py +++ b/funasr/models/sense_voice/whisper_lib/model.py @@ -74,7 +74,10 @@ class MultiHeadAttention(nn.Module): xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, + **kwargs, ): + is_pad_mask = kwargs.get("is_pad_mask", False) + q = self.query(x) if kv_cache is None or xa is None or self.key not in kv_cache: @@ -87,12 +90,13 @@ class MultiHeadAttention(nn.Module): k = kv_cache[self.key] v = kv_cache[self.value] - wv, qk = self.qkv_attention(q, k, v, mask) + wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask) return self.out(wv), qk def qkv_attention( - self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs, ): + is_pad_mask = kwargs.get("is_pad_mask", False) n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale @@ -101,10 +105,20 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: - qk = qk + mask[:n_ctx, :n_ctx] + if not is_pad_mask: + qk = qk + mask[:n_ctx, :n_ctx] + else: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float( + np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min + ) + qk = qk.masked_fill(mask, min_value) + qk = qk.float() w = F.softmax(qk, dim=-1).to(q.dtype) + if mask is not None and is_pad_mask: + w = w.masked_fill(mask, 0.0) return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() @@ -132,10 +146,13 @@ class ResidualAttentionBlock(nn.Module): xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, + **kwargs, ): - x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + is_pad_mask = kwargs.get("is_pad_mask", False) + is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] if self.cross_attn: - x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0] x = x + self.mlp(self.mlp_ln(x)) return x diff --git a/funasr/tokenizer/whisper_tokenizer.py b/funasr/tokenizer/whisper_tokenizer.py index 6684f2598..0a34d19d1 100644 --- a/funasr/tokenizer/whisper_tokenizer.py +++ b/funasr/tokenizer/whisper_tokenizer.py @@ -22,3 +22,25 @@ def WhisperTokenizer(**kwargs): return tokenizer + +@tables.register("tokenizer_classes", "SenseVoiceTokenizer") +def SenseVoiceTokenizer(**kwargs): + try: + from funasr.models.sense_voice.whisper_lib.tokenizer import get_tokenizer + except: + print("Notice: If you want to use whisper, please `pip install -U openai-whisper`") + + language = kwargs.get("language", None) + task = kwargs.get("task", None) + is_multilingual = kwargs.get("is_multilingual", True) + num_languages = kwargs.get("num_languages", 8749) + vocab_path = kwargs.get("vocab_path", None) + tokenizer = get_tokenizer( + multilingual=is_multilingual, + num_languages=num_languages, + language=language, + task=task, + vocab_path=vocab_path, + ) + + return tokenizer