mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix paraformer online last chunk decoding strategy
This commit is contained in:
parent
33693c4182
commit
a7814a7bc3
@ -762,23 +762,6 @@ class Speech2TextParaformerOnline:
|
||||
feats_len = speech_lengths
|
||||
|
||||
if feats.shape[1] != 0:
|
||||
if cache_en["is_final"]:
|
||||
if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]:
|
||||
cache_en["last_chunk"] = True
|
||||
else:
|
||||
# first chunk
|
||||
feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :]
|
||||
feats_len = torch.tensor([feats_chunk1.shape[1]])
|
||||
results_chunk1 = self.infer(feats_chunk1, feats_len, cache)
|
||||
|
||||
# last chunk
|
||||
cache_en["last_chunk"] = True
|
||||
feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :]
|
||||
feats_len = torch.tensor([feats_chunk2.shape[1]])
|
||||
results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
|
||||
|
||||
return [" ".join(results_chunk1 + results_chunk2)]
|
||||
|
||||
results = self.infer(feats, feats_len, cache)
|
||||
|
||||
return results
|
||||
|
||||
@ -355,18 +355,9 @@ class SANMEncoder(AbsEncoder):
|
||||
def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
|
||||
if len(cache) == 0:
|
||||
return feats
|
||||
# process last chunk
|
||||
cache["feats"] = to_device(cache["feats"], device=feats.device)
|
||||
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
|
||||
if cache["is_final"]:
|
||||
cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
|
||||
if not cache["last_chunk"]:
|
||||
padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
|
||||
overlap_feats = overlap_feats.transpose(1, 2)
|
||||
overlap_feats = F.pad(overlap_feats, (0, padding_length))
|
||||
overlap_feats = overlap_feats.transpose(1, 2)
|
||||
else:
|
||||
cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
|
||||
cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
|
||||
return overlap_feats
|
||||
|
||||
def forward_chunk(self,
|
||||
|
||||
@ -221,13 +221,14 @@ class CifPredictorV2(nn.Module):
|
||||
|
||||
if cache is not None and "chunk_size" in cache:
|
||||
alphas[:, :cache["chunk_size"][0]] = 0.0
|
||||
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
||||
if "is_final" in cache and not cache["is_final"]:
|
||||
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
||||
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
||||
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
|
||||
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
|
||||
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
|
||||
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
|
||||
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
|
||||
if cache is not None and "is_final" in cache and cache["is_final"]:
|
||||
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
|
||||
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
|
||||
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user