File size: 9,416 Bytes
0c83406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5eed04
0c83406
 
 
 
 
 
 
 
 
 
 
a822226
 
 
 
 
954f9a3
a5eed04
0c83406
 
954f9a3
a5eed04
0c83406
 
 
 
a822226
0c83406
 
 
954f9a3
0c83406
 
 
 
 
 
 
a822226
0c83406
 
 
 
 
a822226
0c83406
 
 
a822226
0c83406
a5eed04
 
 
0c83406
 
a5eed04
0c83406
 
 
 
 
a5eed04
 
0c83406
 
a5eed04
 
 
 
 
 
 
 
 
 
0c83406
 
a5eed04
0c83406
a5eed04
 
0c83406
 
 
a5eed04
0c83406
a5eed04
0c83406
 
 
a5eed04
 
0c83406
 
 
 
 
 
 
 
a5eed04
0c83406
 
 
 
ba1943f
c0e580a
 
d340330
ba1943f
c0e580a
a5eed04
 
c0e580a
 
0c83406
 
 
 
 
a5eed04
 
 
 
 
 
 
0c83406
a5eed04
0c83406
a5eed04
 
0c83406
a5eed04
 
 
 
 
 
 
 
 
0c83406
a5eed04
 
0c83406
a5eed04
0c83406
a5eed04
 
0c83406
a5eed04
0c83406
 
a5eed04
0c83406
 
 
 
 
 
 
 
a5eed04
0c83406
a5eed04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c83406
 
 
a5eed04
0c83406
 
 
 
a5eed04
0c83406
a5eed04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from __future__ import annotations
import pathlib
import gradio as gr
import torch
import os
import PIL
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from typing import Any

from transformers import (
    CLIPTextModelWithProjection,
    CLIPVisionModelWithProjection,
    CLIPImageProcessor,
    CLIPTokenizer,
)

from transformers import CLIPTokenizer
from src.priors.lambda_prior_transformer import (
    PriorTransformer,
)  # original huggingface prior transformer without time conditioning
from src.pipelines.pipeline_kandinsky_subject_prior import KandinskyPriorPipeline

from diffusers import DiffusionPipeline
from PIL import Image

__device__ = "cpu"
__dtype__ = torch.float32
if torch.cuda.is_available():
    __device__ = "cuda"
    __dtype__ = torch.float16


class Model:
    def __init__(self):
        self.device = __device__

        self.text_encoder = (
            CLIPTextModelWithProjection.from_pretrained(
                "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
                projection_dim=1280,
                torch_dtype=__dtype__,
            )
            .eval()
            .requires_grad_(False)
        ).to(self.device)

        self.tokenizer = CLIPTokenizer.from_pretrained(
            "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
        )

        prior = PriorTransformer.from_pretrained(
            "ECLIPSE-Community/Lambda-ECLIPSE-Prior-v1.0",
            torch_dtype=__dtype__,
        )

        self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
            "kandinsky-community/kandinsky-2-2-prior",
            prior=prior,
            torch_dtype=__dtype__,
        ).to(self.device)

        self.pipe = DiffusionPipeline.from_pretrained(
            "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
        ).to(self.device)

    def inference(self, raw_data, seed):
        generator = torch.Generator(device="cuda").manual_seed(seed)
        image_emb, negative_image_emb = self.pipe_prior(
            raw_data=raw_data,
            generator=generator,
        ).to_tuple()
        image = self.pipe(
            image_embeds=image_emb,
            negative_image_embeds=negative_image_emb,
            num_inference_steps=50,
            guidance_scale=7.5,
            generator=generator,
        ).images[0]
        return image

    def run(
        self,
        image: dict[str, PIL.Image.Image],
        keyword: str,
        image2: dict[str, PIL.Image.Image],
        keyword2: str,
        text: str,
        seed: int,
    ):
        sub_imgs = [image["composite"]]
        sun_keywords = [keyword]
        if keyword2 and keyword2 != "no subject":
            sun_keywords.append(keyword2)
            if image2:
                sub_imgs.append(image2["composite"])
        raw_data = {
            "prompt": text,
            "subject_images": sub_imgs,
            "subject_keywords": sun_keywords,
        }
        image = self.inference(raw_data, seed)
        return image


