throaway2854 commited on
Commit
33e672b
·
verified ·
1 Parent(s): 4c9c639

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +367 -64
app.py CHANGED
@@ -1,91 +1,394 @@
1
- ># Import necessary libraries
2
  import gradio as gr
3
- import json
4
  import os
5
  import zipfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # Define helper functions
 
 
 
 
8
 
9
- def create_dataset(dataset_name):
10
- dataset_path = f'{dataset_name}.zip'
11
- if not os.path.exists(dataset_path):
12
- with zipfile.ZipFile(dataset_path, 'w') as zip_file:
13
- zip_file.writestr('images/', '')
14
- zip_file.writestr('data.jsonl', '')
15
 
16
- return dataset_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
18
 
19
- def upload_pair(dataset_path, image, prompt):
20
- with zipfile.ZipFile(dataset_path, 'a') as zip_file:
21
- image_path = f'images/{image.name}'
22
- zip_file.writestr(image_path, image.read())
23
- data = {'image': image_path, 'prompt': prompt}
24
- zip_file.writestr('data.jsonl', json.dumps(data) + '\n')
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def edit_prompt(dataset_path, image_path, new_prompt):
28
- with zipfile.ZipFile(dataset_path, 'r') as zip_file:
29
- data = json.load(zip_file.open('data.jsonl'))
30
- for item in data:
31
- if item['image'] == image_path:
32
- item['prompt'] = new_prompt
33
- break
34
 
35
- with zipfile.ZipFile(dataset_path, 'w') as zip_file:
36
- zip_file.writestr('data.jsonl', json.dumps(data))
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
38
 
39
- def delete_pair(dataset_path, image_path):
40
- with zipfile.ZipFile(dataset_path, 'r') as zip_file:
41
- data = json.load(zip_file.open('data.jsonl'))
42
- data = [item for item in data if item['image'] != image_path]
 
 
 
 
 
 
 
 
43
 
44
- with zipfile.ZipFile(dataset_path, 'w') as zip_file:
45
- zip_file.writestr('data.jsonl', json.dumps(data))
 
 
 
46
 
 
 
 
 
 
 
 
 
47
 
48
- def download_dataset(dataset_path):
49
- return dataset_path
 
 
 
 
50
 
51
- # Define Gradio application
 
 
 
 
 
52
 
53
- demo = gr.Blocks()
 
 
54
 
55
- with demo:
56
- # Create dataset
57
- dataset_name = gr.Textbox(label='Dataset Name')
58
- create_button = gr.Button('Create Dataset')
59
- create_button.click(create_dataset, inputs=[dataset_name], outputs=[])
 
 
 
 
 
 
 
60
 
61
- # Upload pair
62
- image_upload = gr.File(label='Image')
63
- prompt = gr.Textbox(label='Prompt')
64
- upload_button = gr.Button('Upload Pair')
65
- upload_button.click(upload_pair, inputs=[dataset_name, image_upload, prompt], outputs=[])
66
 
67
- # Edit prompt
68
- image_path = gr.Textbox(label='Image Path')
69
- new_prompt = gr.Textbox(label='New Prompt')
70
- edit_button = gr.Button('Edit Prompt')
71
- edit_button.click(edit_prompt, inputs=[dataset_name, image_path, new_prompt], outputs=[])
 
 
72
 
73
- # Delete pair
74
- delete_button = gr.Button('Delete Pair')
75
- delete_button.click(delete_pair, inputs=[dataset_name, image_path], outputs=[])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Download dataset
78
- download_button = gr.Button('Download Dataset')
79
- download_button.click(download_dataset, inputs=[dataset_name], outputs=[])
 
 
80
 
81
- # Upload dataset
82
- dataset_upload = gr.File(label='Dataset')
83
- upload_dataset_button = gr.Button('Upload Dataset')
84
- upload_dataset_button.click(create_dataset, inputs=[dataset_upload], outputs=[])
85
 
86
- # Horizontal gallery
87
- gallery = gr.Gallery(label='Dataset Gallery')
88
- demo.append(gallery)
 
 
89
 
90
- # Launch Gradio application
91
  demo.launch()
 
1
+
2
  import gradio as gr
 
3
  import os
4
  import zipfile
