auto frontend

This commit is contained in:
游雁 2024-06-05 17:22:21 +08:00
parent a6441441cb
commit 79b09f1d67

View File

@ -27,9 +27,24 @@ class ModelDimensions:
n_text_layer: int
# class LayerNorm(nn.LayerNorm):
# def forward(self, x: Tensor) -> Tensor:
# return super().forward(x.float()).type(x.dtype)
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class Linear(nn.Linear):