hwajjala commited on
Commit
2de9666
β€’
1 Parent(s): 9c9b891

Add hairclassifier to the model

Browse files
Files changed (3) hide show
  1. app.py +54 -9
  2. hairclassifier_rf.pkl +3 -0
  3. text_prompts_hair.json +74 -0
app.py CHANGED
@@ -13,7 +13,9 @@ logger = logging.getLogger("basebody")
13
  CLIP_MODEL_NAME = "ViT-B/16"
14
 
15
  TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
 
16
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
 
17
 
18
  HF_TOKEN = os.getenv('HF_TOKEN')
19
  hf_writer = gr.HuggingFaceDatasetSaver(
@@ -28,6 +30,12 @@ with open(
28
  os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
29
  ) as f:
30
  text_prompts = json.load(f)
 
 
 
 
 
 
31
  with open(
32
  os.path.join(
33
  os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
@@ -36,6 +44,15 @@ with open(
36
  ) as f:
37
  lr_model = pickle.load(f)
38
 
 
 
 
 
 
 
 
 
 
39
  logger.info("Logistic regression model loaded, coefficients: ")
40
 
41
 
@@ -50,6 +67,27 @@ with torch.no_grad():
50
  all_text_features = all_text_features.cpu()
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def predict_fn(input_img):
54
  input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
55
  image = preprocess(
@@ -57,21 +95,28 @@ def predict_fn(input_img):
57
  ).unsqueeze(0)
58
  with torch.no_grad():
59
  image_features = clip_model.encode_image(image)
60
- cosine_simlarities = softmax(
61
- (all_text_features @ image_features.cpu().T)
62
- .squeeze()
63
- .reshape(len(text_prompts), 2, -1),
64
- axis=1,
65
- )[:, 0, :]
66
  # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
67
- logger.info(f"cosine_simlarities: {cosine_simlarities}")
68
  probabilities = lr_model.predict_proba(
69
- cosine_simlarities.reshape(1, -1)
 
 
 
70
  )
71
  logger.info(f"probabilities: {probabilities}")
72
  result_probabilty = float(probabilities[0][1].round(3))
 
73
  # get decision string
74
- if result_probabilty > 0.95:
 
 
 
75
  decision = "AUTO ACCEPT"
76
  elif result_probabilty < 0.4:
77
  decision = "AUTO REJECT"
 
13
  CLIP_MODEL_NAME = "ViT-B/16"
14
 
15
  TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
16
+ HAIR_TEXT_PROMPTS_FILE_NAME = "text_prompts_hair.json"
17
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
18
+ HAIR_RF_CLASSIFIER_MODEL_FILE_NAME = "hairclassifier_rf.pkl"
19
 
20
  HF_TOKEN = os.getenv('HF_TOKEN')
21
  hf_writer = gr.HuggingFaceDatasetSaver(
 
30
  os.path.join(os.path.dirname(__file__), TEXT_PROMPTS_FILE_NAME), "r"
31
  ) as f:
32
  text_prompts = json.load(f)
33
+
34
+ with open(
35
+ os.path.join(os.path.dirname(__file__), HAIR_TEXT_PROMPTS_FILE_NAME), "r"
36
+ ) as f:
37
+ hair_text_prompts = json.load(f)
38
+
39
  with open(
40
  os.path.join(
41
  os.path.dirname(__file__), LOGISTIC_REGRESSION_MODEL_FILE_NAME
 
44
  ) as f:
45
  lr_model = pickle.load(f)
46
 
47
+ with open(
48
+ os.path.join(
49
+ os.path.dirname(__file__), HAIR_RF_CLASSIFIER_MODEL_FILE_NAME
50
+ ),
51
+ "rb",
52
+ ) as f:
53
+ hair_rf_model = pickle.load(f)
54
+
55
+
56
  logger.info("Logistic regression model loaded, coefficients: ")
57
 
58
 
 
67
  all_text_features = all_text_features.cpu()
68
 
69
 
70
+ hair_text_features = []
71
+ with torch.no_grad():
72
+ for k, prompts in hair_text_prompts.items():
73
+ assert len(prompts) == 2
74
+ inputs = clip.tokenize(prompts)
75
+ outputs = clip_model.encode_text(inputs)
76
+ hair_text_features.append(outputs)
77
+ hair_text_features = torch.cat(hair_text_features, dim=0)
78
+ hair_text_features = hair_text_features.cpu()
79
+
80
+
81
+ def get_cosine_similarities(image_features, text_features):
82
+ cosine_simlarities = softmax(
83
+ (text_features @ image_features.cpu().T)
84
+ .squeeze()
85
+ .reshape(len(text_prompts), 2, -1),
86
+ axis=1,
87
+ )[:, 0, :]
88
+ return cosine_simlarities
89
+
90
+
91
  def predict_fn(input_img):
92
  input_img = Image.fromarray(input_img.astype("uint8"), "RGB")
93
  image = preprocess(
 
95
  ).unsqueeze(0)
96
  with torch.no_grad():
97
  image_features = clip_model.encode_image(image)
98
+ base_body_cosine_simlarities = get_cosine_similarities(
99
+ image_features, all_text_features
100
+ )
101
+ hair_cosine_simlarities = get_cosine_similarities(
102
+ image_features, hair_text_features
103
+ )
104
  # logger.info(f"cosine_simlarities shape: {cosine_simlarities.shape}")
105
+ logger.info(f"cosine_simlarities: {base_body_cosine_simlarities}")
106
  probabilities = lr_model.predict_proba(
107
+ base_body_cosine_simlarities.reshape(1, -1)
108
+ )
109
+ hair_probabilities = hair_rf_model.predict_proba(
110
+ hair_cosine_simlarities.reshape(1, -1)
111
  )
112
  logger.info(f"probabilities: {probabilities}")
113
  result_probabilty = float(probabilities[0][1].round(3))
114
+ hair_result_probabilty = float(hair_probabilities[0][1].round(3))
115
  # get decision string
116
+ if result_probabilty > 0.77:
117
+ if hair_result_probabilty < 0.5:
118
+ result_probabilty = hair_result_probabilty
119
+ decision = "AUTO REJECT"
120
  decision = "AUTO ACCEPT"
121
  elif result_probabilty < 0.4:
122
  decision = "AUTO REJECT"
hairclassifier_rf.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ab2fb4a72a581f1943ee57e51031616a7a4d4c6e8be5fb7dfac7fc67cebd7c7
3
+ size 83719733
text_prompts_hair.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "hair": [
3
+ "does character have hair on its head",
4
+ "does character not have hair on its head"
5
+ ],
6
+ "bald": [
7
+ "is character bald",
8
+ "is character not bald"
9
+ ],
10
+ "hat": [
11
+ "is character wearing a hat",
12
+ "is character not wearing a hat"
13
+ ],
14
+ "helmet": [
15
+ "is character wearing a helmet",
16
+ "is character not wearing a helmet"
17
+ ],
18
+ "headband": [
19
+ "is character wearing a headband",
20
+ "is character not wearing a headband"
21
+ ],
22
+ "tiara": [
23
+ "is character wearing a tiara",
24
+ "is character not wearing a tiara"
25
+ ],
26
+ "turban": [
27
+ "is character wearing a turban",
28
+ "is character not wearing a turban"
29
+ ],
30
+ "crown": [
31
+ "is character wearing a crown",
32
+ "is character not wearing a crown"
33
+ ],
34
+ "bandana": [
35
+ "is character wearing a bandana",
36
+ "is character not wearing a bandana"
37
+ ],
38
+ "hood": [
39
+ "is character wearing a hood",
40
+ "is character not wearing a hood"
41
+ ],
42
+ "wig": [
43
+ "is character wearing a wig",
44
+ "is character not wearing a wig"
45
+ ],
46
+ "headphones": [
47
+ "is character wearing headphones",
48
+ "is character not wearing headphones"
49
+ ],
50
+ "earmuffs": [
51
+ "is character wearing earmuffs",
52
+ "is character not wearing earmuffs"
53
+ ],
54
+ "veil": [
55
+ "is character wearing a veil",
56
+ "is character not wearing a veil"
57
+ ],
58
+ "feathers": [
59
+ "are there feathers on the character's head",
60
+ "there are no feathers on the character's head"
61
+ ],
62
+ "horns": [
63
+ "does character have horns on its head",
64
+ "does character not have horns on its head"
65
+ ],
66
+ "antenna": [
67
+ "does character have antenna on its head",
68
+ "does character not have antenna on its head"
69
+ ],
70
+ "head-decoration": [
71
+ "is there any decoration on the character's head",
72
+ "there is no decoration on the character's head"
73
+ ]
74
+ }