auto frontend

This commit is contained in:
游雁 2024-06-05 16:54:33 +08:00
parent 545d69ae92
commit a6441441cb
2 changed files with 4 additions and 4 deletions

View File

@ -1,5 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@ -119,9 +121,8 @@ class Transformer(nn.Module):
x = self.linear2(x)
olens = None
if ilens is not None:
olens = (ilens - 1) // self.k + 1
mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
for layer, block in enumerate(self.blocks):
x, masks = block(x, masks)
return x, olens

View File

@ -621,7 +621,6 @@ class Trainer:
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
def forward_step(self, model, batch, loss_dict={}):
dtype = torch.bfloat16
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
retval = model(**batch)