Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import argparse | |
import pathlib | |
from torch.nn import functional as F | |
from show import * | |
from per_segment_anything import sam_model_registry, SamPredictor | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-op", "--output-path", type=str, default='default') | |
args = parser.parse_args() | |
class ImageMask(gr.components.Image): | |
""" | |
Sets: source="canvas", tool="sketch" | |
""" | |
is_template = True | |
def __init__(self, **kwargs): | |
super().__init__(source="upload", tool='select', interactive=True, **kwargs) | |
def preprocess(self, x): | |
return super().preprocess(x) | |
def point_selection(mask_sim, topk=1): | |
# Top-1 point selection | |
w, h = mask_sim.shape | |
topk_xy = mask_sim.flatten(0).topk(topk)[1] | |
topk_x = (topk_xy // h).unsqueeze(0) | |
topk_y = (topk_xy - topk_x * h) | |
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) | |
topk_label = np.array([1] * topk) | |
topk_xy = topk_xy.cpu().numpy() | |
# Top-last point selection | |
last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1] | |
last_x = (last_xy // h).unsqueeze(0) | |
last_y = (last_xy - last_x * h) | |
last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0) | |
last_label = np.array([0] * topk) | |
last_xy = last_xy.cpu().numpy() | |
return topk_xy, topk_label, last_xy, last_label | |
def inference_scribble(image): | |
# in context image and mask | |
ic_image = image["image"] | |
ic_mask = image["mask"] | |
ic_image = np.array(ic_image.convert("RGB")) | |
ic_mask = np.array(ic_mask.convert("RGB")) | |
# sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' # SAM Model | |
sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' # MobileSAM | |
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() #SAM loading | |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt) #SAM loading | |
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt) # MObileSAM loading | |
predictor = SamPredictor(sam) | |
# Image features encoding | |
ref_mask = predictor.set_image(ic_image, ic_mask) | |
ref_feat = predictor.features.squeeze().permute(1, 2, 0) | |
ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear") | |
ref_mask = ref_mask.squeeze()[0] | |
# Target feature extraction | |
print("======> Obtain Location Prior" ) | |
target_feat = ref_feat[ref_mask > 0] | |
target_embedding = target_feat.mean(0).unsqueeze(0) | |
target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True) | |
target_embedding = target_embedding.unsqueeze(0) | |
test_image = ic_image | |
outputs = [] | |
print("======> Testing Image") | |
# Image feature encoding | |
predictor.set_image(test_image) | |
test_feat = predictor.features.squeeze() | |
# Cosine similarity | |
C, h, w = test_feat.shape | |
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) | |
test_feat = test_feat.reshape(C, h * w) | |
sim = target_feat @ test_feat | |
sim = sim.reshape(1, 1, h, w) | |
sim = F.interpolate(sim, scale_factor=4, mode="bilinear") | |
sim = predictor.model.postprocess_masks( | |
sim, | |
input_size=predictor.input_size, | |
original_size=predictor.original_size).squeeze() | |
# Positive-negative location prior | |
topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1) | |
topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0) | |
topk_label = np.concatenate([topk_label_i, last_label_i], axis=0) | |
# Obtain the target guidance for cross-attention layers | |
sim = (sim - sim.mean()) / torch.std(sim) | |
sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear") | |
attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3) | |
# First-step prediction | |
masks, scores, logits, _ = predictor.predict( | |
point_coords=topk_xy, | |
point_labels=topk_label, | |
multimask_output=True, | |
attn_sim=attn_sim, # Target-guided Attention | |
target_embedding=target_embedding # Target-semantic Prompting | |
) | |
best_idx = 0 | |
# Cascaded Post-refinement-1 | |
masks, scores, logits, _ = predictor.predict( | |
point_coords=topk_xy, | |
point_labels=topk_label, | |
mask_input=logits[best_idx: best_idx + 1, :, :], | |
multimask_output=True) | |
best_idx = np.argmax(scores) | |
# Cascaded Post-refinement-2 | |
y, x = np.nonzero(masks[best_idx]) | |
x_min = x.min() | |
x_max = x.max() | |
y_min = y.min() | |
y_max = y.max() | |
input_box = np.array([x_min, y_min, x_max, y_max]) | |
masks, scores, logits, _ = predictor.predict( | |
point_coords=topk_xy, | |
point_labels=topk_label, | |
box=input_box[None, :], | |
mask_input=logits[best_idx: best_idx + 1, :, :], | |
multimask_output=True) | |
best_idx = np.argmax(scores) | |
final_mask = masks[best_idx] | |
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) | |
mask_colors[final_mask, :] = np.array([[128, 0, 0]]) | |
# Save annotations | |
return [Image.fromarray((mask_colors * 0.6 + test_image * 0.4).astype('uint8'), 'RGB'), | |
Image.fromarray((mask_colors ).astype('uint8'), 'RGB')] | |
main_scribble = gr.Interface( | |
fn=inference_scribble, | |
inputs= | |
gr.ImageMask(label="[Stroke] Draw on Image", type='pil'), | |
outputs=[ | |
gr.outputs.Image(type="pil", label="Mask with Image"), | |
gr.outputs.Image(type="pil", label="Mask") | |
], | |
allow_flagging="never", | |
title="SAM based Segment Annotator.", | |
description='Sketch the portion where you want to create Mask.', | |
examples=[ | |
"./cardamage_example/0006.JPEG", | |
"./cardamage_example/0008.JPEG", | |
"./cardamage_example/0206.JPEG" | |
] | |
) | |
main_scribble.launch(share=True) |