bgaspra commited on
Commit
f2ca68f
·
verified ·
1 Parent(s): adcad2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -220
app.py CHANGED
@@ -1,239 +1,128 @@
1
- # app.py
2
- import gradio as gr
3
  import torch
4
- import numpy as np
 
5
  from PIL import Image
6
- from torch import nn
7
- import torch.nn.functional as F
8
  from datasets import load_dataset
9
- from torch.utils.data import Dataset, DataLoader
10
- import os
11
- from tqdm import tqdm
12
- from transformers import AutoProcessor, AutoModelForCausalLM
13
 
14
- class SDDataset(Dataset):
15
- def __init__(self, dataset, processor, model_to_idx, token_to_idx, max_samples=5000):
16
- self.dataset = dataset.select(range(min(max_samples, len(dataset))))
17
- self.processor = processor
18
- self.model_to_idx = model_to_idx
19
- self.token_to_idx = token_to_idx
20
-
21
- def __len__(self):
22
- return len(self.dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def __getitem__(self, idx):
25
- item = self.dataset[idx]
26
-
27
- # Process image
28
- image = Image.open(item['image'])
29
- image_inputs = self.processor(images=image, return_tensors="pt")
30
-
31
- # Create model label
32
- model_label = torch.zeros(len(self.model_to_idx))
33
- model_label[self.model_to_idx[item['model_name']]] = 1
34
-
35
- # Create prompt label (multi-hot encoding)
36
- prompt_label = torch.zeros(len(self.token_to_idx))
37
- for token in item['prompt'].split():
38
- if token in self.token_to_idx:
39
- prompt_label[self.token_to_idx[token]] = 1
40
-
41
- return image_inputs, model_label, prompt_label
42
 
43
- class SDRecommenderModel(nn.Module):
44
- def __init__(self, florence_model, num_models, vocab_size):
45
- super().__init__()
46
- self.florence = florence_model
47
- hidden_size = 1024 # Florence-2-large hidden size
48
- self.model_head = nn.Linear(hidden_size, num_models)
49
- self.prompt_head = nn.Linear(hidden_size, vocab_size)
50
-
51
- def forward(self, pixel_values):
52
- # Get Florence embeddings
53
- outputs = self.florence(pixel_values=pixel_values, output_hidden_states=True)
54
- features = outputs.hidden_states[-1].mean(dim=1) # Use mean pooling of last hidden state
55
-
56
- # Generate model and prompt recommendations
57
- model_logits = self.model_head(features)
58
- prompt_logits = self.prompt_head(features)
59
-
60
- return model_logits, prompt_logits
61
 
62
- class SDRecommender:
63
- def __init__(self, max_samples=500):
64
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
65
- print(f"Using device: {self.device}")
66
-
67
- # Load Florence model and processor
68
- print("Loading Florence model and processor...")
69
- self.processor = AutoProcessor.from_pretrained(
70
- "microsoft/Florence-2-large",
71
- trust_remote_code=True
72
- )
73
- self.florence = AutoModelForCausalLM.from_pretrained(
74
- "microsoft/Florence-2-large",
75
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
76
- trust_remote_code=True
77
- ).to(self.device)
78
-
79
- # Load dataset
80
- print("Loading dataset...")
81
- self.dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train")
82
- self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
83
- print(f"Using {len(self.dataset)} samples from dataset")
84
-
85
- # Create vocabularies for models and tokens
86
- self.model_to_idx = self._create_model_vocab()
87
- self.token_to_idx = self._create_prompt_vocab()
88
-
89
- # Initialize the recommendation model
90
- self.model = SDRecommenderModel(
91
- self.florence,
92
- len(self.model_to_idx),
93
- len(self.token_to_idx)
94
- ).to(self.device)
95
-
96
- # Load trained weights if available
97
- if os.path.exists("recommender_model.pth"):
98
- self.model.load_state_dict(torch.load("recommender_model.pth", map_location=self.device))
99
- print("Loaded trained model weights")
100
- self.model.eval()
101
-
102
- def _create_model_vocab(self):
103
- print("Creating model vocabulary...")
104
- models = set()
105
- for item in self.dataset:
106
- models.add(item["model_name"])
107
- return {model: idx for idx, model in enumerate(sorted(models))}
108
 
109
- def _create_prompt_vocab(self):
110
- print("Creating prompt vocabulary...")
111
- tokens = set()
112
- for item in self.dataset:
113
- for token in item["prompt"].split():
114
- tokens.add(token)
115
- return {token: idx for idx, token in enumerate(sorted(tokens))}
 
 
 
116
 
117
- def train(self, num_epochs=5, batch_size=8, learning_rate=1e-4):
118
- print("Starting training...")
119
-
120
- # Create dataset and dataloader
121
- train_dataset = SDDataset(
122
- self.dataset,
123
- self.processor,
124
- self.model_to_idx,
125
- self.token_to_idx
126
- )
127
- train_loader = DataLoader(
128
- train_dataset,
129
- batch_size=batch_size,
130
- shuffle=True,
131
- num_workers=2
132
- )
133
-
134
- # Setup optimizer
135
- optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
136
-
137
- # Training loop
138
- self.model.train()
139
- for epoch in range(num_epochs):
140
- total_loss = 0
141
- progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
142
-
143
- for batch_idx, (images, model_labels, prompt_labels) in enumerate(progress_bar):
144
- # Move everything to device
145
- images = {k: v.to(self.device) for k, v in images.items()}
146
- model_labels = model_labels.to(self.device)
147
- prompt_labels = prompt_labels.to(self.device)
148
-
149
- # Forward pass
150
- model_logits, prompt_logits = self.model(images)
151
-
152
- # Calculate loss
153
- model_loss = F.cross_entropy(model_logits, model_labels)
154
- prompt_loss = F.binary_cross_entropy_with_logits(prompt_logits, prompt_labels)
155
- loss = model_loss + prompt_loss
156
-
157
- # Backward pass
158
- optimizer.zero_grad()
159
- loss.backward()
160
- optimizer.step()
161
-
162
- # Update progress
163
- total_loss += loss.item()
164
- progress_bar.set_postfix({"loss": total_loss / (batch_idx + 1)})
165
-
166
- # Save trained model
167
- torch.save(self.model.state_dict(), "recommender_model.pth")
168
- print("Training completed and model saved")
169
 
170
- def get_recommendations(self, image):
171
- # Convert uploaded image to PIL if needed
172
- if not isinstance(image, Image.Image):
173
- image = Image.open(image)
174
 
175
- # Process image
176
- inputs = self.processor(images=image, return_tensors="pt")
177
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
178
 
179
- # Get model predictions
180
- self.model.eval()
181
- with torch.no_grad():
182
- model_logits, prompt_logits = self.model(inputs)
183
-
184
- # Get top 5 model recommendations
185
- model_probs = F.softmax(model_logits, dim=-1)
186
- top_models = torch.topk(model_probs, k=5)
187
- model_recommendations = [
188
- (list(self.model_to_idx.keys())[idx.item()], prob.item())
189
- for prob, idx in zip(top_models.values[0], top_models.indices[0])
190
- ]
191
-
192
- # Get top tokens for prompt recommendations
193
- prompt_probs = F.softmax(prompt_logits, dim=-1)
194
- top_tokens = torch.topk(prompt_probs, k=20)
195
- recommended_tokens = [
196
- list(self.token_to_idx.keys())[idx.item()]
197
- for idx in top_tokens.indices[0]
198
- ]
199
-
200
- # Create 5 prompt combinations
201
- prompt_recommendations = [
202
- " ".join(np.random.choice(recommended_tokens, size=8, replace=False))
203
- for _ in range(5)
204
- ]
205
-
206
- return (
207
- "\n".join(f"{model} (confidence: {conf:.2f})" for model, conf in model_recommendations),
208
- "\n".join(prompt_recommendations)
209
- )
210
 
211
- # Create Gradio interface
212
- def create_interface():
213
- recommender = SDRecommender(max_samples=5000)
214
 
215
- # Train the model if no trained weights exist
216
- if not os.path.exists("recommender_model.pth"):
217
- recommender.train()
218
 
219
- def process_image(image):
220
- model_recs, prompt_recs = recommender.get_recommendations(image)
221
- return model_recs, prompt_recs
222
 
223
- interface = gr.Interface(
224
- fn=process_image,
225
- inputs=gr.Image(type="pil"),
226
- outputs=[
227
- gr.Textbox(label="Recommended Models"),
228
- gr.Textbox(label="Recommended Prompts")
229
- ],
230
- title="Stable Diffusion Model & Prompt Recommender",
231
- description="Upload an AI-generated image to get model and prompt recommendations",
232
- )
233
 
234
- return interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # Launch the interface
237
- if __name__ == "__main__":
238
- interface = create_interface()
239
- interface.launch()
 
 
 
1
  import torch
2
+ import gradio as gr
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
+ import pandas as pd
 
6
  from datasets import load_dataset
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import numpy as np
 
 
9
 
10
+ # Load Florence-2 model and processor
11
+ model_name = "microsoft/Florence-2-base"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_name,
17
+ torch_dtype=torch_dtype,
18
+ trust_remote_code=True
19
+ ).to(device)
20
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
21
+
22
+ # Load CivitAI dataset (limited to 1000 samples)
23
+ dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
24
+ df = pd.DataFrame(dataset)
25
+
26
+ # Create cache for embeddings to improve performance
27
+ text_embedding_cache = {}
28
+
29
+ def get_image_embedding(image):
30
+ inputs = processor(images=image, return_tensors="pt").to(device, torch_dtype)
31
+ with torch.no_grad():
32
+ outputs = model.get_image_features(**inputs)
33
+ return outputs.cpu().numpy()
34
+
35
+ def get_text_embedding(text):
36
+ if text in text_embedding_cache:
37
+ return text_embedding_cache[text]
38
 
