Skip to content

Commit

Permalink
change VAE Keras interface and update VAE TF arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohsen Naghipourfar committed Aug 21, 2019
1 parent 42f7a92 commit 2061dcc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 16 deletions.
20 changes: 5 additions & 15 deletions scgen/models/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,15 @@ class VAEArith:
"""

def __init__(self, x_dimension, z_dimension=100, **kwargs):
tf.reset_default_graph()
self.x_dim = x_dimension
self.z_dim = z_dimension
self.learning_rate = kwargs.get("learning_rate", 0.001)
self.dropout_rate = kwargs.get("dropout_rate", 0.2)
self.model_to_use = kwargs.get("model_path", "./models/scgen")
self.alpha = kwargs.get("alpha", 0.00005)
self.is_training = tf.placeholder(tf.bool, name='training_flag')
self.global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
self.x = tf.placeholder(tf.float32, shape=[None, self.x_dim], name="data")
self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name="latent")
self.time_step = tf.placeholder(tf.int32)
self.size = tf.placeholder(tf.int32)
self.init_w = tf.contrib.layers.xavier_initializer()
self._create_network()
self._loss_function()
Expand Down Expand Up @@ -119,7 +115,8 @@ def _sample_z(self):
# Returns
The computed Tensor of samples with shape [size, z_dim].
"""
eps = tf.random_normal(shape=[self.size, self.z_dim])
batch_size = tf.shape(self.mu)[0]
eps = tf.random_normal(shape=[batch_size, self.z_dim])
return self.mu + tf.exp(self.log_var / 2) * eps

def _create_network(self):
Expand Down Expand Up @@ -174,7 +171,7 @@ def to_latent(self, data):
latent: numpy nd-array
Returns array containing latent space encoding of 'data'
"""
latent = self.sess.run(self.z_mean, feed_dict={self.x: data, self.size: data.shape[0], self.is_training: False})
latent = self.sess.run(self.z_mean, feed_dict={self.x: data, self.is_training: False})
return latent

def _avg_vector(self, data):
Expand Down Expand Up @@ -429,8 +426,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
"""
if initial_run:
log.info("----Training----")
assign_step_zero = tf.assign(self.global_step, 0)
_init_step = self.sess.run(assign_step_zero)
if not initial_run:
self.saver.restore(self.sess, self.model_to_use)
if use_validation and valid_data is None:
Expand All @@ -442,9 +437,6 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
min_delta = threshold
patience_cnt = 0
for it in range(n_epochs):
increment_global_step_op = tf.assign(self.global_step, self.global_step + 1)
_step = self.sess.run(increment_global_step_op)
current_step = self.sess.run(self.global_step)
train_loss = 0.0
for lower in range(0, train_data.shape[0], batch_size):
upper = min(lower + batch_size, train_data.shape[0])
Expand All @@ -454,8 +446,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
x_mb = train_data[lower:upper, :].X
if upper - lower > 1:
_, current_loss_train = self.sess.run([self.solver, self.vae_loss],
feed_dict={self.x: x_mb, self.time_step: current_step,
self.size: len(x_mb), self.is_training: True})
feed_dict={self.x: x_mb, self.is_training: True})
train_loss += current_loss_train
if use_validation:
valid_loss = 0
Expand All @@ -466,8 +457,7 @@ def train(self, train_data, use_validation=False, valid_data=None, n_epochs=25,
else:
x_mb = valid_data[lower:upper, :].X
current_loss_valid = self.sess.run(self.vae_loss,
feed_dict={self.x: x_mb, self.time_step: current_step,
self.size: len(x_mb), self.is_training: False})
feed_dict={self.x: x_mb, self.is_training: False})
valid_loss += current_loss_valid
loss_hist.append(valid_loss / valid_data.shape[0])
if it > 0 and loss_hist[it - 1] - loss_hist[it] > min_delta:
Expand Down
1 change: 0 additions & 1 deletion scgen/models/_vae_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class VAEArithKeras:
"""

def __init__(self, x_dimension, z_dimension=100, **kwargs):
tf.reset_default_graph()
self.x_dim = x_dimension
self.z_dim = z_dimension
self.learning_rate = kwargs.get("learning_rate", 0.001)
Expand Down

0 comments on commit 2061dcc

Please sign in to comment.