File size: 4,388 Bytes
0a275ec
 
f5f8296
c0c08a7
 
 
e14c9b5
c0c08a7
 
e7de587
c0c08a7
 
 
 
e14c9b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc3574
 
 
c50f5c0
dfc3574
 
c50f5c0
 
e14c9b5
 
 
 
 
c0c08a7
 
e14c9b5
c0c08a7
 
 
e14c9b5
c0c08a7
e14c9b5
 
c0c08a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7de587
322cdde
e14c9b5
 
c0c08a7
 
a7ce092
c0c08a7
 
 
 
c1c2031
 
c0c08a7
ed9d9fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.47.1")
from transformers import pipeline
import gradio
import base64
from PIL import Image, ImageDraw
from io import BytesIO
from sentence_transformers import SentenceTransformer, util
import spaces

backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
personDetailsPipe = pipeline("image-segmentation", model="yolo12138/segformer-b2-human-parse-24")
faceModal = pipeline("image-segmentation", model="jonathandinu/face-parsing")
faceDetectionModal = pipeline("object-detection", model="aditmohan96/detr-finetuned-face")
PersonDetectionpipe = pipeline("object-detection", model="hustvl/yolos-tiny")

def getPersonDetail(image):
    data = PersonDetectionpipe(image)
    persn = []
    for per in data:
        if per["label"].lower() == "person":
            persn.append(per["box"])
    n = 1
    ret = {}
    for cord in persn:
        crop_box = (cord['xmin'], cord['ymin'], cord['xmax'], cord['ymax'])
        cropped_image = image.crop(crop_box)
        personData = personDetailsPipe(cropped_image)
        for dt in personData:
            if len(persn) > 1:
                ret[(f'Person {n} {dt["label"]}').lower()] = cbiwm(image, dt["mask"], cord)
            else:
                ret[dt["label"].lower()] = cbiwm(image, dt["mask"], cord)
        n = n + 1
    return ret

def cbiwm(image, mask, coordinates):
    black_image = Image.new("RGBA", image.size, (0, 0, 0, 255))
    black_image.paste(mask, (coordinates['xmin'], coordinates['ymin']), mask)
    return black_image

def processFaceDetails(image):
    ret = getPersonDetail(image)
    data = faceDetectionModal(image)
    if len(data) > 1:
        cordinates = data[1]["box"]
        crop_box = (data[1]["box"]['xmin'], data[1]["box"]['ymin'], data[1]["box"]['xmax'], data[1]["box"]['ymax'])
    elif len(data) > 0:
        cordinates = data[0]["box"]
        crop_box = (data[0]["box"]['xmin'], data[0]["box"]['ymin'], data[0]["box"]['xmax'], data[0]["box"]['ymax'])
    else:
        return ret
    cropped_image = image.crop(crop_box)
    facedata = faceModal(cropped_image)
    for imask in facedata:
        ret[imask["label"].replace(".png", "").lower()] = cbiwm(image, imask["mask"], cordinates)
    return ret

def getImageDetails(image) -> dict:
    ret = processFaceDetails(image)
    person = PersonPipe(image)
    bg = backgroundPipe(image)
    for imask in bg:
        ret[imask["label"].lower()] = imask["mask"] # Apply base64 image converter here if needed
    for mask in person:
        ret[mask["label"].lower()] = mask["mask"] # Apply base64 image converter here if needed
    return ret

def processSentence(sentence: str, semilist: list):
    query_embedding = sentenceModal.encode(sentence)
    passage_embedding = sentenceModal.encode(semilist)
    listv = util.dot_score(query_embedding, passage_embedding)[0]
    float_list = []
    for i in listv:
        float_list.append(i)
    max_value = max(float_list)
    max_index = float_list.index(max_value)
    return semilist[max_index]

def process_image(image):
    rgba_image = image.convert("RGBA")
    switched_data = [
        (255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in rgba_image.getdata()
    ]
    switched_image = Image.new("RGBA", rgba_image.size)
    switched_image.putdata(switched_data)
    final_data = [
        (0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in switched_image.getdata()
    ]
    processed_image = Image.new("RGBA", rgba_image.size)
    processed_image.putdata(final_data)
    return processed_image

@spaces.GPU()
def processAndGetMask(image: str, text: str):
    datas = getImageDetails(image)
    labs = list(datas.keys())
    selector = processSentence(text, labs)
    imageout = datas[selector]
    print(f"Selected : {selector} Among : {labs}")
    return process_image(imageout)

gr = gradio.Interface(
    processAndGetMask,
    [gradio.Image(label="Input Image", type="pil"), gradio.Text(label="Input text to segment")],
    gradio.Image(label="Output Image", type="pil")
)
gr.launch(share=True)