mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add paraformer online infer and finetune
This commit is contained in:
parent
f630892863
commit
84b4a01979
@ -1618,6 +1618,8 @@ def inference_launch(**kwargs):
|
||||
return inference_uniasr(**kwargs)
|
||||
elif mode == "paraformer":
|
||||
return inference_paraformer(**kwargs)
|
||||
elif mode == "paraformer_online":
|
||||
return inference_paraformer(**kwargs)
|
||||
elif mode == "paraformer_streaming":
|
||||
return inference_paraformer_online(**kwargs)
|
||||
elif mode.startswith("paraformer_vad"):
|
||||
@ -1900,4 +1902,4 @@ def main(cmd=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -956,14 +956,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
|
||||
if chunk_mask is not None:
|
||||
memory_mask = memory_mask * chunk_mask
|
||||
if tgt_mask.size(1) != memory_mask.size(1):
|
||||
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
|
||||
|
||||
x = tgt
|
||||
x, tgt_mask, memory, memory_mask, _ = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
|
||||
@ -279,7 +279,7 @@ class Paraformer(FunASRModel):
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
|
||||
) -> Tuple[Tuple[Any, Optional[Any]], Any]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
@ -649,7 +649,35 @@ class ParaformerOnline(Paraformer):
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert 0.0 <= interctc_weight < 1.0, interctc_weight
|
||||
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
token_list=token_list,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
encoder=encoder,
|
||||
postencoder=postencoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
ctc_weight=ctc_weight,
|
||||
interctc_weight=interctc_weight,
|
||||
ignore_id=ignore_id,
|
||||
blank_id=blank_id,
|
||||
sos=sos,
|
||||
eos=eos,
|
||||
lsm_weight=lsm_weight,
|
||||
length_normalized_loss=length_normalized_loss,
|
||||
report_cer=report_cer,
|
||||
report_wer=report_wer,
|
||||
sym_space=sym_space,
|
||||
sym_blank=sym_blank,
|
||||
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
|
||||
predictor=predictor,
|
||||
predictor_weight=predictor_weight,
|
||||
predictor_bias=predictor_bias,
|
||||
sampling_ratio=sampling_ratio,
|
||||
)
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
self.blank_id = blank_id
|
||||
self.sos = vocab_size - 1 if sos is None else sos
|
||||
@ -705,6 +733,7 @@ class ParaformerOnline(Paraformer):
|
||||
self.sampling_ratio = sampling_ratio
|
||||
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
|
||||
self.step_cur = 0
|
||||
self.scama_mask = None
|
||||
if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
|
||||
from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
|
||||
self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
|
||||
@ -859,7 +888,7 @@ class ParaformerOnline(Paraformer):
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
|
||||
|
||||
# 4. Forward encoder
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
@ -1111,12 +1140,73 @@ class ParaformerOnline(Paraformer):
|
||||
|
||||
return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
|
||||
|
||||
def calc_predictor(self, encoder_out, encoder_out_lens):
|
||||
|
||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||
encoder_out.device)
|
||||
mask_chunk_predictor = None
|
||||
if self.encoder.overlap_chunk_cls is not None:
|
||||
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
|
||||
device=encoder_out.device,
|
||||
batch_size=encoder_out.size(
|
||||
0))
|
||||
mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
|
||||
batch_size=encoder_out.size(0))
|
||||
encoder_out = encoder_out * mask_shfit_chunk
|
||||
pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
|
||||
None,
|
||||
encoder_out_mask,
|
||||
ignore_id=self.ignore_id,
|
||||
mask_chunk_predictor=mask_chunk_predictor,
|
||||
target_label_length=None,
|
||||
)
|
||||
predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1],
|
||||
encoder_out_lens)
|
||||
|
||||
scama_mask = None
|
||||
if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
|
||||
encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
|
||||
attention_chunk_center_bias = 0
|
||||
attention_chunk_size = encoder_chunk_size
|
||||
decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
|
||||
mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
|
||||
get_mask_shift_att_chunk_decoder(None,
|
||||
device=encoder_out.device,
|
||||
batch_size=encoder_out.size(0)
|
||||
)
|
||||
scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
|
||||
predictor_alignments=predictor_alignments,
|
||||
encoder_sequence_length=encoder_out_lens,
|
||||
chunk_size=1,
|
||||
encoder_chunk_size=encoder_chunk_size,
|
||||
attention_chunk_center_bias=attention_chunk_center_bias,
|
||||
attention_chunk_size=attention_chunk_size,
|
||||
attention_chunk_type=self.decoder_attention_chunk_type,
|
||||
step=None,
|
||||
predictor_mask_chunk_hopping=mask_chunk_predictor,
|
||||
decoder_att_look_back_factor=decoder_att_look_back_factor,
|
||||
mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
|
||||
target_length=None,
|
||||
is_training=self.training,
|
||||
)
|
||||
self.scama_mask = scama_mask
|
||||
|
||||
return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
|
||||
|
||||
def calc_predictor_chunk(self, encoder_out, cache=None):
|
||||
|
||||
pre_acoustic_embeds, pre_token_length = \
|
||||
self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
||||
return pre_acoustic_embeds, pre_token_length
|
||||
|
||||
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
||||
decoder_outs = self.decoder(
|
||||
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
|
||||
)
|
||||
decoder_out = decoder_outs[0]
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out, ys_pad_lens
|
||||
|
||||
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
||||
decoder_outs = self.decoder.forward_chunk(
|
||||
encoder_out, sematic_embeds, cache["decoder"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user