Spaces:
Runtime error
Runtime error
import cv2 | |
import gradio as gr | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from CCAgT_utils.categories import CategoriesInfos | |
from CCAgT_utils.slice import __create_xy_slice | |
from CCAgT_utils.types.mask import Mask | |
from CCAgT_utils.visualization import plot | |
from PIL import Image | |
from torch import nn | |
from transformers import SegformerFeatureExtractor | |
from transformers import SegformerForSemanticSegmentation | |
from transformers.modeling_outputs import SemanticSegmenterOutput | |
matplotlib.use('Agg') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300' | |
model = SegformerForSemanticSegmentation.from_pretrained( | |
model_hub_name, | |
).to(device) | |
model.eval() | |
feature_extractor = SegformerFeatureExtractor.from_pretrained( | |
model_hub_name, | |
) | |
def segment( | |
image: Image.Image, | |
) -> SemanticSegmenterOutput: | |
inputs = feature_extractor( | |
image, | |
return_tensors='pt', | |
).to(device) | |
outputs = model(**inputs) | |
return outputs | |
def post_processing( | |
outputs: SemanticSegmenterOutput, | |
target_size: tuple[int, int], | |
) -> np.ndarray: | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=target_size, | |
mode='bilinear', | |
align_corners=False, | |
) | |
segmentation_mask = upsampled_logits.argmax(dim=1)[0] | |
return np.array(segmentation_mask) | |
def colorize( | |
mask: Mask, | |
) -> np.ndarray: | |
return mask.colorized(CategoriesInfos()) / 255 | |
def check_and_resize( | |
image: np.ndarray, | |
) -> np.ndarray: | |
if image.shape[0] > 1200 or image.shape[1] > 1600: | |
r = 1600.0 / image.shape[1] | |
dim = (1600, int(image.shape[0] * r)) | |
return cv2.resize(image, dim, interpolation=cv2.INTER_AREA) | |
return image | |
def process_big_images( | |
image: Image.Image, | |
) -> Mask: | |
'''Process and post-processing for images bigger than 400x300''' | |
img = check_and_resize(np.asarray(image)) | |
mask = np.zeros(shape=(img.shape[0], img.shape[1]), dtype=np.uint8) | |
for bbox in __create_xy_slice(image.size[1], image.size[0], 300, 400): | |
part = cv2.copyMakeBorder( | |
img, | |
bbox.y_init, | |
bbox.y_end, | |
bbox.x_init, | |
bbox.x_end, | |
cv2.BORDER_REFLECT, | |
) | |
target_size = (part.shape[0], part.shape[1]) | |
outputs = segment(Image.fromarray(part)) | |
msk = post_processing(outputs, target_size) | |
mask[bbox.slice_y, bbox.slice_x] = msk[bbox.slice_y, bbox.slice_x] | |
return Mask(mask) | |
def image_with_mask( | |
image: Image.Image, | |
mask: Mask, | |
) -> plt.Figure: | |
fig = plt.figure(dpi=600) | |
plt.imshow(image) | |
plt.imshow( | |
mask.categorical, | |
cmap=mask.cmap(CategoriesInfos()), | |
vmax=max(mask.unique_ids), | |
vmin=min(mask.unique_ids), | |
interpolation='nearest', | |
alpha=0.4, | |
) | |
plt.axis('off') | |
return fig | |
def categories_map( | |
mask: Mask, | |
) -> plt.Figure: | |
fig = plt.figure(dpi=600) | |
handles = plot.create_handles( | |
CategoriesInfos(), selected_categories=mask.unique_ids, | |
) | |
plt.legend(handles=handles, fontsize=24, loc='center') | |
plt.axis('off') | |
return fig | |
def main(image): | |
img = Image.fromarray(image) | |
mask = process_big_images(img) | |
mask_colorized = colorize(mask) | |
fig = image_with_mask(img, mask) | |
return categories_map(mask), mask_colorized, fig | |
title = 'SegFormer (b3) - CCAgT dataset' | |
description = f""" | |
This is demo for the SegFormer fine-tuned on sub-dataset from | |
[CCAgT dataset](https://huggingface.co/datasets/lapix/CCAgT). This model | |
was trained to segment cervical cells silver-stained (AgNOR technique) | |
images with resolution of 400x300. The model was available at HF hub at | |
[{model_hub_name}](https://huggingface.co/{model_hub_name}). | |
""" | |
examples = [ | |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'], | |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'], | |
] + [ | |
[f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg'] | |
for x in {3, 10, 12, 18, 35, 78, 89} | |
] | |
demo = gr.Interface( | |
main, | |
inputs=[gr.Image()], | |
outputs=[ | |
gr.Plot(label='Categories map'), | |
gr.Image(label='Mask'), | |
gr.Plot(label='Image with mask'), | |
], | |
title=title, | |
description=description, | |
examples=examples, | |
allow_flagging='never', | |
cache_examples=False, | |
) | |
if __name__ == '__main__': | |
demo.launch() | |