Spaces:
Running
Running
Julien Ajdenbaum
commited on
Commit
·
62844d7
1
Parent(s):
5389132
added predict
Browse files- models/bald_classifity.h5 +3 -0
- predict.py +42 -0
- test_data/1.png +0 -0
- test_data/2.png +0 -0
- test_data/3.png +0 -0
- test_data/4.png +0 -0
- test_data/5.png +0 -0
models/bald_classifity.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83403b2aab2884f2ff3ed409a19017a6ece437c376897fc2c14267efdeefc4bd
|
3 |
+
size 1760072
|
predict.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import keras
|
4 |
+
import os
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import time
|
7 |
+
|
8 |
+
bald_path = "/home/julien/Downloads/archive (1)/Dataset/Test/Bald"
|
9 |
+
notbald_path = "/home/julien/Downloads/archive (1)/Dataset/Test/NotBald"
|
10 |
+
|
11 |
+
def predict(im):
|
12 |
+
# im_names = os.listdirpath)
|
13 |
+
im_names= [im]
|
14 |
+
ims = np.zeros((len(im_names), 64, 64, 3))
|
15 |
+
for i, f in enumerate(im_names):
|
16 |
+
# img = cv2.imread(os.path.join(path, f))
|
17 |
+
# img = im
|
18 |
+
img = cv2.imread(f)
|
19 |
+
data = cv2.resize(img, (64, 64))
|
20 |
+
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
|
21 |
+
ims[i] = np.reshape(data, (1, 64, 64, 3)) / 255.0
|
22 |
+
print(ims.shape)
|
23 |
+
res = model.predict(ims)
|
24 |
+
"""
|
25 |
+
for i in range(len(im_names)):
|
26 |
+
print('[没秃]' if res[i][0] <0.3 else '[秃了]', 'Bald:',res[i][0])
|
27 |
+
"""
|
28 |
+
return res
|
29 |
+
|
30 |
+
model = keras.models.load_model('/home/julien/Documents/bald_classification/models/bald_classifity.h5')
|
31 |
+
|
32 |
+
"""
|
33 |
+
bald_res = predict(bald_path)
|
34 |
+
notbald_res = predict(notbald_path)
|
35 |
+
|
36 |
+
print(notbald_res.shape)
|
37 |
+
print(notbald_res[:,0])
|
38 |
+
print(f"Bald mean : {np.mean(bald_res[:, 0])}")
|
39 |
+
print(f"Not Bald mean : {np.mean(notbald_res[:, 0])}")
|
40 |
+
print(bald_res[:, 0]>0.3)
|
41 |
+
print(f"Accuracy : {(np.sum(bald_res[:, 0]>0.3) + np.sum(notbald_res[:, 0]<0.3))/(len(bald_res)+len(notbald_res))}")
|
42 |
+
"""
|
test_data/1.png
ADDED
![]() |
test_data/2.png
ADDED
![]() |
test_data/3.png
ADDED
![]() |
test_data/4.png
ADDED
![]() |
test_data/5.png
ADDED
![]() |