def create_demo():
    USAGE = """## To run the demo, you should:   
    1. Upload your image.   
    2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
    3. Input a Keyword i.e. 'Dog'
    4. For MultiSubject personalization,
     4-1. Upload another image.
     4-2. Input the Keyword i.e. 'Sunglasses'   
    3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.   
    4. Click the Run button.
    """

    model = Model()

    with gr.Blocks() as demo:
        gr.HTML(
            """<h1 style="text-align: center;"><b><i>λ-ECLIPSE</i>: Multi-Concept Personalized Text-to-Image Diffusion Models by Leveraging CLIP Latent Space</b></h1>
            <h1 style='text-align: center;'><a href='https://eclipse-t2i.github.io/Lambda-ECLIPSE/'>Project Page</a> | <a href='#'>Paper</a>  </h1>
            
            <p style="text-align: center; color: red;">This demo is currently hosted on either a small GPU or CPU. We will soon provide high-end GPU support.</p>
            <p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>

            <a href="https://colab.research.google.com/drive/1VcqzXZmilntec3AsIyzCqlstEhX4Pa1o?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
            """
        )
        gr.Markdown(USAGE)
        with gr.Row():
            with gr.Column():
                with gr.Group():
                    gr.Markdown(
                        "Upload your first masked subject image or mask out marginal space"
                    )
                    image = gr.ImageEditor(
                        label="Input",
                        type="pil",
                        brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
                    )
                    keyword = gr.Text(
                        label="Keyword",
                        placeholder='e.g. "Dog", "Goofie"',
                        info="Keyword for first subject",
                    )
                    gr.Markdown(
                        "For Multi-Subject generation : Upload your second masked subject image or mask out marginal space"
                    )
                    image2 = gr.ImageEditor(
                        label="Input",
                        type="pil",
                        brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
                    )
                    keyword2 = gr.Text(
                        label="Keyword",
                        placeholder='e.g. "Sunglasses", "Grand Canyon"',
                        info="Keyword for second subject",
                    )
                    prompt = gr.Text(
                        label="Prompt",
                        placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
                        info="Keep the keywords used previously in the prompt",
                    )

                run_button = gr.Button("Run")

            with gr.Column():
                result = gr.Image(label="Result")

        inputs = [
            image,
            keyword,
            image2,
            keyword2,
            prompt,
        ]

        gr.Examples(
            examples=[
                [
                    os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
                    "luffy",
                    os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
                    "no subject",
                    "luffy holding a sword",
                ],
                [
                    os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
                    "luffy",
                    os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
                    "no subject",
                    "luffy in the living room",
                ],
                [
                    os.path.join(os.path.dirname(__file__), "./assets/teapot.jpg"),
                    "teapot",
                    os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
                    "no subject",
                    "teapot on a cobblestone street",
                ],
                [
                    os.path.join(os.path.dirname(__file__), "./assets/trex.jpg"),
                    "trex",
                    os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
                    "no subject",
                    "trex near a river",
                ],
                [
                    os.path.join(os.path.dirname(__file__), "./assets/cat.png"),
                    "cat",
                    os.path.join(
                        os.path.dirname(__file__), "./assets/blue_sunglasses.png"
                    ),
                    "glasses",
                    "A cat wearing glasses on a snowy field",
                ],
                [
                    os.path.join(os.path.dirname(__file__), "./assets/statue.jpg"),
                    "statue",
                    os.path.join(os.path.dirname(__file__), "./assets/toilet.jpg"),
                    "toilet",
                    "statue sitting on a toilet",
                ],
                [
                    os.path.join(os.path.dirname(__file__), "./assets/teddy.jpg"),
                    "teddy",
                    os.path.join(os.path.dirname(__file__), "./assets/luffy_hat.jpg"),
                    "hat",
                    "a teddy wearing the hat at a beach",
                ],
                [
                    os.path.join(os.path.dirname(__file__), "./assets/chair.jpg"),
                    "chair",
                    os.path.join(os.path.dirname(__file__), "./assets/table.jpg"),
                    "table",
                    "a chair and table in living room",
                ],
            ],
            inputs=inputs,
            fn=model.run,
            outputs=result,
        )

        run_button.click(fn=model.run, inputs=inputs, outputs=result)
    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.queue(max_size=20).launch()