import gradio as gr from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation from PIL import Image import requests import matplotlib.pyplot as plt import torch.nn as nn extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes") model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") def predict(inp): inputs = extractor(images=inp, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=inp.size[::-1], mode="bilinear", align_corners=False, ) pred_seg = upsampled_logits.argmax(dim=1)[0] pred_seg[pred_seg != 4] = 0 arr_seg = pred_seg.cpu().numpy().astype("uint8") arr_seg *= 255 pil_seg = Image.fromarray(arr_seg) return pil_seg gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="image", ).launch()