Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -23,6 +23,7 @@ from transformers import (
|
|
23 |
T5Tokenizer
|
24 |
)
|
25 |
from accelerate import init_empty_weights
|
|
|
26 |
from safetensors import safe_open
|
27 |
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -50,8 +51,14 @@ with safe_open(model_file, framework="pt") as f:
|
|
50 |
state_dict[key] = f.get_tensor(key)
|
51 |
|
52 |
state_dict = convert_sd3_transformer_checkpoint_to_diffusers(state_dict)
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# Try to keep memory usage down
|
57 |
del state_dict
|
|
|
23 |
T5Tokenizer
|
24 |
)
|
25 |
from accelerate import init_empty_weights
|
26 |
+
from accelerate.utils import set_module_tensor_to_device
|
27 |
from safetensors import safe_open
|
28 |
|
29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
51 |
state_dict[key] = f.get_tensor(key)
|
52 |
|
53 |
state_dict = convert_sd3_transformer_checkpoint_to_diffusers(state_dict)
|
54 |
+
for key, value in state_dict.items():
|
55 |
+
set_module_tensor_to_device(
|
56 |
+
transformer,
|
57 |
+
key,
|
58 |
+
device,
|
59 |
+
value=value,
|
60 |
+
dtype=torch_dtype
|
61 |
+
)
|
62 |
|
63 |
# Try to keep memory usage down
|
64 |
del state_dict
|