-
Notifications
You must be signed in to change notification settings - Fork 1
/
rnn.delta.rnn.go
78 lines (70 loc) · 2.08 KB
/
rnn.delta.rnn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package gortex
import "fmt"
// DeltaRNN cell https://arxiv.org/pdf/1703.08864.pdf
type DeltaRNN struct {
Wr *Matrix
Ur *Matrix
Wx *Matrix
Wh *Matrix
Wo *Matrix
Br *Matrix
Bias *Matrix
A *Matrix
B *Matrix
C *Matrix
}
// MakeDeltaRNN create new cell
func MakeDeltaRNN(x_size, h_size, out_size int) *DeltaRNN {
net := new(DeltaRNN)
net.Wr = RandXavierMat(h_size, x_size)
net.Ur = RandXavierMat(h_size, h_size)
net.Wx = RandXavierMat(h_size, x_size)
net.Wh = RandXavierMat(h_size, h_size)
net.Wo = RandXavierMat(out_size, h_size)
net.Br = RandXavierMat(h_size, 1)
net.Bias = RandXavierMat(h_size, 1)
net.A = RandXavierMat(h_size, 1)
net.B = RandXavierMat(h_size, 1)
net.C = RandXavierMat(h_size, 1)
return net
}
func (rnn *DeltaRNN) GetParameters(namespace string) map[string]*Matrix {
return map[string]*Matrix{
namespace + "_Wr": rnn.Wr,
namespace + "_Ur": rnn.Ur,
namespace + "_Wx": rnn.Wx,
namespace + "_Wh": rnn.Wh,
namespace + "_Wo": rnn.Wo,
namespace + "_A": rnn.A,
namespace + "_B": rnn.B,
namespace + "_C": rnn.C,
namespace + "_Br": rnn.Br,
namespace + "_Bias": rnn.Bias}
}
func (rnn *DeltaRNN) SetParameters(namespace string, parameters map[string]*Matrix) error {
for k, v := range rnn.GetParameters(namespace) {
fmt.Printf("Look for %s parameters\n", k)
if m, ok := parameters[k]; ok {
fmt.Printf("Got %s parameters\n", k)
v.W = m.W
} else {
return fmt.Errorf("Model geometry is not compatible, parameter %s is unknown", k)
}
}
return nil
}
func (rnn *DeltaRNN) Step(g *Graph, x, h_prev *Matrix) (h, y *Matrix) {
// make DeltaRNN computation graph at one time-step
xx := g.Mul(rnn.Wx, x)
hh := g.Mul(rnn.Wh, h_prev)
//r := g.Sigmoid(g.Add(g.Mul(rnn.Wr, x), rnn.Br))
r := g.Sigmoid(g.Add(g.Add(g.Mul(rnn.Wr, x), g.Mul(rnn.Ur, h_prev)), rnn.Br))
// Hadamard product
z1 := g.EMul(rnn.A, xx)
z2 := g.EMul(rnn.B, hh)
z3 := g.EMul(rnn.C, g.EMul(xx, hh))
z := g.Add(g.Add(g.Add(z1, z2), z3), rnn.Bias)
h = g.Add(g.EMul(r, h_prev), g.EMul(g.Sub(r.OnesAs(), r), z))
y = g.Mul(rnn.Wo, g.Tanh(h))
return
}