bgaspra commited on
Commit
935a747
·
verified ·
1 Parent(s): ab76b8b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from transformers import AutoModel, AutoProcessor
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from datasets import load_dataset
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import os
12
+ from tqdm import tqdm
13
+
14
+ class SDDataset(Dataset):
15
+ def __init__(self, dataset, processor, model_to_idx, token_to_idx, max_samples=5000):
16
+ self.dataset = dataset.select(range(min(max_samples, len(dataset))))
17
+ self.processor = processor
18
+ self.model_to_idx = model_to_idx
19
+ self.token_to_idx = token_to_idx
20
+
21
+ def __len__(self):
22
+ return len(self.dataset)
23
+
24
+ def __getitem__(self, idx):
25
+ item = self.dataset[idx]
26
+
27
+ # Process image
28
+ image = Image.open(item['image'])
29
+ image_inputs = self.processor(images=image, return_tensors="pt")
30
+
31
+ # Create model label
32
+ model_label = torch.zeros(len(self.model_to_idx))
33
+ model_label[self.model_to_idx[item['model_name']]] = 1
34
+
35
+ # Create prompt label (multi-hot encoding)
36
+ prompt_label = torch.zeros(len(self.token_to_idx))
37
+ for token in item['prompt'].split():
38
+ if token in self.token_to_idx:
39
+ prompt_label[self.token_to_idx[token]] = 1
40
+
41
+ return image_inputs, model_label, prompt_label
42
+
43
+ class SDRecommenderModel(nn.Module):
44
+ def __init__(self, florence_model, num_models, vocab_size):
45
+ super().__init__()
46
+ self.florence = florence_model
47
+ self.model_head = nn.Linear(florence_model.config.hidden_size, num_models)
48
+ self.prompt_head = nn.Linear(florence_model.config.hidden_size, vocab_size)
49
+
50
+ def forward(self, image_features):
51
+ # Get Florence embeddings
52
+ features = self.florence.get_image_features(image_features)
53
+
54
+ # Generate model and prompt recommendations
55
+ model_logits = self.model_head(features)
56
+ prompt_logits = self.prompt_head(features)
57
+
58
+ return model_logits, prompt_logits
59
+
60
+ class SDRecommender:
61
+ def __init__(self, max_samples=1000):
62
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ print(f"Using device: {self.device}")
64
+
65
+ # Load Florence model and processor
66
+ self.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
67
+ self.florence = AutoModel.from_pretrained("microsoft/Florence-2-large")
68
+
69
+ # Load dataset
70
+ print("Loading dataset...")
71
+ self.dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train")
72
+ self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
73
+ print(f"Using {len(self.dataset)} samples from dataset")
74
+
75
+ # Create vocabularies for models and tokens
76
+ self.model_to_idx = self._create_model_vocab()
77
+ self.token_to_idx = self._create_prompt_vocab()
78
+
79
+ # Initialize the recommendation model
80
+ self.model = SDRecommenderModel(
81
+ self.florence,
82
+ len(self.model_to_idx),
83
+ len(self.token_to_idx)
84
+ ).to(self.device)
85
+
86
+ # Load trained weights if available
87
+ if os.path.exists("recommender_model.pth"):
88
+ self.model.load_state_dict(torch.load("recommender_model.pth"))
89
+ print("Loaded trained model weights")
90
+ self.model.eval()
91
+
92
+ def _create_model_vocab(self):
93
+ print("Creating model vocabulary...")
94
+ models = set()
95
+ for item in self.dataset:
96
+ models.add(item["model_name"])
97
+ return {model: idx for idx, model in enumerate(sorted(models))}
98
+
99
+ def _create_prompt_vocab(self):
100
+ print("Creating prompt vocabulary...")
101
+ tokens = set()
102
+ for item in self.dataset:
103
+ for token in item["prompt"].split():
104
+ tokens.add(token)
105
+ return {token: idx for idx, token in enumerate(sorted(tokens))}
106
+
107
+ def train(self, num_epochs=5, batch_size=8, learning_rate=1e-4):
108
+ print("Starting training...")
109
+
110
+ # Create dataset and dataloader
111
+ train_dataset = SDDataset(
112
+ self.dataset,
113
+ self.processor,
114
+ self.model_to_idx,
115
+ self.token_to_idx
116
+ )
117
+ train_loader = DataLoader(
118
+ train_dataset,
119
+ batch_size=batch_size,
120
+ shuffle=True,
121
+ num_workers=2
122
+ )
123
+
124
+ # Setup optimizer
125
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
126
+
127
+ # Training loop
128
+ self.model.train()
129
+ for epoch in range(num_epochs):
130
+ total_loss = 0
131
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
132
+
133
+ for batch_idx, (images, model_labels, prompt_labels) in enumerate(progress_bar):
134
+ # Move everything to device
135
+ images = images.to(self.device)
136
+ model_labels = model_labels.to(self.device)
137
+ prompt_labels = prompt_labels.to(self.device)
138
+
139
+ # Forward pass
140
+ model_logits, prompt_logits = self.model(images)
141
+
142
+ # Calculate loss
143
+ model_loss = F.cross_entropy(model_logits, model_labels)
144
+ prompt_loss = F.binary_cross_entropy_with_logits(prompt_logits, prompt_labels)
145
+ loss = model_loss + prompt_loss
146
+
147
+ # Backward pass
148
+ optimizer.zero_grad()
149
+ loss.backward()
150
+ optimizer.step()
151
+
152
+ # Update progress
153
+ total_loss += loss.item()
154
+ progress_bar.set_postfix({"loss": total_loss / (batch_idx + 1)})
155
+
156
+ # Save trained model
157
+ torch.save(self.model.state_dict(), "recommender_model.pth")
158
+ print("Training completed and model saved")
159
+
160
+ def get_recommendations(self, image):
161
+ # Convert uploaded image to PIL if needed
162
+ if not isinstance(image, Image.Image):
163
+ image = Image.open(image)
164
+
165
+ # Process image
166
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
167
+
168
+ # Get model predictions
169
+ self.model.eval()
170
+ with torch.no_grad():
171
+ model_logits, prompt_logits = self.model(inputs)
172
+
173
+ # Get top 5 model recommendations
174
+ model_probs = F.softmax(model_logits, dim=-1)
175
+ top_models = torch.topk(model_probs, k=5)
176
+ model_recommendations = [
177
+ (list(self.model_to_idx.keys())[idx.item()], prob.item())
178
+ for prob, idx in zip(top_models.values[0], top_models.indices[0])
179
+ ]
180
+
181
+ # Get top tokens for prompt recommendations
182
+ prompt_probs = F.softmax(prompt_logits, dim=-1)
183
+ top_tokens = torch.topk(prompt_probs, k=20)
184
+ recommended_tokens = [
185
+ list(self.token_to_idx.keys())[idx.item()]
186
+ for idx in top_tokens.indices[0]
187
+ ]
188
+
189
+ # Create 5 prompt combinations
190
+ prompt_recommendations = [
191
+ " ".join(np.random.choice(recommended_tokens, size=8, replace=False))
192
+ for _ in range(5)
193
+ ]
194
+
195
+ return (
196
+ "\n".join(f"{model} (confidence: {conf:.2f})" for model, conf in model_recommendations),
197
+ "\n".join(prompt_recommendations)
198
+ )
199
+
200
+ # Create Gradio interface
201
+ def create_interface():
202
+ recommender = SDRecommender(max_samples=5000)
203
+
204
+ # Train the model if no trained weights exist
205
+ if not os.path.exists("recommender_model.pth"):
206
+ recommender.train()
207
+
208
+ def process_image(image):
209
+ model_recs, prompt_recs = recommender.get_recommendations(image)
210
+ return model_recs, prompt_recs
211
+
212
+ interface = gr.Interface(
213
+ fn=process_image,
214
+ inputs=gr.Image(type="pil"),
215
+ outputs=[
216
+ gr.Textbox(label="Recommended Models"),
217
+ gr.Textbox(label="Recommended Prompts")
218
+ ],
219
+ title="Stable Diffusion Model & Prompt Recommender",
220
+ description="Upload an AI-generated image to get model and prompt recommendations",
221
+ )
222
+
223
+ return interface
224
+
225
+ # Launch the interface
226
+ if __name__ == "__main__":
227
+ interface = create_interface()
228
+ interface.launch()