Jensen-holm commited on
Commit
9c1370a
·
1 Parent(s): d6cc61a

changing the example to what is actually in the example

Browse files
Files changed (1) hide show
  1. README.md +12 -11
README.md CHANGED
@@ -26,7 +26,7 @@ A small, simple neural network framework built using only [numpy](https://numpy.
26
  from sklearn import datasets
27
  from sklearn.preprocessing import OneHotEncoder
28
  from sklearn.model_selection import train_test_split
29
- from sklearn.metrics import accuracy_score, precision_score, recall_score
30
  import numpy as np
31
  from numpyneuron import (
32
  NN,
@@ -39,7 +39,7 @@ from numpyneuron import (
39
  RANDOM_SEED = 2
40
 
41
 
42
- def _preprocess_digits(
43
  seed: int,
44
  ) -> tuple[np.ndarray, ...]:
45
  digits = datasets.load_digits(as_frame=False)
@@ -55,9 +55,10 @@ def _preprocess_digits(
55
  return X_train, X_test, y_train, y_test
56
 
57
 
58
- def train_nn_classifier() -> None:
59
- X_train, X_test, y_train, y_test = _preprocess_digits(seed=RANDOM_SEED)
60
-
 
61
  nn_classifier = NN(
62
  epochs=2_000,
63
  hidden_size=16,
@@ -75,19 +76,19 @@ def train_nn_classifier() -> None:
75
  X_train=X_train,
76
  y_train=y_train,
77
  )
 
 
78
 
79
- pred = nn_classifier.predict(X_test=X_test)
 
 
80
 
 
81
  pred = np.argmax(pred, axis=1)
82
  y_test = np.argmax(y_test, axis=1)
83
 
84
  accuracy = accuracy_score(y_true=y_test, y_pred=pred)
85
-
86
  print(f"accuracy on validation set: {accuracy:.4f}")
87
-
88
-
89
- if __name__ == "__main__":
90
- train_nn_classifier()
91
  ```
92
 
93
  ## Running Example
 
26
  from sklearn import datasets
27
  from sklearn.preprocessing import OneHotEncoder
28
  from sklearn.model_selection import train_test_split
29
+ from sklearn.metrics import accuracy_score
30
  import numpy as np
31
  from numpyneuron import (
32
  NN,
 
39
  RANDOM_SEED = 2
40
 
41
 
42
+ def preprocess_digits(
43
  seed: int,
44
  ) -> tuple[np.ndarray, ...]:
45
  digits = datasets.load_digits(as_frame=False)
 
55
  return X_train, X_test, y_train, y_test
56
 
57
 
58
+ def train_nn_classifier(
59
+ X_train: np.ndarray,
60
+ y_train: np.ndarray,
61
+ ) -> NN:
62
  nn_classifier = NN(
63
  epochs=2_000,
64
  hidden_size=16,
 
76
  X_train=X_train,
77
  y_train=y_train,
78
  )
79
+ return nn_classifier
80
+
81
 
82
+ if __name__ == "__main__":
83
+ X_train, X_test, y_train, y_test = preprocess_digits(seed=RANDOM_SEED)
84
+ classifier = train_nn_classifier(X_train, y_train)
85
 
86
+ pred = classifier.predict(X_test)
87
  pred = np.argmax(pred, axis=1)
88
  y_test = np.argmax(y_test, axis=1)
89
 
90
  accuracy = accuracy_score(y_true=y_test, y_pred=pred)
 
91
  print(f"accuracy on validation set: {accuracy:.4f}")
 
 
 
 
92
  ```
93
 
94
  ## Running Example