-
Notifications
You must be signed in to change notification settings - Fork 1
/
recursivenetwork.py
67 lines (58 loc) · 2.36 KB
/
recursivenetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import genericlayer
class NodeGenerator():
def __init__(self, node, *args, **kwargs):
self.node = node
self.args = args
self.kwargs = kwargs
class RNN(genericlayer.GenericLayer, NodeGenerator):
def __init__(self, Node, *args, **kwargs):
NodeGenerator.__init__(self, Node, *args, **kwargs)
self.window_size = 0
self.window_step = 0
self.nodes = []
self.net = self.node(*self.args, **self.kwargs)
self.message_fun = {
'delete_nodes' : self.delete_nodes,
'init_nodes' : self.init_nodes,
'clear_memory' : self.clear_memory
}
def on_message(self,message,*args,**kwargs):
self.message_fun[message](*args,**kwargs)
def delete_nodes(self):
self.nodes = []
def init_nodes(self, window_size):
self.window_size = window_size
self.window_step = 0
self.nodes = []
for ind in range(window_size):
self.nodes.append(self.node(*self.args, **self.kwargs))
def clear_memory(self):
if type(self.net.state) is list:
for e in range(len(self.net.state)):
self.net.state[e].fill(0.0)
for ind in range(self.window_size):
self.nodes[ind].state[e].fill(0.0)
self.nodes[ind].dJdstate[e].fill(0.0)
else:
self.net.state.fill(0.0)
for ind in range(self.window_size):
self.nodes[ind].state.fill(0.0)
self.nodes[ind].dJdstate.fill(0.0)
def forward(self, x, update = False):
if update:
if self.window_step == self.window_size:
raise Exception('Window Exceeded')
[y, self.nodes[self.window_step].state] = self.nodes[self.window_step].forward([x, self.nodes[self.window_step - 1].state])
self.window_step += 1
return y
else:
[y,self.net.state] = self.net.forward([x, self.net.state])
return y
def backward(self, dJdy, optimizer = None):
if self.window_step == 0:
raise Exception('Window Exceeded')
self.window_step -= 1
[dJdx, dJdh] = self.nodes[self.window_step].backward([dJdy, self.nodes[self.window_step].dJdstate], optimizer)
if self.window_step > 0:
self.nodes[self.window_step-1].dJdstate = dJdh
return dJdx