File size: 1,978 Bytes
6e00714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn

processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"

#image = Image.open(requests.get(url, stream=True).raw)
image_path = "C:/Users/Admin/Downloads/dress1.jpg"
image = Image.open(image_path)
                   
inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs)
logits = outputs.logits.cpu()
print("here")
upsampled_logits = nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,
)
print(upsampled_logits.argmax(dim=1))

pred_seg = upsampled_logits.argmax(dim=1)[0]
plt.imshow(pred_seg)
import matplotlib as mpl
label_names = list(model.config.id2label)
# Create a color map with the same number of colors as your labels
# Use the updated method to get the colormap
cmap = mpl.colormaps['tab20']

# Create the figure and axes for the plot and the colorbar
fig, ax = plt.subplots()

# Display the segmentation
im = ax.imshow(pred_seg, cmap=cmap)

# Create a colorbar
cbar = fig.colorbar(im, ax=ax, ticks=range(len(label_names)))
cbar.ax.set_yticklabels(label_names)

plt.show()

# Get the number of labels
n_labels = len(label_names)

# Extract RGB values for each color in the colormap
colors = cmap.colors[:n_labels]

# Convert RGBA to RGB by omitting the Alpha value
rgb_colors = [color[:3] for color in colors]

# Create a dictionary mapping labels to RGB colors
label_to_color = dict(zip(label_names, rgb_colors))

# Display the mapping
for label, color in label_to_color.items():
    print(f"{label}: {color}")