Julien Ajdenbaum commited on
Commit
62844d7
·
1 Parent(s): 5389132

added predict

Browse files
models/bald_classifity.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83403b2aab2884f2ff3ed409a19017a6ece437c376897fc2c14267efdeefc4bd
3
+ size 1760072
predict.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import keras
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+ import time
7
+
8
+ bald_path = "/home/julien/Downloads/archive (1)/Dataset/Test/Bald"
9
+ notbald_path = "/home/julien/Downloads/archive (1)/Dataset/Test/NotBald"
10
+
11
+ def predict(im):
12
+ # im_names = os.listdirpath)
13
+ im_names= [im]
14
+ ims = np.zeros((len(im_names), 64, 64, 3))
15
+ for i, f in enumerate(im_names):
16
+ # img = cv2.imread(os.path.join(path, f))
17
+ # img = im
18
+ img = cv2.imread(f)
19
+ data = cv2.resize(img, (64, 64))
20
+ data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
21
+ ims[i] = np.reshape(data, (1, 64, 64, 3)) / 255.0
22
+ print(ims.shape)
23
+ res = model.predict(ims)
24
+ """
25
+ for i in range(len(im_names)):
26
+ print('[没秃]' if res[i][0] <0.3 else '[秃了]', 'Bald:',res[i][0])
27
+ """
28
+ return res
29
+
30
+ model = keras.models.load_model('/home/julien/Documents/bald_classification/models/bald_classifity.h5')
31
+
32
+ """
33
+ bald_res = predict(bald_path)
34
+ notbald_res = predict(notbald_path)
35
+
36
+ print(notbald_res.shape)
37
+ print(notbald_res[:,0])
38
+ print(f"Bald mean : {np.mean(bald_res[:, 0])}")
39
+ print(f"Not Bald mean : {np.mean(notbald_res[:, 0])}")
40
+ print(bald_res[:, 0]>0.3)
41
+ print(f"Accuracy : {(np.sum(bald_res[:, 0]>0.3) + np.sum(notbald_res[:, 0]<0.3))/(len(bald_res)+len(notbald_res))}")
42
+ """
test_data/1.png ADDED
test_data/2.png ADDED
test_data/3.png ADDED
test_data/4.png ADDED
test_data/5.png ADDED