-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
executable file
·30 lines (22 loc) · 933 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
from keras.models import Sequential, load_model
from keras.layers import LSTM, Dropout, TimeDistributed, Dense, Activation, Embedding
MODEL_DIR = './model'
def save_weights(epoch, model):
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
model.save_weights(os.path.join(MODEL_DIR, 'weights.{}.h5'.format(epoch)))
def load_weights(epoch, model):
model.load_weights(os.path.join(MODEL_DIR, 'weights.{}.h5'.format(epoch)))
def build_model(batch_size, seq_len, vocab_size):
model = Sequential()
model.add(Embedding(vocab_size, 512, batch_input_shape=(batch_size, seq_len)))
for i in range(3):
model.add(LSTM(256, return_sequences=True, stateful=True))
model.add(Dropout(0.2))
model.add(TimeDistributed(Dense(vocab_size)))
model.add(Activation('softmax'))
return model
if __name__ == '__main__':
model = build_model(16, 64, 50)
model.summary()