AmitIsraeli commited on
Commit
f6ef98f
·
verified ·
1 Parent(s): 84da114

Update infrance_text_pop.py

Browse files
Files changed (1) hide show
  1. infrance_text_pop.py +4 -2
infrance_text_pop.py CHANGED
@@ -128,9 +128,11 @@ if __name__ == '__main__':
128
  import torch.nn as nn
129
 
130
  # Initialize the model
131
- pl_checkpoint = '/Users/mac/Downloads/model-step-step=35000.ckpt' # Replace with your actual checkpoint path
132
  device = 'mps'
133
- model = InrenceTextVAR(pl_checkpoint=pl_checkpoint, device=device)
 
 
134
  model.to(device)
135
 
136
  def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
 
128
  import torch.nn as nn
129
 
130
  # Initialize the model
131
+ checkpoint = 'VARtext_v1.pth' # Replace with your actual checkpoint path
132
  device = 'mps'
133
+ model = InrenceTextVAR(device=device)
134
+ state_dict = torch.load(checkpoint,map_location = "cpu")
135
+ model.load_state_dict(state_dict)
136
  model.to(device)
137
 
138
  def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):