mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix bug, 1 fix cuda oom, 2 fix choose a window size 400 that is [2, 0] (#2075)
Co-authored-by: nixonjin <nixonjin@tencent.com>
This commit is contained in:
parent
e8f535f533
commit
1af68ba6ff
@ -134,7 +134,7 @@ class WavFrontend(nn.Module):
|
|||||||
mat = kaldi.fbank(
|
mat = kaldi.fbank(
|
||||||
waveform,
|
waveform,
|
||||||
num_mel_bins=self.n_mels,
|
num_mel_bins=self.n_mels,
|
||||||
frame_length=self.frame_length,
|
frame_length=min(self.frame_length,waveform_length/self.fs*1000),
|
||||||
frame_shift=self.frame_shift,
|
frame_shift=self.frame_shift,
|
||||||
dither=self.dither,
|
dither=self.dither,
|
||||||
energy_floor=0.0,
|
energy_floor=0.0,
|
||||||
|
|||||||
@ -104,13 +104,13 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
"inf"
|
"inf"
|
||||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||||
scores = scores.masked_fill(mask, min_value)
|
scores = scores.masked_fill(mask, min_value)
|
||||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||||
mask, 0.0
|
mask, 0.0
|
||||||
) # (batch, head, time1, time2)
|
) # (batch, head, time1, time2)
|
||||||
else:
|
else:
|
||||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||||
|
|
||||||
p_attn = self.dropout(self.attn)
|
p_attn = self.dropout(attn)
|
||||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||||
x = (
|
x = (
|
||||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||||
@ -191,7 +191,7 @@ class MultiHeadedAttentionSANM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||||
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
||||||
self.attn = None
|
attn = None
|
||||||
self.dropout = nn.Dropout(p=dropout_rate)
|
self.dropout = nn.Dropout(p=dropout_rate)
|
||||||
|
|
||||||
self.fsmn_block = nn.Conv1d(
|
self.fsmn_block = nn.Conv1d(
|
||||||
@ -275,13 +275,13 @@ class MultiHeadedAttentionSANM(nn.Module):
|
|||||||
"inf"
|
"inf"
|
||||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||||
scores = scores.masked_fill(mask, min_value)
|
scores = scores.masked_fill(mask, min_value)
|
||||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||||
mask, 0.0
|
mask, 0.0
|
||||||
) # (batch, head, time1, time2)
|
) # (batch, head, time1, time2)
|
||||||
else:
|
else:
|
||||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||||
|
|
||||||
p_attn = self.dropout(self.attn)
|
p_attn = self.dropout(attn)
|
||||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||||
x = (
|
x = (
|
||||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||||
@ -400,8 +400,8 @@ class MultiHeadedAttentionSANMExport(nn.Module):
|
|||||||
def forward_attention(self, value, scores, mask):
|
def forward_attention(self, value, scores, mask):
|
||||||
scores = scores + mask
|
scores = scores + mask
|
||||||
|
|
||||||
self.attn = torch.softmax(scores, dim=-1)
|
attn = torch.softmax(scores, dim=-1)
|
||||||
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
|
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
@ -459,8 +459,8 @@ class MultiHeadedAttentionSANMExport(nn.Module):
|
|||||||
def forward_attention(self, value, scores, mask):
|
def forward_attention(self, value, scores, mask):
|
||||||
scores = scores + mask
|
scores = scores + mask
|
||||||
|
|
||||||
self.attn = torch.softmax(scores, dim=-1)
|
attn = torch.softmax(scores, dim=-1)
|
||||||
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
|
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
@ -683,18 +683,18 @@ class MultiHeadedAttentionCrossAtt(nn.Module):
|
|||||||
# logging.info(
|
# logging.info(
|
||||||
# "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
|
# "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
|
||||||
scores = scores.masked_fill(mask, min_value)
|
scores = scores.masked_fill(mask, min_value)
|
||||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||||
mask, 0.0
|
mask, 0.0
|
||||||
) # (batch, head, time1, time2)
|
) # (batch, head, time1, time2)
|
||||||
else:
|
else:
|
||||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||||
p_attn = self.dropout(self.attn)
|
p_attn = self.dropout(attn)
|
||||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||||
x = (
|
x = (
|
||||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||||
) # (batch, time1, d_model)
|
) # (batch, time1, d_model)
|
||||||
if ret_attn:
|
if ret_attn:
|
||||||
return self.linear_out(x), self.attn # (batch, time1, d_model)
|
return self.linear_out(x), attn # (batch, time1, d_model)
|
||||||
return self.linear_out(x) # (batch, time1, d_model)
|
return self.linear_out(x) # (batch, time1, d_model)
|
||||||
|
|
||||||
def forward(self, x, memory, memory_mask, ret_attn=False):
|
def forward(self, x, memory, memory_mask, ret_attn=False):
|
||||||
@ -782,14 +782,14 @@ class MultiHeadedAttentionCrossAttExport(nn.Module):
|
|||||||
def forward_attention(self, value, scores, mask, ret_attn):
|
def forward_attention(self, value, scores, mask, ret_attn):
|
||||||
scores = scores + mask.to(scores.device)
|
scores = scores + mask.to(scores.device)
|
||||||
|
|
||||||
self.attn = torch.softmax(scores, dim=-1)
|
attn = torch.softmax(scores, dim=-1)
|
||||||
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
|
context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
if ret_attn:
|
if ret_attn:
|
||||||
return self.linear_out(context_layer), self.attn
|
return self.linear_out(context_layer), attn
|
||||||
return self.linear_out(context_layer) # (batch, time1, d_model)
|
return self.linear_out(context_layer) # (batch, time1, d_model)
|
||||||
|
|
||||||
|
|
||||||
@ -868,13 +868,13 @@ class MultiHeadSelfAttention(nn.Module):
|
|||||||
"inf"
|
"inf"
|
||||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||||
scores = scores.masked_fill(mask, min_value)
|
scores = scores.masked_fill(mask, min_value)
|
||||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||||
mask, 0.0
|
mask, 0.0
|
||||||
) # (batch, head, time1, time2)
|
) # (batch, head, time1, time2)
|
||||||
else:
|
else:
|
||||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||||
|
|
||||||
p_attn = self.dropout(self.attn)
|
p_attn = self.dropout(attn)
|
||||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||||
x = (
|
x = (
|
||||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user