Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import io | |
import torch | |
import torchvision | |
import clip | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
from torchvision.transforms.functional import to_pil_image | |
from utils import load_model_weights | |
from model import NetG, CLIP_TXT_ENCODER | |
# checking the device | |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
# Getting the HF token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# repository of the model | |
repo_id = "VinayHajare/EfficientCLIP-GAN" | |
cub_model = "saved_models/state_epoch_1480.pth" | |
cc12m_model = "saved_models/EfficientCLIP-GAN-CC12M.pth" | |
# clip model wrapped with the custom encoder | |
clip_text = "ViT-B/32" | |
clip_model, preprocessor = clip.load(clip_text, device=device) | |
clip_model = clip_model.eval() | |
text_encoder = CLIP_TXT_ENCODER(clip_model).to(device) | |
# loading the models from the repository and extracting the generator model | |
cub_model_path = hf_hub_download(repo_id = repo_id, filename = cub_model, token = HF_TOKEN) | |
checkpoint_cub = torch.load(cub_model_path, map_location=torch.device(device)) | |
cc12m_model_path = hf_hub_download(repo_id = repo_id, filename = cc12m_model, token = HF_TOKEN) | |
checkpoint_cc12m = torch.load(cc12m_model_path, map_location=torch.device(device)) | |
# Create a new Generator model and initialize it with the pre-trained weights | |
netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device) | |
netG1 = NetG(64, 100, 512, 256, 3, False, clip_model).to(device) | |
cub = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False) | |
cc12m = load_model_weights(netG1, checkpoint_cc12m['model']['netG'], multi_gpus=False) | |
# Function to generate images from text | |
def generate_image_from_text(caption, model, batch_size=4): | |
if model == "CUB": | |
generator = cub | |
else: | |
generator = cc12m | |
# Create the noise tensor | |
noise = torch.randn((batch_size, 100)).to(device) | |
with torch.no_grad(): | |
# Tokenize caption | |
tokenized_text = clip.tokenize([caption]).to(device) | |
# Extract the sentence and word embedding from Custom CLIP ENCODER | |
sent_emb, word_emb = text_encoder(tokenized_text) | |
# Repeat the sentence embedding to match the batch size | |
sent_emb = sent_emb.repeat(batch_size, 1) | |
# generate the images | |
generated_images = generator(noise, sent_emb, eval=True).float() | |
# Convert the tensor images to PIL format | |
pil_images = [] | |
for image_tensor in generated_images.unbind(0): | |
# Rescale tensor values to [0, 1] | |
image_tensor = image_tensor.data.clamp(-1, 1) | |
image_tensor = (image_tensor + 1.0) / 2.0 | |
# Convert tensor to numpy array | |
image_numpy = image_tensor.permute(1, 2, 0).cpu().numpy() | |
# Clip numpy array values to [0, 1] | |
image_numpy = np.clip(image_numpy, 0, 1) | |
# Create a PIL image from the numpy array | |
pil_image = Image.fromarray((image_numpy * 255).astype(np.uint8)) | |
pil_images.append(pil_image) | |
return pil_images | |
# Function to generate images from text | |
def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4): | |
if model == "CUB": | |
generator = cub | |
else: | |
generator = cc12m | |
# Create the noise tensor | |
noise = torch.randn((batch_size, 100)).to(device) | |
with torch.no_grad(): | |
# Tokenize caption | |
tokenized_text = clip.tokenize([caption]).to(device) | |
# Extract the sentence and word embedding from Custom CLIP ENCODER | |
sent_emb, word_emb = text_encoder(tokenized_text) | |
# Repeat the sentence embedding to match the batch size | |
sent_emb = sent_emb.repeat(batch_size, 1) | |
# generate the images | |
generated_images = generator(noise, sent_emb, eval=True).float() | |
# Create a permanent directory if it doesn't exist | |
permanent_dir = "generated_images" | |
if not os.path.exists(permanent_dir): | |
os.makedirs(permanent_dir) | |
image_paths = [] | |
for idx, image_tensor in enumerate(generated_images.unbind(0)): | |
# Save the image tensor to a permanent file | |
image_path = os.path.join(permanent_dir, f"image_{idx}.png") | |
torchvision.utils.save_image(image_tensor.data, image_path, value_range=(-1, 1), normalize=True) | |
image_paths.append(image_path) | |
return image_paths |