Add hairclassifier to the model
Browse files- app.py +54 -9
- hairclassifier_rf.pkl +3 -0
- text_prompts_hair.json +74 -0
app.py
CHANGED
@@ -13,7 +13,9 @@ logger = logging.getLogger("basebody")
|
|
13 |
CLIP_MODEL_NAME = "ViT-B/16"
|
14 |
|
15 |
TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
|
|
|
16 |
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
|
|
|
17 |
|
18 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
19 |
hf_writer = gr.HuggingFaceDatasetSaver(
|
@@ -28,6 +30,12 @@ with open(
|
|
28 |
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
|
29 |
) as f:
|
30 |
text_prompts = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
with open(
|
32 |
os.path.join(
|
33 |
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
|
@@ -36,6 +44,15 @@ with open(
|
|
36 |
) as f:
|
37 |
lr_model = pickle.load(f)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
logger.info("Logistic regression model loaded, coefficients: ")
|
40 |
|
41 |
|
@@ -50,6 +67,27 @@ with torch.no_grad():
|
|
50 |
all_text_features = all_text_features.cpu()
|
51 |
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def predict_fn(input_img):
|
54 |
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
|
55 |
image = preprocess(
|
@@ -57,21 +95,28 @@ def predict_fn(input_img):
|
|
57 |
).unsqueeze(0)
|
58 |
with torch.no_grad():
|
59 |
image_features = clip_model.encode_image(image)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
)
|
66 |
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
|
67 |
-
logger.info(f"cosine_simlarities: {
|
68 |
probabilities = lr_model.predict_proba(
|
69 |
-
|
|
|
|
|
|
|
70 |
)
|
71 |
logger.info(f"probabilities: {probabilities}")
|
72 |
result_probabilty = float(probabilities[0][1].round(3))
|
|
|
73 |
# get decision string
|
74 |
-
if result_probabilty > 0.
|
|
|
|
|
|
|
75 |
decision = "AUTO ACCEPT"
|
76 |
elif result_probabilty < 0.4:
|
77 |
decision = "AUTO REJECT"
|
|
|
13 |
CLIP_MODEL_NAME = "ViT-B/16"
|
14 |
|
15 |
TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
|
16 |
+
HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
|
17 |
LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
|
18 |
+
HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
|
19 |
|
20 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
21 |
hf_writer = gr.HuggingFaceDatasetSaver(
|
|
|
30 |
os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
|
31 |
) as f:
|
32 |
text_prompts = json.load(f)
|
33 |
+
|
34 |
+
with open(
|
35 |
+
os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r"
|
36 |
+
) as f:
|
37 |
+
hair_text_prompts = json.load(f)
|
38 |
+
|
39 |
with open(
|
40 |
os.path.join(
|
41 |
os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
|
|
|
44 |
) as f:
|
45 |
lr_model = pickle.load(f)
|
46 |
|
47 |
+
with open(
|
48 |
+
os.path.join(
|
49 |
+
os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
|
50 |
+
),
|
51 |
+
"rb",
|
52 |
+
) as f:
|
53 |
+
hair_rf_model = pickle.load(f)
|
54 |
+
|
55 |
+
|
56 |
logger.info("Logistic regression model loaded, coefficients: ")
|
57 |
|
58 |
|
|
|
67 |
all_text_features = all_text_features.cpu()
|
68 |
|
69 |
|
70 |
+
hair_text_features = []
|
71 |
+
with torch.no_grad():
|
72 |
+
for k, prompts in hair_text_prompts.items():
|
73 |
+
assert len(prompts) == 2
|
74 |
+
inputs = clip.tokenize(prompts)
|
75 |
+
outputs = clip_model.encode_text(inputs)
|
76 |
+
hair_text_features.append(outputs)
|
77 |
+
hair_text_features = torch.cat(hair_text_features, dim=0)
|
78 |
+
hair_text_features = hair_text_features.cpu()
|
79 |
+
|
80 |
+
|
81 |
+
def get_cosine_similarities(image_features, text_features):
|
82 |
+
cosine_simlarities = softmax(
|
83 |
+
(text_features @ image_features.cpu().T)
|
84 |
+
.squeeze()
|
85 |
+
.reshape(len(text_prompts), 2, -1),
|
86 |
+
axis=1,
|
87 |
+
)[:, 0, :]
|
88 |
+
return cosine_simlarities
|
89 |
+
|
90 |
+
|
91 |
def predict_fn(input_img):
|
92 |
input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
|
93 |
image = preprocess(
|
|
|
95 |
).unsqueeze(0)
|
96 |
with torch.no_grad():
|
97 |
image_features = clip_model.encode_image(image)
|
98 |
+
base_body_cosine_simlarities = get_cosine_similarities(
|
99 |
+
image_features, all_text_features
|
100 |
+
)
|
101 |
+
hair_cosine_simlarities = get_cosine_similarities(
|
102 |
+
image_features, hair_text_features
|
103 |
+
)
|
104 |
# logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
|
105 |
+
logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
|
106 |
probabilities = lr_model.predict_proba(
|
107 |
+
base_body_cosine_simlarities.reshape(1, -1)
|
108 |
+
)
|
109 |
+
hair_probabilities = hair_rf_model.predict_proba(
|
110 |
+
hair_cosine_simlarities.reshape(1, -1)
|
111 |
)
|
112 |
logger.info(f"probabilities: {probabilities}")
|
113 |
result_probabilty = float(probabilities[0][1].round(3))
|
114 |
+
hair_result_probabilty = float(hair_probabilities[0][1].round(3))
|
115 |
# get decision string
|
116 |
+
if result_probabilty > 0.77:
|
117 |
+
if hair_result_probabilty < 0.5:
|
118 |
+
result_probabilty = hair_result_probabilty
|
119 |
+
decision = "AUTO REJECT"
|
120 |
decision = "AUTO ACCEPT"
|
121 |
elif result_probabilty < 0.4:
|
122 |
decision = "AUTO REJECT"
|
hairclassifier_rf.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ab2fb4a72a581f1943ee57e51031616a7a4d4c6e8be5fb7dfac7fc67cebd7c7
|
3 |
+
size 83719733
|
text_prompts_hair.json
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"hair": [
|
3 |
+
"does character have hair on its head",
|
4 |
+
"does character not have hair on its head"
|
5 |
+
],
|
6 |
+
"bald": [
|
7 |
+
"is character bald",
|
8 |
+
"is character not bald"
|
9 |
+
],
|
10 |
+
"hat": [
|
11 |
+
"is character wearing a hat",
|
12 |
+
"is character not wearing a hat"
|
13 |
+
],
|
14 |
+
"helmet": [
|
15 |
+
"is character wearing a helmet",
|
16 |
+
"is character not wearing a helmet"
|
17 |
+
],
|
18 |
+
"headband": [
|
19 |
+
"is character wearing a headband",
|
20 |
+
"is character not wearing a headband"
|
21 |
+
],
|
22 |
+
"tiara": [
|
23 |
+
"is character wearing a tiara",
|
24 |
+
"is character not wearing a tiara"
|
25 |
+
],
|
26 |
+
"turban": [
|
27 |
+
"is character wearing a turban",
|
28 |
+
"is character not wearing a turban"
|
29 |
+
],
|
30 |
+
"crown": [
|
31 |
+
"is character wearing a crown",
|
32 |
+
"is character not wearing a crown"
|
33 |
+
],
|
34 |
+
"bandana": [
|
35 |
+
"is character wearing a bandana",
|
36 |
+
"is character not wearing a bandana"
|
37 |
+
],
|
38 |
+
"hood": [
|
39 |
+
"is character wearing a hood",
|
40 |
+
"is character not wearing a hood"
|
41 |
+
],
|
42 |
+
"wig": [
|
43 |
+
"is character wearing a wig",
|
44 |
+
"is character not wearing a wig"
|
45 |
+
],
|
46 |
+
"headphones": [
|
47 |
+
"is character wearing headphones",
|
48 |
+
"is character not wearing headphones"
|
49 |
+
],
|
50 |
+
"earmuffs": [
|
51 |
+
"is character wearing earmuffs",
|
52 |
+
"is character not wearing earmuffs"
|
53 |
+
],
|
54 |
+
"veil": [
|
55 |
+
"is character wearing a veil",
|
56 |
+
"is character not wearing a veil"
|
57 |
+
],
|
58 |
+
"feathers": [
|
59 |
+
"are there feathers on the character's head",
|
60 |
+
"there are no feathers on the character's head"
|
61 |
+
],
|
62 |
+
"horns": [
|
63 |
+
"does character have horns on its head",
|
64 |
+
"does character not have horns on its head"
|
65 |
+
],
|
66 |
+
"antenna": [
|
67 |
+
"does character have antenna on its head",
|
68 |
+
"does character not have antenna on its head"
|
69 |
+
],
|
70 |
+
"head-decoration": [
|
71 |
+
"is there any decoration on the character's head",
|
72 |
+
"there is no decoration on the character's head"
|
73 |
+
]
|
74 |
+
}
|