Fabrice-TIERCELIN commited on
Commit
6f1239e
·
verified ·
1 Parent(s): fc362b8

Comment more code

Browse files
Files changed (1) hide show
  1. app.py +70 -102
app.py CHANGED
@@ -94,81 +94,79 @@ class Tango:
94
  return outputs
95
  return list(self.chunks(outputs, samples))
96
 
97
- # Initialize TANGO
98
-
99
- tango = Tango(device = "cpu")
100
- tango.vae.to(device_type)
101
- tango.stft.to(device_type)
102
- tango.model.to(device_type)
103
-
104
- def update_seed(is_randomize_seed, seed):
105
- if is_randomize_seed:
106
- return random.randint(0, max_64_bit_int)
107
- return seed
108
-
109
- def check(
110
- prompt,
111
- output_number,
112
- steps,
113
- guidance,
114
- is_randomize_seed,
115
- seed
116
- ):
117
- if prompt is None or prompt == "":
118
- raise gr.Error("Please provide a prompt input.")
119
- if not output_number in [1, 2, 3]:
120
- raise gr.Error("Please ask for 1, 2 or 3 output files.")
121
-
122
- def update_output(output_format, output_number):
123
- return [
124
- gr.update(format = output_format),
125
- gr.update(format = output_format, visible = (2 <= output_number)),
126
- gr.update(format = output_format, visible = (output_number == 3)),
127
- gr.update(visible = False)
128
- ]
129
-
130
- def text2audio(
131
- prompt,
132
- output_number,
133
- steps,
134
- guidance,
135
- is_randomize_seed,
136
- seed
137
- ):
138
- start = time.time()
139
-
140
- if seed is None:
141
- seed = random.randint(0, max_64_bit_int)
142
-
143
- random.seed(seed)
144
- torch.manual_seed(seed)
145
-
146
- output_wave = tango.generate(prompt, steps, guidance, output_number)
147
-
148
- output_wave_1 = gr.make_waveform((16000, output_wave[0]))
149
- output_wave_2 = gr.make_waveform((16000, output_wave[1])) if (2 <= output_number) else None
150
- output_wave_3 = gr.make_waveform((16000, output_wave[2])) if (output_number == 3) else None
151
-
152
- end = time.time()
153
- secondes = int(end - start)
154
- minutes = secondes // 60
155
- secondes = secondes - (minutes * 60)
156
- hours = minutes // 60
157
- minutes = minutes - (hours * 60)
158
- return [
159
- output_wave_1,
160
- output_wave_2,
161
- output_wave_3,
162
- gr.update(visible = True, value = "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec.")
163
- ]
164
-
165
- if is_space_imported:
166
- text2audio = spaces.GPU(text2audio, duration = 420)
167
 
168
  # Old code
169
  net=BriaRMBG()
170
- # model_path = "./model1.pth"
171
- #model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
172
  model_path = hf_hub_download("cocktailpeanut/gbmr", 'model.pth')
173
  if torch.cuda.is_available():
174
  net.load_state_dict(torch.load(model_path))
@@ -220,36 +218,8 @@ def process(image):
220
  # paste the mask on the original image
221
  new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
222
  new_im.paste(orig_image, mask=pil_im)
223
- # new_orig_image = orig_image.convert('RGBA')
224
 
225
  return new_im
226
- # return [new_orig_image, new_im]
227
-
228
-
229
- # block = gr.Blocks().queue()
230
-
231
- # with block:
232
- # gr.Markdown("## BRIA RMBG 1.4")
233
- # gr.HTML('''
234
- # <p style="margin-bottom: 10px; font-size: 94%">
235
- # This is a demo for BRIA RMBG 1.4 that using
236
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
237
- # </p>
238
- # ''')
239
- # with gr.Row():
240
- # with gr.Column():
241
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
242
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
243
- # run_button = gr.Button(value="Run")
244
-
245
- # with gr.Column():
246
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
247
- # ips = [input_image]
248
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
249
-
250
- # block.launch(debug = True)
251
-
252
- # block = gr.Blocks().queue()
253
 
