mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
auto frontend
This commit is contained in:
parent
a6441441cb
commit
79b09f1d67
@ -27,9 +27,24 @@ class ModelDimensions:
|
|||||||
n_text_layer: int
|
n_text_layer: int
|
||||||
|
|
||||||
|
|
||||||
|
# class LayerNorm(nn.LayerNorm):
|
||||||
|
# def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def __init__(self, *args, **kwargs):
|
||||||
return super().forward(x.float()).type(x.dtype)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = F.layer_norm(
|
||||||
|
input.float(),
|
||||||
|
self.normalized_shape,
|
||||||
|
self.weight.float() if self.weight is not None else None,
|
||||||
|
self.bias.float() if self.bias is not None else None,
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
|
return output.type_as(input)
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Linear):
|
class Linear(nn.Linear):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user