File size: 5,269 Bytes
8b56978
 
 
 
 
586d33c
8b56978
 
 
 
 
 
 
 
ab43ad3
 
 
8b56978
586d33c
8b56978
 
db3414b
 
8b56978
975b0fc
d6ef0da
6430849
 
928852e
8b56978
eed3e5c
11575f9
8b56978
6430849
d6ef0da
975b0fc
6430849
928852e
 
 
8b56978
 
4d4bd97
 
17e8a18
 
 
4d4bd97
 
8b56978
eed3e5c
db8bc30
8b56978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8211204
586d33c
8b56978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b25e4d1
8b56978
 
db8bc30
6718f50
e2ff674
db8bc30
 
 
 
 
bcf93ba
db8bc30
 
 
 
 
 
 
 
6718f50
8b56978
 
 
 
 
596dd7c
6718f50
 
 
 
 
64bb2f4
b3b1f2a
8b56978
f630f91
da3d4a0
9ad530b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio
import cv2
from PIL import Image
import numpy as np

import spaces
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import accelerate
import transformers
from random import randrange


from transformers.utils.hub import move_cache
move_cache()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model_id = "runwayml/stable-diffusion-v1-5"
model_id = "LuyangZ/FloorAI"
# model_id = "LuyangZ/controlnet_Neufert4_64_100"

# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype="auto")
# controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float32, force_download=True)
controlnet = ControlNetModel.from_pretrained(model_id, force_download=True)

controlnet.to(device)
torch.cuda.empty_cache()


# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float32,  force_download=True)
# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype="auto")
# pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id , controlnet=controlnet, torch_dtype=torch.float16)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(base_model_id, controlnet=controlnet, force_download=True)
pipeline.safety_checker = None
pipeline.requires_safety_checker = False

pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)



# pipeline.enable_xformers_memory_efficient_attention()
# pipeline.enable_model_cpu_offload()
# pipeline.enable_attention_slicing()


pipeline = pipeline.to(device)
torch.cuda.empty_cache()




def expand2square(ol_img, background_color):
    width, height = ol_img.size
    
    if width == height:
        pad = int(width*0.2)
        width_new = width + pad
        halfpad = int(pad/2)
        
        ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color)
        ol_result.paste(ol_img, (halfpad, halfpad))

        return ol_img
    
    elif width > height:
        
        pad = int(width*0.2)
        width_new = width + pad
        halfpad = int(pad/2)

        ol_result = Image.new(ol_img.mode, (width_new, width_new), background_color)
        ol_result.paste(ol_img, (halfpad, (width_new - height) // 2))

        return ol_result
    
    else:
        pad = int(height*0.2)
        height_new = height + pad
        halfpad = int(pad/2)

        ol_result = Image.new(ol_img.mode, (height_new, height_new), background_color)
        ol_result.paste(ol_img, ((height_new - width) // 2, halfpad))

        return ol_result

def clean_img(image, mask):
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    mask = cv2.threshold(mask, 250, 255, cv2.THRESH_BINARY_INV)[1]
    
    image[mask<250]=(255,255,255)
    image = Image.fromarray(image).convert('RGB')
    return image

@spaces.GPU
def floorplan_generation(outline, num_of_rooms):
    new_width = 512
    new_height = 512

    outline = cv2.cvtColor(outline, cv2.COLOR_RGB2BGR)
    outline_original = outline.copy()

    gray = cv2.cvtColor(outline, cv2.COLOR_BGR2GRAY)
    thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)[1]

    x,y,w,h = cv2.boundingRect(thresh)
    n_outline = outline_original[y:y+h, x:x+w]
    n_outline = cv2.cvtColor(n_outline, cv2.COLOR_BGR2RGB)
    n_outline = Image.fromarray(n_outline).convert('RGB')
    n_outline = expand2square(n_outline, (255, 255, 255))
    n_outline = n_outline.resize((new_width, new_height))

    num_of_rooms = str(num_of_rooms)
    validation_prompt = "floor plan, " + num_of_rooms + " rooms"
    validation_image = n_outline

    image_lst = []
    for i in range(5):
        seed = randrange(5000)
        generator = torch.Generator(device=device).manual_seed(seed)
    
    
        image = pipeline(validation_prompt, 
                         validation_image, 
                         num_inference_steps=20, 
                         generator=generator).images[0]
    
        image = np.array(image)
        mask = np.array(n_outline)
        mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
        image = clean_img(image, mask)
        image_lst.append(image)
        
    return image_lst[0], image_lst[1], image_lst[2], image_lst[3], image_lst[4]


gradio_interface = gradio.Interface(
  fn=floorplan_generation,
  inputs=[gradio.Image(label="Floor Plan Outline, Entrance"),
          gradio.Textbox(type="text", label="Number of Rooms", placeholder="Number of Rooms")],
  outputs=[gradio.Image(label="Generated Floor Plan 1"), 
           gradio.Image(label="Generated Floor Plan 2"),
           gradio.Image(label="Generated Floor Plan 3"),
           gradio.Image(label="Generated Floor Plan 4"),
           gradio.Image(label="Generated Floor Plan 5")],
  title="FloorAI",
  examples=[["example_1.png", "4"], ["example_2.png", "3"], ["example_3.png", "2"], ["example_4.png", "4"], ["example_5.png", "4"]])


gradio_interface.queue(max_size=10, status_update_rate="auto", api_open=True)
gradio_interface.launch(share=True, show_api=True, show_error=True)