File size: 1,978 Bytes
6e00714 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
#image = Image.open(requests.get(url, stream=True).raw)
image_path = "C:/Users/Admin/Downloads/dress1.jpg"
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
print("here")
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
print(upsampled_logits.argmax(dim=1))
pred_seg = upsampled_logits.argmax(dim=1)[0]
plt.imshow(pred_seg)
import matplotlib as mpl
label_names = list(model.config.id2label)
# Create a color map with the same number of colors as your labels
# Use the updated method to get the colormap
cmap = mpl.colormaps['tab20']
# Create the figure and axes for the plot and the colorbar
fig, ax = plt.subplots()
# Display the segmentation
im = ax.imshow(pred_seg, cmap=cmap)
# Create a colorbar
cbar = fig.colorbar(im, ax=ax, ticks=range(len(label_names)))
cbar.ax.set_yticklabels(label_names)
plt.show()
# Get the number of labels
n_labels = len(label_names)
# Extract RGB values for each color in the colormap
colors = cmap.colors[:n_labels]
# Convert RGBA to RGB by omitting the Alpha value
rgb_colors = [color[:3] for color in colors]
# Create a dictionary mapping labels to RGB colors
label_to_color = dict(zip(label_names, rgb_colors))
# Display the mapping
for label, color in label_to_color.items():
print(f"{label}: {color}") |