Spaces:
Sleeping
Sleeping
Jensen-holm
commited on
Commit
·
0987346
1
Parent(s):
0a04cd7
stuck on working with computeOutput function, getting dim error every
Browse files- go.mod +3 -3
- go.sum +7 -0
- nn/backprop.go +91 -0
- nn/main.go +32 -36
- nn/split.go +22 -4
- nn/train.go +1 -7
- server.go +0 -1
go.mod
CHANGED
@@ -18,7 +18,7 @@ require (
|
|
18 |
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
19 |
github.com/valyala/fasthttp v1.49.0 // indirect
|
20 |
github.com/valyala/tcplisten v1.0.0 // indirect
|
21 |
-
golang.org/x/net v0.
|
22 |
-
golang.org/x/sys v0.
|
23 |
-
gonum.org/v1/gonum v0.
|
24 |
)
|
|
|
18 |
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
19 |
github.com/valyala/fasthttp v1.49.0 // indirect
|
20 |
github.com/valyala/tcplisten v1.0.0 // indirect
|
21 |
+
golang.org/x/net v0.17.0 // indirect
|
22 |
+
golang.org/x/sys v0.13.0 // indirect
|
23 |
+
gonum.org/v1/gonum v0.14.0 // indirect
|
24 |
)
|
go.sum
CHANGED
@@ -55,6 +55,7 @@ golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL
|
|
55 |
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
56 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
|
57 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
|
|
58 |
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
59 |
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
60 |
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
@@ -71,6 +72,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
|
|
71 |
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
72 |
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
73 |
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
|
|
|
|
74 |
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
75 |
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
76 |
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
@@ -82,6 +85,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
|
82 |
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
83 |
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
84 |
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
|
|
|
85 |
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
86 |
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
87 |
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
@@ -95,6 +100,8 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ
|
|
95 |
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
|
96 |
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
|
97 |
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
|
|
|
|
|
98 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
|
99 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
100 |
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
|
|
55 |
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
56 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
|
57 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
58 |
+
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
|
59 |
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
60 |
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
61 |
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
|
|
72 |
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
73 |
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
74 |
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
75 |
+
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
76 |
+
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
77 |
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
78 |
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
79 |
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
|
|
85 |
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
86 |
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
87 |
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
88 |
+
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
89 |
+
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
90 |
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
91 |
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
92 |
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|
|
100 |
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
|
101 |
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
|
102 |
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
|
103 |
+
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
|
104 |
+
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
|
105 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
|
106 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
107 |
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
nn/backprop.go
CHANGED
@@ -1,8 +1,99 @@
|
|
1 |
package nn
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
func (nn *NN) Backprop() {
|
|
|
|
|
|
|
|
|
4 |
|
5 |
for i := 0; i < nn.Epochs; i++ {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
}
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
}
|
|
|
1 |
package nn
|
2 |
|
3 |
+
import (
|
4 |
+
"fmt"
|
5 |
+
|
6 |
+
"gonum.org/v1/gonum/mat"
|
7 |
+
)
|
8 |
+
|
9 |
func (nn *NN) Backprop() {
|
10 |
+
var (
|
11 |
+
activation = *nn.ActivationFunc
|
12 |
+
// lossHist []float64
|
13 |
+
)
|
14 |
|
15 |
for i := 0; i < nn.Epochs; i++ {
|
16 |
+
// compute output with current w + b
|
17 |
+
// then compute loss & backprop
|
18 |
+
hiddenOutput, err := computeOutput(
|
19 |
+
nn.XTrain,
|
20 |
+
nn.Wh,
|
21 |
+
nn.Bh,
|
22 |
+
activation,
|
23 |
+
)
|
24 |
+
if err != nil {
|
25 |
+
fmt.Printf("error computing hidden output: %v", err)
|
26 |
+
}
|
27 |
+
|
28 |
+
yHat, err := computeOutput(
|
29 |
+
hiddenOutput,
|
30 |
+
nn.Wo,
|
31 |
+
nn.Bo,
|
32 |
+
activation,
|
33 |
+
)
|
34 |
+
if err != nil {
|
35 |
+
fmt.Printf("error computing yHat: %v", err)
|
36 |
+
}
|
37 |
+
|
38 |
+
mse := meanSquaredError(nn.YTrain, yHat)
|
39 |
+
fmt.Println(mse)
|
40 |
+
|
41 |
+
}
|
42 |
+
|
43 |
+
}
|
44 |
+
|
45 |
+
func computeOutput(arr, w, b *mat.Dense, activationFunc func(float64) float64) (*mat.Dense, error) {
|
46 |
+
// Check if any of the input matrices is nil
|
47 |
+
if arr == nil || w == nil || b == nil {
|
48 |
+
return nil, fmt.Errorf("Input matrices cannot be nil")
|
49 |
+
}
|
50 |
+
|
51 |
+
// Check input dimensions
|
52 |
+
arrRows, arrCols := arr.Dims()
|
53 |
+
wRows, wCols := w.Dims()
|
54 |
+
bRows, bCols := b.Dims()
|
55 |
+
|
56 |
+
if arrCols != wRows || bCols != wCols {
|
57 |
+
return nil, fmt.Errorf("Matrix dimension mismatch: arr[%d, %d], w[%d, %d], b[%d, %d]", arrRows, arrCols, wRows, wCols, bRows, bCols)
|
58 |
}
|
59 |
|
60 |
+
// Compute the dot product between the input matrix 'arr' and the weight matrix 'w'
|
61 |
+
var product mat.Dense
|
62 |
+
product.Mul(arr, w)
|
63 |
+
|
64 |
+
// Check dimensions of product and bias
|
65 |
+
productRows, productCols := product.Dims()
|
66 |
+
if productCols != bCols {
|
67 |
+
return nil, fmt.Errorf("Matrix dimension mismatch: product[%d, %d], b[%d, %d]", productRows, productCols, bRows, bCols)
|
68 |
+
}
|
69 |
+
|
70 |
+
// Add the bias matrix 'b' to the product
|
71 |
+
var result mat.Dense
|
72 |
+
result.Add(&product, b)
|
73 |
+
|
74 |
+
// Apply the activation function to the result
|
75 |
+
applyActivation(&result, activationFunc)
|
76 |
+
|
77 |
+
return &result, nil
|
78 |
+
}
|
79 |
+
|
80 |
+
func applyActivation(m *mat.Dense, f func(float64) float64) {
|
81 |
+
r, c := m.Dims()
|
82 |
+
data := m.RawMatrix().Data
|
83 |
+
for i := 0; i < r*c; i++ {
|
84 |
+
data[i] = f(data[i])
|
85 |
+
}
|
86 |
+
}
|
87 |
+
|
88 |
+
func meanSquaredError(y, yHat *mat.Dense) float64 {
|
89 |
+
var sum float64
|
90 |
+
r, c := y.Dims()
|
91 |
+
|
92 |
+
for row := 0; row < r; row++ {
|
93 |
+
for col := 0; col < c; col++ {
|
94 |
+
diff := y.At(row, col) - yHat.At(row, col)
|
95 |
+
sum += (diff * diff)
|
96 |
+
}
|
97 |
+
}
|
98 |
+
return sum / float64((r * c))
|
99 |
}
|
nn/main.go
CHANGED
@@ -7,6 +7,7 @@ import (
|
|
7 |
|
8 |
"github.com/go-gota/gota/dataframe"
|
9 |
"github.com/gofiber/fiber/v2"
|
|
|
10 |
)
|
11 |
|
12 |
type NN struct {
|
@@ -23,14 +24,14 @@ type NN struct {
|
|
23 |
// attributes set after args above are parsed
|
24 |
ActivationFunc *func(float64) float64
|
25 |
Df *dataframe.DataFrame
|
26 |
-
XTrain *
|
27 |
-
YTrain *
|
28 |
-
XTest *
|
29 |
-
YTest *
|
30 |
-
Wh *
|
31 |
-
Bh *
|
32 |
-
Wo *
|
33 |
-
Bo *
|
34 |
}
|
35 |
|
36 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
@@ -53,36 +54,31 @@ func (nn *NN) InitWnB() {
|
|
53 |
hiddenSize := nn.HiddenSize
|
54 |
outputSize := 1 // only predicting one thing
|
55 |
|
56 |
-
// input hidden layer weights
|
57 |
-
wh :=
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
wh[i][j] = rand.Float64() - 0.5
|
62 |
-
}
|
63 |
-
}
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
69 |
|
70 |
-
//
|
71 |
-
wo :=
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
wo[i][j] = rand.Float64() - 0.5
|
76 |
-
}
|
77 |
-
}
|
78 |
|
79 |
-
bo :=
|
80 |
-
|
81 |
-
|
82 |
-
}
|
83 |
|
84 |
-
nn.Wh =
|
85 |
-
nn.Bh =
|
86 |
-
nn.Wo =
|
87 |
-
nn.Bo =
|
88 |
}
|
|
|
7 |
|
8 |
"github.com/go-gota/gota/dataframe"
|
9 |
"github.com/gofiber/fiber/v2"
|
10 |
+
"gonum.org/v1/gonum/mat"
|
11 |
)
|
12 |
|
13 |
type NN struct {
|
|
|
24 |
// attributes set after args above are parsed
|
25 |
ActivationFunc *func(float64) float64
|
26 |
Df *dataframe.DataFrame
|
27 |
+
XTrain *mat.Dense
|
28 |
+
YTrain *mat.Dense
|
29 |
+
XTest *mat.Dense
|
30 |
+
YTest *mat.Dense
|
31 |
+
Wh *mat.Dense
|
32 |
+
Bh *mat.Dense
|
33 |
+
Wo *mat.Dense
|
34 |
+
Bo *mat.Dense
|
35 |
}
|
36 |
|
37 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
|
|
54 |
hiddenSize := nn.HiddenSize
|
55 |
outputSize := 1 // only predicting one thing
|
56 |
|
57 |
+
// Initialize input hidden layer weights as a Gonum matrix
|
58 |
+
wh := mat.NewDense(inputSize, hiddenSize, nil)
|
59 |
+
wh.Apply(func(i, j int, v float64) float64 {
|
60 |
+
return rand.Float64() - 0.5
|
61 |
+
}, wh)
|
|
|
|
|
|
|
62 |
|
63 |
+
// Initialize hidden layer bias as a Gonum matrix
|
64 |
+
bh := mat.NewDense(1, hiddenSize, nil)
|
65 |
+
bh.Apply(func(i, j int, v float64) float64 {
|
66 |
+
return rand.Float64() - 0.5
|
67 |
+
}, bh)
|
68 |
|
69 |
+
// Initialize weights and biases for hidden -> output layer as Gonum matrices
|
70 |
+
wo := mat.NewDense(hiddenSize, outputSize, nil)
|
71 |
+
wo.Apply(func(i, j int, v float64) float64 {
|
72 |
+
return rand.Float64() - 0.5
|
73 |
+
}, wo)
|
|
|
|
|
|
|
74 |
|
75 |
+
bo := mat.NewDense(1, outputSize, nil)
|
76 |
+
bo.Apply(func(i, j int, v float64) float64 {
|
77 |
+
return rand.Float64() - 0.5
|
78 |
+
}, bo)
|
79 |
|
80 |
+
nn.Wh = wh
|
81 |
+
nn.Bh = bh
|
82 |
+
nn.Wo = wo
|
83 |
+
nn.Bo = bo
|
84 |
}
|
nn/split.go
CHANGED
@@ -3,6 +3,9 @@ package nn
|
|
3 |
import (
|
4 |
"math"
|
5 |
"math/rand"
|
|
|
|
|
|
|
6 |
)
|
7 |
|
8 |
func (nn *NN) TrainTestSplit() {
|
@@ -34,9 +37,24 @@ func (nn *NN) TrainTestSplit() {
|
|
34 |
XTest := test.Select(nn.Features)
|
35 |
YTest := test.Select(nn.Target)
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
}
|
|
|
3 |
import (
|
4 |
"math"
|
5 |
"math/rand"
|
6 |
+
|
7 |
+
"github.com/go-gota/gota/dataframe"
|
8 |
+
"gonum.org/v1/gonum/mat"
|
9 |
)
|
10 |
|
11 |
func (nn *NN) TrainTestSplit() {
|
|
|
37 |
XTest := test.Select(nn.Features)
|
38 |
YTest := test.Select(nn.Target)
|
39 |
|
40 |
+
// to make linear algebra easier & faster,
|
41 |
+
// we convert these dataframes that we are
|
42 |
+
// performing potentially expensive computations
|
43 |
+
// on into gonum matrices since we no longer need the
|
44 |
+
// column names.
|
45 |
+
nn.XTrain = df2mat(&XTrain)
|
46 |
+
nn.YTrain = df2mat(&YTrain)
|
47 |
+
nn.XTest = df2mat(&XTest)
|
48 |
+
nn.YTest = df2mat(&YTest)
|
49 |
+
}
|
50 |
|
51 |
+
// df2mat -> converts gota dataframe into gonum matrix
|
52 |
+
func df2mat(df *dataframe.DataFrame) *mat.Dense {
|
53 |
+
m := mat.NewDense(df.Nrow(), df.Ncol(), nil)
|
54 |
+
for i := 0; i < df.Nrow(); i++ {
|
55 |
+
for j := 0; j < df.Ncol(); j++ {
|
56 |
+
m.Set(i, j, df.Elem(i, j).Float())
|
57 |
+
}
|
58 |
+
}
|
59 |
+
return m
|
60 |
}
|
nn/train.go
CHANGED
@@ -3,11 +3,5 @@ package nn
|
|
3 |
func (nn *NN) Train() {
|
4 |
nn.InitWnB()
|
5 |
nn.TrainTestSplit()
|
6 |
-
|
7 |
-
// iterate n times where n = nn.Epochs
|
8 |
-
// use backprop algorithm on each iteration
|
9 |
-
// to fit the model to the data
|
10 |
-
for i := 0; i < nn.Epochs; i++ {
|
11 |
-
}
|
12 |
-
|
13 |
}
|
|
|
3 |
func (nn *NN) Train() {
|
4 |
nn.InitWnB()
|
5 |
nn.TrainTestSplit()
|
6 |
+
nn.Backprop()
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
}
|
server.go
CHANGED
@@ -19,7 +19,6 @@ func main() {
|
|
19 |
}
|
20 |
|
21 |
nn.Train()
|
22 |
-
|
23 |
return c.SendString("No error")
|
24 |
})
|
25 |
|
|
|
19 |
}
|
20 |
|
21 |
nn.Train()
|
|
|
22 |
return c.SendString("No error")
|
23 |
})
|
24 |
|