Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,988 Bytes
956fa05 85b09dd a1124c1 85b09dd 956fa05 31a0f6f e7204ee 956fa05 31a0f6f e7204ee 4e76f82 30ad9cf 12fa528 4e76f82 12fa528 31a0f6f f1aa060 31a0f6f 680331e 85b09dd 680331e a1124c1 31a0f6f ad93a8b 31a0f6f 956fa05 e7204ee de81f33 a1124c1 de81f33 a1124c1 de81f33 a1124c1 680331e a1124c1 85b09dd a1124c1 956fa05 64fe77f 31a0f6f 956fa05 e7204ee 956fa05 31a0f6f 956fa05 31a0f6f 956fa05 e9d4032 3245b5c f1aa060 e7204ee 346cb40 956fa05 e7204ee 956fa05 e9d4032 956fa05 a1124c1 e9d4032 956fa05 749fdab 7e38241 749fdab 7e38241 a5d42f0 1945d3f e9d4032 749fdab 7e38241 1945d3f 7e38241 e9d4032 749fdab e9d4032 7e38241 31a0f6f ebac435 31a0f6f 85b09dd a1124c1 31a0f6f 8f5be51 31a0f6f 956fa05 e9d4032 749fdab e9d4032 956fa05 e7204ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
import gradio as gr
import torch
from diffusers import (
DiffusionPipeline,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
EulerDiscreteScheduler,
UNet2DConditionModel,
StableDiffusion3Pipeline,
FluxPipeline
)
from transformers import BlipProcessor, BlipForConditionalGeneration
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
import stone
import os
import spaces
access_token = os.getenv("AccessTokenSD3")
from huggingface_hub import login
login(token = access_token)
# Define model initialization functions
def load_model(model_name):
if model_name == "stabilityai/sdxl-turbo":
pipeline = DiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
elif model_name == "ByteDance/SDXL-Lightning":
base = "stabilityai/stable-diffusion-xl-base-1.0"
ckpt = "sdxl_lightning_4step_unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(model_name, ckpt), device="cuda"))
pipeline = StableDiffusionXLPipeline.from_pretrained(
base,
unet=unet,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
elif model_name == "segmind/SSD-1B":
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
).to("cuda")
elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_name,
torch_dtype=torch.float16
).to("cuda")
elif model_name == "stabilityai/stable-diffusion-2":
scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
scheduler=scheduler,
torch_dtype=torch.float16
).to("cuda")
elif model_name == "black-forest-labs/FLUX.1-dev":
pipeline = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
pipeline.enable_model_cpu_offload()
else:
raise ValueError("Unknown model name")
return pipeline
# Initialize the default model
default_model = "stabilityai/stable-diffusion-3-medium-diffusers"
pipeline_text2image = load_model(default_model)
@spaces.GPU
def getimgen(prompt, model_name):
if model_name == "stabilityai/sdxl-turbo":
return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2, height=512, width=512).images[0]
elif model_name == "ByteDance/SDXL-Lightning":
return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0, height=512, width=512).images[0]
elif model_name == "segmind/SSD-1B":
neg_prompt = "ugly, blurry, poor quality"
return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt, height=512, width=512).images[0]
elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0, height=512, width=512).images[0]
elif model_name == "stabilityai/stable-diffusion-2":
return pipeline_text2image(prompt=prompt, height=512, width=512).images[0]
elif model_name == "black-forest-labs/FLUX.1-dev":
return pipeline_text2image(
prompt,
height=512,
width=512,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
@spaces.GPU
def blip_caption_image(image, prefix):
inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
return blip_processor.decode(out[0], skip_special_tokens=True)
def genderfromcaption(caption):
cc = caption.split()
if "man" in cc or "boy" in cc:
return "Man"
elif "woman" in cc or "girl" in cc:
return "Woman"
return "Unsure"
def genderplot(genlist):
order = ["Man", "Woman", "Unsure"]
words = sorted(genlist, key=lambda x: order.index(x))
colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
def skintoneplot(hex_codes):
hex_codes = [code for code in hex_codes if code is not None]
rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
if i < len(sorted_hex_codes):
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
return fig
def age_detector(image):
pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
result = pipe(image)
max_score_item = max(result, key=lambda item: item['score'])
return max_score_item['label']
def ageplot(agelist):
order = ["YOUNG", "MIDDLE", "OLD"]
words = sorted(agelist, key=lambda x: order.index(x))
colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
def is_nsfw(image):
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
result = classifier(image)
max_score_item = max(result, key=lambda item: item['score'])
return max_score_item['label']
def nsfwplot(nsfwlist):
order = ["normal", "nsfw"]
words = sorted(nsfwlist, key=lambda x: order.index(x))
colors = {"normal": "mistyrose", "nsfw": "red"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
@spaces.GPU(duration=200)
def generate_images_plots(prompt, model_name):
global pipeline_text2image
pipeline_text2image = load_model(model_name)
foldername = "temp"
Path(foldername).mkdir(parents=True, exist_ok=True)
images = [getimgen(prompt, model_name) for _ in range(10)]
genders = []
skintones = []
ages = []
nsfws = []
for image, i in zip(images, range(10)):
prompt_prefix = "photo of a "
caption = blip_caption_image(image, prefix=prompt_prefix)
image.save(f"{foldername}/image_{i}.png")
try:
skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
skintones.append(tone)
except:
skintones.append(None)
genders.append(genderfromcaption(caption))
ages.append(age_detector(image))
nsfws.append(is_nsfw(image))
return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
with gr.Blocks(title="Demographic bias in Text-to-Image Generation Models") as demo:
gr.Markdown("# Demographic bias in Text to Image Models")
gr.Markdown('''
In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender, skin tone, age, and potential sexual nature of the generated subjects. Here's how the analysis works:
1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
5. **NFAA Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NFAA (not for all audiences).
#### Visualization
We create visual grids to represent the data:
- **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
- **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
- **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
- **NFAA Grids**: Light red denotes FAA images, and dark red denotes NFAA images.
This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
[Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
''')
model_dropdown = gr.Dropdown(
label="Choose a model",
choices=[
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/sdxl-turbo",
"ByteDance/SDXL-Lightning",
"stabilityai/stable-diffusion-2",
"segmind/SSD-1B",
"black-forest-labs/FLUX.1-dev"
],
value=default_model
)
prompt = gr.Textbox(label="Enter the Prompt", value = "photo of a doctor in india, detailed, 8k, sharp, high quality, good lighting")
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[5],
rows=[2],
object_fit="contain",
height="auto"
)
btn = gr.Button("Generate images", scale=0)
with gr.Row(equal_height=True):
skinplot = gr.Plot(label="Skin Tone")
genplot = gr.Plot(label="Gender")
with gr.Row(equal_height=True):
agesplot = gr.Plot(label="Age")
nsfwsplot = gr.Plot(label="NFAA")
btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot, agesplot, nsfwsplot])
demo.launch(debug=True) |