jiuface commited on
Commit
1241278
·
verified ·
1 Parent(s): 8f770b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -38
app.py CHANGED
@@ -56,46 +56,14 @@ class calculateDuration:
56
 
57
  @spaces.GPU(duration=120)
58
  @torch.inference_mode()
59
- def generate_image(prompt, lora_strings_json, steps, seed, cfg_scale, width, height, progress):
 
 
 
60
  with calculateDuration(f"Make a new generator:{seed}"):
61
  pipe.to(device)
62
  generator = torch.Generator(device=device).manual_seed(seed)
63
 
64
- # Load LoRA weights
65
- pipe.unload_lora_weights()
66
- lora_configs = None
67
- if lora_strings_json:
68
- try:
69
- lora_configs = json.loads(lora_strings_json)
70
- except:
71
- gr.Warning("Parse lora config json failed")
72
- print("parse lora config json failed")
73
-
74
- if lora_configs:
75
- with calculateDuration("Loading LoRA weights"):
76
- adapter_names = []
77
- adapter_weights = []
78
- for lora_info in lora_configs:
79
- lora_repo = lora_info.get("repo")
80
- weights = lora_info.get("weights")
81
- adapter_name = lora_info.get("adapter_name")
82
- adapter_weight = lora_info.get("adapter_weight")
83
-
84
- adapter_names.append(adapter_name)
85
- adapter_weights.append(adapter_weight)
86
-
87
- if lora_repo and weights and adapter_name:
88
- # load lora
89
- try:
90
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
91
- except ValueError as e:
92
- print(f"Error loading LoRA adapter: {e}")
93
- continue
94
-
95
- # set lora weights
96
- if len(adapter_names) > 0:
97
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
98
-
99
  with calculateDuration("Generating image"):
100
  # Generate image
101
  generated_image = pipe(
@@ -142,12 +110,48 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
142
  if randomize_seed:
143
  with calculateDuration("Set random seed"):
144
  seed = random.randint(0, MAX_SEED)
145
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # Generate image
147
  error_message = ""
148
  try:
149
  print("Start applying for zeroGPU resources")
150
- final_image = generate_image(prompt, lora_strings_json, steps, seed, cfg_scale, width, height, progress)
151
  except Exception as e:
152
  error_message = str(e)
153
  gr.Error(error_message)
 
56
 
57
  @spaces.GPU(duration=120)
58
  @torch.inference_mode()
59
+ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
60
+
61
+ gr.Info("Start to generate images ...")
62
+
63
  with calculateDuration(f"Make a new generator:{seed}"):
64
  pipe.to(device)
65
  generator = torch.Generator(device=device).manual_seed(seed)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with calculateDuration("Generating image"):
68
  # Generate image
69
  generated_image = pipe(
 
110
  if randomize_seed:
111
  with calculateDuration("Set random seed"):
112
  seed = random.randint(0, MAX_SEED)
113
+
114
+ # Load LoRA weights
115
+ gr.Info("Start to load loras ...")
116
+ pipe.unload_lora_weights()
117
+ lora_configs = None
118
+ if lora_strings_json:
119
+ try:
120
+ lora_configs = json.loads(lora_strings_json)
121
+ except:
122
+ gr.Warning("Parse lora config json failed")
123
+ print("parse lora config json failed")
124
+
125
+ if lora_configs:
126
+ with calculateDuration("Loading LoRA weights"):
127
+ adapter_names = []
128
+ adapter_weights = []
129
+ for lora_info in lora_configs:
130
+ lora_repo = lora_info.get("repo")
131
+ weights = lora_info.get("weights")
132
+ adapter_name = lora_info.get("adapter_name")
133
+ adapter_weight = lora_info.get("adapter_weight")
134
+
135
+ adapter_names.append(adapter_name)
136
+ adapter_weights.append(adapter_weight)
137
+
138
+ if lora_repo and weights and adapter_name:
139
+ # load lora
140
+ try:
141
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
142
+ except ValueError as e:
143
+ print(f"Error loading LoRA adapter: {e}")
144
+ continue
145
+
146
+ # set lora weights
147
+ if len(adapter_names) > 0:
148
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
149
+
150
  # Generate image
151
  error_message = ""
152
  try:
153
  print("Start applying for zeroGPU resources")
154
+ final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
155
  except Exception as e:
156
  error_message = str(e)
157
  gr.Error(error_message)