EvanTHU commited on
Commit
eee8228
·
verified ·
1 Parent(s): 34c81fb

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +1 -1
models/unet.py CHANGED
@@ -25,7 +25,7 @@ class CustomLayerNorm(nn.LayerNorm):
25
  def replace_layer_norm(model):
26
  for name, module in model.named_children():
27
  if isinstance(module, nn.LayerNorm):
28
- setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine)).cuda()
29
  else:
30
  replace_layer_norm(module) # Recursively apply to all submodules
31
 
 
25
  def replace_layer_norm(model):
26
  for name, module in model.named_children():
27
  if isinstance(module, nn.LayerNorm):
28
+ setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine).cuda())
29
  else:
30
  replace_layer_norm(module) # Recursively apply to all submodules
31