From 62178770dccdbf5da42e831898ea32adeeacba45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AF=AD=E5=B8=86?= Date: Wed, 21 Feb 2024 20:04:01 +0800 Subject: [PATCH] test --- funasr/auto/auto_model.py | 6 +-- funasr/models/contextual_paraformer/model.py | 29 +++++------ funasr/models/seaco_paraformer/model.py | 51 +++++++++++++++++--- 3 files changed, 55 insertions(+), 31 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 60aeb1600..a3202fdb4 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -209,14 +209,12 @@ class AutoModel: kwargs.update(cfg) model = self.model if model is None else model model.eval() - pdb.set_trace() batch_size = kwargs.get("batch_size", 1) # if kwargs.get("device", "cpu") == "cpu": # batch_size = 1 key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key) - pdb.set_trace() speed_stats = {} asr_result_list = [] @@ -225,14 +223,12 @@ class AutoModel: pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None time_speech_total = 0.0 time_escape_total = 0.0 - pdb.set_trace() for beg_idx in range(0, num_samples, batch_size): - pdb.set_trace() end_idx = min(num_samples, beg_idx + batch_size) data_batch = data_list[beg_idx:end_idx] key_batch = key_list[beg_idx:end_idx] batch = {"data_in": data_batch, "key": key_batch} - pdb.set_trace() + if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank batch["data_in"] = data_batch[0] batch["data_lengths"] = input_len diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py index 10bbf9d00..1c0805ab0 100644 --- a/funasr/models/contextual_paraformer/model.py +++ b/funasr/models/contextual_paraformer/model.py @@ -102,17 +102,16 @@ class ContextualParaformer(Paraformer): text_lengths = text_lengths[:, 0] if len(speech_lengths.size()) > 1: speech_lengths = speech_lengths[:, 0] - pdb.set_trace() + batch_size = speech.shape[0] hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") dha_pad = kwargs.get("dha_pad") - pdb.set_trace() + # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) - pdb.set_trace() loss_ctc, cer_ctc = None, None stats = dict() @@ -127,12 +126,11 @@ class ContextualParaformer(Paraformer): stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc - pdb.set_trace() # 2b. Attention decoder branch loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss( encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths ) - pdb.set_trace() + # 3. CTC-Att loss definition if self.ctc_weight == 0.0: loss = loss_att + loss_pre * self.predictor_weight @@ -170,26 +168,24 @@ class ContextualParaformer(Paraformer): ): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - pdb.set_trace() + if self.predictor_bias == 1: _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_pad_lens = ys_pad_lens + self.predictor_bias - pdb.set_trace() + pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) - pdb.set_trace() # -1. bias encoder if self.use_decoder_embedding: hw_embed = self.decoder.embed(hotword_pad) else: hw_embed = self.bias_embed(hotword_pad) - pdb.set_trace() + hw_embed, (_, _) = self.bias_encoder(hw_embed) - pdb.set_trace() _ind = np.arange(0, hotword_pad.shape[0]).tolist() selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]] contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) - pdb.set_trace() + # 0. sampler decoder_out_1st = None if self.sampling_ratio > 0.0: @@ -201,7 +197,7 @@ class ContextualParaformer(Paraformer): if self.step_cur < 2: logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) sematic_embeds = pre_acoustic_embeds - pdb.set_trace() + # 1. Forward decoder decoder_outs = self.decoder( encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info @@ -217,7 +213,7 @@ class ContextualParaformer(Paraformer): loss_ideal = None ''' loss_ideal = None - pdb.set_trace() + if decoder_out_1st is None: decoder_out_1st = decoder_out # 2. Compute attention loss @@ -294,11 +290,11 @@ class ContextualParaformer(Paraformer): enforce_sorted=False) _, (h_n, _) = self.bias_encoder(hw_embed) hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) - pdb.set_trace() + decoder_outs = self.decoder( encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale ) - pdb.set_trace() + decoder_out = decoder_outs[0] decoder_out = torch.log_softmax(decoder_out, dim=-1) return decoder_out, ys_pad_lens @@ -363,14 +359,11 @@ class ContextualParaformer(Paraformer): clas_scale=kwargs.get("clas_scale", 1.0)) decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - pdb.set_trace() results = [] b, n, d = decoder_out.size() - pdb.set_trace() for i in range(b): x = encoder_out[i, :encoder_out_lens[i], :] am_scores = decoder_out[i, :pre_token_length[i], :] - pdb.set_trace() if self.beam_search is not None: nbest_hyps = self.beam_search( x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index caf2b15c7..b3b913344 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -32,7 +32,7 @@ from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank - +import pdb if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: @@ -130,7 +130,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): hotword_pad = kwargs.get("hotword_pad") hotword_lengths = kwargs.get("hotword_lengths") dha_pad = kwargs.get("dha_pad") - + batch_size = speech.shape[0] self.step_cur += 1 # for data-parallel @@ -212,58 +212,87 @@ class SeacoParaformer(BiCifParaformer, Paraformer): nfilter=50, seaco_weight=1.0): # decoder forward + pdb.set_trace() decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) + pdb.set_trace() decoder_pred = torch.log_softmax(decoder_out, dim=-1) if hw_list is not None: + pdb.set_trace() hw_lengths = [len(i) for i in hw_list] hw_list_ = [torch.Tensor(i).long() for i in hw_list] hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) + pdb.set_trace() selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) + pdb.set_trace() contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + pdb.set_trace() num_hot_word = contextual_info.shape[1] _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) - + pdb.set_trace() # ASF Core if nfilter > 0 and nfilter < num_hot_word: for dec in self.seaco_decoder.decoders: dec.reserve_attn = True + pdb.set_trace() # cif_attended, _ = self.decoder2(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) # cif_filter = torch.topk(self.decoder2.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1], min(nfilter, num_hot_word-1))[1].tolist() + pdb.set_trace() hotword_scores = self.seaco_decoder.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1] # hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device) + pdb.set_trace() dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist() + pdb.set_trace() add_filter = dec_filter + pdb.set_trace() add_filter.append(len(hw_list_pad)-1) # filter hotword embedding + pdb.set_trace() selected = selected[add_filter] # again + pdb.set_trace() contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) + pdb.set_trace() num_hot_word = contextual_info.shape[1] _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) + pdb.set_trace() for dec in self.seaco_decoder.decoders: dec.attn_mat = [] dec.reserve_attn = False - + pdb.set_trace() # SeACo Core cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) + pdb.set_trace() dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) + pdb.set_trace() merged = self._merge(cif_attended, dec_attended) - + pdb.set_trace() + dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation + pdb.set_trace() dha_pred = torch.log_softmax(dha_output, dim=-1) + pdb.set_trace() def _merge_res(dec_output, dha_output): + pdb.set_trace() lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) + pdb.set_trace() dha_ids = dha_output.max(-1)[-1]# [0] + pdb.set_trace() dha_mask = (dha_ids == 8377).int().unsqueeze(-1) + pdb.set_trace() a = (1 - lmbd) / lmbd b = 1 / lmbd + pdb.set_trace() a, b = a.to(dec_output.device), b.to(dec_output.device) + pdb.set_trace() dha_mask = (dha_mask + a.reshape(-1, 1, 1)) / b.reshape(-1, 1, 1) # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) + pdb.set_trace() logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) return logits + merged_pred = _merge_res(decoder_pred, dha_pred) + pdb.set_trace() # import pdb; pdb.set_trace() return merged_pred else: @@ -318,7 +347,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - + pdb.set_trace() meta_data = {} # extract fbank feats @@ -326,6 +355,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer): audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" + pdb.set_trace() speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend) time3 = time.perf_counter() @@ -336,14 +366,18 @@ class SeacoParaformer(BiCifParaformer, Paraformer): speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) + pdb.set_trace() # hotword self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) + pdb.set_trace() # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] + + pdb.set_trace() # predictor predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ @@ -352,15 +386,16 @@ class SeacoParaformer(BiCifParaformer, Paraformer): if torch.max(pre_token_length) < 1: return [] - + pdb.set_trace() decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) + pdb.set_trace() # decoder_out, _ = decoder_outs[0], decoder_outs[1] _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, pre_token_length) - + pdb.set_trace() results = [] b, n, d = decoder_out.size() for i in range(b):