Spaces:
Runtime error
Runtime error
File size: 2,398 Bytes
26be1cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import torch.nn as nn
import cv2
import gradio as gr
import glob
from typing import List
import torch.nn.functional as F
import torchvision.transforms as T
from sklearn.decomposition import PCA
import sklearn
import numpy as np
# Constants
patch_h = 40
patch_w = 40
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# DINOV2
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
# Trasnforms
transform = T.Compose([
T.Resize((patch_h * 14, patch_w * 14)),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
# Empty Tenosr
imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14)
# PCA
pca = PCA(n_components=3)
def query_image(img1, img2, img3, img4) -> List[np.ndarray]:
# Transform
imgs = [img1, img2, img3, img4]
for i, img in enumerate(imgs):
img = np.transpose(img, (2, 0, 1))
imgs_tensor[i] = transform(torch.Tensor(img))
# Get feature from patches
with torch.no_grad():
features_dict = model.forward_features(imgs_tensor)
features = features_dict['x_prenorm'][:, 1:]
features = features.reshape(4 * patch_h * patch_w, -1)
# PCA Feature
pca.fit(features)
pca_features = pca.transform(features)
pca_feature = sklearn.preprocessing.minmax_scale(pca_features)
# Foreground/Background
pca_features_bg = pca_features[:, 0] < 0
pca_features_fg = ~pca_features_bg
# PCA with only foreground
pca.fit(features[pca_features_fg])
pca_features_rem = pca.transform(features[pca_features_fg])
# Min Max Normalization
for i in range(3):
pca_features_rem[:, i] = (pca_features_rem[:, i] - pca_features_rem[:, i].min()) / (pca_features_rem[:, i].max() - pca_features_rem[:, i].min())
pca_features_rgb = np.zeros((4 * patch_h * patch_w, 3))
pca_features_rgb[pca_features_bg] = 0
pca_features_rgb[pca_features_fg] = pca_features_rem
pca_features_rgb = pca_features_rgb.reshape(4, patch_h, patch_w, 3)
return [pca_features_rgb[i] for i in range(4)]
description = """
DINOV2 PCA
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image()],
title="DINOV2 PCA",
description=description,
examples=[],
)
demo.launch()
|