From 8a03879937fd50ca9f554f22490ecb43da05cab8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=B4=E7=9F=B3?= Date: Sun, 29 Sep 2024 17:37:55 +0800 Subject: [PATCH] update sensevoice with pitch --- .../datasets/sense_voice_datasets/datasets.py | 253 +++++++++++++++- funasr/models/extract_tokens/model.py | 152 ++++++++++ funasr/models/sense_voice/model_small.py | 269 ++++++++++++++++++ 3 files changed, 670 insertions(+), 4 deletions(-) diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py index e8c838530..1a55593f9 100644 --- a/funasr/datasets/sense_voice_datasets/datasets.py +++ b/funasr/datasets/sense_voice_datasets/datasets.py @@ -1,9 +1,7 @@ -import logging - -import re import torch -import random +import logging import traceback +import numpy as np from funasr.register import tables from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video @@ -438,3 +436,250 @@ class SenseVoiceCTCDataset(torch.utils.data.Dataset): outputs["text"] = outputs["text"][:, :text_lengths_max] return outputs + + +@tables.register("dataset_classes", "SenseVoicePitchDataset") +class SenseVoicePitchDataset(torch.utils.data.Dataset): + """ + SenseVoiceDataset + """ + + def __init__( + self, + path, + index_ds: str = None, + frontend=None, + tokenizer=None, + int_pad_value: int = -1, + float_pad_value: float = 0.0, + **kwargs, + ): + super().__init__() + index_ds_class = tables.index_ds_classes.get(index_ds) + self.index_ds = index_ds_class(path, **kwargs) + preprocessor_speech = kwargs.get("preprocessor_speech", None) + if preprocessor_speech: + preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech) + preprocessor_speech = preprocessor_speech_class( + **kwargs.get("preprocessor_speech_conf") + ) + self.preprocessor_speech = preprocessor_speech + preprocessor_text = kwargs.get("preprocessor_text", None) + if preprocessor_text: + preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text) + preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) + self.preprocessor_text = preprocessor_text + + self.frontend = frontend + self.fs = 16000 if frontend is None else frontend.fs + self.data_type = "sound" + self.tokenizer = tokenizer + + self.int_pad_value = int_pad_value + self.float_pad_value = float_pad_value + self.sos = kwargs.get("sos", "<|startoftranscript|>") + self.eos = kwargs.get("eos", "<|endoftext|>") + self.batch_size = kwargs.get("batch_size") + self.batch_type = kwargs.get("batch_type") + self.prompt_ids_len = 0 + self.retry = kwargs.get("retry", 5) + + self.permute = False + from funasr.frontends.whisper_frontend import WhisperFrontend + + if isinstance(self.frontend, WhisperFrontend): + self.permute = True + self.max_token_length = kwargs.get("max_token_length", 1500) + self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5) + self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500) + self.multiturn_num_max = kwargs.get("multiturn_num_max", 5) + self.max_source_length = kwargs.get("max_source_length", 3000) + + def get_source_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_source_len(item) + + def get_target_len(self, index): + item = self.index_ds[index] + return self.index_ds.get_target_len(item) + + def __len__(self): + return len(self.index_ds) + + def __getitem__(self, index): + + output = None + for idx in range(self.retry): + if idx == 0: + index_cur = index + else: + index_cur = torch.randint(0, len(self.index_ds), ()).item() + + item = self.index_ds[index_cur] + + source = item["source"] + try: + data_src = load_audio_text_image_video(source, fs=self.fs) + except Exception as e: + logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") + continue + + if self.preprocessor_speech: + data_src = self.preprocessor_speech(data_src, fs=self.fs) + speech, speech_lengths = extract_fbank( + data_src, data_type=self.data_type, frontend=self.frontend, is_final=True + ) # speech: [b, T, d] + if speech_lengths > self.max_source_length: + logging.info( + f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}" + ) + continue + + if self.permute: + speech = speech.permute(0, 2, 1) + target = item["target"] + if self.preprocessor_text: + target = self.preprocessor_text(target) + + task = item.get("prompt", "<|ASR|>") + text_language = item.get("text_language", "<|zh|>") + + if isinstance(self.sos, str): + prompt = f"{self.sos}{task}{text_language}" + prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") + else: + prompt = f"{task}{text_language}" + prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") + prompt_ids = [self.sos] + prompt_ids + + prompt_ids_len = len(prompt_ids) - 1 # [sos, task] + self.prompt_ids_len = prompt_ids_len + + target_ids = self.tokenizer.encode(target, allowed_special="all") + target_ids_len = len(target_ids) + 1 # [lid, text] + if target_ids_len > 200: + continue + + if isinstance(self.eos, str): + eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] + else: + eos = [self.eos] + + ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos] + ids_lengths = len(ids) + + text = torch.tensor(ids, dtype=torch.int64) + text_lengths = torch.tensor([ids_lengths], dtype=torch.int32) + + target_mask = ( + [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1] + ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + target_mask_lengths = len(target_mask) + target_mask = torch.tensor(target_mask, dtype=torch.float32) + target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32) + + # pitch + if 'f0' in item: + f0_file = item['f0'] + f0 = torch.tensor(np.load(f0_file), dtype=torch.float32) + if f0.shape > speech.shape[1]: + f0 = f0[:speech.shape[1]] + elif f0.shape < speech.shape[1]: + last_value = f0[-1] + f0 = torch.cat([f0, last_value.repeat(speech.shape[1] - f0.shape)]) + f0_tag = torch.Tensor([1], dtype=torch.int32) + else: + f0 = torch.tensor([0.0], dtype=torch.float32) + f0_tag = torch.Tensor([0], dtype=torch.int32) + + output = { + "speech": speech[0, :, :], + "speech_lengths": speech_lengths, + "text": text, + "text_lengths": text_lengths, + "target_mask": target_mask, + "target_mask_lengths": target_mask_lengths, + "f0": f0, + "f0_tag": f0_tag, + } + break + + return output + + def collator(self, samples: list = None): + outputs = {} + for sample in samples: + if sample is None: + continue + for key in sample.keys(): + if key not in outputs: + outputs[key] = [] + outputs[key].append(sample[key]) + + if len(outputs) < 1: + logging.error(f"ERROR: data is empty!") + outputs = { + "speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :], + "speech_lengths": torch.tensor( + [ + 10, + ], + dtype=torch.int32, + )[:, None], + "text": torch.tensor( + [ + 58836, + ], + dtype=torch.int32, + )[None, :], + "text_lengths": torch.tensor( + [ + 1, + ], + dtype=torch.int32, + )[:, None], + "target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]])[ + None, : + ], + } + return outputs + + for key, data_list in outputs.items(): + if isinstance(data_list[0], torch.Tensor): + if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32: + + pad_value = self.int_pad_value + else: + pad_value = self.float_pad_value + + outputs[key] = torch.nn.utils.rnn.pad_sequence( + data_list, batch_first=True, padding_value=pad_value + ) + + if self.batch_type != "example": + for i in range(10): + outputs = self._filter_badcase(outputs, i=i) + + return outputs + + def _filter_badcase(self, outputs, i=0): + b, t, _ = outputs["speech"].shape + + if b * t > self.batch_size * 1.25: + beg = torch.randint(0, 2, ()).item() + if b < 2: + beg = 0 + logging.info( + f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}" + ) + for key, data_list in outputs.items(): + outputs[key] = outputs[key][beg : beg + b : 2] + + speech_lengths_max = outputs["speech_lengths"].max().item() + outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :] + text_lengths_max = outputs["text_lengths"].max().item() + outputs["text"] = outputs["text"][:, :text_lengths_max] + target_mask_lengths_max = outputs["target_mask_lengths"].max().item() + outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max] + + return outputs \ No newline at end of file diff --git a/funasr/models/extract_tokens/model.py b/funasr/models/extract_tokens/model.py index cb1e76625..16fe5aaa0 100644 --- a/funasr/models/extract_tokens/model.py +++ b/funasr/models/extract_tokens/model.py @@ -1047,6 +1047,158 @@ class EncoderLayerSANMLarge(nn.Module): x = x + self.mlp(self.mlp_ln(x)) return x + + +@tables.register("encoder_classes", "SenseVoiceQuantizedEncoderPitch") +class SenseVoiceQuantizedEncoderPitch(nn.Module): + def __init__( + self, + input_size, + linear_units: int, + attention_heads: int, + num_blocks: int, + quantize_layer_idx: int, + normalized_quant_input: bool, + quantizer_config: dict, + units: int, + **kwargs, + ): + super().__init__() + self.conv1 = Conv1d(input_size, linear_units, kernel_size=3, stride=2, padding=1) + self.conv2 = Conv1d(linear_units, linear_units, kernel_size=3, stride=2, padding=1) + + self.blocks = nn.ModuleList( + [ + EncoderLayerSANMLarge(linear_units, attention_heads, **kwargs) + for _ in range(num_blocks) + ] + ) + self.ln_post = LayerNorm(linear_units) + self.use_padmask = kwargs.get("use_padmask", True) + self.downsample_rate = kwargs.get("downsample_rate", 4) + + self.linear_units = linear_units + self.quantize_layer_idx = quantize_layer_idx + self.normalized_quant_input = normalized_quant_input + self.quantizer = self.build_quantizer(quantizer_config) + + self.pitch_predictor = torch.Linear(units, 1) + self.pitch_act = torch.nn.ReLU() + + def build_quantizer(self, vq_config): + if vq_config is None: + return None + name = vq_config.pop("name", "costume_quantizer") + if name == "costume_quantizer": + from funasr.models.sense_voice.quantizer.costume_quantizer import CostumeQuantizer + + quantizer = CostumeQuantizer( + input_size=self.linear_units, + **vq_config, + ) + vq_config["name"] = "costume_quantizer" + return quantizer + elif name == "lookup_free_quantizer": + from funasr.models.sense_voice.quantizer.lookup_free_quantizer import LFQ + + quantizer = LFQ( + input_size=self.linear_units, + **vq_config, + ) + vq_config["name"] = "lookup_free_quantizer" + return quantizer + elif name == "finite_scalar_quantizer": + from funasr.models.sense_voice.quantizer.finite_scalar_quantizer import FSQ + + quantizer = FSQ( + input_size=self.linear_units, + **vq_config, + ) + vq_config["name"] = "finite_scalar_quantizer" + return quantizer + else: + raise NotImplemented("quantizer {} not implemented".format(name)) + + def cal_f0(self, x): + x = self.pitch_predictor(x) + x = self.pitch_act(x) + return x + + def quantize_enc_outs(self, x): + ret_dict = {} + + if self.normalized_quant_input: + x = F.normalize(x, dim=-1) + ret_dict["quant_in"] = x + x, indices, commit_loss, sub_quants = self.quantizer(x) + ret_dict["quant_out"] = x + ret_dict["indices"] = indices + ret_dict["quant_loss"] = commit_loss + + return x, ret_dict + + def forward( + self, + x: torch.Tensor, + ilens: torch.Tensor = None, + **kwargs, + ): + use_padmask = self.use_padmask + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + only_extract_tokens = kwargs.get("only_extract_tokens", False) + + n_frames = x.size(1) + max_pos = n_frames + + if ilens is not None: + if self.downsample_rate == 4: + olens = ( + 1 + + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0]) + // self.conv1.stride[0] + ) + else: + olens = ilens + olens = ( + 1 + + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0]) + // self.conv2.stride[0] + ) + olens = torch.clamp(olens, max=max_pos) + else: + olens = None + + if use_padmask and olens is not None: + padding_mask = (~make_pad_mask(olens)[:, None, :]).to(torch.bool).to(x.device) + else: + padding_mask = None + + device = x.device + seq_length = x.shape[1] + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + + for layer, block in enumerate(self.blocks): + x = block(x, mask=padding_mask, position_ids=position_ids) + if self.quantize_layer_idx is not None and self.quantizer is not None: + if layer == self.quantize_layer_idx: + hint_once( + f"Quantization at layer {layer} wit {self.quantizer}", + "normalize_quant_enc_out", + rank=0, + ) + x, ret_dict = self.quantize_enc_outs(x) + if only_extract_tokens: + return (x, ret_dict), olens + + x = self.ln_post(x) + + if ilens is None: + return x, self.cal_f0(x) + else: + return x, self.cal_f0(x), olens @tables.register("encoder_classes", "SenseVoiceQuantizedEncoder") diff --git a/funasr/models/sense_voice/model_small.py b/funasr/models/sense_voice/model_small.py index 33d3ba527..d61c8ae9c 100644 --- a/funasr/models/sense_voice/model_small.py +++ b/funasr/models/sense_voice/model_small.py @@ -1950,3 +1950,272 @@ class SenseVoiceL(nn.Module): # meta_data["write_tokens"] = f"{time5 - time4:0.3f}" # # return results, meta_data + + +class SenseVoiceLF0Pred(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + encoder = kwargs.get("encoder") + encoder_conf = kwargs.get("encoder_conf", {}) + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(**encoder_conf) + + if encoder_conf.get("freeze", False): + freeze_exclude_key = encoder_conf.get("freeze_exclude_key", None) + for name, param in encoder.named_parameters(): + if not freeze_exclude_key in name: + logging.info(f"name: {name} is freeze") + param.requires_grad = False + + dims = kwargs.get("dims", {}) + dims = whisper.model.ModelDimensions(**dims) + model = whisper.model.Whisper(dims=dims) + + # encoder + del model.encoder + model.encoder = encoder + + # decoder + model.decoder.use_padmask = kwargs.get("use_padmask", True) + from .decoder import sense_voice_decode_forward + + model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder) + + self.model = model + + self.encoder_output_size = self.model.dims.n_audio_state + + self.activation_checkpoint = kwargs.get("activation_checkpoint", False) + self.ignore_id = kwargs.get("ignore_id", -1) + self.vocab_size = kwargs.get("vocab_size", -1) + self.length_normalized_loss = kwargs.get("length_normalized_loss", True) + self.criterion_att = LabelSmoothingLoss( + size=self.vocab_size, + padding_idx=self.ignore_id, + smoothing=kwargs.get("lsm_weight", 0.0), + normalize_length=self.length_normalized_loss, + ) + + specaug = kwargs.get("specaug", None) + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**kwargs.get("specaug_conf", {})) + self.specaug = specaug + + self.loss_f0_weight = kwargs.get("loss_f0_weight", 0.3) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + f0: torch.Tensor, + f0_tag: torch.Tensor, + **kwargs, + ): + target_mask = kwargs.get("target_mask", None) + + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + if self.activation_checkpoint: + from torch.utils.checkpoint import checkpoint + + encoder_out, encoder_out_lens = checkpoint( + self.encode, speech, speech_lengths, use_reentrant=False + ) + else: + encoder_out, encoder_f0_out, encoder_out_lens = self.encode(speech, speech_lengths) + + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask + ) + + + loss_f0 = self._cal_f0_loss( + encoder_f0_out, speech_lengths, f0, f0_tag + ) + + loss = loss_att + loss_f0 * self.loss_f0_weight + stats = {} + stats["acc"] = acc_att + stats["loss"] = torch.clone(loss.detach()) + stats["batch_size"] = batch_size + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = int((text_lengths + 1).sum()) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def _cal_f0_loss(self, encoder_f0_out, speech_lengths, f0, f0_tag): + if self.encoder.downsample_rate == 4: + olens = ( + 1 + + (speech_lengths - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0]) + // self.conv1.stride[0] + ) + f0 = f0[::2][::2] + else: + olens = speech_lengths + olens = ( + 1 + + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0]) + // self.conv2.stride[0] + ) + # olens = torch.clamp(olens, max=encoder_f0_out.shape[1]) + padding_mask = (make_pad_mask(olens)[:, :]).to(torch.bool).to(encoder_f0_out.device) + padding_mask = padding_mask * f0_tag.unsqueeze(1) # B*T * B*1 + + f0_loss = torch.abs(f0 - encoder_f0_out.squeeze()) * padding_mask + f0_loss = f0_loss.mean() + return f0_loss + + def encode( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + **kwargs, + ): + """Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Forward encoder + encoder_out, encoder_f0_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) + + return encoder_out, encoder_f0_out, encoder_out_lens + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + **kwargs, + ): + target_mask = kwargs.get("target_mask", None) + stats = {} + + # 1. Forward decoder + decoder_out = self.model.decoder( + x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens + ) + + # 2. Compute attention loss + mask = torch.ones_like(ys_pad) * (-1) + ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) + ys_pad_mask[ys_pad_mask == 0] = -1 + loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) + + with torch.no_grad(): + preds = torch.argmax(decoder_out, -1) + acc_att = compute_accuracy( + preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id + ) + + return loss_att, acc_att, None, None + + def inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + if kwargs.get("batch_size", 1) > 1: + raise NotImplementedError("batch decoding is not implemented") + + if frontend is None and not hasattr(self, "frontend"): + frontend_class = tables.frontend_classes.get("WhisperFrontend") + frontend = frontend_class( + n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) + ) + self.frontend = frontend + else: + frontend = frontend if frontend is not None else self.frontend + + meta_data = {} + if ( + isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" + ): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video( + data_in, + fs=frontend.fs if hasattr(frontend, "fs") else 16000, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + ) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank( + audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend + ) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 + lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 + meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 + + speech = speech.to(device=kwargs["device"])[0, :, :] + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + DecodingOptions = kwargs.get("DecodingOptions", {"fp16": kwargs.get("fp16", True)}) + task = DecodingOptions.get("task", "ASR") + if isinstance(task, str): + task = [task] + task = "".join([f"<|{x}|>" for x in task]) + initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") + DecodingOptions["initial_prompt"] = initial_prompt + + language = DecodingOptions.get("language", None) + language = None if language == "auto" else language + DecodingOptions["language"] = language + + DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) + + if "without_timestamps" not in DecodingOptions: + DecodingOptions["without_timestamps"] = True + + options = whisper.DecodingOptions(**DecodingOptions) + + result = whisper.decode(self.model, speech, options) + text = f"{result.text}" + results = [] + result_i = {"key": key[0], "text": text} + + results.append(result_i) + + ibest_writer = None + if kwargs.get("output_dir") is not None: + if not hasattr(self, "writer"): + self.writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = self.writer[f"1best_recog"] + if ibest_writer is not None: + ibest_writer["text"][key[0]] = text + + return results, meta_data \ No newline at end of file