File size: 4,424 Bytes
c5025e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4618d22
 
 
c5025e3
 
 
 
4618d22
c5025e3
4618d22
c5025e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4618d22
c5025e3
4618d22
c5025e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import sys
import io
import torch
import torchvision
import clip
import numpy as np
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision.transforms.functional import to_pil_image

from utils import load_model_weights
from model import NetG, CLIP_TXT_ENCODER

# checking the device
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Getting the HF token
HF_TOKEN = os.getenv("HF_TOKEN")

# repository of the model
repo_id = "VinayHajare/EfficientCLIP-GAN"
cub_model = "saved_models/state_epoch_1480.pth"
cc12m_model = "saved_models/EfficientCLIP-GAN-CC12M.pth"

# clip model wrapped with the custom encoder
clip_text = "ViT-B/32"
clip_model, preprocessor = clip.load(clip_text, device=device)
clip_model = clip_model.eval()
text_encoder = CLIP_TXT_ENCODER(clip_model).to(device)

# loading the models from the repository and extracting the generator model
cub_model_path = hf_hub_download(repo_id = repo_id, filename = cub_model, token = HF_TOKEN)
checkpoint_cub = torch.load(cub_model_path, map_location=torch.device(device))
cc12m_model_path = hf_hub_download(repo_id = repo_id, filename = cc12m_model, token = HF_TOKEN)
checkpoint_cc12m = torch.load(cc12m_model_path, map_location=torch.device(device))

# Create a new Generator model and initialize it with the pre-trained weights
netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
netG1 = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
cub = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
cc12m = load_model_weights(netG1, checkpoint_cc12m['model']['netG'], multi_gpus=False)

# Function to generate images from text
def generate_image_from_text(caption, model, batch_size=4):
    if model == "CUB":
        generator = cub
    else:
        generator = cc12m

    # Create the noise tensor
    noise = torch.randn((batch_size, 100)).to(device)
    with torch.no_grad():
        # Tokenize caption
        tokenized_text = clip.tokenize([caption]).to(device)
        # Extract the sentence and word embedding from Custom CLIP ENCODER
        sent_emb, word_emb = text_encoder(tokenized_text)
        # Repeat the sentence embedding to match the batch size
        sent_emb = sent_emb.repeat(batch_size, 1)
        # generate the images
        generated_images = generator(noise, sent_emb, eval=True).float()

        # Convert the tensor images to PIL format
        pil_images = []
        for image_tensor in generated_images.unbind(0):
            # Rescale tensor values to [0, 1]
            image_tensor = image_tensor.data.clamp(-1, 1)
            image_tensor = (image_tensor + 1.0) / 2.0

            # Convert tensor to numpy array
            image_numpy = image_tensor.permute(1, 2, 0).cpu().numpy()

            # Clip numpy array values to [0, 1]
            image_numpy = np.clip(image_numpy, 0, 1)

            # Create a PIL image from the numpy array
            pil_image = Image.fromarray((image_numpy * 255).astype(np.uint8))

            pil_images.append(pil_image)

        return pil_images

# Function to generate images from text
def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
    if model == "CUB":
        generator = cub
    else:
        generator = cc12m

    # Create the noise tensor
    noise = torch.randn((batch_size, 100)).to(device)
    with torch.no_grad():
        # Tokenize caption
        tokenized_text = clip.tokenize([caption]).to(device)
        # Extract the sentence and word embedding from Custom CLIP ENCODER
        sent_emb, word_emb = text_encoder(tokenized_text)
        # Repeat the sentence embedding to match the batch size
        sent_emb = sent_emb.repeat(batch_size, 1)
        # generate the images
        generated_images = generator(noise, sent_emb, eval=True).float()

        # Create a permanent directory if it doesn't exist
        permanent_dir = "generated_images"
        if not os.path.exists(permanent_dir):
            os.makedirs(permanent_dir)

        image_paths = []
        for idx, image_tensor in enumerate(generated_images.unbind(0)):
            # Save the image tensor to a permanent file
            image_path = os.path.join(permanent_dir, f"image_{idx}.png")
            torchvision.utils.save_image(image_tensor.data, image_path, value_range=(-1, 1), normalize=True)
            image_paths.append(image_path)

        return image_paths