mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
auto frontend
This commit is contained in:
parent
545d69ae92
commit
a6441441cb
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@ -119,9 +121,8 @@ class Transformer(nn.Module):
|
||||
x = self.linear2(x)
|
||||
|
||||
olens = None
|
||||
if ilens is not None:
|
||||
olens = (ilens - 1) // self.k + 1
|
||||
mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
||||
olens = (ilens - 1) // self.k + 1
|
||||
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x, masks = block(x, masks)
|
||||
return x, olens
|
||||
|
||||
@ -621,7 +621,6 @@ class Trainer:
|
||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
def forward_step(self, model, batch, loss_dict={}):
|
||||
dtype = torch.bfloat16
|
||||
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
|
||||
retval = model(**batch)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user