yucornetto commited on
Commit
15a2a80
1 Parent(s): 995325f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -10,9 +10,7 @@ import argparse
10
  import demo_util
11
  import os
12
  import spaces
13
-
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
 
17
  model2ckpt = {
18
  "TiTok-L-32": ("tokenizer_titok_l32.bin", "generator_titok_l32.bin"),
@@ -32,19 +30,24 @@ parser.add_argument("--seed", type=int, default=42)
32
  parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
33
  args = parser.parse_args()
34
 
35
- config = demo_util.get_config("configs/titok_l32.yaml")
36
- print(config)
37
- titok_tokenizer = demo_util.get_titok_tokenizer(config)
38
- print(titok_tokenizer)
39
- titok_generator = demo_util.get_titok_generator(config)
40
- print(titok_generator)
41
 
42
- titok_tokenizer = titok_tokenizer.to(device)
43
- titok_generator = titok_generator.to(device)
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  @spaces.GPU
47
- def demo_infer(tokenizer,
48
  generator,
49
  guidance_scale, randomize_temperature, num_sample_steps,
50
  class_label, seed):
@@ -70,7 +73,8 @@ def demo_infer(tokenizer,
70
  samples = [Image.fromarray(sample) for sample in generated_image]
71
  return samples
72
 
73
-
 
74
  with gr.Blocks() as demo:
75
  gr.Markdown("<h1 style='text-align: center'>An Image is Worth 32 Tokens for Reconstruction and Generation</h1>")
76
 
@@ -92,7 +96,6 @@ with gr.Blocks() as demo:
92
  with gr.Column():
93
  output = gr.Gallery(label='Generated Images', height=700)
94
  button.click(demo_infer, inputs=[
95
- titok_tokenizer, titok_generator,
96
  guidance_scale, randomize_temperature, num_sample_steps,
97
  i1k_class, seed],
98
  outputs=[output])
 
10
  import demo_util
11
  import os
12
  import spaces
13
+ from functools import partial
 
 
14
 
15
  model2ckpt = {
16
  "TiTok-L-32": ("tokenizer_titok_l32.bin", "generator_titok_l32.bin"),
 
30
  parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
31
  args = parser.parse_args()
32
 
 
 
 
 
 
 
33
 
34
+ @spaces.GPU
35
+ def load_model():
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ config = demo_util.get_config("configs/titok_l32.yaml")
38
+ print(config)
39
+ titok_tokenizer = demo_util.get_titok_tokenizer(config)
40
+ print(titok_tokenizer)
41
+ titok_generator = demo_util.get_titok_generator(config)
42
+ print(titok_generator)
43
+
44
+ titok_tokenizer = titok_tokenizer.to(device)
45
+ titok_generator = titok_generator.to(device)
46
+ return titok_tokenizer, titok_generator
47
 
48
 
49
  @spaces.GPU
50
+ def demo_infer_(tokenizer,
51
  generator,
52
  guidance_scale, randomize_temperature, num_sample_steps,
53
  class_label, seed):
 
73
  samples = [Image.fromarray(sample) for sample in generated_image]
74
  return samples
75
 
76
+ titok_tokenizer, titok_generator = load_model()
77
+ demo_infer = partial(demo_infer_, tokenizer=titok_tokenizer, generator=titok_generator)
78
  with gr.Blocks() as demo:
79
  gr.Markdown("<h1 style='text-align: center'>An Image is Worth 32 Tokens for Reconstruction and Generation</h1>")
80
 
 
96
  with gr.Column():
97
  output = gr.Gallery(label='Generated Images', height=700)
98
  button.click(demo_infer, inputs=[
 
99
  guidance_scale, randomize_temperature, num_sample_steps,
100
  i1k_class, seed],
101
  outputs=[output])