mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update with main (#1158)
* v0.8.7
* update cmd version
* set openfst HAVE_BIN/HAVE_SCRIPT off for win32
* 修复为支持新版本的热词 (#1137)
* update CMakeLists.txt
* Revert "update CMakeLists.txt"
This reverts commit 54bcd1f674.
* rm log.h for wins-websocket
* fix bug of websocket lock blocking
* update funasr-wss-server
* update model-revision by model name
* update funasr-wss-server-2pass
* 增加分角色语音识别对ERes2Net模型的支持。
* Update README.md (#1140)
minor fix
* automatically configure parameters such as decoder-thread-num
* update docs
* update docs
* update docs
* 分角色语音识别支持更多的模型
* update spk inference
* remove never use code (#1151)
* fix loss normalization for ddp training
---------
Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com>
Co-authored-by: 夜雨飘零 <yeyupiaoling@foxmail.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
This commit is contained in:
parent
fc246ab820
commit
5cf512419c
@ -122,6 +122,7 @@ class ASRModel(FunASRModel):
|
||||
self.ctc = ctc
|
||||
|
||||
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -220,6 +221,8 @@ class ASRModel(FunASRModel):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + 1).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
@ -125,6 +125,7 @@ class NeatContextualParaformer(Paraformer):
|
||||
if self.crit_attn_weight > 0:
|
||||
self.attn_loss = torch.nn.L1Loss()
|
||||
self.crit_attn_smooth = crit_attn_smooth
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -231,6 +232,8 @@ class NeatContextualParaformer(Paraformer):
|
||||
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
@ -137,6 +137,7 @@ class Paraformer(FunASRModel):
|
||||
self.predictor_bias = predictor_bias
|
||||
self.sampling_ratio = sampling_ratio
|
||||
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.step_cur = 0
|
||||
|
||||
self.share_embedding = share_embedding
|
||||
@ -253,6 +254,8 @@ class Paraformer(FunASRModel):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -352,8 +355,9 @@ class Paraformer(FunASRModel):
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
|
||||
encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
||||
@ -487,8 +491,9 @@ class Paraformer(FunASRModel):
|
||||
if self.step_cur < 2:
|
||||
logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
|
||||
if self.use_1st_decoder_loss:
|
||||
sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens,
|
||||
ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
else:
|
||||
sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
|
||||
pre_acoustic_embeds)
|
||||
@ -727,6 +732,7 @@ class ParaformerOnline(Paraformer):
|
||||
self.predictor = predictor
|
||||
self.predictor_weight = predictor_weight
|
||||
self.predictor_bias = predictor_bias
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.sampling_ratio = sampling_ratio
|
||||
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
|
||||
self.step_cur = 0
|
||||
@ -860,11 +866,13 @@ class ParaformerOnline(Paraformer):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
@ -885,7 +893,7 @@ class ParaformerOnline(Paraformer):
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
@ -970,11 +978,11 @@ class ParaformerOnline(Paraformer):
|
||||
return encoder_out, torch.tensor([encoder_out.size(1)])
|
||||
|
||||
def _calc_att_predictor_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
@ -1006,7 +1014,7 @@ class ParaformerOnline(Paraformer):
|
||||
attention_chunk_center_bias = 0
|
||||
attention_chunk_size = encoder_chunk_size
|
||||
decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
|
||||
mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
|
||||
mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
|
||||
get_mask_shift_att_chunk_decoder(None,
|
||||
device=encoder_out.device,
|
||||
batch_size=encoder_out.size(0)
|
||||
@ -1106,7 +1114,8 @@ class ParaformerOnline(Paraformer):
|
||||
input_mask_expand_dim, 0)
|
||||
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
|
||||
|
||||
def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
|
||||
def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds,
|
||||
chunk_mask=None):
|
||||
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
|
||||
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
|
||||
if self.share_embedding:
|
||||
@ -1158,7 +1167,7 @@ class ParaformerOnline(Paraformer):
|
||||
target_label_length=None,
|
||||
)
|
||||
predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
|
||||
encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
|
||||
encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
|
||||
|
||||
scama_mask = None
|
||||
if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
|
||||
@ -1166,7 +1175,7 @@ class ParaformerOnline(Paraformer):
|
||||
attention_chunk_center_bias = 0
|
||||
attention_chunk_size = encoder_chunk_size
|
||||
decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
|
||||
mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
|
||||
mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
|
||||
get_mask_shift_att_chunk_decoder(None,
|
||||
device=encoder_out.device,
|
||||
batch_size=encoder_out.size(0)
|
||||
@ -1484,6 +1493,8 @@ class ParaformerBert(Paraformer):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -1589,8 +1600,9 @@ class BiCifParaformer(Paraformer):
|
||||
if self.predictor_bias == 1:
|
||||
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_pad_lens = ys_pad_lens + self.predictor_bias
|
||||
pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
|
||||
encoder_out_mask,
|
||||
ignore_id=self.ignore_id)
|
||||
|
||||
# 0. sampler
|
||||
decoder_out_1st = None
|
||||
@ -1739,7 +1751,7 @@ class BiCifParaformer(Paraformer):
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (
|
||||
1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
|
||||
1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
|
||||
|
||||
# Collect Attn branch stats
|
||||
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||
@ -1752,6 +1764,8 @@ class BiCifParaformer(Paraformer):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -1952,6 +1966,8 @@ class ContextualParaformer(Paraformer):
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
@ -2107,7 +2123,8 @@ class ContextualParaformer(Paraformer):
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att, loss_pre
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
|
||||
clas_scale=1.0):
|
||||
if hw_list is None:
|
||||
# default hotword list
|
||||
hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
|
||||
@ -2245,4 +2262,4 @@ class ContextualParaformer(Paraformer):
|
||||
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
return var_dict_torch_update
|
||||
@ -167,6 +167,7 @@ class UniASR(FunASRModel):
|
||||
self.enable_maas_finetune = enable_maas_finetune
|
||||
self.freeze_encoder2 = freeze_encoder2
|
||||
self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -440,6 +441,8 @@ class UniASR(FunASRModel):
|
||||
stats["loss2"] = torch.clone(loss2.detach())
|
||||
stats["loss"] = torch.clone(loss.detach())
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = (text_lengths + 1).sum().type_as(batch_size)
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user