Spaces:
Running
on
Zero
Running
on
Zero
Update models/unet.py
Browse files- models/unet.py +1 -1
models/unet.py
CHANGED
@@ -25,7 +25,7 @@ class CustomLayerNorm(nn.LayerNorm):
|
|
25 |
def replace_layer_norm(model):
|
26 |
for name, module in model.named_children():
|
27 |
if isinstance(module, nn.LayerNorm):
|
28 |
-
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine))
|
29 |
else:
|
30 |
replace_layer_norm(module) # Recursively apply to all submodules
|
31 |
|
|
|
25 |
def replace_layer_norm(model):
|
26 |
for name, module in model.named_children():
|
27 |
if isinstance(module, nn.LayerNorm):
|
28 |
+
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine)).cuda()
|
29 |
else:
|
30 |
replace_layer_norm(module) # Recursively apply to all submodules
|
31 |
|