sadjava commited on
Commit
479c88d
1 Parent(s): c0e0595

Add app.py

Browse files
Files changed (7) hide show
  1. app.py +102 -0
  2. constants.py +33 -0
  3. examples/1.jpg +0 -0
  4. examples/2.jpg +0 -0
  5. examples/3.jpg +0 -0
  6. models.py +135 -0
  7. utils.py +250 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from ultralytics import YOLO
4
+ from torchvision.transforms.functional import to_tensor
5
+ from huggingface_hub import hf_hub_download
6
+ import torch
7
+ import albumentations as A
8
+ from albumentations.pytorch.transforms import ToTensorV2
9
+ import pandas as pd
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+
12
+ from utils import *
13
+ from models import YOLOStamp, Encoder
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+
18
+ yolov8 = YOLO(hf_hub_download('stamps-labs/yolov8-finetuned', filename='best.torchscript'), task='detect')
19
+
20
+ yolo_stamp = YOLOStamp()
21
+ yolo_stamp.load_state_dict(torch.load(hf_hub_download('stamps-labs/yolo-stamp', filename='state_dict.pth')))
22
+ yolo_stamp = yolo_stamp.to(device)
23
+ yolo_stamp.eval()
24
+ transform = A.Compose([
25
+ A.Normalize(),
26
+ ToTensorV2(p=1.0),
27
+ ])
28
+
29
+ vits8 = torch.jit.load(hf_hub_download('stamps-labs/vits8-stamp', filename='vits8stamp-torchscript.pth'))
30
+ vits8 = vits8.to(device)
31
+ vits8.eval()
32
+
33
+ encoder = Encoder()
34
+ encoder.load_state_dict(torch.load(hf_hub_download('stamps-labs/vae-encoder', filename='encoder.pth')))
35
+ encoder = encoder.to(device)
36
+ encoder.eval()
37
+
38
+
39
+ def predict(image, det_choice, emb_choice):
40
+
41
+ shape = torch.tensor(image.size)
42
+ image = image.convert('RGB')
43
+
44
+ if det_choice == 'yolov8':
45
+ coef = torch.hstack((shape, shape)) / 640
46
+ image = image.resize((640, 640))
47
+ boxes = yolov8(image)[0].boxes.xyxy.cpu()
48
+ image_with_boxes = visualize_bbox(image, boxes)
49
+
50
+ elif det_choice == 'yolo-stamp':
51
+ coef = torch.hstack((shape, shape)) / 448
52
+ image = image.resize((448, 448))
53
+ image_tensor = transform(image=np.array(image))['image']
54
+ output = yolo_stamp(image_tensor.unsqueeze(0).to(device))
55
+
56
+ boxes = output_tensor_to_boxes(output[0].detach().cpu())
57
+ boxes = nonmax_suppression(boxes)
58
+ boxes = xywh2xyxy(torch.tensor(boxes)[:, :4])
59
+ image_with_boxes = visualize_bbox(image, boxes)
60
+ else:
61
+ return
62
+
63
+
64
+ embeddings = []
65
+ if emb_choice == 'vits8':
66
+ for box in boxes:
67
+ cropped_stamp = to_tensor(image.crop(box.tolist()))
68
+ embeddings.append(vits8(cropped_stamp.unsqueeze(0).to(device))[0].detach().cpu())
69
+
70
+ elif emb_choice == 'vae-encoder':
71
+ for box in boxes:
72
+ cropped_stamp = to_tensor(image.crop(box.tolist()).resize((118, 118)))
73
+ embeddings.append(encoder(cropped_stamp.unsqueeze(0).to(device))[0][0].detach().cpu())
74
+
75
+ embeddings = np.stack(embeddings)
76
+
77
+ similarities = cosine_similarity(embeddings)
78
+
79
+ boxes = boxes * coef
80
+ df_boxes = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
81
+
82
+ fig, ax = plt.subplots()
83
+ im, cbar = heatmap(similarities, range(1, len(embeddings) + 1), range(1, len(embeddings) + 1), ax=ax,
84
+ cmap="YlGn", cbarlabel="Embeddings similarities")
85
+ texts = annotate_heatmap(im, valfmt="{x:.3f}")
86
+ return image_with_boxes, df_boxes, embeddings, fig
87
+
88
+
89
+ examples = [['./examples/1.jpg', 'yolov8', 'vits8'], ['./examples/2.jpg', 'yolov8', 'vae-encoder'], ['./examples/3.jpg', 'yolo-stamp', 'vits8']]
90
+ inputs = [
91
+ gr.Image(type="pil"),
92
+ gr.Dropdown(choices=['yolov8', 'yolo-stamp'], value='yolov8', label='Detection model'),
93
+ gr.Dropdown(choices=['vits8', 'vae-encoder'], value='vits8', label='Embedding model'),
94
+ ]
95
+ outputs = [
96
+ gr.Image(type="pil"),
97
+ gr.DataFrame(type='pandas', label="Bounding boxes"),
98
+ gr.DataFrame(type='numpy', label="Embeddings"),
99
+ gr.Plot(type='numpy', label="Cosine Similarities")
100
+ ]
101
+ app = gr.Interface(predict, inputs, outputs, examples=examples)
102
+ app.launch()
constants.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # shape of input image to YOLO
2
+ W, H = 448, 448
3
+ # grid size after last convolutional layer of YOLO
4
+ S = 7
5
+ # anchors of YOLO model
6
+ ANCHORS = [[1.5340836003942058, 1.258424277571925],
7
+ [1.4957766780406023, 2.2319885681948217],
8
+ [1.2508985343739407, 0.8233350471152914]]
9
+ # number of anchors boxes
10
+ BOX = len(ANCHORS)
11
+ # maximum number of stamps on image
12
+ STAMP_NB_MAX = 10
13
+ # minimal confidence of presence a stamp in the grid cell
14
+ OUTPUT_THRESH = 0.7
15
+ # maximal iou score to consider boxes different
16
+ IOU_THRESH = 0.3
17
+ # path to folder containing images
18
+ IMAGE_FOLDER = './data/images'
19
+ # path to .cvs file containing annotations
20
+ ANNOTATIONS_PATH = './data/all_annotations.csv'
21
+ # standard deviation and mean of pixel values for normalization
22
+ STD = (0.229, 0.224, 0.225)
23
+ MEAN = (0.485, 0.456, 0.406)
24
+ # box color to show the bounding box on image
25
+ BOX_COLOR = (0, 0, 255)
26
+
27
+
28
+ # dimenstion of image embedding
29
+ Z_DIM = 128
30
+ # hidden dimensions for encoder model
31
+ ENC_HIDDEN_DIM = 16
32
+ # hidden dimensions for decoder model
33
+ DEC_HIDDEN_DIM = 64
examples/1.jpg ADDED
examples/2.jpg ADDED
examples/3.jpg ADDED
models.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from constants import *
5
+
6
+ """
7
+ Class for custom activation.
8
+ """
9
+ class SymReLU(nn.Module):
10
+ def __init__(self, inplace: bool = False):
11
+ super().__init__()
12
+ self.inplace = inplace
13
+
14
+ def forward(self, input):
15
+ return torch.min(torch.max(input, -torch.ones_like(input)), torch.ones_like(input))
16
+
17
+ def extra_repr(self) -> str:
18
+ inplace_str = 'inplace=True' if self.inplace else ''
19
+ return inplace_str
20
+
21
+
22
+ """
23
+ Class implementing YOLO-Stamp architecture described in https://link.springer.com/article/10.1134/S1054661822040046.
24
+ """
25
+ class YOLOStamp(nn.Module):
26
+ def __init__(
27
+ self,
28
+ anchors=ANCHORS,
29
+ in_channels=3,
30
+ ):
31
+ super().__init__()
32
+
33
+ self.register_buffer('anchors', torch.tensor(anchors))
34
+
35
+ self.act = SymReLU()
36
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
37
+ self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
38
+ self.norm1 = nn.BatchNorm2d(num_features=8)
39
+ self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
40
+ self.norm2 = nn.BatchNorm2d(num_features=16)
41
+ self.conv3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
42
+ self.norm3 = nn.BatchNorm2d(num_features=16)
43
+ self.conv4 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
44
+ self.norm4 = nn.BatchNorm2d(num_features=16)
45
+ self.conv5 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
46
+ self.norm5 = nn.BatchNorm2d(num_features=16)
47
+ self.conv6 = nn.Conv2d(in_channels=16, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
48
+ self.norm6 = nn.BatchNorm2d(num_features=24)
49
+ self.conv7 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
50
+ self.norm7 = nn.BatchNorm2d(num_features=24)
51
+ self.conv8 = nn.Conv2d(in_channels=24, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
52
+ self.norm8 = nn.BatchNorm2d(num_features=48)
53
+ self.conv9 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
54
+ self.norm9 = nn.BatchNorm2d(num_features=48)
55
+ self.conv10 = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
56
+ self.norm10 = nn.BatchNorm2d(num_features=48)
57
+ self.conv11 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
58
+ self.norm11 = nn.BatchNorm2d(num_features=64)
59
+ self.conv12 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
60
+ self.norm12 = nn.BatchNorm2d(num_features=256)
61
+ self.conv13 = nn.Conv2d(in_channels=256, out_channels=len(anchors) * 5, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
62
+
63
+ def forward(self, x, head=True):
64
+ x = x.type(self.conv1.weight.dtype)
65
+ x = self.act(self.pool(self.norm1(self.conv1(x))))
66
+ x = self.act(self.pool(self.norm2(self.conv2(x))))
67
+ x = self.act(self.pool(self.norm3(self.conv3(x))))
68
+ x = self.act(self.pool(self.norm4(self.conv4(x))))
69
+ x = self.act(self.pool(self.norm5(self.conv5(x))))
70
+ x = self.act(self.norm6(self.conv6(x)))
71
+ x = self.act(self.norm7(self.conv7(x)))
72
+ x = self.act(self.pool(self.norm8(self.conv8(x))))
73
+ x = self.act(self.norm9(self.conv9(x)))
74
+ x = self.act(self.norm10(self.conv10(x)))
75
+ x = self.act(self.norm11(self.conv11(x)))
76
+ x = self.act(self.norm12(self.conv12(x)))
77
+ x = self.conv13(x)
78
+ nb, _, nh, nw= x.shape
79
+ x = x.permute(0, 2, 3, 1).view(nb, nh, nw, self.anchors.shape[0], 5)
80
+ return x
81
+
82
+
83
+ class Encoder(torch.nn.Module):
84
+ '''
85
+ Encoder Class
86
+ Values:
87
+ im_chan: the number of channels of the output image, a scalar
88
+ hidden_dim: the inner dimension, a scalar
89
+ '''
90
+
91
+ def __init__(self, im_chan=3, output_chan=Z_DIM, hidden_dim=ENC_HIDDEN_DIM):
92
+ super(Encoder, self).__init__()
93
+ self.z_dim = output_chan
94
+ self.disc = torch.nn.Sequential(
95
+ self.make_disc_block(im_chan, hidden_dim),
96
+ self.make_disc_block(hidden_dim, hidden_dim * 2),
97
+ self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
98
+ self.make_disc_block(hidden_dim * 4, hidden_dim * 8),
99
+ self.make_disc_block(hidden_dim * 8, output_chan * 2, final_layer=True),
100
+ )
101
+
102
+ def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
103
+ '''
104
+ Function to return a sequence of operations corresponding to a encoder block of the VAE,
105
+ corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
106
+ Parameters:
107
+ input_channels: how many channels the input feature representation has
108
+ output_channels: how many channels the output feature representation should have
109
+ kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
110
+ stride: the stride of the convolution
111
+ final_layer: whether we're on the final layer (affects activation and batchnorm)
112
+ '''
113
+ if not final_layer:
114
+ return torch.nn.Sequential(
115
+ torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
116
+ torch.nn.BatchNorm2d(output_channels),
117
+ torch.nn.LeakyReLU(0.2, inplace=True),
118
+ )
119
+ else:
120
+ return torch.nn.Sequential(
121
+ torch.nn.Conv2d(input_channels, output_channels, kernel_size, stride),
122
+ )
123
+
124
+ def forward(self, image):
125
+ '''
126
+ Function for completing a forward pass of the Encoder: Given an image tensor,
127
+ returns a 1-dimension tensor representing fake/real.
128
+ Parameters:
129
+ image: a flattened image tensor with dimension (im_dim)
130
+ '''
131
+ disc_pred = self.disc(image)
132
+ encoding = disc_pred.view(len(disc_pred), -1)
133
+ # The stddev output is treated as the log of the variance of the normal
134
+ # distribution by convention and for numerical stability
135
+ return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()
utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib
6
+ from constants import *
7
+
8
+ def visualize_bbox(image: Image, prediction):
9
+ img = image.copy()
10
+ draw = ImageDraw.Draw(img)
11
+ for i, box in enumerate(prediction):
12
+ x1, y1, x2, y2 = box.cpu()
13
+ draw = ImageDraw.Draw(img)
14
+ text_w, text_h = draw.textsize(str(i + 1))
15
+ label_y = y1 if y1 <= text_h else y1 - text_h
16
+ draw.rectangle((x1, y1, x2, y2), outline='red')
17
+ draw.rectangle((x1, label_y, x1+text_w, label_y+text_h), outline='red', fill='red')
18
+ draw.text((x1, label_y), str(i + 1), fill='white')
19
+ return img
20
+
21
+ def xywh2xyxy(x):
22
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
23
+ y[..., 0] = x[..., 0]
24
+ y[..., 1] = x[..., 1]
25
+ y[..., 2] = x[..., 0] + x[..., 2]
26
+ y[..., 3] = x[..., 1] + x[..., 3]
27
+ return y
28
+
29
+ def output_tensor_to_boxes(boxes_tensor):
30
+ """
31
+ Converts the YOLO output tensor to list of boxes with probabilites.
32
+
33
+ Arguments:
34
+ boxes_tensor -- tensor of shape (S, S, BOX, 5)
35
+
36
+ Returns:
37
+ boxes -- list of shape (None, 5)
38
+
39
+ Note: "None" is here because you don't know the exact number of selected boxes, as it depends on the threshold.
40
+ For example, the actual output size of scores would be (10, 5) if there are 10 boxes
41
+ """
42
+ cell_w, cell_h = W/S, H/S
43
+ boxes = []
44
+
45
+ for i in range(S):
46
+ for j in range(S):
47
+ for b in range(BOX):
48
+ anchor_wh = torch.tensor(ANCHORS[b])
49
+ data = boxes_tensor[i,j,b]
50
+ xy = torch.sigmoid(data[:2])
51
+ wh = torch.exp(data[2:4])*anchor_wh
52
+ obj_prob = torch.sigmoid(data[4])
53
+
54
+ if obj_prob > OUTPUT_THRESH:
55
+ x_center, y_center, w, h = xy[0], xy[1], wh[0], wh[1]
56
+ x, y = x_center+j-w/2, y_center+i-h/2
57
+ x,y,w,h = x*cell_w, y*cell_h, w*cell_w, h*cell_h
58
+ box = [x,y,w,h, obj_prob]
59
+ boxes.append(box)
60
+ return boxes
61
+
62
+ def overlap(interval_1, interval_2):
63
+ """
64
+ Calculates length of overlap between two intervals.
65
+
66
+ Arguments:
67
+ interval_1 -- list or tuple of shape (2,) containing endpoints of the first interval
68
+ interval_2 -- list or tuple of shape (2, 2) containing endpoints of the second interval
69
+
70
+ Returns:
71
+ overlap -- length of overlap
72
+ """
73
+ x1, x2 = interval_1
74
+ x3, x4 = interval_2
75
+ if x3 < x1:
76
+ if x4 < x1:
77
+ return 0
78
+ else:
79
+ return min(x2,x4) - x1
80
+ else:
81
+ if x2 < x3:
82
+ return 0
83
+ else:
84
+ return min(x2,x4) - x3
85
+
86
+
87
+ def compute_iou(box1, box2):
88
+ """
89
+ Compute IOU between box1 and box2.
90
+
91
+ Argmunets:
92
+ box1 -- list of shape (5, ). Represents the first box
93
+ box2 -- list of shape (5, ). Represents the second box
94
+ Each box is [x, y, w, h, prob]
95
+
96
+ Returns:
97
+ iou -- intersection over union score between two boxes
98
+ """
99
+ x1,y1,w1,h1 = box1[0], box1[1], box1[2], box1[3]
100
+ x2,y2,w2,h2 = box2[0], box2[1], box2[2], box2[3]
101
+
102
+ area1, area2 = w1*h1, w2*h2
103
+ intersect_w = overlap((x1,x1+w1), (x2,x2+w2))
104
+ intersect_h = overlap((y1,y1+h1), (y2,y2+w2))
105
+ if intersect_w == w1 and intersect_h == h1 or intersect_w == w2 and intersect_h == h2:
106
+ return 1.
107
+ intersect_area = intersect_w*intersect_h
108
+ iou = intersect_area/(area1 + area2 - intersect_area)
109
+ return iou
110
+
111
+
112
+ def nonmax_suppression(boxes, iou_thresh = IOU_THRESH):
113
+ """
114
+ Removes ovelap bboxes
115
+
116
+ Arguments:
117
+ boxes -- list of shape (None, 5)
118
+ iou_thresh -- maximal value of iou when boxes are considered different
119
+ Each box is [x, y, w, h, prob]
120
+
121
+ Returns:
122
+ boxes -- list of shape (None, 5) with removed overlapping boxes
123
+ """
124
+ boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
125
+ for i, current_box in enumerate(boxes):
126
+ if current_box[4] <= 0:
127
+ continue
128
+ for j in range(i+1, len(boxes)):
129
+ iou = compute_iou(current_box, boxes[j])
130
+ if iou > iou_thresh:
131
+ boxes[j][4] = 0
132
+ boxes = [box for box in boxes if box[4] > 0]
133
+ return boxes
134
+
135
+ def heatmap(data, row_labels, col_labels, ax=None,
136
+ cbar_kw=None, cbarlabel="", **kwargs):
137
+ """
138
+ Create a heatmap from a numpy array and two lists of labels.
139
+
140
+ Parameters
141
+ ----------
142
+ data
143
+ A 2D numpy array of shape (M, N).
144
+ row_labels
145
+ A list or array of length M with the labels for the rows.
146
+ col_labels
147
+ A list or array of length N with the labels for the columns.
148
+ ax
149
+ A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
150
+ not provided, use current axes or create a new one. Optional.
151
+ cbar_kw
152
+ A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
153
+ cbarlabel
154
+ The label for the colorbar. Optional.
155
+ **kwargs
156
+ All other arguments are forwarded to `imshow`.
157
+ """
158
+
159
+ if ax is None:
160
+ ax = plt.gca()
161
+
162
+ if cbar_kw is None:
163
+ cbar_kw = {}
164
+
165
+ # Plot the heatmap
166
+ im = ax.imshow(data, **kwargs)
167
+
168
+ # Create colorbar
169
+ cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
170
+ cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
171
+
172
+ # Show all ticks and label them with the respective list entries.
173
+ ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
174
+ ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
175
+
176
+ # Let the horizontal axes labeling appear on top.
177
+ ax.tick_params(top=True, bottom=False,
178
+ labeltop=True, labelbottom=False)
179
+
180
+ # Rotate the tick labels and set their alignment.
181
+ plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
182
+ rotation_mode="anchor")
183
+
184
+ # Turn spines off and create white grid.
185
+ ax.spines[:].set_visible(False)
186
+
187
+ ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
188
+ ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
189
+ ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
190
+ ax.tick_params(which="minor", bottom=False, left=False)
191
+
192
+ return im, cbar
193
+
194
+ def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
195
+ textcolors=("black", "white"),
196
+ threshold=None, **textkw):
197
+ """
198
+ A function to annotate a heatmap.
199
+
200
+ Parameters
201
+ ----------
202
+ im
203
+ The AxesImage to be labeled.
204
+ data
205
+ Data used to annotate. If None, the image's data is used. Optional.
206
+ valfmt
207
+ The format of the annotations inside the heatmap. This should either
208
+ use the string format method, e.g. "$ {x:.2f}", or be a
209
+ `matplotlib.ticker.Formatter`. Optional.
210
+ textcolors
211
+ A pair of colors. The first is used for values below a threshold,
212
+ the second for those above. Optional.
213
+ threshold
214
+ Value in data units according to which the colors from textcolors are
215
+ applied. If None (the default) uses the middle of the colormap as
216
+ separation. Optional.
217
+ **kwargs
218
+ All other arguments are forwarded to each call to `text` used to create
219
+ the text labels.
220
+ """
221
+
222
+ if not isinstance(data, (list, np.ndarray)):
223
+ data = im.get_array()
224
+
225
+ # Normalize the threshold to the images color range.
226
+ if threshold is not None:
227
+ threshold = im.norm(threshold)
228
+ else:
229
+ threshold = im.norm(data.max())/2.
230
+
231
+ # Set default alignment to center, but allow it to be
232
+ # overwritten by textkw.
233
+ kw = dict(horizontalalignment="center",
234
+ verticalalignment="center")
235
+ kw.update(textkw)
236
+
237
+ # Get the formatter in case a string is supplied
238
+ if isinstance(valfmt, str):
239
+ valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
240
+
241
+ # Loop over the data and create a `Text` for each "pixel".
242
+ # Change the text's color depending on the data.
243
+ texts = []
244
+ for i in range(data.shape[0]):
245
+ for j in range(data.shape[1]):
246
+ kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
247
+ text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
248
+ texts.append(text)
249
+
250
+ return texts