5
+ import json
6
+ from io import BytesIO
7
+ import base64
8
+ from PIL import Image
9
+ import uuid
10
+ import tempfile
11
+ import numpy as np
12
+
13
+ def save_dataset_to_zip(dataset_name, dataset):
14
+ temp_dir = tempfile.mkdtemp()
15
+ dataset_path = os.path.join(temp_dir, dataset_name)
16
+ os.makedirs(dataset_path, exist_ok=True)
17
+ images_dir = os.path.join(dataset_path, 'images')
18
+ os.makedirs(images_dir, exist_ok=True)
19
+
20
+ annotations = []
21
+ for idx, entry in enumerate(dataset):
22
+ image_data = entry['image']
23
+ prompt = entry['prompt']
24
+
25
+ # Save image to images directory
26
+ image_filename = f"{uuid.uuid4().hex}.png"
27
+ image_path = os.path.join(images_dir, image_filename)
28
+ # Decode the base64 image data
29
+ image = Image.open(BytesIO(base64.b64decode(image_data.split(",")[1])))
30
+ image.save(image_path)
31
+
32
+ # Add annotation
33
+ annotations.append({
34
+ 'file_name': os.path.join('images', image_filename),
35
+ 'text': prompt
36
+ })
37
+
38
+ # Save annotations to JSONL file
39
+ annotations_path = os.path.join(dataset_path, 'annotations.jsonl')
40
+ with open(annotations_path, 'w') as f:
41
+ for ann in annotations:
42
+ f.write(json.dumps(ann) + '\n')
43
+
44
+ # Create a zip file with the dataset_name as the top-level folder
45
+ zip_buffer = BytesIO()
46
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
47
+ for root, dirs, files in os.walk(dataset_path):
48
+ for file in files:
49
+ abs_file = os.path.join(root, file)
50
+ rel_file = os.path.relpath(abs_file, temp_dir)
51
+ zipf.write(abs_file, rel_file)
52
+
53
+ zip_buffer.seek(0)
54
+ return zip_buffer
55
+
56
+ def load_dataset_from_zip(zip_file_path):
57
+ temp_dir = tempfile.mkdtemp()
58
+ try:
59
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
60
+ zip_ref.extractall(temp_dir)
61
+
62
+ # Get dataset name from zip file name
63
+ dataset_name_guess = os.path.splitext(os.path.basename(zip_file_path))[0]
64
+ dataset_path = os.path.join(temp_dir, dataset_name_guess)
65
+
66
+ if os.path.exists(dataset_path):
67
+ dataset_name = dataset_name_guess
68
+ else:
69
+ # If the dataset_name directory doesn't exist, try to find the top-level directory
70
+ entries = [entry for entry in os.listdir(temp_dir) if os.path.isdir(os.path.join(temp_dir, entry))]
71
+ if entries:
72
+ dataset_name = entries[0]
73
+ dataset_path = os.path.join(temp_dir, dataset_name)
74
+ else:
75
+ # Files are directly in temp_dir
76
+ dataset_name = dataset_name_guess
77
+ dataset_path = temp_dir
78
+
79
+ annotations_path = os.path.join(dataset_path, 'annotations.jsonl')
80
+ dataset = []
81
+
82
+ if os.path.exists(annotations_path):
83
+ with open(annotations_path, 'r') as f:
84
+ for line in f:
85
+ ann = json.loads(line)
86
+ file_name = ann['file_name']
87
+ prompt = ann['text']
88
+ image_path = os.path.join(dataset_path, file_name)
89
+
90
+ # Read image and convert to base64
91
+ with open(image_path, 'rb') as img_f:
92
+ image_bytes = img_f.read()
93
+ encoded = base64.b64encode(image_bytes).decode()
94
+ mime_type = "image/png"
95
+ image_data = f"data:{mime_type};base64,{encoded}"
96
+
97
+ dataset.append({
98
+ 'image': image_data,
99
+ 'prompt': prompt
100
+ })
101
+ else:
102
+ # If annotations file not found
103
+ return None, []
104
+
105
+ return dataset_name, dataset
106
+ except Exception as e:
107
+ print(f"Error loading dataset: {e}")
108
+ return None, []
109
+
110
+ def display_dataset_html(dataset, page_number=0, items_per_page=2):
111
+ if dataset:
112
+ start_idx = page_number * items_per_page
113
+ end_idx = start_idx + items_per_page
114
+ dataset_slice = dataset[start_idx:end_idx]
115
+ html_content = '''
116
+ <div style="display: flex; overflow-x: auto; padding: 10px; border: 1px solid #ccc;">
117
+ '''
118
+ for idx_offset, entry in enumerate(dataset_slice):
119
+ idx = start_idx + idx_offset
120
+ image_data = entry['image']
121
+ prompt = entry['prompt']
122
+ html_content += f"""
123
+ <div style="display: flex; flex-direction: column; align-items: center; margin-right: 20px;">
124
+ <div style="margin-bottom: 5px;">{idx}</div>
125
+ <img src="{image_data}" alt="Image {idx}" style="max-height: 150px;"/>
126
+ <div style="max-width: 150px; word-wrap: break-word; text-align: center;">{prompt}</div>
127
+ </div>
128
+ """
129
+ html_content += '</div>'
130
+ return html_content
131
+ else:
132
+ return "<div>No entries in dataset.</div>"
133
+
134
+ #Interface
135
+ with gr.Blocks() as demo:
136
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1px;'>Dataset Creator</h1>")
137
+ gr.Markdown("You must create/upload a dataset before selecting one")
138
+ datasets = gr.State({})
139
+ current_dataset_name = gr.State("")
140
+ current_page_number = gr.State(0)
141
+
142
+ # Top-level components
143
+ with gr.Column():
144
+ dataset_selector = gr.Dropdown(label="Select Dataset", interactive=True)
145
+ message_box = gr.Textbox(interactive=False, label="Message")
146
+
147
+ # Dataset Viewer and Pagination Controls at the Bottom
148
+ with gr.Column():
149
+ gr.Markdown("### Dataset Viewer")
150
+ dataset_html = gr.HTML()
151
+ with gr.Row():
152
+ prev_button = gr.Button("Previous Page")
153
+ next_button = gr.Button("Next Page")
154
+
155
+ # Tabs
156
+ with gr.Tabs():
157
+ with gr.TabItem("Create / Upload Dataset"):
158
+ with gr.Row():
159
+ with gr.Column():
160
+ gr.Markdown("### Create a New Dataset")
161
+ dataset_name_input = gr.Textbox(label="New Dataset Name")
162
+ create_button = gr.Button("Create Dataset")
163
+ with gr.Column():
164
+ gr.Markdown("### Upload Existing Dataset")
165
+ upload_input = gr.File(label="Upload Dataset Zip", type="filepath", file_types=['.zip'])
166
+ upload_button = gr.Button("Upload Dataset")
167
+
168
+ def create_dataset(name, datasets):
169
+ if not name:
170
+ return gr.update(), "Please enter a dataset name."
171
+ if name in datasets:
172
+ return gr.update(), f"Dataset '{name}' already exists."
173
+ datasets[name] = []
174
+ return gr.update(choices=list(datasets.keys()), value=name), f"Dataset '{name}' created."
175
+
176
+ create_button.click(
177
+ create_dataset,
178
+ inputs=[dataset_name_input, datasets],
179
+ outputs=[dataset_selector, message_box]
180
+ )
181
+
182
+ def upload_dataset(zip_file_path, datasets):
183
+ if not zip_file_path:
184
+ return gr.update(), "Please upload a zip file."
185
+ dataset_name, dataset = load_dataset_from_zip(zip_file_path)
186
+ if dataset_name is None:
187
+ return gr.update(), "Failed to load dataset from zip file."
188
+ if dataset_name in datasets:
189
+ return gr.update(), f"Dataset '{dataset_name}' already exists."
190
+ datasets[dataset_name] = dataset
191
+ return gr.update(choices=list(datasets.keys()), value=dataset_name), f"Dataset '{dataset_name}' uploaded."
192
 
