Update infrance_text_pop.py
Browse files- infrance_text_pop.py +4 -2
infrance_text_pop.py
CHANGED
@@ -128,9 +128,11 @@ if __name__ == '__main__':
|
|
128 |
import torch.nn as nn
|
129 |
|
130 |
# Initialize the model
|
131 |
-
|
132 |
device = 'mps'
|
133 |
-
model = InrenceTextVAR(
|
|
|
|
|
134 |
model.to(device)
|
135 |
|
136 |
def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
|
|
|
128 |
import torch.nn as nn
|
129 |
|
130 |
# Initialize the model
|
131 |
+
checkpoint = 'VARtext_v1.pth' # Replace with your actual checkpoint path
|
132 |
device = 'mps'
|
133 |
+
model = InrenceTextVAR(device=device)
|
134 |
+
state_dict = torch.load(checkpoint,map_location = "cpu")
|
135 |
+
model.load_state_dict(state_dict)
|
136 |
model.to(device)
|
137 |
|
138 |
def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
|