254
  gr.Markdown("## BRIA RMBG 1.4")
255
  gr.HTML('''
@@ -263,8 +233,6 @@ description = r"""Background removal model developed by <a href='https://BRIA.AI
263
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
264
  """
265
  examples = [['./input.jpg'],]
266
- # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
267
- # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
268
  demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
269
 
270
  if __name__ == "__main__":
 
94
  return outputs
95
  return list(self.chunks(outputs, samples))
96
 
97
+ ## Initialize TANGO
98
+ #
99
+ #tango = Tango(device = "cpu")
100
+ #tango.vae.to(device_type)
101
+ #tango.stft.to(device_type)
102
+ #tango.model.to(device_type)
103
+ #
104
+ #def update_seed(is_randomize_seed, seed):
105
+ # if is_randomize_seed:
106
+ # return random.randint(0, max_64_bit_int)
107
+ # return seed
108
+ #
109
+ #def check(
110
+ # prompt,
111
+ # output_number,
112
+ # steps,
113
+ # guidance,
114
+ # is_randomize_seed,
115
+ # seed
116
+ #):
117
+ # if prompt is None or prompt == "":
118
+ # raise gr.Error("Please provide a prompt input.")
119
+ # if not output_number in [1, 2, 3]:
120
+ # raise gr.Error("Please ask for 1, 2 or 3 output files.")
121
+ #
122
+ #def update_output(output_format, output_number):
123
+ # return [
124
+ # gr.update(format = output_format),
125
+ # gr.update(format = output_format, visible = (2 <= output_number)),
126
+ # gr.update(format = output_format, visible = (output_number == 3)),
127
+ # gr.update(visible = False)
128
+ # ]
129
+ #
130
+ #def text2audio(
131
+ # prompt,
132
+ # output_number,
133
+ # steps,
134
+ # guidance,
135
+ # is_randomize_seed,
136
+ # seed
137
+ #):
138
+ # start = time.time()
139
+ #
140
+ # if seed is None:
141
+ # seed = random.randint(0, max_64_bit_int)
142
+ #
143
+ # random.seed(seed)
144
+ # torch.manual_seed(seed)
145
+ #
146
+ # output_wave = tango.generate(prompt, steps, guidance, output_number)
147
+ #
148
+ # output_wave_1 = gr.make_waveform((16000, output_wave[0]))
149
+ # output_wave_2 = gr.make_waveform((16000, output_wave[1])) if (2 <= output_number) else None
150
+ # output_wave_3 = gr.make_waveform((16000, output_wave[2])) if (output_number == 3) else None
151
+ #
152
+ # end = time.time()
153
+ # secondes = int(end - start)
154
+ # minutes = secondes // 60
155
+ # secondes = secondes - (minutes * 60)
156
+ # hours = minutes // 60
157
+ # minutes = minutes - (hours * 60)
158
+ # return [
159
+ # output_wave_1,
160
+ # output_wave_2,
161
+ # output_wave_3,
162
+ # gr.update(visible = True, value = "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec.")
163
+ # ]
164
+ #
165
+ #if is_space_imported:
166
+ # text2audio = spaces.GPU(text2audio, duration = 420)
167
 
168
  # Old code
169
  net=BriaRMBG()
 
 
170
  model_path = hf_hub_download("cocktailpeanut/gbmr", 'model.pth')
171
  if torch.cuda.is_available():
172
  net.load_state_dict(torch.load(model_path))
 
218
  # paste the mask on the original image
219
  new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
220
  new_im.paste(orig_image, mask=pil_im)
 
221
 
222
  return new_im
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  gr.Markdown("## BRIA RMBG 1.4")
225
  gr.HTML('''
 
233
  For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
234
  """
235
  examples = [['./input.jpg'],]
 
 
236
  demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
237
 
238
  if __name__ == "__main__":