LuyangZ commited on
Commit
6718f50
·
verified ·
1 Parent(s): db8bc30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -18,11 +18,13 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
  base_model_id = "runwayml/stable-diffusion-v1-5"
19
  model_id = "LuyangZ/FloorAI"
20
 
21
- controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
 
22
  controlnet.to(device)
23
  torch.cuda.empty_cache()
24
 
25
- pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float32)
 
26
  pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
27
 
28
  pipeline = pipeline.to(device)
@@ -95,7 +97,7 @@ def floorplan_generation(outline, num_of_rooms):
95
  validation_image = n_outline
96
 
97
  image_lst = []
98
- for i in range(2):
99
  seed = randrange(500)
100
  generator = torch.Generator(device=device).manual_seed(seed)
101
 
@@ -111,15 +113,19 @@ def floorplan_generation(outline, num_of_rooms):
111
  image = clean_img(image, mask)
112
  image_lst.append(image)
113
 
114
- return image_lst[0], image_lst[1]
115
 
116
 
117
  gradio_interface = gradio.Interface(
118
  fn=floorplan_generation,
119
  inputs=[gradio.Image(label="Floor Plan Outline, Entrance"),
120
  gradio.Textbox(type="text", label="number of rooms", placeholder="number of rooms")],
121
- outputs=[gradio.Image(label="Generated Floor Plan 1"), gradio.Image(label="Generated Floor Plan 2")],
122
- title="floorplan generation")
 
 
 
 
123
 
124
 
125
  gradio_interface.queue(max_size=10, status_update_rate="auto")
 
18
  base_model_id = "runwayml/stable-diffusion-v1-5"
19
  model_id = "LuyangZ/FloorAI"
20
 
21
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype="auto")
22
+ # controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32)
23
  controlnet.to(device)
24
  torch.cuda.empty_cache()
25
 
26
+ # pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float32)
27
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype="auto")
28
  pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
29
 
30
  pipeline = pipeline.to(device)
 
97
  validation_image = n_outline
98
 
99
  image_lst = []
100
+ for i in range(5):
101
  seed = randrange(500)
102
  generator = torch.Generator(device=device).manual_seed(seed)
103
 
 
113
  image = clean_img(image, mask)
114
  image_lst.append(image)
115
 
116
+ return image_lst[0], image_lst[1], image_lst[2], image_lst[3], image_lst[4]
117
 
118
 
119
  gradio_interface = gradio.Interface(
120
  fn=floorplan_generation,
121
  inputs=[gradio.Image(label="Floor Plan Outline, Entrance"),
122
  gradio.Textbox(type="text", label="number of rooms", placeholder="number of rooms")],
123
+ outputs=[gradio.Image(label="Generated Floor Plan 1"),
124
+ gradio.Image(label="Generated Floor Plan 2"),
125
+ gradio.Image(label="Generated Floor Plan 3"),
126
+ gradio.Image(label="Generated Floor Plan 4"),
127
+ gradio.Image(label="Generated Floor Plan 5")],
128
+ title="FloorAI")
129
 
130
 
131
  gradio_interface.queue(max_size=10, status_update_rate="auto")