fix uniasr inference bug

This commit is contained in:
haoneng.lhn 2023-09-21 16:11:48 +08:00
parent fa61ffd3dd
commit 32650335eb

View File

@ -105,48 +105,48 @@ class DecoderLayerSANM(nn.Module):
return x, tgt_mask, memory, memory_mask, cache
#def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
# """Compute decoded features.
def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
"""Compute decoded features.
# Args:
# tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
# tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
# memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
# memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
# cache (List[torch.Tensor]): List of cached tensors.
# Each tensor shape should be (#batch, maxlen_out - 1, size).
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
# Returns:
# torch.Tensor: Output tensor(#batch, maxlen_out, size).
# torch.Tensor: Mask for output tensor (#batch, maxlen_out).
# torch.Tensor: Encoded memory (#batch, maxlen_in, size).
# torch.Tensor: Encoded memory mask (#batch, maxlen_in).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
# """
# # tgt = self.dropout(tgt)
# residual = tgt
# if self.normalize_before:
# tgt = self.norm1(tgt)
# tgt = self.feed_forward(tgt)
"""
# tgt = self.dropout(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
# x = tgt
# if self.self_attn:
# if self.normalize_before:
# tgt = self.norm2(tgt)
# if self.training:
# cache = None
# x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
# x = residual + self.dropout(x)
x = tgt
if self.self_attn:
if self.normalize_before:
tgt = self.norm2(tgt)
if self.training:
cache = None
x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
x = residual + self.dropout(x)
# if self.src_attn is not None:
# residual = x
# if self.normalize_before:
# x = self.norm3(x)
if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm3(x)
# x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
# return x, tgt_mask, memory, memory_mask, cache
return x, tgt_mask, memory, memory_mask, cache
def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
"""Compute decoded features.
@ -438,7 +438,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
@ -448,13 +448,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=None
)
@ -878,6 +878,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
lora_rank: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.1,
chunk_multiply_factor: tuple = (1,),
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
@ -970,6 +971,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
)
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.chunk_multiply_factor = chunk_multiply_factor
def forward(
self,
@ -1190,7 +1192,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
@ -1200,14 +1202,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=None
)