bugfix memory leaky

This commit is contained in:
游雁 2024-09-25 15:26:14 +08:00
parent d20c030e5b
commit fc547e14e8
5 changed files with 38 additions and 32 deletions

View File

@ -78,19 +78,19 @@ class MultiHeadedAttentionReturnWeight(nn.Module):
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x), self.attn # (batch, time1, d_model)
return self.linear_out(x), attn # (batch, time1, d_model)
def forward(self, query, key, value, mask):
"""Compute scaled dot product attention.

View File

@ -55,8 +55,8 @@ class MultiHeadedAttentionSANMExport(nn.Module):
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@ -134,8 +134,8 @@ class MultiHeadedAttentionCrossAttExport(nn.Module):
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@ -177,8 +177,8 @@ class OnnxMultiHeadedAttention(nn.Module):
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@ -232,8 +232,8 @@ class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

View File

@ -196,13 +196,13 @@ class MultiHeadedAttentionSANM(nn.Module):
"inf"
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@ -644,7 +644,13 @@ class SenseVoiceSmall(nn.Module):
self.embed = torch.nn.Embedding(
7 + len(self.lid_dict) + len(self.textnorm_dict), input_size
)
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
self.emo_dict = {
"unk": 25009,
"happy": 25001,
"sad": 25002,
"angry": 25003,
"neutral": 25004,
}
self.criterion_att = LabelSmoothingLoss(
size=self.vocab_size,
@ -874,7 +880,7 @@ class SenseVoiceSmall(nn.Module):
ctc_logits = self.ctc.log_softmax(encoder_out)
if kwargs.get("ban_emo_unk", False):
ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf")
results = []
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):

View File

@ -84,13 +84,13 @@ class MultiHeadedAttention(nn.Module):
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@ -287,13 +287,13 @@ class MultiHeadSelfAttention(nn.Module):
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)

View File

@ -87,13 +87,13 @@ class MultiHeadedAttention(nn.Module):
"inf"
) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@ -154,8 +154,8 @@ class MultiHeadedAttentionExport(nn.Module):
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@ -209,8 +209,8 @@ class RelPosMultiHeadedAttentionExport(MultiHeadedAttentionExport):
def forward_attention(self, value, scores, mask):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@ -575,9 +575,9 @@ class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
if chunk_mask is not None:
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
scores = scores.masked_fill(mask, float("-inf"))
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
attn_output = self.dropout(self.attn)
attn_output = self.dropout(attn)
attn_output = torch.matmul(attn_output, value)
attn_output = self.linear_out(