193
+ upload_button.click(
194
+ upload_dataset,
195
+ inputs=[upload_input, datasets],
196
+ outputs=[dataset_selector, message_box]
197
+ )
198
 
199
+ with gr.TabItem("Add Entry"):
200
+ with gr.Row():
201
+ image_input = gr.Image(label="Upload Image", type="numpy")
202
+ prompt_input = gr.Textbox(label="Prompt")
203
+ add_button = gr.Button("Add Entry")
 
204
 
205
+ def add_entry(image_data, prompt, current_dataset_name, datasets):
206
+ if not current_dataset_name:
207
+ return datasets, gr.update(), gr.update(), "No dataset selected."
208
+ if image_data is None or not prompt:
209
+ return datasets, gr.update(), gr.update(), "Please provide both an image and a prompt."
210
+ # Convert image_data to base64
211
+ image = Image.fromarray(image_data.astype('uint8'))
212
+ buffered = BytesIO()
213
+ image.save(buffered, format="PNG")
214
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
215
+ img_data = f"data:image/png;base64,{img_str}"
216
+ datasets[current_dataset_name].append({'image': img_data, 'prompt': prompt})
217
+ dataset = datasets[current_dataset_name]
218
+ # Reset page number to 0 and refresh HTML
219
+ page_number = 0
220
+ dataset = datasets[current_dataset_name]
221
+ html_content = display_dataset_html(dataset, page_number=page_number)
222
+ return datasets, page_number, gr.update(value=html_content), f"Entry added to dataset '{current_dataset_name}'."
223
 
