VinitT's picture
Update app.py
2c2a1bc verified
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
import os
import zipfile
import gradio as gr
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define your custom dataset
class CustomImageDataset(Dataset):
def __init__(self, images, prompts, transform=None):
self.images = images
self.prompts = prompts
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
if self.transform:
image = self.transform(image)
prompt = self.prompts[idx]
return image, prompt
# Function to fine-tune the model
def fine_tune_model(images, prompts, model_save_path, num_epochs=3):
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = CustomImageDataset(images, prompts, transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Load Stable Diffusion model
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
# Load model components
vae = pipeline.vae.to(device)
unet = pipeline.unet.to(device)
text_encoder = pipeline.text_encoder.to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Ensure correct tokenizer is used
optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) # Define the optimizer
# Define timestep range for training
timesteps = torch.linspace(0, 1, steps=5).to(device)
# Fine-tuning loop
for epoch in range(num_epochs):
for i, (images, prompts) in enumerate(dataloader):
images = images.to(device) # Move images to GPU if available
# Tokenize the prompts
inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device)
latents = vae.encode(images).latent_dist.sample() * 0.18215
text_embeddings = text_encoder(inputs.input_ids).last_hidden_state
noise = torch.randn_like(latents).to(device)
noisy_latents = latents + noise
# Pass text embeddings and timestep to UNet
timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float()
pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample
loss = torch.nn.functional.mse_loss(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save the fine-tuned model
pipeline.save_pretrained(model_save_path)
# Function to convert tensor to PIL Image
def tensor_to_pil(tensor):
tensor = tensor.squeeze().cpu().clamp(0, 1) # Remove batch dimension if necessary
tensor = transforms.ToPILImage()(tensor)
return tensor
# Function to generate images
def generate_images(pipeline, prompt):
with torch.no_grad():
# Generate image from the prompt
output = pipeline(prompt)
# Convert the output to PIL Image
image = output.images[0] # Get the first generated image
return image
# Function to zip the fine-tuned model
def zip_model(model_path):
zip_path = f"{model_path}.zip"
with zipfile.ZipFile(zip_path, "w") as zipf:
for root, _, files in os.walk(model_path):
for file in files:
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
return zip_path
# Function to save uploaded files
def save_uploaded_file(uploaded_file, save_path):
# Open the file in binary write mode
with open(save_path, 'wb') as f:
f.write(uploaded_file.data) # Use .data for the file content
return f"File saved at {save_path}"
# Gradio interface functions
def start_fine_tuning(uploaded_files, prompts, num_epochs):
images = [Image.open(file).convert("RGB") for file in uploaded_files]
model_save_path = "fine_tuned_model"
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
return "Fine-tuning completed! Model is ready for download."
def download_model():
model_save_path = "fine_tuned_model"
if os.path.exists(model_save_path):
return zip_model(model_save_path)
else:
return None
def generate_new_image(prompt):
model_save_path = "fine_tuned_model"
if os.path.exists(model_save_path):
pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device)
else:
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
image = generate_images(pipeline, prompt)
image_path = "generated_image.png"
image.save(image_path)
return image_path
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images")
with gr.Tab("Fine-Tune Model"):
with gr.Row():
uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple")
with gr.Row():
prompts = gr.Textbox(label="Enter Prompts (comma-separated)")
num_epochs = gr.Number(label="Number of Epochs", value=3)
with gr.Row():
fine_tune_button = gr.Button("Start Fine-Tuning")
fine_tune_output = gr.Textbox(label="Output")
fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output)
with gr.Tab("Download Fine-Tuned Model"):
download_button = gr.Button("Download Fine-Tuned Model")
download_output = gr.File()
download_button.click(download_model, [], download_output)
with gr.Tab("Generate New Images"):
prompt_input = gr.Textbox(label="Enter a Prompt")
generate_button = gr.Button("Generate Image")
generated_image = gr.Image(label="Generated Image")
generate_button.click(generate_new_image, [prompt_input], generated_image)
demo.launch()