mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #23 from alibaba-damo-academy/dev
fix uniasr inference bug
This commit is contained in:
commit
cbce36ed3c
@ -215,14 +215,14 @@ class Speech2Text:
|
||||
lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
|
||||
# lengths: (1,)
|
||||
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
||||
speech_raw = speech.clone().to(self.device)
|
||||
if self.frontend is not None:
|
||||
feats, feats_len = self.frontend.forward(speech, lengths)
|
||||
feats = to_device(feats, device=self.device)
|
||||
feats_len = feats_len.int()
|
||||
else:
|
||||
feats = speech_raw
|
||||
feats = speech
|
||||
feats_len = lengths
|
||||
feats_raw = feats.clone().to(self.device)
|
||||
batch = {"speech": feats, "speech_lengths": feats_len}
|
||||
|
||||
# a. To device
|
||||
@ -235,7 +235,7 @@ class Speech2Text:
|
||||
if self.decoding_mode == "model1":
|
||||
predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
|
||||
else:
|
||||
enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind)
|
||||
enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
|
||||
predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
|
||||
|
||||
scama_mask = predictor_outs[4]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user