224
+ add_button.click(
225
+ add_entry,
226
+ inputs=[image_input, prompt_input, current_dataset_name, datasets],
227
+ outputs=[datasets, current_page_number, dataset_html, message_box]
228
+ )
229
 
230
+ with gr.TabItem("Edit / Delete Entry"):
231
+ with gr.Column():
232
+ selected_image = gr.Image(label="Selected Image", interactive=False, type="numpy")
233
+ selected_prompt = gr.Textbox(label="Current Prompt", interactive=False)
234
+ # Define entry_selector here
235
+ entry_selector = gr.Dropdown(label="Select Entry to Edit/Delete")
236
+ new_prompt_input = gr.Textbox(label="New Prompt (for Edit)")
237
+ with gr.Row():
238
+ edit_button = gr.Button("Edit Entry")
239
+ delete_button = gr.Button("Delete Entry")
240
 
241
+ def update_selected_entry(entry_option, current_dataset_name, datasets):
242
+ if not current_dataset_name or not entry_option:
243
+ return gr.update(), gr.update()
244
+ index = int(entry_option.split(":")[0])
245
+ entry = datasets[current_dataset_name][index]
246
+ image_data = entry['image']
247
+ prompt = entry['prompt']
248
+ # Decode base64 image data to numpy array
249
+ image_bytes = base64.b64decode(image_data.split(",")[1])
250
+ image = Image.open(BytesIO(image_bytes))
251
+ image_array = np.array(image)
252
+ return gr.update(value=image_array), gr.update(value=prompt)
253
 
254
+ entry_selector.change(
255
+ update_selected_entry,
256
+ inputs=[entry_selector, current_dataset_name, datasets],
257
+ outputs=[selected_image, selected_prompt]
258
+ )
 
 
259
 
260
+ def edit_entry(entry_option, new_prompt, current_dataset_name, datasets, current_page_number):
261
+ if not current_dataset_name:
262
+ return datasets, gr.update(), gr.update(), gr.update(), f"No dataset selected."
263
+ if not entry_option or not new_prompt.strip():
264
+ return datasets, gr.update(), gr.update(), gr.update(), f"Please select an entry and provide a new prompt."
265
+ index = int(entry_option.split(":")[0])
266
+ datasets[current_dataset_name][index]['prompt'] = new_prompt
267
+ dataset = datasets[current_dataset_name]
268
+ html_content = display_dataset_html(dataset, page_number=current_page_number)
269
+ # Update entry_selector options
270
+ entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
271
+ return datasets, gr.update(value=html_content), gr.update(choices=entry_options), gr.update(value=""), f"Entry {index} updated."
272
 
273
+ edit_button.click(
274
+ edit_entry,
275
+ inputs=[entry_selector, new_prompt_input, current_dataset_name, datasets, current_page_number],
276
+ outputs=[datasets, dataset_html, entry_selector, new_prompt_input, message_box]
277
+ )
278
 
279
+ def delete_entry(entry_option, current_dataset_name, datasets, current_page_number):
280
+ if not current_dataset_name:
281
+ return datasets, gr.update(), gr.update(), gr.update(), gr.update(), "No dataset selected."
282
+ if not entry_option:
283
+ return datasets, gr.update(), gr.update(), gr.update(), gr.update(), "Please select an entry to delete."
284
+ index = int(entry_option.split(":")[0])
285
+ del datasets[current_dataset_name][index]
286
+ dataset = datasets[current_dataset_name]
287
+ html_content = display_dataset_html(dataset, page_number=current_page_number)
288
+ # Update entry_selector options
289
+ entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
290
+ return datasets, gr.update(value=html_content), gr.update(choices=entry_options), gr.update(value=None), f"Entry {index} deleted."
291
 
292
+ delete_button.click(
293
+ delete_entry,
294
+ inputs=[entry_selector, current_dataset_name, datasets, current_page_number],
295
+ outputs=[datasets, dataset_html, entry_selector, selected_image, message_box]
296
+ )
297
 
