benjamin-paine commited on
Commit
6cf8b24
·
verified ·
1 Parent(s): d28d8ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -23,6 +23,7 @@ from transformers import (
23
  T5Tokenizer
24
  )
25
  from accelerate import init_empty_weights
 
26
  from safetensors import safe_open
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -50,8 +51,14 @@ with safe_open(model_file, framework="pt") as f:
50
  state_dict[key] = f.get_tensor(key)
51
 
52
  state_dict = convert_sd3_transformer_checkpoint_to_diffusers(state_dict)
53
- transformer.load_state_dict(state_dict)
54
- transformer.to_empty(device=device)
 
 
 
 
 
 
55
 
56
  # Try to keep memory usage down
57
  del state_dict
 
23
  T5Tokenizer
24
  )
25
  from accelerate import init_empty_weights
26
+ from accelerate.utils import set_module_tensor_to_device
27
  from safetensors import safe_open
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
51
  state_dict[key] = f.get_tensor(key)
52
 
53
  state_dict = convert_sd3_transformer_checkpoint_to_diffusers(state_dict)
54
+ for key, value in state_dict.items():
55
+ set_module_tensor_to_device(
56
+ transformer,
57
+ key,
58
+ device,
59
+ value=value,
60
+ dtype=torch_dtype
61
+ )
62
 
63
  # Try to keep memory usage down
64
  del state_dict