-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
26 lines (20 loc) · 883 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tensorflow as tf
import os
def create_dirs_if_not_exists(path):
if not os.path.exists(path):
os.makedirs(path)
def printShape(tensor):
print(tensor.shape)
def squash(s, axis=-1, epsilon=1e-7, name=None):
with tf.name_scope(name, default_name="squash"):
squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
keep_dims=True)
safe_norm = tf.sqrt(squared_norm + epsilon)
squash_factor = squared_norm / (1. + squared_norm)
unit_vector = s / safe_norm
return squash_factor * unit_vector
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):
with tf.name_scope(name, default_name="safe_norm"):
squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
keep_dims=keep_dims)
return tf.sqrt(squared_norm + epsilon)