pbernabeu commited on
Commit
1051963
·
1 Parent(s): ec3621f

Release Model

Browse files
README.md CHANGED
@@ -1,3 +1,256 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - HuggingFaceM4/COCO
5
+ - ehristoforu/dalle-3-images
6
+ - poloclub/diffusiondb
7
+ - ehristoforu/midjourney-images
8
+ - nateraw/midjourney-texttoimage
9
+ - duchaiten/duchaiten-realistic-sdxl
10
+ tags:
11
+ - vision
12
+ - image-classification
13
+ pipeline_tag: multi-class-image-classification
14
+ ---
15
+
16
+ # SuSy - Synthetic Image Detector
17
+
18
+ ## Model Details
19
+
20
+ <!-- Provide a longer summary of what this model is. -->
21
+
22
+ SuSy is a Spatial-Based Synthetic Image Detection and Recognition Model, designed and trained to detect synthetic images and attribute them to a generative model (i.e., two StableDiffusion models, two Midjourney versions and DALL·E 3). The model takes image patches of size 224x224 as input, and outputs the probability of the image being authentic or having been created by each of the aforementioned generative models.
23
+
24
+ ![model-architecture](model_architecture.png)
25
+
26
+ The model is based on a CNN architecture and is trained using a supervised learning approach. It's design is based on [previous work](https://upcommons.upc.edu/handle/2117/395959), originally intended for video superresolution detection, adapted here for the tasks of synthetic image detection and recognition. The architecture consists of two modules: a feature extractor and a multi-layer perceptron (MLP), as it's quite light weight. SuSy has a total of 12.7M parameters, with the feature extractor accounting for 12.5M parameters and the MLP accounting for the remaining 197K.
27
+
28
+ The CNN feature extractor consists of five stages following a [ResNet-18](https://arxiv.org/abs/1512.03385) scheme. The output of each of the blocks is used as input for various bottleneck modules that are arranged in a staircase pattern. The bottleneck modules consist of three 2D convolutional layers. Each level of bottlenecks takes input at a later stage than the previous level, and each bottleneck module takes input from the current stage and, except the first bottleneck of each level, from the previous bottleneck module.
29
+
30
+ The outputs of each level of bottlenecks and stage 4 are passed to a 2D adaptative average pooling layer and then concatenated to form the feature map feeding the MLP. The MLP consists of three fully connected layers with 512, 256 and 256 units, respectively. Between each layer, a dropout layer (rate of 0.5) prevents overfitting. The output of the MLP has 6 units, corresponding to the number of classes in the dataset (5 synthetic models and 1 real image class).
31
+
32
+ The model can be used as a detector by either taking the class with the highest probability as the output or summing the probabilities of the synthetic classes and comparing them to the real class. The model can also be used as an recognition model by taking the class with the highest probability as the output.
33
+
34
+ ### Model Description
35
+
36
+ - **Developed by:** Pablo Bernabeu, Enrique Lopez and Dario Garcia-Gasulla from [HPAI](https://hpai.bsc.es/)
37
+ - **Model type:** Spatial-Based Synthetic Image Detection and Recognition Convolutional Neural Network
38
+ - **License:** [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)
39
+
40
+ ## Uses
41
+
42
+ This model can be used to detect synthetic images in a scalable manner, thanks to its small size. Since it operates on patches of 224x224, a moving window should be implemented in inference when applied on larger inputs (the most likely scenario, and the one it was trained under). This also enables the capacity for synthetic content localization within a high resolution input.
43
+
44
+ Any individual or organization seeking for support on the identification of synthetic content can use this model. However, it should not be used as the only source of evidence, particularly when applied to inputs produced by generative models not included in its training (see details in Training Data below).
45
+
46
+ ### Intended Uses
47
+
48
+ Intended uses include the following:
49
+
50
+ * Detection of authentic and synthetic images
51
+ * Attribution of synthetic images to their generative model (if included in the training data)
52
+ * Localization of image patches likely to be synthetic or tampered.
53
+
54
+ ### Out-of-Scope Uses
55
+
56
+ Out-of-scope uses include the following:
57
+
58
+ * Detection of manually edited images using traditional tools.
59
+ * Detection of images automatically downscaled and/or upscaled. These are considered as non-synthetic samples in the model training phase.
60
+ * Detection of inpainted images.
61
+ * Detection of synthetic vs manually crafted illustrations. The model is trained only on photorealistic samples.
62
+ * Attribution of synthetic images to their generative model if the model was not included in the training data. Although some generalization capabilities are expected, reliability in this case cannot be estimated.
63
+
64
+ ## Bias, Risks, and Limitations
65
+
66
+ The model may be biased in the following ways:
67
+
68
+ * The model may be biased towards the training data, which may not be representative of all authentic and synthetic images. Particularly for the class of real world images, which were obtained from a single source.
69
+ * The model may be biased towards the generative models included in the training data, which may not be representative of all possible generative models. Particularly new ones, since all models included were released between 2022 and 2023.
70
+ * The model may be biased towards certain type of images or contents. While it is trained using roughly 18K synthetic images, no assessment was made on which domains and profiles are included in those.
71
+
72
+ The model has the following technical limitations:
73
+
74
+ * The performance of the model may be influenced by transformations and editions performed on the images. While the model was trained on some alterations (JPEG compression, downscaling, and downscaling+upscaling) there are other alterations applicable to images that could reduce the model accuracy.
75
+ * The model will not be able to attribute synthetic images to their generative model if the model was not included in the training data.
76
+ * The model is trained on patches with high gray-level contrast. For images composed entirely by low contrast regions, the model may not work as expected.
77
+
78
+ ### Recommendations
79
+
80
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
81
+
82
+ ## How to Get Started with the Model
83
+
84
+ Use the code below to get started with the model.
85
+
86
+ ```python
87
+ import torch
88
+ from PIL import Image
89
+ from torchvision import transforms
90
+
91
+ # Load the model
92
+ model = torch.jit.load("SuSy.pt")
93
+
94
+ # Load patch
95
+ patch = Image.open("midjourney-images-example-patch0.png")
96
+
97
+ # Transform patch to tensor
98
+ patch = transforms.PILToTensor()(patch).unsqueeze(0) / 255.
99
+
100
+ # Predict patch
101
+ model.eval()
102
+ with torch.no_grad():
103
+ preds = model(patch)
104
+
105
+ print(preds)
106
+ ```
107
+
108
+ See `test_image.py` and `test_patch.py` for other examples on how to use the model.
109
+
110
+ ## Training Details
111
+
112
+ ### Training Data
113
+
114
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
115
+
116
+ | Dataset | Year | Test | Train |
117
+ |:-----------------:|:----:|------:|------:|
118
+ | COCO | 2017 | 1,234 | 4,201 |
119
+ | dalle-3-images | 2023 | 330 | 1,317 |
120
+ | diffusiondb | 2022 | 1,234 | 4,201 |
121
+ | midjourney-images | 2023 | 246 | 980 |
122
+ | midjourney-tti | 2022 | 906 | 3,624 |
123
+ | realisticSDXL | 2023 | 1,234 | 4,201 |
124
+
125
+ #### Authentic Images
126
+
127
+ - [COCO](https://cocodataset.org/)
128
+
129
+ We use a random subset of the COCO dataset, containing 5,435 images, for the authentic images in our training dataset. The partitions are made respecting the original COCO splits, with 4,201 images in the training partition and 1,234 in the test partition.
130
+
131
+ #### Synthetic Images
132
+
133
+ - [dalle-3-images](https://huggingface.co/datasets/ehristoforu/dalle-3-images)
134
+ - [diffusiondb](https://poloclub.github.io/diffusiondb/)
135
+ - [midjourney-images](https://huggingface.co/datasets/ehristoforu/midjourney-images)
136
+ - [midjourney-texttoimage](https://www.kaggle.com/datasets/succinctlyai/midjourney-texttoimage)
137
+ - [realistic-SDXL](https://huggingface.co/datasets/DucHaiten/DucHaiten-realistic-SDXL)
138
+
139
+ For the diffusiondb dataset, we use a random subset of 5,435 images, with 4,201 in the training partition and 1,234 in the test partition. We use only the realistic images from the realisticSDXL dataset, with images in the realistic-2.2 split in our training data and the realistic-1 split for our test partition. The remaining datasets are used in their entirety, with 80% of the images in the training partition and 20% in the test partition.
140
+
141
+ ### Training Procedure
142
+
143
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
144
+
145
+ #### Preprocessing
146
+
147
+ **Patch Extraction**
148
+
149
+ To prepare the training data, we extract 240x240 patches from the images, minimizing the overlap between them. We then select the most informative patches by calculating the gray-level co-occurrence matrix (GLCM) for each patch. Given the GLCM, we calculate the contrast and select the five patches with the highest contrast. These patches are then passed to the model in their original RGB format and cropped to 224x224.
150
+
151
+ **Data Augmentation**
152
+
153
+ | Technique | Probability | Other Parameters |
154
+ |--------------------------|:-----------:|---------------------------------------------------------------------------------------------------------|
155
+ | HorizontalFlip | 0.35 | - |
156
+ | RandomBrightnessContrast | 0.50 | brightness\_limit=0.2 contrast\_limit=0.2 |
157
+ | RandomGamma | 0.50 | gamma\_limit=(80, 120) |
158
+ | CoarseDropout | 0.50 | min\_holes=1, max\_holes=3 min\_height=64, max\_height=100, min\_width=64, max\_width=100 fill\_value=0 |
159
+
160
+
161
+ #### Training Hyperparameters
162
+
163
+ - Loss Function: Cross-Entropy Loss
164
+ - Optimizer: Adam
165
+ - Learning Rate: 0.0001
166
+ - Weight Decay: 0
167
+ - Scheduler: ReduceLROnPlateau
168
+ - Factor: 0.1
169
+ - Patience: 4
170
+ - Batch Size: 128
171
+ - Epochs: 50
172
+ - Early Stopping: 8
173
+
174
+ ## Evaluation
175
+
176
+ <!-- This section describes the evaluation protocols and provides the results. -->
177
+
178
+ ### Testing Data, Factors & Metrics
179
+
180
+ #### Testing Data
181
+
182
+ <!-- This should link to a Dataset Card if possible. -->
183
+
184
+ - Test Split of our Training Dataset
185
+ - Synthetic Images in the Wild: Dataset containing 210 Authentic and Synthetic Images obtained from Social Media Platforms
186
+ - [Flickr 30k Dataset](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset)
187
+
188
+ #### Metrics
189
+
190
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
191
+
192
+ - Accuracy: The proportion of correctly classified images.
193
+ - F1 Score: The harmonic mean of precision and recall.
194
+
195
+ ### Results
196
+
197
+ <!-- This section provides the results of the evaluation. -->
198
+
199
+ #### Test Split
200
+
201
+ Task | Detection F1 Score | Recognition F1 Score
202
+ --- | --- | ---
203
+ Original Images | 0.9867 | 0.9041
204
+ JPEG Compressed Images | 0.9918 | 0.9141
205
+ Downscaled Images | 0.9761 | 0.7470
206
+ Downscaled+Upscaled Images | 0.9868 | 0.8266
207
+
208
+ #### Synthetic Images in the Wild
209
+
210
+ 79.55% Detection Accuracy
211
+
212
+ #### Flickr30k Dataset
213
+
214
+ 99.19% Detection Accuracy
215
+
216
+ ### Summary
217
+
218
+ The model obtains performs well in the test split, with high detection and recognition F1 scores. The model shows robustness to the JPEG compressed images for both tasks while the performance in the downscaled and rescaled images suffers in the recognition task, but the detection task remains stable.
219
+
220
+ The model is also evaluated in our Synthetic Images in the Wild dataset, which contains 220 images obtained from social media platforms, with 121 real images and 99 AI-generated images. The difficuly of this dataset lies in the fact that the images are uploaded to social media by a wide range of users, so the images may have different resolutions, lighting conditions and quality, additionally they may have been edited or compressed. Regarding the synthetic images, the generation process is unknown, so the model has to generalize to unseen generative models. The dataset was tested by 10 human evaluators, which achieved an average detection accuracy of 72.22% and a best detection accuracy of 78.73%. The model achieves a detection accuracy of 79.55% in this dataset at its best threshold.
221
+
222
+ Finally, the model shows excellent performance in the Flickr30k dataset. This dataset contains authentic images, so it serves the purpose of testing the number of false positives generated by the model. The model achieves a detection accuracy of 99.19% in this dataset at its best threshold.
223
+
224
+ ## Environmental Impact
225
+
226
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
227
+
228
+ - **Hardware Type:** 2xH100
229
+ - **Hours used:** 15
230
+ - **Hardware Provider:** Barcelona Supercomputing Center (BSC)
231
+ - **Compute Region:** Spain
232
+ - **Carbon Emitted:** 2.11kg
233
+
234
+ ## Citation
235
+
236
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
237
+
238
+ **BibTeX:**
239
+
240
+ ```bibtex
241
+ @thesis{bernabeu2024stair,
242
+ title={Detecting and Attributing AI-Generated Images with Machine Learning},
243
+ author={Bernabeu Pérez, Pablo},
244
+ school={UPC, Facultat d'Informàtica de Barcelona, Departament de Ciències de la Computació},
245
+ year={2024},
246
+ month={06}
247
+ }
248
+ ```
249
+
250
+ ## Model Card Authors
251
+
252
+ [Pablo Bernabeu](https://huggingface.co/pabberpe) and [Dario Garcia-Gasulla](https://huggingface.co/dariog)
253
+
254
+ ## Model Card Contact
255
+
256
+ For further inquiries, please contact [HPAI](mailto:[email protected])
SuSy.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f2916d1fe0967380340860a0d046f625fd4bb34359157bc318707770a002ce9
3
+ size 50810328
midjourney-images-example-patch0.png ADDED
midjourney-images-example.jpg ADDED
model_architecture.png ADDED
test_image.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from PIL import Image
5
+ from skimage.feature import graycomatrix, graycoprops
6
+ from torchvision import transforms
7
+
8
+ # Load the model
9
+ model = torch.jit.load("SuSy.pt")
10
+
11
+ # Load the image
12
+ image = Image.open("midjourney-images-example.jpg")
13
+
14
+ # Set Parameters
15
+ top_k_patches = 5
16
+ patch_size = 224
17
+
18
+ # Get the image dimensions
19
+ width, height = image.size
20
+
21
+ # Calculate the number of patches
22
+ num_patches_x = width // patch_size
23
+ num_patches_y = height // patch_size
24
+
25
+ # Divide the image in patches
26
+ patches = np.zeros((num_patches_x * num_patches_y, patch_size, patch_size, 3), dtype=np.uint8)
27
+ for i in range(num_patches_x):
28
+ for j in range(num_patches_y):
29
+ x = i * patch_size
30
+ y = j * patch_size
31
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
32
+ patches[i * num_patches_y + j] = np.array(patch)
33
+
34
+ # Compute the most relevant patches (optional)
35
+ dissimilarity_scores = []
36
+ for patch in patches:
37
+ transform_patch = transforms.Compose([transforms.PILToTensor(), transforms.Grayscale()])
38
+ grayscale_patch = transform_patch(Image.fromarray(patch)).squeeze(0)
39
+ glcm = graycomatrix(grayscale_patch, [5], [0], 256, symmetric=True, normed=True)
40
+ dissimilarity_scores.append(graycoprops(glcm, "contrast")[0, 0])
41
+
42
+ # Sort patch indices by their dissimilarity score
43
+ sorted_indices = np.argsort(dissimilarity_scores)[::-1]
44
+
45
+ # Extract top k patches and convert them to tensor
46
+ top_patches = patches[sorted_indices[:top_k_patches]]
47
+ top_patches = torch.from_numpy(np.transpose(top_patches, (0, 3, 1, 2))) / 255.0
48
+
49
+ # Predict patches
50
+ model.eval()
51
+ with torch.no_grad():
52
+ preds = model(top_patches)
53
+
54
+ # Print results
55
+ classes = ['authentic', 'dalle-3-images', 'diffusiondb', 'midjourney-images', 'midjourney_tti', 'realisticSDXL']
56
+ result = pd.DataFrame(preds.numpy(), columns=classes)
57
+ print(result)
test_patch.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+
6
+ # Load the model
7
+ model = torch.jit.load("SuSy.pt")
8
+
9
+ # Load patch
10
+ patch = Image.open("midjourney-images-example-patch0.png")
11
+
12
+ # Transform patch to tensor
13
+ patch = transforms.PILToTensor()(patch).unsqueeze(0) / 255.
14
+
15
+ # Predict patch
16
+ model.eval()
17
+ with torch.no_grad():
18
+ preds = model(patch)
19
+
20
+ # Print results
21
+ classes = ['authentic', 'dalle-3-images', 'diffusiondb', 'midjourney-images', 'midjourney_tti', 'realisticSDXL']
22
+ result = pd.DataFrame(preds.numpy(), columns=classes)
23
+ print(result)