mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
auto frontend
This commit is contained in:
parent
545d69ae92
commit
a6441441cb
@ -1,5 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||||
|
|
||||||
from funasr.register import tables
|
from funasr.register import tables
|
||||||
|
|
||||||
@ -119,9 +121,8 @@ class Transformer(nn.Module):
|
|||||||
x = self.linear2(x)
|
x = self.linear2(x)
|
||||||
|
|
||||||
olens = None
|
olens = None
|
||||||
if ilens is not None:
|
|
||||||
olens = (ilens - 1) // self.k + 1
|
olens = (ilens - 1) // self.k + 1
|
||||||
mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
x, masks = block(x, masks)
|
x, masks = block(x, masks)
|
||||||
return x, olens
|
return x, olens
|
||||||
|
|||||||
@ -621,7 +621,6 @@ class Trainer:
|
|||||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||||
|
|
||||||
def forward_step(self, model, batch, loss_dict={}):
|
def forward_step(self, model, batch, loss_dict={}):
|
||||||
dtype = torch.bfloat16
|
|
||||||
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
|
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
|
||||||
retval = model(**batch)
|
retval = model(**batch)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user