add paraformer online infer and finetune

This commit is contained in:
haoneng.lhn 2023-05-26 11:43:27 +08:00
parent f630892863
commit 84b4a01979
3 changed files with 99 additions and 7 deletions

View File

@ -1618,6 +1618,8 @@ def inference_launch(**kwargs):
return inference_uniasr(**kwargs)
elif mode == "paraformer":
return inference_paraformer(**kwargs)
elif mode == "paraformer_online":
return inference_paraformer(**kwargs)
elif mode == "paraformer_streaming":
return inference_paraformer_online(**kwargs)
elif mode.startswith("paraformer_vad"):
@ -1900,4 +1902,4 @@ def main(cmd=None):
if __name__ == "__main__":
main()
main()

View File

@ -956,14 +956,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
memory_mask = memory_mask * chunk_mask
if tgt_mask.size(1) != memory_mask.size(1):
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.decoders(
x, tgt_mask, memory, memory_mask

View File

@ -279,7 +279,7 @@ class Paraformer(FunASRModel):
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
) -> Tuple[Tuple[Any, Optional[Any]], Any]:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
@ -649,7 +649,35 @@ class ParaformerOnline(Paraformer):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__()
super().__init__(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,
blank_id=blank_id,
sos=sos,
eos=eos,
lsm_weight=lsm_weight,
length_normalized_loss=length_normalized_loss,
report_cer=report_cer,
report_wer=report_wer,
sym_space=sym_space,
sym_blank=sym_blank,
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
predictor=predictor,
predictor_weight=predictor_weight,
predictor_bias=predictor_bias,
sampling_ratio=sampling_ratio,
)
# note that eos is the same as sos (equivalent ID)
self.blank_id = blank_id
self.sos = vocab_size - 1 if sos is None else sos
@ -705,6 +733,7 @@ class ParaformerOnline(Paraformer):
self.sampling_ratio = sampling_ratio
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
self.step_cur = 0
self.scama_mask = None
if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
@ -859,7 +888,7 @@ class ParaformerOnline(Paraformer):
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@ -1111,12 +1140,73 @@ class ParaformerOnline(Paraformer):
return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
mask_chunk_predictor = None
if self.encoder.overlap_chunk_cls is not None:
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
device=encoder_out.device,
batch_size=encoder_out.size(
0))
mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
batch_size=encoder_out.size(0))
encoder_out = encoder_out * mask_shfit_chunk
pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
None,
encoder_out_mask,
ignore_id=self.ignore_id,
mask_chunk_predictor=mask_chunk_predictor,
target_label_length=None,
)
predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1],
encoder_out_lens)
scama_mask = None
if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
attention_chunk_center_bias = 0
attention_chunk_size = encoder_chunk_size
decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
get_mask_shift_att_chunk_decoder(None,
device=encoder_out.device,
batch_size=encoder_out.size(0)
)
scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
predictor_alignments=predictor_alignments,
encoder_sequence_length=encoder_out_lens,
chunk_size=1,
encoder_chunk_size=encoder_chunk_size,
attention_chunk_center_bias=attention_chunk_center_bias,
attention_chunk_size=attention_chunk_size,
attention_chunk_type=self.decoder_attention_chunk_type,
step=None,
predictor_mask_chunk_hopping=mask_chunk_predictor,
decoder_att_look_back_factor=decoder_att_look_back_factor,
mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
target_length=None,
is_training=self.training,
)
self.scama_mask = scama_mask
return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
def calc_predictor_chunk(self, encoder_out, cache=None):
pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"])
return pre_acoustic_embeds, pre_token_length
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]