From 5cf512419c282f833ee35a2f31890bff00d94343 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Thu, 7 Dec 2023 16:57:03 +0800 Subject: [PATCH] update with main (#1158) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * v0.8.7 * update cmd version * set openfst HAVE_BIN/HAVE_SCRIPT off for win32 * 修复为支持新版本的热词 (#1137) * update CMakeLists.txt * Revert "update CMakeLists.txt" This reverts commit 54bcd1f6742269fc1ce90d9871245db5cd6a1cbf. * rm log.h for wins-websocket * fix bug of websocket lock blocking * update funasr-wss-server * update model-revision by model name * update funasr-wss-server-2pass * 增加分角色语音识别对ERes2Net模型的支持。 * Update README.md (#1140) minor fix * automatically configure parameters such as decoder-thread-num * update docs * update docs * update docs * 分角色语音识别支持更多的模型 * update spk inference * remove never use code (#1151) * fix loss normalization for ddp training --------- Co-authored-by: 雾聪 Co-authored-by: 夜雨飘零 Co-authored-by: Ikko Eltociear Ashimine Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Co-authored-by: shixian.shi Co-authored-by: haoneng.lhn --- funasr/models/e2e_asr.py | 3 + .../models/e2e_asr_contextual_paraformer.py | 3 + funasr/models/e2e_asr_paraformer.py | 57 ++++++++++++------- funasr/models/e2e_uni_asr.py | 3 + 4 files changed, 46 insertions(+), 20 deletions(-) diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py index 79c5387ae..162bfba92 100644 --- a/funasr/models/e2e_asr.py +++ b/funasr/models/e2e_asr.py @@ -122,6 +122,7 @@ class ASRModel(FunASRModel): self.ctc = ctc self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + self.length_normalized_loss = length_normalized_loss def forward( self, @@ -220,6 +221,8 @@ class ASRModel(FunASRModel): stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + 1).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py index a2f7078ae..d4dc784ac 100644 --- a/funasr/models/e2e_asr_contextual_paraformer.py +++ b/funasr/models/e2e_asr_contextual_paraformer.py @@ -125,6 +125,7 @@ class NeatContextualParaformer(Paraformer): if self.crit_attn_weight > 0: self.attn_loss = torch.nn.L1Loss() self.crit_attn_smooth = crit_attn_smooth + self.length_normalized_loss = length_normalized_loss def forward( self, @@ -231,6 +232,8 @@ class NeatContextualParaformer(Paraformer): stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index e157454e4..b793d529c 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -137,6 +137,7 @@ class Paraformer(FunASRModel): self.predictor_bias = predictor_bias self.sampling_ratio = sampling_ratio self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) + self.length_normalized_loss = length_normalized_loss self.step_cur = 0 self.share_embedding = share_embedding @@ -253,6 +254,8 @@ class Paraformer(FunASRModel): stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight @@ -352,8 +355,9 @@ class Paraformer(FunASRModel): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask, - ignore_id=self.ignore_id) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, + encoder_out_mask, + ignore_id=self.ignore_id) return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): @@ -487,8 +491,9 @@ class Paraformer(FunASRModel): if self.step_cur < 2: logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) if self.use_1st_decoder_loss: - sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, - pre_acoustic_embeds) + sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, + ys_pad, ys_pad_lens, + pre_acoustic_embeds) else: sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds) @@ -727,6 +732,7 @@ class ParaformerOnline(Paraformer): self.predictor = predictor self.predictor_weight = predictor_weight self.predictor_bias = predictor_bias + self.length_normalized_loss = length_normalized_loss self.sampling_ratio = sampling_ratio self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) self.step_cur = 0 @@ -860,11 +866,13 @@ class ParaformerOnline(Paraformer): stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def encode( - self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, + self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: @@ -885,7 +893,7 @@ class ParaformerOnline(Paraformer): # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) - + # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) @@ -970,11 +978,11 @@ class ParaformerOnline(Paraformer): return encoder_out, torch.tensor([encoder_out.size(1)]) def _calc_att_predictor_loss( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ys_pad: torch.Tensor, - ys_pad_lens: torch.Tensor, + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, ): encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( encoder_out.device) @@ -1006,7 +1014,7 @@ class ParaformerOnline(Paraformer): attention_chunk_center_bias = 0 attention_chunk_size = encoder_chunk_size decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur - mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\ + mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \ get_mask_shift_att_chunk_decoder(None, device=encoder_out.device, batch_size=encoder_out.size(0) @@ -1106,7 +1114,8 @@ class ParaformerOnline(Paraformer): input_mask_expand_dim, 0) return sematic_embeds * tgt_mask, decoder_out * tgt_mask - def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None): + def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, + chunk_mask=None): tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device) ys_pad_masked = ys_pad * tgt_mask[:, :, 0] if self.share_embedding: @@ -1158,7 +1167,7 @@ class ParaformerOnline(Paraformer): target_label_length=None, ) predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas, - encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens) + encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens) scama_mask = None if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk': @@ -1166,7 +1175,7 @@ class ParaformerOnline(Paraformer): attention_chunk_center_bias = 0 attention_chunk_size = encoder_chunk_size decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur - mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\ + mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \ get_mask_shift_att_chunk_decoder(None, device=encoder_out.device, batch_size=encoder_out.size(0) @@ -1484,6 +1493,8 @@ class ParaformerBert(Paraformer): stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight @@ -1589,8 +1600,9 @@ class BiCifParaformer(Paraformer): if self.predictor_bias == 1: _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_pad_lens = ys_pad_lens + self.predictor_bias - pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, - ignore_id=self.ignore_id) + pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, + encoder_out_mask, + ignore_id=self.ignore_id) # 0. sampler decoder_out_1st = None @@ -1739,7 +1751,7 @@ class BiCifParaformer(Paraformer): loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + ( - 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 + 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None @@ -1752,6 +1764,8 @@ class BiCifParaformer(Paraformer): stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight @@ -1952,6 +1966,8 @@ class ContextualParaformer(Paraformer): stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight @@ -2107,7 +2123,8 @@ class ContextualParaformer(Paraformer): return loss_att, acc_att, cer_att, wer_att, loss_pre - def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0): + def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, + clas_scale=1.0): if hw_list is None: # default hotword list hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list @@ -2245,4 +2262,4 @@ class ContextualParaformer(Paraformer): "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape)) - return var_dict_torch_update + return var_dict_torch_update \ No newline at end of file diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py index 8bc3b4278..45d90f1e7 100644 --- a/funasr/models/e2e_uni_asr.py +++ b/funasr/models/e2e_uni_asr.py @@ -167,6 +167,7 @@ class UniASR(FunASRModel): self.enable_maas_finetune = enable_maas_finetune self.freeze_encoder2 = freeze_encoder2 self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training + self.length_normalized_loss = length_normalized_loss def forward( self, @@ -440,6 +441,8 @@ class UniASR(FunASRModel): stats["loss2"] = torch.clone(loss2.detach()) stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + 1).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight