jiuface commited on
Commit
ed27447
·
verified ·
1 Parent(s): c5a1c7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -42
app.py CHANGED
@@ -56,11 +56,47 @@ class calculateDuration:
56
 
57
  @spaces.GPU(duration=120)
58
  @torch.inference_mode()
59
- def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
60
- with calculateDuration(f"Make a new generator:${seed}"):
61
  pipe.to(device)
62
  generator = torch.Generator(device=device).manual_seed(seed)
63
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  with calculateDuration("Generating image"):
65
  # Generate image
66
  generated_image = pipe(
@@ -102,44 +138,6 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
102
  def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
103
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
104
  gr.Info("Starting process")
105
- # Load LoRA weights
106
- lora_configs = None
107
- if lora_strings_json:
108
- try:
109
- lora_configs = json.loads(lora_strings_json)
110
- except:
111
- gr.Warning("Parse lora config json failed")
112
- print("parse lora config json failed")
113
-
114
- if lora_configs:
115
- with calculateDuration("Loading LoRA weights"):
116
- active_adapters = pipe.get_active_adapters()
117
- print("get_active_adapters", active_adapters)
118
- adapter_names = []
119
- adapter_weights = []
120
- for lora_info in lora_configs:
121
- lora_repo = lora_info.get("repo")
122
- weights = lora_info.get("weights")
123
- adapter_name = lora_info.get("adapter_name")
124
- adapter_weight = lora_info.get("adapter_weight")
125
-
126
- adapter_names.append(adapter_name)
127
- adapter_weights.append(adapter_weight)
128
-
129
- if adapter_name in active_adapters:
130
- print(f"Adapter '{adapter_name}' is already loaded, skipping.")
131
- continue
132
- if lora_repo and weights and adapter_name:
133
- # load lora
134
- try:
135
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
136
- except ValueError as e:
137
- print(f"Error loading LoRA adapter: {e}")
138
- continue
139
-
140
- # set lora weights
141
- if len(adapter_names) > 0:
142
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
143
 
144
  # Set random seed for reproducibility
145
  if randomize_seed:
@@ -150,7 +148,7 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
150
  error_message = ""
151
  try:
152
  print("Start applying for zeroGPU resources")
153
- final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
154
  except Exception as e:
155
  error_message = str(e)
156
  gr.Error(error_message)
 
56
 
57
  @spaces.GPU(duration=120)
58
  @torch.inference_mode()
59
+ def generate_image(prompt, lora_strings_json, steps, seed, cfg_scale, width, height, progress):
60
+ with calculateDuration(f"Make a new generator:{seed}"):
61
  pipe.to(device)
62
  generator = torch.Generator(device=device).manual_seed(seed)
63
+
64
+ # Load LoRA weights
65
+ pipe.unload_lora_weights()
66
+ lora_configs = None
67
+ if lora_strings_json:
68
+ try:
69
+ lora_configs = json.loads(lora_strings_json)
70
+ except:
71
+ gr.Warning("Parse lora config json failed")
72
+ print("parse lora config json failed")
73
+
74
+ if lora_configs:
75
+ with calculateDuration("Loading LoRA weights"):
76
+ print("get_active_adapters", active_adapters)
77
+ adapter_names = []
78
+ adapter_weights = []
79
+ for lora_info in lora_configs:
80
+ lora_repo = lora_info.get("repo")
81
+ weights = lora_info.get("weights")
82
+ adapter_name = lora_info.get("adapter_name")
83
+ adapter_weight = lora_info.get("adapter_weight")
84
+
85
+ adapter_names.append(adapter_name)
86
+ adapter_weights.append(adapter_weight)
87
+
88
+ if lora_repo and weights and adapter_name:
89
+ # load lora
90
+ try:
91
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
92
+ except ValueError as e:
93
+ print(f"Error loading LoRA adapter: {e}")
94
+ continue
95
+
96
+ # set lora weights
97
+ if len(adapter_names) > 0:
98
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
99
+
100
  with calculateDuration("Generating image"):
101
  # Generate image
102
  generated_image = pipe(
 
138
  def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
139
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
140
  gr.Info("Starting process")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  # Set random seed for reproducibility
143
  if randomize_seed:
 
148
  error_message = ""
149
  try:
150
  print("Start applying for zeroGPU resources")
151
+ final_image = generate_image(prompt, lora_strings_json, steps, seed, cfg_scale, width, height, progress)
152
  except Exception as e:
153
  error_message = str(e)
154
  gr.Error(error_message)