From d19f48e17478be273584853568ac101c994c37e5 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Mon, 8 Apr 2024 18:51:53 +0800 Subject: [PATCH] Dev gzf exp (#1593) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update * update with main (#1582) * update * Expose the max_end_silence_time to the user (#1532) * update * update * update * update * update * update * update * update * update * finetune * finetune * finetune * finetune * finetune * finetune * fix: resolve IndexError when using spk model and the audio contains only 1 segment (#1535) * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * v1.0.19 * v1.0.19 * train * train * docs * update * update * update * update * update * update * update * train update * bugfix seg_dict_file * bugfix seg_dict_file * train * train * train (#1548) * Dev gzf new (#1551) * train * train * : (#1552) 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 Co-authored-by: carl.che * Dev gzf new (#1553) * train * train * train * train * train * train * train * train * Dev gzf new (#1554) * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1555) * train * train * train * train * train * train * train * train * train * train * train * train * train * 修正commit 87b62d68957a2194b017a43b6c2a15424a05a984 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 (#1556) * : 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 * : 1.修正commit 87b62d68957a2194b017a43b6c2a15424a05a984 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 --------- Co-authored-by: carl.che * Dev gzf new (#1557) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1559) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1561) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1562) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1567) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice (#1568) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * Dev gzf new (#1574) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix (#1580) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix * v1.0.20 --------- Co-authored-by: BOBOTANG Co-authored-by: Atomie CHEN Co-authored-by: Carl <415692979@qq.com> Co-authored-by: carl.che * ctc * ctc * ctc * ctc * update with main (#1592) * update * Expose the max_end_silence_time to the user (#1532) * update * update * update * update * update * update * update * update * update * finetune * finetune * finetune * finetune * finetune * finetune * fix: resolve IndexError when using spk model and the audio contains only 1 segment (#1535) * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * install requirements automatically * v1.0.19 * v1.0.19 * train * train * docs * update * update * update * update * update * update * update * train update * bugfix seg_dict_file * bugfix seg_dict_file * train * train * train (#1548) * Dev gzf new (#1551) * train * train * : (#1552) 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 Co-authored-by: carl.che * Dev gzf new (#1553) * train * train * train * train * train * train * train * train * Dev gzf new (#1554) * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1555) * train * train * train * train * train * train * train * train * train * train * train * train * train * 修正commit 87b62d68957a2194b017a43b6c2a15424a05a984 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 (#1556) * : 1.修正添加标点时英文首单词和第二个单词被错误合并的问题。 * : 1.修正commit 87b62d68957a2194b017a43b6c2a15424a05a984 引入的英文整句标点预测导致末尾两个单词中间的空格被删除的问题。 --------- Co-authored-by: carl.che * Dev gzf new (#1557) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1559) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1561) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1562) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * Dev gzf new (#1567) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice (#1568) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * Dev gzf new (#1574) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix (#1580) * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * train * whisper_lib for sense voice * aishell recipe * sense voice * docs * bugfix * v1.0.20 * update demo page (#1585) * commit web page vue * optimize web page * optimize web page * remove other private component * modify web page * Update index.vue * Update lxwjzxfw.vue * Update sstx.vue * update static file --------- Co-authored-by: BOBOTANG Co-authored-by: Atomie CHEN Co-authored-by: Carl <415692979@qq.com> Co-authored-by: carl.che Co-authored-by: bltcn * sensevoice * sensevoice --------- Co-authored-by: BOBOTANG Co-authored-by: Atomie CHEN Co-authored-by: Carl <415692979@qq.com> Co-authored-by: carl.che Co-authored-by: bltcn --- funasr/frontends/whisper_frontend.py | 15 ++++- funasr/models/llm_asr_nar/model.py | 59 ++++++++++++++++--- funasr/models/sense_voice/model.py | 6 +- .../models/sense_voice/whisper_lib/audio.py | 6 +- .../sense_voice/whisper_lib/decoding.py | 12 +++- .../sense_voice/whisper_lib/tokenizer.py | 11 +++- funasr/version.txt | 2 +- 7 files changed, 89 insertions(+), 22 deletions(-) diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py index dd61f8ea5..acc99af01 100644 --- a/funasr/frontends/whisper_frontend.py +++ b/funasr/frontends/whisper_frontend.py @@ -38,7 +38,13 @@ class WhisperFrontend(nn.Module): if whisper_model == "large-v3" or whisper_model == "large": self.n_mels = 128 - self.mel_filters = whisper.audio.mel_filters + filters_path = kwargs.get("filters_path", None) + self.filters_path = filters_path + if filters_path is not None: + from funasr.models.sense_voice.whisper_lib.audio import mel_filters + self.mel_filters = mel_filters + else: + self.mel_filters = whisper.audio.mel_filters self.do_pad_trim = do_pad_trim if do_pad_trim: self.pad_or_trim = whisper.pad_or_trim @@ -61,8 +67,10 @@ class WhisperFrontend(nn.Module): # whisper deletes the last frame by default (Shih-Lun) magnitudes = stft[..., :-1].abs() ** 2 - - filters = self.mel_filters(audio.device, self.n_mels) + if self.filters_path is not None: + filters = self.mel_filters(audio.device, self.n_mels, self.filters_path) + else: + filters = self.mel_filters(audio.device, self.n_mels) mel_spec = filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() @@ -86,6 +94,7 @@ class WhisperFrontend(nn.Module): batch_size = input.size(0) feats = [] feats_lens = [] + input = input.to(torch.float32) for i in range(batch_size): if self.do_pad_trim: feat = self.pad_or_trim(input[i], self.pad_samples) diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py index 30537cf79..994259a8f 100644 --- a/funasr/models/llm_asr_nar/model.py +++ b/funasr/models/llm_asr_nar/model.py @@ -366,7 +366,7 @@ class LLMASRNARPrompt(nn.Module): decoder_conf: dict = None, ctc: str = None, ctc_conf: dict = None, - ctc_weight: float = 0.5, + ctc_weight: float = 0.0, llm: str = None, llm_conf: dict = None, adaptor: str = None, @@ -473,6 +473,15 @@ class LLMASRNARPrompt(nn.Module): self.length_normalized_loss = length_normalized_loss self.beam_search = None + if ctc_weight > 0.0: + if ctc_conf is None: + ctc_conf = {} + + ctc = CTC( + odim=vocab_size, encoder_output_size=adaptor_conf["encoder_dim"], **ctc_conf + ) + self.ctc_weight = ctc_weight + self.ctc = ctc def forward( self, @@ -502,9 +511,23 @@ class LLMASRNARPrompt(nn.Module): speech_lengths = speech_lengths[:, 0] batch_size = speech.shape[0] - + + stats = {} # audio encoder - encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask) + outs = self.encode(speech, speech_lengths, audio_mask=audio_mask) + enc, enc_lens = outs[0], outs[1] + encoder_out, encoder_out_lens, loss_pre = outs[2], outs[3], outs[4] + + + # decoder: CTC branch + + if self.ctc_weight != 0.0: + loss_ctc, cer_ctc = self._calc_ctc_loss( + enc, enc_lens, text, text_lengths + ) + + # Collect CTC branch stats + stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None # adaptor encoder_out = self.adaptor(encoder_out) @@ -536,17 +559,19 @@ class LLMASRNARPrompt(nn.Module): # labels_ids[1:] -> [prompt, input, target, eos] -> [-1, input, target, eos]; model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids) loss_llm = model_outputs.loss + stats["loss_llm"] = torch.clone(loss_llm.detach()) + if self.ctc_weight > 0.0: + loss_llm = self.ctc_weight * loss_ctc + loss_llm loss = loss_llm + loss_pre * self.predictor_weight - stats = {} + with torch.no_grad(): preds = torch.argmax(model_outputs.logits, -1) acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) stats["acc"] = acc_att - stats["loss_pre"] = torch.clone(loss_pre.detach()) - stats["loss_llm"] = torch.clone(loss_llm.detach()) stats["loss"] = torch.clone(loss.detach()) + stats["batch_size"] = batch_size # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: @@ -576,7 +601,24 @@ class LLMASRNARPrompt(nn.Module): if audio_token_lengths is not None: loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length) - return pre_acoustic_embeds, pre_token_length, loss_pre + return enc, enc_lens, pre_acoustic_embeds, pre_token_length, loss_pre + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc def inference(self, data_in, @@ -648,7 +690,8 @@ class LLMASRNARPrompt(nn.Module): else: inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) - inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio] + # inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio, pad] + inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio] attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"]) # model_outputs = self.llm.generate( diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index d6552a6d0..521dec888 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -91,7 +91,11 @@ class SenseVoice(nn.Module): # decode the audio # initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>") - options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt) + + vocab_path = kwargs.get("vocab_path", None) + options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt, vocab_path=vocab_path) + + result = whisper.decode(self.model, speech, options) results = [] diff --git a/funasr/models/sense_voice/whisper_lib/audio.py b/funasr/models/sense_voice/whisper_lib/audio.py index cf6c66ad9..52da32cab 100644 --- a/funasr/models/sense_voice/whisper_lib/audio.py +++ b/funasr/models/sense_voice/whisper_lib/audio.py @@ -89,7 +89,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) -def mel_filters(device, n_mels: int) -> torch.Tensor: +def mel_filters(device, n_mels: int, filters_path: str=None) -> torch.Tensor: """ load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa dependency; saved using: @@ -101,8 +101,8 @@ def mel_filters(device, n_mels: int) -> torch.Tensor: ) """ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" - - filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + if filters_path is None: + filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") with np.load(filters_path, allow_pickle=False) as f: return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py index 73b02626a..caca114ba 100644 --- a/funasr/models/sense_voice/whisper_lib/decoding.py +++ b/funasr/models/sense_voice/whisper_lib/decoding.py @@ -119,6 +119,7 @@ class DecodingOptions: # FIX(funasr): sense vocie initial_prompt: str = None + vocab_path: str = None @dataclass(frozen=True) @@ -527,6 +528,7 @@ class DecodingTask: num_languages=model.num_languages, language=language, task=options.task, + vocab_path=options.vocab_path ) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) @@ -616,10 +618,13 @@ class DecodingTask: + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens ) - #FIX(gzf): sense vocie + #FIX(funasr): sense vocie if initial_prompt := self.options.initial_prompt: - tokens = self.tokenizer.encode(initial_prompt, allowed_special="all") - if self.options.language is None: + if self.options.language is not None: + initial_prompt = f"{initial_prompt}<|{self.options.language}|>" + tokens = self.tokenizer.encode(initial_prompt, allowed_special="all") + else: + tokens = self.tokenizer.encode(initial_prompt, allowed_special="all") tokens += [0] @@ -691,6 +696,7 @@ class DecodingTask: if self.options.language is None: # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens languages = "".join([f"<|{language}|>" for language in languages]) + n_audio = audio_features.shape[0] lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to( audio_features.device) # [n_audio, 1] diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py index e941fb2b5..463ce8383 100644 --- a/funasr/models/sense_voice/whisper_lib/tokenizer.py +++ b/funasr/models/sense_voice/whisper_lib/tokenizer.py @@ -363,8 +363,10 @@ class Tokenizer: @lru_cache(maxsize=None) -def get_encoding(name: str = "gpt2", num_languages: int = 99): - vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") +def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path:str=None): + if vocab_path is None: + vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") + ranks = { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(vocab_path) if line) @@ -423,6 +425,7 @@ def get_tokenizer( language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] encoding_path: Optional[str] = None, + vocab_path: Optional[str] = None, ) -> Tokenizer: if language is not None: language = language.lower() @@ -443,7 +446,9 @@ def get_tokenizer( if encoding_path is not None: encoding_name = encoding_path - encoding = get_encoding(name=encoding_name, num_languages=num_languages) + + encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path) + return Tokenizer( encoding=encoding, num_languages=num_languages, language=language, task=task diff --git a/funasr/version.txt b/funasr/version.txt index c2320f5be..2fa390179 100644 --- a/funasr/version.txt +++ b/funasr/version.txt @@ -1 +1 @@ -1.0.20 +1.0.22 \ No newline at end of file