Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import random | |
import gradio as gr | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from CCAgT_utils.categories import CategoriesInfos | |
from CCAgT_utils.types.mask import Mask | |
from CCAgT_utils.visualization import plot | |
from PIL import Image | |
from torch import nn | |
from transformers import SegformerFeatureExtractor | |
from transformers import SegformerForSemanticSegmentation | |
from transformers.modeling_outputs import SemanticSegmenterOutput | |
matplotlib.use('Agg') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300' | |
model = SegformerForSemanticSegmentation.from_pretrained( | |
model_hub_name, | |
).to(device) | |
model.eval() | |
feature_extractor = SegformerFeatureExtractor.from_pretrained( | |
model_hub_name, | |
) | |
def segment( | |
image: Image.Image, | |
) -> SemanticSegmenterOutput: | |
inputs = feature_extractor( | |
image, | |
return_tensors='pt', | |
).to(device) | |
outputs = model(**inputs) | |
return outputs | |
def post_processing( | |
outputs: SemanticSegmenterOutput, | |
target_size: tuple[int, int], | |
) -> np.ndarray: | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=target_size, | |
mode='bilinear', | |
align_corners=False, | |
) | |
segmentation_mask = upsampled_logits.argmax(dim=1)[0] | |
return np.array(segmentation_mask) | |
def colorize( | |
mask: Mask, | |
) -> np.ndarray: | |
return mask.colorized(CategoriesInfos()) / 255 | |
# Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L35 | |
def get_random_crop_coords( | |
height: int, | |
width: int, | |
crop_height: int, | |
crop_width: int, | |
h_start: float, | |
w_start: float, | |
): | |
y1 = int((height - crop_height + 1) * h_start) | |
y2 = y1 + crop_height | |
x1 = int((width - crop_width + 1) * w_start) | |
x2 = x1 + crop_width | |
return x1, y1, x2, y2 | |
# Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L46 | |
def random_crop( | |
img: np.ndarray, | |
crop_height: int, | |
crop_width: int, | |
h_start: float, | |
w_start: float, | |
) -> np.ndarray: | |
height, width = img.shape[:2] | |
x1, y1, x2, y2 = get_random_crop_coords( | |
height, width, crop_height, crop_width, h_start, w_start, | |
) | |
img = img[y1:y2, x1:x2] | |
return img | |
def process_big_images( | |
image: Image.Image, | |
) -> Mask: | |
'''Process and post-processing for images bigger than 400x300''' | |
img = np.asarray(image) | |
if img.shape[0] > 300 or img.shape[1] > 400: | |
img = random_crop(img, 300, 400, random.random(), random.random()) | |
target_size = (img.shape[0], img.shape[1]) | |
outputs = segment(Image.fromarray(img)) | |
msk = post_processing(outputs, target_size) | |
return img, Mask(msk) | |
def image_with_mask( | |
image: Image.Image, | |
mask: Mask, | |
) -> plt.Figure: | |
fig = plt.figure(dpi=600) | |
plt.imshow(image) | |
plt.imshow( | |
mask.categorical, | |
cmap=mask.cmap(CategoriesInfos()), | |
vmax=max(mask.unique_ids), | |
vmin=min(mask.unique_ids), | |
interpolation='nearest', | |
alpha=0.4, | |
) | |
plt.axis('off') | |
plt.tight_layout(pad=0) | |
return fig | |
def categories_map( | |
mask: Mask, | |
) -> plt.Figure: | |
fig = plt.figure(dpi=600) | |
handles = plot.create_handles( | |
CategoriesInfos(), selected_categories=mask.unique_ids, | |
) | |
plt.legend(handles=handles, fontsize=24, loc='center') | |
plt.axis('off') | |
return fig | |
def main(image): | |
image = Image.fromarray(image) | |
img, mask = process_big_images(image) | |
mask_colorized = colorize(mask) | |
fig = image_with_mask(img, mask) | |
return categories_map(mask), Image.fromarray(img), mask_colorized, fig | |
title = 'SegFormer (b3) - CCAgT dataset' | |
description = f""" | |
This is demo for the SegFormer fine-tuned on sub-dataset from | |
[CCAgT dataset](https://huggingface.co/datasets/lapix/CCAgT). This model | |
was trained to segment cervical cells silver-stained (AgNOR technique) | |
images with resolution of 400x300. The model was available at HF hub at | |
[{model_hub_name}](https://huggingface.co/{model_hub_name}). If input | |
an image bigger than 400x300, the demo will random crop it. | |
""" | |
examples = [ | |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'], | |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'], | |
] + [ | |
[f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg'] | |
for x in {3, 10, 12, 18, 35, 78, 89} | |
] | |
demo = gr.Interface( | |
main, | |
inputs=[gr.Image()], | |
outputs=[ | |
gr.Plot(label='Categories map'), | |
gr.Image(label='Image'), | |
gr.Image(label='Mask'), | |
gr.Plot(label='Image with mask'), | |
], | |
title=title, | |
description=description, | |
examples=examples, | |
allow_flagging='never', | |
cache_examples=False, | |
) | |
if __name__ == '__main__': | |
demo.launch() | |