taesiri commited on
Commit
04e8185
·
1 Parent(s): d7d9d3e
Files changed (2) hide show
  1. app.py +337 -0
  2. utils.py +154 -0
app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import json
4
+ import os
5
+ import shutil
6
+ import uuid
7
+ import glob
8
+ from huggingface_hub import CommitScheduler, HfApi, snapshot_download
9
+ from pathlib import Path
10
+ import git
11
+ from datasets import Dataset, Features, Value, Sequence, Image as ImageFeature
12
+ import threading
13
+ import time
14
+ from utils import process_and_push_dataset
15
+ from datasets import load_dataset
16
+
17
+ api = HfApi(token=os.environ["HF_TOKEN"])
18
+
19
+ VALID_DATASET = load_dataset("taesiri/IERv2-Subset", split="train")
20
+
21
+ VALID_DATASET_POST_IDS = (
22
+ load_dataset("taesiri/IERv2-Subset", split="train", columns=["post_id"])
23
+ .to_pandas()["post_id"]
24
+ .tolist()
25
+ )
26
+
27
+ POST_ID_TO_ID_MAP = {post_id: idx for idx, post_id in enumerate(VALID_DATASET_POST_IDS)}
28
+
29
+ DATASET_REPO = "taesiri/AIImageEditingResults_Intemediate"
30
+ FINAL_DATASET_REPO = "taesiri/AIImageEditingResults"
31
+
32
+
33
+ # Download existing data from hub
34
+ def sync_with_hub():
35
+ """
36
+ Synchronize local data with the hub by cloning the dataset repo
37
+ """
38
+ print("Starting sync with hub...")
39
+ data_dir = Path("./data")
40
+ if data_dir.exists():
41
+ # Backup existing data
42
+ backup_dir = Path("./data_backup")
43
+ if backup_dir.exists():
44
+ shutil.rmtree(backup_dir)
45
+ shutil.copytree(data_dir, backup_dir)
46
+
47
+ # Clone/pull latest data from hub
48
+ repo_url = f"https://huggingface.co/datasets/{DATASET_REPO}"
49
+ hub_data_dir = Path("hub_data")
50
+
51
+ if hub_data_dir.exists():
52
+ # If repo exists, do a git pull
53
+ print("Pulling latest changes...")
54
+ repo = git.Repo(hub_data_dir)
55
+ origin = repo.remotes.origin
56
+ origin.pull()
57
+ else:
58
+ # Clone the repo
59
+ print("Cloning repository...")
60
+ git.Repo.clone_from(repo_url, hub_data_dir)
61
+
62
+ # Merge hub data with local data
63
+ hub_data_source = hub_data_dir / "data"
64
+ if hub_data_source.exists():
65
+ # Create data dir if it doesn't exist
66
+ data_dir.mkdir(exist_ok=True)
67
+
68
+ # Copy files from hub
69
+ for item in hub_data_source.glob("*"):
70
+ if item.is_dir():
71
+ dest = data_dir / item.name
72
+ if not dest.exists(): # Only copy if doesn't exist locally
73
+ shutil.copytree(item, dest)
74
+
75
+ # Clean up cloned repo
76
+ if hub_data_dir.exists():
77
+ shutil.rmtree(hub_data_dir)
78
+ print("Finished syncing with hub!")
79
+
80
+
81
+ scheduler = CommitScheduler(
82
+ repo_id=DATASET_REPO,
83
+ repo_type="dataset",
84
+ folder_path="./data",
85
+ path_in_repo="data",
86
+ every=1,
87
+ )
88
+
89
+
90
+ def load_question_data(question_id):
91
+ """
92
+ Load a specific question's data
93
+ Returns a tuple of all form fields
94
+ """
95
+ if not question_id:
96
+ return [None] * 11 # Reduced number of fields
97
+
98
+ # Extract the ID part before the colon from the dropdown selection
99
+ question_id = (
100
+ question_id.split(":")[0].strip() if ":" in question_id else question_id
101
+ )
102
+
103
+ json_path = os.path.join("./data", question_id, "question.json")
104
+ if not os.path.exists(json_path):
105
+ print(f"Question file not found: {json_path}")
106
+ return [None] * 11
107
+
108
+ try:
109
+ with open(json_path, "r", encoding="utf-8") as f:
110
+ data = json.loads(f.read().strip())
111
+
112
+ # Load images
113
+ def load_image(image_path):
114
+ if not image_path:
115
+ return None
116
+ full_path = os.path.join(
117
+ "./data", question_id, os.path.basename(image_path)
118
+ )
119
+ return full_path if os.path.exists(full_path) else None
120
+
121
+ question_images = data.get("question_images", [])
122
+ rationale_images = data.get("rationale_images", [])
123
+
124
+ return [
125
+ (
126
+ ",".join(data["question_categories"])
127
+ if isinstance(data["question_categories"], list)
128
+ else data["question_categories"]
129
+ ),
130
+ data["question"],
131
+ data["final_answer"],
132
+ data.get("rationale_text", ""),
133
+ load_image(question_images[0] if question_images else None),
134
+ load_image(question_images[1] if len(question_images) > 1 else None),
135
+ load_image(question_images[2] if len(question_images) > 2 else None),
136
+ load_image(question_images[3] if len(question_images) > 3 else None),
137
+ load_image(rationale_images[0] if rationale_images else None),
138
+ load_image(rationale_images[1] if len(rationale_images) > 1 else None),
139
+ question_id,
140
+ ]
141
+ except Exception as e:
142
+ print(f"Error loading question {question_id}: {str(e)}")
143
+ return [None] * 11
144
+
145
+
146
+ def load_post_image(post_id):
147
+ if not post_id:
148
+ return [None] * 21 # source image + 10 pairs of (image, text)
149
+
150
+ idx = POST_ID_TO_ID_MAP[post_id]
151
+ source_image = VALID_DATASET[idx]["image"]
152
+
153
+ # Load existing responses if any
154
+ post_folder = os.path.join("./data", str(post_id))
155
+ metadata_path = os.path.join(post_folder, "metadata.json")
156
+
157
+ if os.path.exists(metadata_path):
158
+ with open(metadata_path, "r") as f:
159
+ metadata = json.load(f)
160
+
161
+ # Initialize response data
162
+ responses = [(None, "")] * 10
163
+
164
+ # Fill in existing responses
165
+ for response in metadata["responses"]:
166
+ idx = response["response_id"]
167
+ if idx < 10: # Ensure we don't exceed our UI limit
168
+ image_path = os.path.join(post_folder, response["image_path"])
169
+ responses[idx] = (image_path, response["answer_text"])
170
+
171
+ # Flatten responses for output
172
+ flat_responses = [item for pair in responses for item in pair]
173
+ return [source_image] + flat_responses
174
+
175
+ # If no existing responses, return source image and empty responses
176
+ return [source_image] + [None] * 20
177
+
178
+
179
+ def generate_json_files(source_image, responses, post_id):
180
+ """
181
+ Save the source image and multiple responses to the data directory
182
+
183
+ Args:
184
+ source_image: Path to the source image
185
+ responses: List of (image, answer) tuples
186
+ post_id: The post ID from the dataset
187
+ """
188
+ # Create parent data folder if it doesn't exist
189
+ parent_data_folder = "./data"
190
+ os.makedirs(parent_data_folder, exist_ok=True)
191
+
192
+ # Create/clear post_id folder
193
+ post_folder = os.path.join(parent_data_folder, str(post_id))
194
+ if os.path.exists(post_folder):
195
+ shutil.rmtree(post_folder)
196
+ os.makedirs(post_folder)
197
+
198
+ # Save source image
199
+ source_image_path = os.path.join(post_folder, "source_image.png")
200
+ if isinstance(source_image, str):
201
+ shutil.copy2(source_image, source_image_path)
202
+ else:
203
+ gr.processing_utils.save_image(source_image, source_image_path)
204
+
205
+ # Create responses data
206
+ responses_data = []
207
+ for idx, (response_image, answer_text) in enumerate(responses):
208
+ if response_image and answer_text: # Only process if both image and text exist
209
+ response_folder = os.path.join(post_folder, f"response_{idx}")
210
+ os.makedirs(response_folder)
211
+
212
+ # Save response image
213
+ response_image_path = os.path.join(response_folder, "response_image.png")
214
+ if isinstance(response_image, str):
215
+ shutil.copy2(response_image, response_image_path)
216
+ else:
217
+ gr.processing_utils.save_image(response_image, response_image_path)
218
+
219
+ # Add to responses data
220
+ responses_data.append(
221
+ {
222
+ "response_id": idx,
223
+ "answer_text": answer_text,
224
+ "image_path": f"response_{idx}/response_image.png",
225
+ }
226
+ )
227
+
228
+ # Create metadata JSON
229
+ metadata = {
230
+ "post_id": post_id,
231
+ "source_image": "source_image.png",
232
+ "responses": responses_data,
233
+ }
234
+
235
+ # Save metadata
236
+ with open(os.path.join(post_folder, "metadata.json"), "w", encoding="utf-8") as f:
237
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
238
+
239
+ return post_folder
240
+
241
+
242
+ # Build the Gradio app
243
+ with gr.Blocks() as demo:
244
+ gr.Markdown("# Image Response Collector")
245
+
246
+ # Source image selection at the top
247
+ with gr.Column():
248
+ post_id_dropdown = gr.Dropdown(
249
+ label="Select Post ID to Load Image",
250
+ choices=VALID_DATASET_POST_IDS,
251
+ type="value",
252
+ allow_custom_value=False,
253
+ )
254
+ source_image = gr.Image(label="Source Image", type="filepath")
255
+
256
+ # Responses in tabs
257
+ with gr.Tabs() as response_tabs:
258
+ responses = []
259
+ for i in range(10):
260
+ with gr.Tab(f"Response {i+1}"):
261
+ img = gr.Image(label=f"Response Image {i+1}", type="filepath")
262
+ txt = gr.Textbox(label=f"Model Name {i+1}", lines=2)
263
+ responses.append((img, txt))
264
+
265
+ with gr.Row():
266
+ submit_btn = gr.Button("Submit All Responses")
267
+ clear_btn = gr.Button("Clear Form")
268
+
269
+ def submit_responses(source_img, post_id, *response_data):
270
+ if not source_img:
271
+ gr.Warning("Please select a source image first!")
272
+ return
273
+
274
+ if not post_id:
275
+ gr.Warning("Please select a post ID first!")
276
+ return
277
+
278
+ # Convert flat response_data into pairs of (image, text)
279
+ response_pairs = list(zip(response_data[::2], response_data[1::2]))
280
+
281
+ # Filter out empty responses
282
+ valid_responses = [
283
+ (img, txt) for img, txt in response_pairs if img is not None and txt
284
+ ]
285
+
286
+ if not valid_responses:
287
+ gr.Warning("Please provide at least one response (image + text)!")
288
+ return
289
+
290
+ generate_json_files(source_img, valid_responses, post_id)
291
+ gr.Info("Responses saved successfully! 🎉")
292
+
293
+ def clear_form():
294
+ outputs = [None] * (1 + 20) # 1 source image + 10 pairs of (image, text)
295
+ return outputs
296
+
297
+ # Connect components
298
+ post_id_dropdown.change(
299
+ fn=load_post_image,
300
+ inputs=[post_id_dropdown],
301
+ outputs=[source_image] + [comp for pair in responses for comp in pair],
302
+ )
303
+
304
+ submit_inputs = [source_image, post_id_dropdown] + [
305
+ comp for pair in responses for comp in pair
306
+ ]
307
+ submit_btn.click(fn=submit_responses, inputs=submit_inputs)
308
+
309
+ clear_outputs = [source_image] + [comp for pair in responses for comp in pair]
310
+ clear_btn.click(fn=clear_form, outputs=clear_outputs)
311
+
312
+
313
+ def process_thread():
314
+ while True:
315
+ try:
316
+ pass
317
+ # process_and_push_dataset(
318
+ # "./data",
319
+ # FINAL_DATASET_REPO,
320
+ # token=os.environ["HF_TOKEN"],
321
+ # private=True,
322
+ # )
323
+ except Exception as e:
324
+ print(f"Error in process thread: {e}")
325
+ time.sleep(120) # Sleep for 2 minutes
326
+
327
+
328
+ if __name__ == "__main__":
329
+ print("Initializing app...")
330
+ sync_with_hub() # Sync before launching the app
331
+ print("Starting Gradio interface...")
332
+
333
+ # Start the processing thread when the app starts
334
+ processing_thread = threading.Thread(target=process_thread, daemon=True)
335
+ processing_thread.start()
336
+
337
+ demo.launch()
utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from pathlib import Path
5
+ from datasets import Dataset, Features, Value, Sequence, Image as ImageFeature
6
+
7
+
8
+ def process_and_push_dataset(
9
+ data_dir: str, hub_repo: str, token: str, private: bool = True
10
+ ):
11
+ """
12
+ Process local dataset files and push to Hugging Face Hub.
13
+
14
+ Args:
15
+ data_dir (str): Path to the data directory containing submission folders
16
+ hub_repo (str): Name of the Hugging Face repository to push to
17
+ private (bool): Whether to make the pushed dataset private
18
+
19
+ Returns:
20
+ datasets.Dataset: The processed dataset
21
+ """
22
+ # List to store all records
23
+ all_records = []
24
+
25
+ # Walk through all subdirectories in data_dir
26
+ for root, dirs, files in os.walk(data_dir):
27
+ for file in files:
28
+ if file == "question.json":
29
+ file_path = Path(root) / file
30
+ try:
31
+ # Read the JSON file
32
+ with open(file_path, "r", encoding="utf-8") as f:
33
+ record = json.load(f)
34
+
35
+ # Get the folder path for this record
36
+ folder_path = os.path.dirname(file_path)
37
+
38
+ # Fix image paths to include full path
39
+ if "question_images" in record:
40
+ record["question_images"] = [
41
+ str(Path(folder_path) / img_path)
42
+ for img_path in record["question_images"]
43
+ if img_path
44
+ ]
45
+
46
+ if "rationale_images" in record:
47
+ record["rationale_images"] = [
48
+ str(Path(folder_path) / img_path)
49
+ for img_path in record["rationale_images"]
50
+ if img_path
51
+ ]
52
+
53
+ # Flatten author_info dictionary
54
+ author_info = record.pop("author_info", {})
55
+ record.update(
56
+ {f"author_{k}": v for k, v in author_info.items()}
57
+ )
58
+
59
+ # Add the record
60
+ all_records.append(record)
61
+ except Exception as e:
62
+ print(f"Error processing {file_path}: {e}")
63
+
64
+ # Convert to DataFrame
65
+ df = pd.DataFrame(all_records)
66
+
67
+ # Sort by custom_id for consistency
68
+ if not df.empty and "custom_id" in df.columns:
69
+ df = df.sort_values("custom_id")
70
+
71
+ # Ensure all required columns exist with default values
72
+ required_columns = {
73
+ "custom_id": "",
74
+ "author_name": "",
75
+ "author_email_address": "",
76
+ "author_institution": "",
77
+ "question_categories": [],
78
+ "question": "",
79
+ "question_images": [],
80
+ "final_answer": "",
81
+ "rationale_text": "",
82
+ "rationale_images": [],
83
+ "image_attribution": "",
84
+ "subquestions_1_text": "",
85
+ "subquestions_1_answer": "",
86
+ "subquestions_2_text": "",
87
+ "subquestions_2_answer": "",
88
+ "subquestions_3_text": "",
89
+ "subquestions_3_answer": "",
90
+ "subquestions_4_text": "",
91
+ "subquestions_4_answer": "",
92
+ "subquestions_5_text": "",
93
+ "subquestions_5_answer": "",
94
+ }
95
+
96
+ for col, default_value in required_columns.items():
97
+ if col not in df.columns:
98
+ df[col] = default_value
99
+
100
+ # Define features
101
+ features = Features(
102
+ {
103
+ "custom_id": Value("string"),
104
+ "question": Value("string"),
105
+ "question_images": Sequence(ImageFeature()),
106
+ "question_categories": Sequence(Value("string")),
107
+ "final_answer": Value("string"),
108
+ "rationale_text": Value("string"),
109
+ "rationale_images": Sequence(ImageFeature()),
110
+ "image_attribution": Value("string"),
111
+ "subquestions_1_text": Value("string"),
112
+ "subquestions_1_answer": Value("string"),
113
+ "subquestions_2_text": Value("string"),
114
+ "subquestions_2_answer": Value("string"),
115
+ "subquestions_3_text": Value("string"),
116
+ "subquestions_3_answer": Value("string"),
117
+ "subquestions_4_text": Value("string"),
118
+ "subquestions_4_answer": Value("string"),
119
+ "subquestions_5_text": Value("string"),
120
+ "subquestions_5_answer": Value("string"),
121
+ "author_name": Value("string"),
122
+ "author_email_address": Value("string"),
123
+ "author_institution": Value("string"),
124
+ }
125
+ )
126
+
127
+ # Convert DataFrame to dict of lists (Hugging Face Dataset format)
128
+ dataset_dict = {col: df[col].tolist() for col in features.keys()}
129
+
130
+ # Create Dataset directly from dict
131
+ dataset = Dataset.from_dict(dataset_dict, features=features)
132
+
133
+ # Push to hub
134
+ dataset.push_to_hub(hub_repo, private=private, max_shard_size="200MB", token=token)
135
+
136
+ print(f"\nDataset Statistics:")
137
+ print(f"Total number of submissions: {len(dataset)}")
138
+ print(f"\nSuccessfully pushed dataset to {hub_repo}")
139
+
140
+ return dataset
141
+
142
+
143
+ def save_metadata(post_id, metadata):
144
+ # Create directory named after post_id
145
+ directory = os.path.join("data", post_id)
146
+ os.makedirs(directory, exist_ok=True)
147
+
148
+ # Add post_id to metadata
149
+ metadata["post_id"] = post_id
150
+
151
+ # Save metadata to JSON file
152
+ metadata_path = os.path.join(directory, "metadata.json")
153
+ with open(metadata_path, "w") as f:
154
+ json.dump(metadata, f, indent=4)