Jensen-holm commited on
Commit
0987346
·
1 Parent(s): 0a04cd7

stuck on working with computeOutput function, getting dim error every

Browse files
Files changed (7) hide show
  1. go.mod +3 -3
  2. go.sum +7 -0
  3. nn/backprop.go +91 -0
  4. nn/main.go +32 -36
  5. nn/split.go +22 -4
  6. nn/train.go +1 -7
  7. server.go +0 -1
go.mod CHANGED
@@ -18,7 +18,7 @@ require (
18
  github.com/valyala/bytebufferpool v1.0.0 // indirect
19
  github.com/valyala/fasthttp v1.49.0 // indirect
20
  github.com/valyala/tcplisten v1.0.0 // indirect
21
- golang.org/x/net v0.8.0 // indirect
22
- golang.org/x/sys v0.12.0 // indirect
23
- gonum.org/v1/gonum v0.9.1 // indirect
24
  )
 
18
  github.com/valyala/bytebufferpool v1.0.0 // indirect
19
  github.com/valyala/fasthttp v1.49.0 // indirect
20
  github.com/valyala/tcplisten v1.0.0 // indirect
21
+ golang.org/x/net v0.17.0 // indirect
22
+ golang.org/x/sys v0.13.0 // indirect
23
+ gonum.org/v1/gonum v0.14.0 // indirect
24
  )
go.sum CHANGED
@@ -55,6 +55,7 @@ golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL
55
  golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
56
  golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
57
  golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
 
58
  golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
59
  golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
60
  golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -71,6 +72,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
71
  golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
72
  golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
73
  golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
 
 
74
  golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
75
  golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
76
  golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -82,6 +85,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
82
  golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
83
  golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
84
  golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 
 
85
  golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
86
  golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
87
  golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -95,6 +100,8 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ
95
  gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
96
  gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
97
  gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
 
 
98
  gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
99
  gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
100
  gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
 
55
  golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
56
  golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
57
  golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
58
+ golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
59
  golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
60
  golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
61
  golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 
72
  golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
73
  golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
74
  golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
75
+ golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
76
+ golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
77
  golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
78
  golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
79
  golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 
85
  golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
86
  golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
87
  golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
88
+ golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
89
+ golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
90
  golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
91
  golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
92
  golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 
100
  gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
101
  gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
102
  gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
103
+ gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
104
+ gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
105
  gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
106
  gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
107
  gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
nn/backprop.go CHANGED
@@ -1,8 +1,99 @@
1
  package nn
2
 
 
 
 
 
 
 
