ironjr commited on
Commit
e43c1c3
1 Parent(s): 7c18416

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -86,8 +86,8 @@ opt = parser.parse_args()
86
 
87
  ### Global variables and data structures
88
 
89
- device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
90
-
91
 
92
  if opt.model is None:
93
  model_dict = {
@@ -98,9 +98,8 @@ else:
98
  opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
99
  model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
100
 
101
- dtype = torch.float32 if device == 'cpu' else torch.float16
102
  models = {
103
- k: StableMultiDiffusion3Pipeline(device, dtype=dtype, hf_key=v, has_i2t=False)
104
  for k, v in model_dict.items()
105
  }
106
 
 
86
 
87
  ### Global variables and data structures
88
 
89
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
+ print(device)
91
 
92
  if opt.model is None:
93
  model_dict = {
 
98
  opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
99
  model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
100
 
 
101
  models = {
102
+ k: StableMultiDiffusion3Pipeline(device, hf_key=v, has_i2t=False).cuda()
103
  for k, v in model_dict.items()
104
  }
105