Merge pull request #23 from alibaba-damo-academy/dev

fix uniasr inference bug
This commit is contained in:
hnluo 2022-12-29 14:25:10 +08:00 committed by GitHub
commit cbce36ed3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -215,14 +215,14 @@ class Speech2Text:
lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
# lengths: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
speech_raw = speech.clone().to(self.device)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
else:
feats = speech_raw
feats = speech
feats_len = lengths
feats_raw = feats.clone().to(self.device)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
@ -235,7 +235,7 @@ class Speech2Text:
if self.decoding_mode == "model1":
predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
else:
enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind)
enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
scama_mask = predictor_outs[4]