3
  func (nn *NN) Backprop() {
 
 
 
 
4
 
5
  for i := 0; i < nn.Epochs; i++ {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  }
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
 
1
  package nn
2
 
3
+ import (
4
+ "fmt"
5
+
6
+ "gonum.org/v1/gonum/mat"
7
+ )
8
+
9
  func (nn *NN) Backprop() {
10
+ var (
11
+ activation = *nn.ActivationFunc
12
+ // lossHist []float64
13
+ )
14
 
15
  for i := 0; i < nn.Epochs; i++ {
16
+ // compute output with current w + b
17
+ // then compute loss & backprop
18
+ hiddenOutput, err := computeOutput(
19
+ nn.XTrain,
20
+ nn.Wh,
21
+ nn.Bh,
22
+ activation,
23
+ )
24
+ if err != nil {
25
+ fmt.Printf("error computing hidden output: %v", err)
26
+ }
27
+
28
+ yHat, err := computeOutput(
29
+ hiddenOutput,
30
+ nn.Wo,
31
+ nn.Bo,
32
+ activation,
33
+ )
34
+ if err != nil {
35
+ fmt.Printf("error computing yHat: %v", err)
36
+ }
37
+
38
+ mse := meanSquaredError(nn.YTrain, yHat)
39
+ fmt.Println(mse)
40
+
41
+ }
42
+
43
+ }
44
+
45
+ func computeOutput(arr, w, b *mat.Dense, activationFunc func(float64) float64) (*mat.Dense, error) {
46
+ // Check if any of the input matrices is nil
47
+ if arr == nil || w == nil || b == nil {
48
+ return nil, fmt.Errorf("Input matrices cannot be nil")
49
+ }
50
+
51
+ // Check input dimensions
52
+ arrRows, arrCols := arr.Dims()
53
+ wRows, wCols := w.Dims()
54
+ bRows, bCols := b.Dims()
55
+
56
+ if arrCols != wRows || bCols != wCols {
57
+ return nil, fmt.Errorf("Matrix dimension mismatch: arr[%d, %d], w[%d, %d], b[%d, %d]", arrRows, arrCols, wRows, wCols, bRows, bCols)
58
  }
59
 
60
+ // Compute the dot product between the input matrix 'arr' and the weight matrix 'w'
61
+ var product mat.Dense
62
+ product.Mul(arr, w)
63
+
64
+ // Check dimensions of product and bias
65
+ productRows, productCols := product.Dims()
66
+ if productCols != bCols {
67
+ return nil, fmt.Errorf("Matrix dimension mismatch: product[%d, %d], b[%d, %d]", productRows, productCols, bRows, bCols)
68
+ }
69
+
70
+ // Add the bias matrix 'b' to the product
71
+ var result mat.Dense
72
+ result.Add(&product, b)
73
+
74
+ // Apply the activation function to the result
75
+ applyActivation(&result, activationFunc)
76
+
77
+ return &result, nil
78
+ }
79
+
80
+ func applyActivation(m *mat.Dense, f func(float64) float64) {
81
+ r, c := m.Dims()
82
+ data := m.RawMatrix().Data
83
+ for i := 0; i < r*c; i++ {
84
+ data[i] = f(data[i])
85
+ }
86
+ }
87
+
88
+ func meanSquaredError(y, yHat *mat.Dense) float64 {
89
+ var sum float64
90
+ r, c := y.Dims()
91
+
92
+ for row := 0; row < r; row++ {
93
+ for col := 0; col < c; col++ {
94
+ diff := y.At(row, col) - yHat.At(row, col)
95
+ sum += (diff * diff)
96
+ }
97
+ }
98
+ return sum / float64((r * c))
99
  }
nn/main.go CHANGED
@@ -7,6 +7,7 @@ import (
7
 
8
  "github.com/go-gota/gota/dataframe"
9
  "github.com/gofiber/fiber/v2"
 
10
  )
11
 
12
  type NN struct {
@@ -23,14 +24,14 @@ type NN struct {
23
  // attributes set after args above are parsed
24
  ActivationFunc *func(float64) float64
25
  Df *dataframe.DataFrame
26
- XTrain *dataframe.DataFrame
27
- YTrain *dataframe.DataFrame
28
- XTest *dataframe.DataFrame
29
- YTest *dataframe.DataFrame
30
- Wh *[][]float64
31
- Bh *[]float64
32
- Wo *[][]float64
33
- Bo *[]float64
34
  }
35
 
36
  func NewNN(c *fiber.Ctx) (*NN, error) {
@@ -53,36 +54,31 @@ func (nn *NN) InitWnB() {
53
  hiddenSize := nn.HiddenSize
54
  outputSize := 1 // only predicting one thing
55
 
56
- // input hidden layer weights
57
- wh := make([][]float64, inputSize)
58
- for i := range wh {
59
- wh[i] = make([]float64, hiddenSize)
60
- for j := range wh[i] {
61
- wh[i][j] = rand.Float64() - 0.5
62
- }
63
- }
64
 
65
- bh := make([]float64, hiddenSize)
66
- for i := range bh {
67
- bh[i] = rand.Float64() - 0.5
68
- }
 
69
 
70
- // initialize weights and biases for hidden -> output layer
71
- wo := make([][]float64, hiddenSize)
72
- for i := range wo {
73
- wo[i] = make([]float64, outputSize)
74
- for j := range wo[i] {
75
- wo[i][j] = rand.Float64() - 0.5
76
- }
77
- }
78
 
79
- bo := make([]float64, outputSize)
80
- for i := range bo {
81
- bo[i] = rand.Float64() - 0.5
82
- }
83
 
84
- nn.Wh = &wh
85
- nn.Bh = &bh
86
- nn.Wo = &wo
87
- nn.Bo = &bo
88
  }
 
7
 
8
  "github.com/go-gota/gota/dataframe"
9
  "github.com/gofiber/fiber/v2"
10
+ "gonum.org/v1/gonum/mat"
11
  )
12
 
13
  type NN struct {
 
24
  // attributes set after args above are parsed
25
  ActivationFunc *func(float64) float64
26
  Df *dataframe.DataFrame
27
+ XTrain *mat.Dense
28
+ YTrain *mat.Dense
29
+ XTest *mat.Dense
30
+ YTest *mat.Dense
31
+ Wh *mat.Dense
32
+ Bh *mat.Dense
33
+ Wo *mat.Dense
34
+ Bo *mat.Dense
35
  }
36
 
37
  func NewNN(c *fiber.Ctx) (*NN, error) {
 
54
  hiddenSize := nn.HiddenSize
55
  outputSize := 1 // only predicting one thing
56
 
57
+ // Initialize input hidden layer weights as a Gonum matrix
58
+ wh := mat.NewDense(inputSize, hiddenSize, nil)
59
+ wh.Apply(func(i, j int, v float64) float64 {
60
+ return rand.Float64() - 0.5
61
+ }, wh)
 
 
 
62
 
63
+ // Initialize hidden layer bias as a Gonum matrix
64
+ bh := mat.NewDense(1, hiddenSize, nil)
65
+ bh.Apply(func(i, j int, v float64) float64 {
66
+ return rand.Float64() - 0.5
67
+ }, bh)
68
 
69
+ // Initialize weights and biases for hidden -> output layer as Gonum matrices
70
+ wo := mat.NewDense(hiddenSize, outputSize, nil)
71
+ wo.Apply(func(i, j int, v float64) float64 {
72
+ return rand.Float64() - 0.5
73
+ }, wo)
 
 
 
74
 
75
+ bo := mat.NewDense(1, outputSize, nil)
76
+ bo.Apply(func(i, j int, v float64) float64 {
77
+ return rand.Float64() - 0.5
78
+ }, bo)
79
 
80
+ nn.Wh = wh
81
+ nn.Bh = bh
82
+ nn.Wo = wo
83
+ nn.Bo = bo
84
  }
nn/split.go CHANGED
@@ -3,6 +3,9 @@ package nn
3
  import (
4
  "math"
5
  "math/rand"
 
 
 
6
  )
7
 
8
  func (nn *NN) TrainTestSplit() {
@@ -34,9 +37,24 @@ func (nn *NN) TrainTestSplit() {
34
  XTest := test.Select(nn.Features)
35
  YTest := test.Select(nn.Target)
36
 
37
- nn.XTrain = &XTrain
38
- nn.YTrain = &YTrain
39
- nn.XTest = &XTest
40
- nn.YTest = &YTest
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
42
  }
 
3
  import (
4
  "math"
5
  "math/rand"
6
+
7
+ "github.com/go-gota/gota/dataframe"
8
+ "gonum.org/v1/gonum/mat"
9
  )
10
 
11
  func (nn *NN) TrainTestSplit() {
 
37
  XTest := test.Select(nn.Features)
38
  YTest := test.Select(nn.Target)
39
 
40
+ // to make linear algebra easier & faster,
41
+ // we convert these dataframes that we are
42
+ // performing potentially expensive computations
43
+ // on into gonum matrices since we no longer need the
44
+ // column names.
45
+ nn.XTrain = df2mat(&XTrain)
46
+ nn.YTrain = df2mat(&YTrain)
47
+ nn.XTest = df2mat(&XTest)
48
+ nn.YTest = df2mat(&YTest)
49
+ }
50
 
51
+ // df2mat -> converts gota dataframe into gonum matrix
52
+ func df2mat(df *dataframe.DataFrame) *mat.Dense {
53
+ m := mat.NewDense(df.Nrow(), df.Ncol(), nil)
54
+ for i := 0; i < df.Nrow(); i++ {
55
+ for j := 0; j < df.Ncol(); j++ {
56
+ m.Set(i, j, df.Elem(i, j).Float())
57
+ }
58
+ }
59
+ return m
60
  }
nn/train.go CHANGED
@@ -3,11 +3,5 @@ package nn
3
  func (nn *NN) Train() {
4
  nn.InitWnB()
5
  nn.TrainTestSplit()
6
-
7
- // iterate n times where n = nn.Epochs
8
- // use backprop algorithm on each iteration
9
- // to fit the model to the data
10
- for i := 0; i < nn.Epochs; i++ {
11
- }
12
-
13
  }
 
3
  func (nn *NN) Train() {
4
  nn.InitWnB()
5
  nn.TrainTestSplit()
6
+ nn.Backprop()
 
 
 
 
 
 
7
  }
server.go CHANGED
@@ -19,7 +19,6 @@ func main() {
19
  }
20
 
21
  nn.Train()
22
-
23
  return c.SendString("No error")
24
  })
25
 
 
19
  }
20
 
21
  nn.Train()
 
22
  return c.SendString("No error")
23
  })
24