-
Notifications
You must be signed in to change notification settings - Fork 1
/
genericlayer.py
135 lines (112 loc) · 3.67 KB
/
genericlayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import inspect, os
import dill as pickle
import numpy as np
class StoreNetwork:
def save(self, file):
f = open(file, "w")
pickle.dump(self,f)
@staticmethod
def load(file):
if os.path.isfile(file):
f = open(file, "r")
return pickle.load(f)
else:
raise Exception('File does not exist!')
@staticmethod
def load_or_create(file, net):
if os.path.isfile(file):
f = open(file, "r")
return pickle.load(f)
else:
return net
class GenericLayer(StoreNetwork):
def numeric_gradient(self,x):
dx = 0.00000001
fx = self.forward(x)
if type(x) is list:
dJdx = []
for ind,element in enumerate(x):
dJdx.append(np.zeros([fx.size,element.size]))
for r in xrange(element.size):
dxvett = np.zeros(element.size)
dxvett[r] = dx
xin = [el.copy() for el in x]
xin[ind] = xin[ind]+dxvett
fxdx = self.forward(xin)
dJdx[ind][:,r] = (fxdx-fx)/dx
return dJdx
else:
dJdx = np.zeros([fx.size,x.size])
for r in xrange(x.size):
dxvett = np.zeros(x.size)
dxvett[r] = dx
fxdx = self.forward(x+dxvett)
dJdx[:,r] = (fxdx-fx)/dx
return dJdx
def on_message(self, message, *args, **kwargs):
pass
def forward(self, x, update = False):
return x
def backward(self, dJdy, optimizer = None):
return dJdy
def printlayer(self, level):
strlab = self.__class__.__name__
if hasattr(self,'printelements'):
strlab += self.printelements(level)
return strlab
def __str__(self):
return self.printlayer(1)
class WithNet(GenericLayer):
def __init__(self, net):
self.net = net
def forward(self, x, update = False):
return self.net.forward(x, update)
def backward(self, dJdy, optimizer = None):
return self.net.backward(dJdy, optimizer)
def printelements(self,level):
strlab = '(\n'
for l in range(level):
strlab += '\t'
strlab += self.net.printlayer(level+1)+'\n'
for l in range(level-1):
strlab += '\t'
strlab += ')'
return strlab
class WithElements:
def __init__(self, *args):
self.elements = []
if len(args) == 1 and type(args[0]) == list:
args = args[0]
for element in args:
self.add(element)
def insert(self, index, element):
if inspect.isclass(element):
element = element()
self.elements.insert(index, element)
return self
def add(self, element):
if inspect.isclass(element):
element = element()
self.elements.append(element)
return self
def on_message(self, message, *args, **kwargs):
for element in self.elements:
op = getattr(element, "on_message", None)
if callable(op):
element.on_message(message,*args,**kwargs)
def printelements(self,level):
strlab = '(\n'
for element in self.elements:
for l in range(level):
strlab += '\t'
strlab += element.printlayer(level+1)+'\n'
for l in range(level-1):
strlab += '\t'
strlab += ')'
return strlab
# def __str__(self):
# strlab = self.__class__.__name__+'(\n'
# for element in self.elements:
# strlab+='\t'+str(element)+'\n'
# strlab+=')'
# return strlab