AndreiUrsu's picture
Update main.py
917022f verified
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from huggingface_hub import HfApi
from torch.optim import AdamW
from tqdm import tqdm
import gc
from torch.cuda.amp import autocast
# Setare configurare CUDA pentru a reduce fragmentarea memoriei
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Verifică dacă GPU-ul este detectat
print(torch.cuda.is_available())
img_dir = '/media/andrei_ursu/storage2/chess/branches/chessgpt/backend/src/experiments/full/primulTest/SD21data'
# Definirea dataset-ului
class ManualCaptionDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.img_names = os.listdir(img_dir)
self.transform = transform
self.captions = []
for img_name in self.img_names:
caption = 'Photo of Andrei smiling and dressed in winter clothes at a Christmas market'
self.captions.append(caption)
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_name = os.path.join(self.img_dir, self.img_names[idx])
image = Image.open(img_name).convert("RGB")
caption = self.captions[idx]
if self.transform:
image = self.transform(image)
return image, caption
# Configurare transformări
transform = transforms.Compose([
transforms.Resize((256, 256)), # Dimensiune imagine redusă
transforms.ToTensor(),
])
# Crearea dataset-ului
dataset = ManualCaptionDataset(img_dir=img_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # Dimensiune batch redusă
# Încărcare model UNet
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", torch_dtype=torch.float16)
unet.to("cuda")
# Încărcare model pentru autoencoder
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", torch_dtype=torch.float16)
vae.to("cuda")
# Încărcare tokenizer și text model pentru CLIP
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
text_model.to("cuda")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Scheduler
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
# Pregătire optimizer
optimizer = AdamW(unet.parameters(), lr=5e-6)
# Setare model în modul de antrenament
unet.train()
text_model.train()
# Definire număr de epoci
num_epochs = 5
# Training loop
for epoch in range(num_epochs):
for images, captions in tqdm(dataloader):
images = images.to("cuda", dtype=torch.float16)
# Curăță memoria GPU înainte de fiecare iterare
gc.collect()
torch.cuda.empty_cache()
# Tokenizare captions
inputs = tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to("cuda")
# Generare zgomot aleatoriu
noise = torch.randn_like(images).to("cuda", dtype=torch.float16)
# Codificare imagini în latențe
latents = vae.encode(images).latent_dist.sample()
latents = latents * 0.18215
# Generare timesteps
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device="cuda").long()
# Forward pass prin UNet
encoder_hidden_states = text_model(inputs.input_ids)[0]
# Convertim encoder_hidden_states la float16
encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float16)
# Proiectăm dimensiunile `encoder_hidden_states` pentru a se potrivi cu cele așteptate de UNet
expected_dim = unet.config.cross_attention_dim
if encoder_hidden_states.shape[-1] != expected_dim:
projection_layer = torch.nn.Linear(encoder_hidden_states.shape[-1], expected_dim).to("cuda", dtype=torch.float16)
encoder_hidden_states = projection_layer(encoder_hidden_states)
# Generare predicție de zgomot
with autocast():
noise_pred = unet(latents, timesteps, encoder_hidden_states).sample
# Verifică dimensiunile tensorilor
print(f"noise_pred shape: {noise_pred.shape}")
print(f"noise shape: {noise.shape}")
# Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
if noise_pred.shape[1] != noise.shape[1]:
# Ajustează numărul de canale pentru noise_pred
conv_layer = torch.nn.Conv2d(
in_channels=noise_pred.shape[1],
out_channels=noise.shape[1],
kernel_size=1
).to("cuda", dtype=torch.float16)
noise_pred = conv_layer(noise_pred)
# Redimensionare noise_pred pentru a se potrivi cu dimensiunea noise
if noise_pred.shape[2:] != noise.shape[2:]:
noise_pred = torch.nn.functional.interpolate(noise_pred, size=images.shape[2:], mode='bilinear', align_corners=False)
# Calcul pierdere (loss) comparând ieșirea modelului cu zgomotul original
loss = torch.nn.functional.mse_loss(noise_pred, noise)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Curăță memoria GPU după fiecare iterare
gc.collect()
torch.cuda.empty_cache()
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
# Salvarea modelului antrenat
unet.save_pretrained("./finetuned-unet")
text_model.save_pretrained("./finetuned-text-model")
api = HfApi()
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-unet", repo_type="model")
#api.create_repo(repo_id="AndreiUrsu/finetuned-stable-diffusion-text-model", repo_type="model")
# Încărcarea pe Hugging Face
api.upload_folder(
folder_path="./finetuned-unet",
path_in_repo=".",
repo_id="AndreiUrsu/finetuned-stable-diffusion-unet",
repo_type="model"
)
# Curăță memoria GPU la final
gc.collect()
torch.cuda.empty_cache()