From 16d4e0054986cd5036cc311cc45fa6dff36cc9da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8C=97=E5=BF=B5?= Date: Thu, 9 Feb 2023 17:53:04 +0800 Subject: [PATCH] add BiCifParaformer --- .../bin/asr_inference_paraformer_vad_punc.py | 16 ++- funasr/models/e2e_asr_paraformer.py | 126 +++--------------- funasr/models/predictor/cif.py | 5 +- funasr/utils/timestamp_tools.py | 58 ++++++-- 4 files changed, 83 insertions(+), 122 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index 1d09c790a..629ee4fdb 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -14,6 +14,7 @@ from typing import Dict from typing import Any from typing import List import math +import copy import numpy as np import torch from typeguard import check_argument_types @@ -38,7 +39,7 @@ from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.tasks.vad import VADTask -from funasr.utils.timestamp_tools import time_stamp_lfr6 +from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl from funasr.bin.punctuation_infer import Text2Punc header_colors = '\033[95m' @@ -234,6 +235,10 @@ class Speech2Text: decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + if isinstance(self.asr_model, BiCifParaformer): + _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len, + pre_token_length) # test no bias cif2 + results = [] b, n, d = decoder_out.size() for i in range(b): @@ -276,9 +281,12 @@ class Speech2Text: else: text = None - time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time) - - results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor)) + if isinstance(self.asr_model, BiCifParaformer): + timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) + results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) + else: + time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time) + results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor)) # assert check_return_type(results) return results diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 759689629..34ee35e82 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -8,6 +8,8 @@ from typing import Tuple from typing import Union import torch +import random +import numpy as np from typeguard import check_argument_types from funasr.layers.abs_normalize import AbsNormalize @@ -24,7 +26,7 @@ from funasr.models.predictor.cif import mae_loss from funasr.models.preencoder.abs_preencoder import AbsPreEncoder from funasr.models.specaug.abs_specaug import AbsSpecAug from funasr.modules.add_sos_eos import add_sos_eos -from funasr.modules.nets_utils import make_pad_mask +from funasr.modules.nets_utils import make_pad_mask, pad_list from funasr.modules.nets_utils import th_accuracy from funasr.torch_utils.device_funcs import force_gatherable from funasr.train.abs_espnet_model import AbsESPnetModel @@ -824,7 +826,10 @@ class ParaformerBert(Paraformer): class BiCifParaformer(Paraformer): - """CTC-attention hybrid Encoder-Decoder model""" + """ + Paraformer model with an extra cif predictor + to conduct accurate timestamp prediction + """ def __init__( self, @@ -891,7 +896,7 @@ class BiCifParaformer(Paraformer): ) assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3" - def _calc_att_loss( + def _calc_pre2_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, @@ -903,47 +908,12 @@ class BiCifParaformer(Paraformer): if self.predictor_bias == 1: _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_pad_lens = ys_pad_lens + self.predictor_bias - pre_acoustic_embeds, pre_token_length, _, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, - ignore_id=self.ignore_id) + _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) - # 0. sampler - decoder_out_1st = None - if self.sampling_ratio > 0.0: - if self.step_cur < 2: - logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) - sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, - pre_acoustic_embeds) - else: - if self.step_cur < 2: - logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) - sematic_embeds = pre_acoustic_embeds + # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) + loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2) - # 1. Forward decoder - decoder_outs = self.decoder( - encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens - ) - decoder_out, _ = decoder_outs[0], decoder_outs[1] - - if decoder_out_1st is None: - decoder_out_1st = decoder_out - # 2. Compute attention loss - loss_att = self.criterion_att(decoder_out, ys_pad) - acc_att = th_accuracy( - decoder_out_1st.view(-1, self.vocab_size), - ys_pad, - ignore_label=self.ignore_id, - ) - loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) - loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length2) - - # Compute cer/wer using attention-decoder - if self.training or self.error_calculator is None: - cer_att, wer_att = None, None - else: - ys_hat = decoder_out_1st.argmax(dim=-1) - cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) - - return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 + return loss_pre2 def calc_predictor(self, encoder_out, encoder_out_lens): @@ -956,8 +926,10 @@ class BiCifParaformer(Paraformer): def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, None, encoder_out_mask, token_num=token_num, - ignore_id=self.ignore_id) + ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, + encoder_out_mask, + token_num) + import pdb; pdb.set_trace() return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak @@ -992,72 +964,16 @@ class BiCifParaformer(Paraformer): # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - intermediate_outs = None - if isinstance(encoder_out, tuple): - intermediate_outs = encoder_out[1] - encoder_out = encoder_out[0] - loss_att, acc_att, cer_att, wer_att = None, None, None, None - loss_ctc, cer_ctc = None, None - loss_pre = None stats = dict() - # 1. CTC branch - if self.ctc_weight != 0.0: - loss_ctc, cer_ctc = self._calc_ctc_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) + loss_pre2 = self._calc_pre2_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) - # Collect CTC branch stats - stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None - stats["cer_ctc"] = cer_ctc - - # Intermediate CTC (optional) - loss_interctc = 0.0 - if self.interctc_weight != 0.0 and intermediate_outs is not None: - for layer_idx, intermediate_out in intermediate_outs: - # we assume intermediate_out has the same length & padding - # as those of encoder_out - loss_ic, cer_ic = self._calc_ctc_loss( - intermediate_out, encoder_out_lens, text, text_lengths - ) - loss_interctc = loss_interctc + loss_ic - - # Collect Intermedaite CTC stats - stats["loss_interctc_layer{}".format(layer_idx)] = ( - loss_ic.detach() if loss_ic is not None else None - ) - stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic - - loss_interctc = loss_interctc / len(intermediate_outs) - - # calculate whole encoder loss - loss_ctc = ( - 1 - self.interctc_weight - ) * loss_ctc + self.interctc_weight * loss_interctc - - # 2b. Attention decoder branch - if self.ctc_weight != 1.0: - loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 = self._calc_att_loss( - encoder_out, encoder_out_lens, text, text_lengths - ) - - # 3. CTC-Att loss definition - if self.ctc_weight == 0.0: - loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight - elif self.ctc_weight == 1.0: - loss = loss_ctc - else: - loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight - - # Collect Attn branch stats - stats["loss_att"] = loss_att.detach() if loss_att is not None else None - stats["acc"] = acc_att - stats["cer"] = cer_att - stats["wer"] = wer_att - stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None - stats["loss_pre2"] = loss_pre2.detach().cpu() if loss_pre is not None else None + loss = loss_pre2 + stats["loss_pre2"] = loss_pre2.detach().cpu() stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index c34759d0d..561537323 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -544,9 +544,8 @@ class CifPredictorV3(nn.Module): token_num_int = torch.max(token_num).type(torch.int32).item() acoustic_embeds = acoustic_embeds[:, :token_num_int, :] return acoustic_embeds, token_num, alphas, cif_peak, token_num2 - - def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, - target_label_length=None, token_num=None): + + def get_upsample_timestamp(self, hidden, mask=None, token_num=None): h = hidden b = hidden.shape[0] context = h.transpose(1, 2) diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 3afaa4049..12337d166 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -86,14 +86,52 @@ def time_stamp_lfr6(alphas: torch.Tensor, speech_lengths: torch.Tensor, raw_text else: return time_stamp_list - -def time_stamp_lfr6_advance(tst: List, text: str): - # advanced timestamp prediction for BiCIF_Paraformer using upsampled alphas - ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = tst - if text.endswith(''): - text = text[:-4] +def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None): + START_END_THRESHOLD = 5 + TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled + if len(us_alphas.shape) == 3: + alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only else: - text = text[:-1] - logging.warning("found text does not end with ") - assert int(ds_alphas.sum() + 1e-4) - 1 == len(text) - + alphas, cif_peak = us_alphas, us_cif_peak + num_frames = cif_peak.shape[0] + if char_list[-1] == '': + char_list = char_list[:-1] + # char_list = [i for i in text] + timestamp_list = [] + # for bicif model trained with large data, cif2 actually fires when a character starts + # so treat the frames between two peaks as the duration of the former token + fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5 + num_peak = len(fire_place) + assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 + # begin silence + if fire_place[0] > START_END_THRESHOLD: + char_list.insert(0, '') + timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) + # tokens timestamp + for i in range(len(fire_place)-1): + # the peak is always a little ahead of the start time + # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE]) + timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE]) + # cut the duration to token and sil of the 0-weight frames last long + # tail token and end silence + if num_frames - fire_place[-1] > START_END_THRESHOLD: + _end = (num_frames + fire_place[-1]) / 2 + timestamp_list[-1][1] = _end*TIME_RATE + timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) + char_list.append("") + else: + timestamp_list[-1][1] = num_frames*TIME_RATE + if begin_time: # add offset time in model with vad + for i in range(len(timestamp_list)): + timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0 + timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0 + res_txt = "" + for char, timestamp in zip(char_list, timestamp_list): + res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1]) + logging.warning(res_txt) # for test + res = [] + for char, timestamp in zip(char_list, timestamp_list): + if char != '': + res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) + return res +