39
+ inputs = processor(text=text, return_tensors="pt").to(device, torch_dtype)
40
+ with torch.no_grad():
41
+ outputs = model.get_text_features(**inputs)
42
+
43
+ embedding = outputs.cpu().numpy()
44
+ text_embedding_cache[text] = embedding
45
+ return embedding
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Pre-compute text embeddings for all prompts in the dataset
48
+ def precompute_embeddings():
49
+ print("Pre-computing text embeddings...")
50
+ for idx, row in df.iterrows():
51
+ if row['prompt'] not in text_embedding_cache:
52
+ _ = get_text_embedding(row['prompt'])
53
+ if idx % 100 == 0:
54
+ print(f"Processed {idx}/1000 embeddings")
55
+ print("Finished pre-computing embeddings")
 
 
 
 
 
 
 
 
 
56
 
57
+ def find_similar_images(uploaded_image, top_k=5):
58
+ # Get embedding for uploaded image
59
+ query_embedding = get_image_embedding(uploaded_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Calculate similarities with dataset
62
+ similarities = []
63
+ for idx, row in df.iterrows():
64
+ prompt_embedding = get_text_embedding(row['prompt'])
65
+ similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0]
66
+ similarities.append({
67
+ 'similarity': similarity,
68
+ 'model': row['Model'],
69
+ 'prompt': row['prompt']
70
+ })
71
 
