Merge pull request #441 from alibaba-damo-academy/dev_lhn

update
This commit is contained in:
hnluo 2023-04-28 16:38:03 +08:00 committed by GitHub
commit 6013d3c4a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 1 deletions

View File

@ -205,9 +205,12 @@ class Speech2Text:
results = [] results = []
cache_en = cache["encoder"] cache_en = cache["encoder"]
if speech.shape[1] < 16 * 60 and cache_en["is_final"]: if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
if cache_en["start_idx"] == 0:
return []
cache_en["tail_chunk"] = True cache_en["tail_chunk"] = True
feats = cache_en["feats"] feats = cache_en["feats"]
feats_len = torch.tensor([feats.shape[1]]) feats_len = torch.tensor([feats.shape[1]])
self.asr_model.frontend = None
results = self.infer(feats, feats_len, cache) results = self.infer(feats, feats_len, cache)
return results return results
else: else:

View File

@ -380,7 +380,7 @@ class SANMEncoder(AbsEncoder):
else: else:
xs_pad = self.embed(xs_pad, cache) xs_pad = self.embed(xs_pad, cache)
if cache["tail_chunk"]: if cache["tail_chunk"]:
xs_pad = cache["feats"] xs_pad = to_device(cache["feats"], device=xs_pad.device)
else: else:
xs_pad = self._add_overlap_chunk(xs_pad, cache) xs_pad = self._add_overlap_chunk(xs_pad, cache)
encoder_outs = self.encoders0(xs_pad, None, None, None, None) encoder_outs = self.encoders0(xs_pad, None, None, None, None)