Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline | |
from transformers import CLIPTokenizer | |
import os | |
import zipfile | |
import gradio as gr | |
# Define the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define your custom dataset | |
class CustomImageDataset(Dataset): | |
def __init__(self, images, prompts, transform=None): | |
self.images = images | |
self.prompts = prompts | |
self.transform = transform | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
image = self.images[idx] | |
if self.transform: | |
image = self.transform(image) | |
prompt = self.prompts[idx] | |
return image, prompt | |
# Function to fine-tune the model | |
def fine_tune_model(images, prompts, model_save_path, num_epochs=3): | |
transform = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
]) | |
dataset = CustomImageDataset(images, prompts, transform) | |
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) | |
# Load Stable Diffusion model | |
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) | |
# Load model components | |
vae = pipeline.vae.to(device) | |
unet = pipeline.unet.to(device) | |
text_encoder = pipeline.text_encoder.to(device) | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Ensure correct tokenizer is used | |
optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) # Define the optimizer | |
# Define timestep range for training | |
timesteps = torch.linspace(0, 1, steps=5).to(device) | |
# Fine-tuning loop | |
for epoch in range(num_epochs): | |
for i, (images, prompts) in enumerate(dataloader): | |
images = images.to(device) # Move images to GPU if available | |
# Tokenize the prompts | |
inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device) | |
latents = vae.encode(images).latent_dist.sample() * 0.18215 | |
text_embeddings = text_encoder(inputs.input_ids).last_hidden_state | |
noise = torch.randn_like(latents).to(device) | |
noisy_latents = latents + noise | |
# Pass text embeddings and timestep to UNet | |
timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float() | |
pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample | |
loss = torch.nn.functional.mse_loss(pred_noise, noise) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# Save the fine-tuned model | |
pipeline.save_pretrained(model_save_path) | |
# Function to convert tensor to PIL Image | |
def tensor_to_pil(tensor): | |
tensor = tensor.squeeze().cpu().clamp(0, 1) # Remove batch dimension if necessary | |
tensor = transforms.ToPILImage()(tensor) | |
return tensor | |
# Function to generate images | |
def generate_images(pipeline, prompt): | |
with torch.no_grad(): | |
# Generate image from the prompt | |
output = pipeline(prompt) | |
# Convert the output to PIL Image | |
image = output.images[0] # Get the first generated image | |
return image | |
# Function to zip the fine-tuned model | |
def zip_model(model_path): | |
zip_path = f"{model_path}.zip" | |
with zipfile.ZipFile(zip_path, "w") as zipf: | |
for root, _, files in os.walk(model_path): | |
for file in files: | |
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path)) | |
return zip_path | |
# Function to save uploaded files | |
def save_uploaded_file(uploaded_file, save_path): | |
# Open the file in binary write mode | |
with open(save_path, 'wb') as f: | |
f.write(uploaded_file.data) # Use .data for the file content | |
return f"File saved at {save_path}" | |
# Gradio interface functions | |
def start_fine_tuning(uploaded_files, prompts, num_epochs): | |
images = [Image.open(file).convert("RGB") for file in uploaded_files] | |
model_save_path = "fine_tuned_model" | |
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs)) | |
return "Fine-tuning completed! Model is ready for download." | |
def download_model(): | |
model_save_path = "fine_tuned_model" | |
if os.path.exists(model_save_path): | |
return zip_model(model_save_path) | |
else: | |
return None | |
def generate_new_image(prompt): | |
model_save_path = "fine_tuned_model" | |
if os.path.exists(model_save_path): | |
pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device) | |
else: | |
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) | |
image = generate_images(pipeline, prompt) | |
image_path = "generated_image.png" | |
image.save(image_path) | |
return image_path | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images") | |
with gr.Tab("Fine-Tune Model"): | |
with gr.Row(): | |
uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple") | |
with gr.Row(): | |
prompts = gr.Textbox(label="Enter Prompts (comma-separated)") | |
num_epochs = gr.Number(label="Number of Epochs", value=3) | |
with gr.Row(): | |
fine_tune_button = gr.Button("Start Fine-Tuning") | |
fine_tune_output = gr.Textbox(label="Output") | |
fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output) | |
with gr.Tab("Download Fine-Tuned Model"): | |
download_button = gr.Button("Download Fine-Tuned Model") | |
download_output = gr.File() | |
download_button.click(download_model, [], download_output) | |
with gr.Tab("Generate New Images"): | |
prompt_input = gr.Textbox(label="Enter a Prompt") | |
generate_button = gr.Button("Generate Image") | |
generated_image = gr.Image(label="Generated Image") | |
generate_button.click(generate_new_image, [prompt_input], generated_image) | |
demo.launch() | |