mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix uniasr inference bug
This commit is contained in:
parent
fa61ffd3dd
commit
32650335eb
@ -105,48 +105,48 @@ class DecoderLayerSANM(nn.Module):
|
||||
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
#def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
||||
# """Compute decoded features.
|
||||
def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
||||
"""Compute decoded features.
|
||||
|
||||
# Args:
|
||||
# tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
# tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
# memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
# memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
# cache (List[torch.Tensor]): List of cached tensors.
|
||||
# Each tensor shape should be (#batch, maxlen_out - 1, size).
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
cache (List[torch.Tensor]): List of cached tensors.
|
||||
Each tensor shape should be (#batch, maxlen_out - 1, size).
|
||||
|
||||
# Returns:
|
||||
# torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
# torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
# torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
# torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
# """
|
||||
# # tgt = self.dropout(tgt)
|
||||
# residual = tgt
|
||||
# if self.normalize_before:
|
||||
# tgt = self.norm1(tgt)
|
||||
# tgt = self.feed_forward(tgt)
|
||||
"""
|
||||
# tgt = self.dropout(tgt)
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
tgt = self.feed_forward(tgt)
|
||||
|
||||
# x = tgt
|
||||
# if self.self_attn:
|
||||
# if self.normalize_before:
|
||||
# tgt = self.norm2(tgt)
|
||||
# if self.training:
|
||||
# cache = None
|
||||
# x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
|
||||
# x = residual + self.dropout(x)
|
||||
x = tgt
|
||||
if self.self_attn:
|
||||
if self.normalize_before:
|
||||
tgt = self.norm2(tgt)
|
||||
if self.training:
|
||||
cache = None
|
||||
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
|
||||
x = residual + self.dropout(x)
|
||||
|
||||
# if self.src_attn is not None:
|
||||
# residual = x
|
||||
# if self.normalize_before:
|
||||
# x = self.norm3(x)
|
||||
if self.src_attn is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
# x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
|
||||
|
||||
|
||||
# return x, tgt_mask, memory, memory_mask, cache
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
|
||||
"""Compute decoded features.
|
||||
@ -438,7 +438,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
for i in range(self.att_layer_num):
|
||||
decoder = self.decoders[i]
|
||||
c = cache[i]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
|
||||
x, tgt_mask, memory, memory_mask, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
@ -448,13 +448,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
j = i + self.att_layer_num
|
||||
decoder = self.decoders2[i]
|
||||
c = cache[j]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
|
||||
x, tgt_mask, memory, memory_mask, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
|
||||
for decoder in self.decoders3:
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
|
||||
x, tgt_mask, memory, None, cache=None
|
||||
)
|
||||
|
||||
@ -878,6 +878,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
lora_rank: int = 8,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.1,
|
||||
chunk_multiply_factor: tuple = (1,),
|
||||
tf2torch_tensor_name_prefix_torch: str = "decoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
|
||||
):
|
||||
@ -970,6 +971,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
)
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.chunk_multiply_factor = chunk_multiply_factor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1190,7 +1192,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
for i in range(self.att_layer_num):
|
||||
decoder = self.decoders[i]
|
||||
c = cache[i]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
|
||||
x, tgt_mask, memory, None, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
@ -1200,14 +1202,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
j = i + self.att_layer_num
|
||||
decoder = self.decoders2[i]
|
||||
c = cache[j]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
|
||||
x, tgt_mask, memory, None, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
|
||||
for decoder in self.decoders3:
|
||||
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
|
||||
x, tgt_mask, memory, None, cache=None
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user