72
+ # Sort by similarity and get top k results
73
+ sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
74
+ top_models = []
75
+ top_prompts = []
76
+ seen_models = set()
77
+ seen_prompts = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ for result in sorted_results:
80
+ if len(top_models) < top_k and result['model'] not in seen_models:
81
+ top_models.append(result['model'])
82
+ seen_models.add(result['model'])
83
 
84
+ if len(top_prompts) < top_k and result['prompt'] not in seen_prompts:
85
+ top_prompts.append(result['prompt'])
86
+ seen_prompts.add(result['prompt'])
87
 
88
+ if len(top_models) == top_k and len(top_prompts) == top_k:
89
+ break
90
+
91
+ return top_models, top_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ def process_image(input_image):
94
+ if input_image is None:
95
+ return "Please upload an image.", "Please upload an image."
96
 
97
+ # Convert to PIL Image if needed
98
+ if not isinstance(input_image, Image.Image):
99
+ input_image = Image.fromarray(input_image)
100
 
101
+ # Get recommendations
102
+ recommended_models, recommended_prompts = find_similar_images(input_image)
 
103
 
104
+ # Format output
105
+ models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)])
106
+ prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)])
 
 
 
 
 
 
 
107
 
108
+ return models_text, prompts_text
109
+
110
+ # Pre-compute embeddings when starting the application
111
+ precompute_embeddings()
112
+
113
+ # Create Gradio interface
114
+ iface = gr.Interface(
115
+ fn=process_image,
116
+ inputs=gr.Image(type="pil", label="Upload AI-generated image"),
117
+ outputs=[
118
+ gr.Textbox(label="Recommended Models", lines=6),
119
+ gr.Textbox(label="Recommended Prompts", lines=6)
120
+ ],
121
+ title="AI Image Model & Prompt Recommender",
122
+ description="Upload an AI-generated image to get recommendations for Stable Diffusion models and prompts.",
123
+ examples=[],
124
+ cache_examples=False
125
+ )
126
 
127
  # Launch the interface
128
+ iface.launch()