298
+ # Function to update entry_selector options
299
+ def update_entry_selector(current_dataset_name, datasets):
300
+ if current_dataset_name in datasets:
301
+ dataset = datasets[current_dataset_name]
302
+ entry_options = [f"{idx}: {entry['prompt'][:30]}" for idx, entry in enumerate(dataset)]
303
+ return gr.update(choices=entry_options)
304
+ else:
305
+ return gr.update(choices=[])
306
 
307
+ # Update entry_selector when dataset is selected
308
+ dataset_selector.change(
309
+ update_entry_selector,
310
+ inputs=[current_dataset_name, datasets],
311
+ outputs=[entry_selector]
312
+ )
313
 
314
+ # Also update entry_selector when an entry is added in "Add Entry" tab
315
+ add_button.click(
316
+ update_entry_selector,
317
+ inputs=[current_dataset_name, datasets],
318
+ outputs=[entry_selector]
319
+ )
320
 
321
+ with gr.TabItem("Download Dataset"):
322
+ download_button = gr.Button("Download Dataset")
323
+ download_output = gr.File(label="Download Zip", interactive=False)
324
 
325
+ def download_dataset(current_dataset_name, datasets):
326
+ if not current_dataset_name:
327
+ return None, "No dataset selected."
328
+ if not datasets[current_dataset_name]:
329
+ return None, "Dataset is empty."
330
+ zip_buffer = save_dataset_to_zip(current_dataset_name, datasets[current_dataset_name])
331
+ # Write zip_buffer to a temporary file
332
+ temp_dir = tempfile.mkdtemp()
333
+ zip_path = os.path.join(temp_dir, f"{current_dataset_name}.zip")
334
+ with open(zip_path, 'wb') as f:
335
+ f.write(zip_buffer.getvalue())
336
+ return zip_path, f"Dataset '{current_dataset_name}' is ready for download."
337
 
338
+ download_button.click(
339
+ download_dataset,
340
+ inputs=[current_dataset_name, datasets],
341
+ outputs=[download_output, message_box]
342
+ )
343
 
344
+ def select_dataset(dataset_name, datasets):
345
+ if dataset_name in datasets:
346
+ dataset = datasets[dataset_name]
347
+ html_content = display_dataset_html(dataset, page_number=0)
348
+ return dataset_name, 0, gr.update(value=html_content), f"Dataset '{dataset_name}' selected."
349
+ else:
350
+ return "", 0, gr.update(value="<div>Select a dataset.</div>"), ""
351
 
352
+ dataset_selector.change(
353
+ select_dataset,
354
+ inputs=[dataset_selector, datasets],
355
+ outputs=[current_dataset_name, current_page_number, dataset_html, message_box]
356
+ )
357
+
358
+ def change_page(action, current_page_number, datasets, current_dataset_name):
359
+ if not current_dataset_name:
360
+ return current_page_number, gr.update(), "No dataset selected."
361
+ dataset = datasets[current_dataset_name]
362
+ total_pages = (len(dataset) - 1) // 5 + 1
363
+ if action == "next":
364
+ if current_page_number + 1 < total_pages:
365
+ current_page_number += 1
366
+ elif action == "prev":
367
+ if current_page_number > 0:
368
+ current_page_number -= 1
369
+ html_content = display_dataset_html(dataset, page_number=current_page_number)
370
+ return current_page_number, gr.update(value=html_content), ""
371
+
372
+ prev_button.click(
373
+ fn=lambda current_page_number, datasets, current_dataset_name: change_page("prev", current_page_number, datasets, current_dataset_name),
374
+ inputs=[current_page_number, datasets, current_dataset_name],
375
+ outputs=[current_page_number, dataset_html, message_box]
376
+ )
377
 
378
+ next_button.click(
379
+ fn=lambda current_page_number, datasets, current_dataset_name: change_page("next", current_page_number, datasets, current_dataset_name),
380
+ inputs=[current_page_number, datasets, current_dataset_name],
381
+ outputs=[current_page_number, dataset_html, message_box]
382
+ )
383
 
384
+ # Initialize dataset_selector
385
+ def initialize_components(datasets):
386
+ return gr.update(choices=list(datasets.keys()))
 
387
 
388
+ demo.load(
389
+ initialize_components,
390
+ inputs=[datasets],
391
+ outputs=[dataset_selector]
392
+ )
393
 
 
394
  demo.launch()