forked from aishoot/LSTM_PIT_Speech_Separation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
125 lines (100 loc) · 4.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2017 Sining Sun
from __future__ import absolute_import
import sys, os, time
import pprint
import numpy as np
import warnings
import tensorflow as tf
import tensorflow.contrib.slim as slim
pp = pprint.PrettyPrinter()
def show_all_variables():
model_vars = tf.trainable_variables()
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
sys.stdout.flush()
def mkdir_p(path):
""" Creates a path recursively without throwing an error if it already exists
:param path: path to create
:return: None
"""
if not os.path.exists(path):
os.makedirs(path)
def segment_axis(a, length, overlap=0, axis=None, end='cut', endvalue=0):
"""Generate a new array that chops the given array along the given axis into overlapping frames.
example:
>>> segment_axis(np.arange(10), 4, 2)
array([[0, 1, 2, 3],
[2, 3, 4, 5],
[4, 5, 6, 7],
[6, 7, 8, 9]])
arguments:
a The array to segment
length The length of each frame
overlap The number of array elements by which the frames should overlap
axis The axis to operate on; if None, act on the flattened array
end What to do with the last frame, if the array is not evenly
divisible into pieces. Options are:
'cut' Simply discard the extra values
'wrap' Copy values from the beginning of the array
'pad' Pad with a constant value
endvalue The value to use for end='pad'
The array is not copied unless necessary (either because it is
unevenly strided and being flattened or because end is set to
'pad' or 'wrap').
"""
if axis is None:
a = np.ravel(a) # may copy
axis = 0
l = a.shape[axis]
if overlap >= length: raise ValueError(
"frames cannot overlap by more than 100%")
if overlap < 0 or length <= 0: raise ValueError(
"overlap must be nonnegative and length must be positive")
if l < length or (l - length) % (length - overlap):
if l > length:
roundup = length + (1 + (l - length) // (length - overlap)) * (
length - overlap)
rounddown = length + ((l - length) // (length - overlap)) * (
length - overlap)
else:
roundup = length
rounddown = 0
assert rounddown < l < roundup
assert roundup == rounddown + (length - overlap) or (
roundup == length and rounddown == 0)
a = a.swapaxes(-1, axis)
if end == 'cut':
a = a[..., :rounddown]
elif end in ['pad', 'wrap']: # copying will be necessary
s = list(a.shape)
s[-1] = roundup
b = np.empty(s, dtype=a.dtype)
b[..., :l] = a
if end == 'pad':
b[..., l:] = endvalue
elif end == 'wrap':
b[..., l:] = a[..., :roundup - l]
a = b
a = a.swapaxes(-1, axis)
l = a.shape[axis]
if l == 0: raise ValueError(
"Not enough data points to segment array in 'cut' mode; try 'pad' or 'wrap'")
assert l >= length
assert (l - length) % (length - overlap) == 0
n = 1 + (l - length) // (length - overlap)
s = a.strides[axis]
newshape = a.shape[:axis] + (n, length) + a.shape[axis + 1:]
newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[axis + 1:]
if not a.flags.contiguous:
a = a.copy()
newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[axis + 1:]
return np.ndarray.__new__(np.ndarray, strides=newstrides, shape=newshape, buffer=a, dtype=a.dtype)
try:
return np.ndarray.__new__(np.ndarray, strides=newstrides, shape=newshape, buffer=a, dtype=a.dtype)
except TypeError or ValueError:
warnings.warn("Problem with ndarray creation forces copy.")
a = a.copy()
# Shape doesn't change but strides does
newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[axis + 1:]
return np.ndarray.__new__(np.ndarray, strides=newstrides, shape=newshape, buffer=a, dtype=a.dtype)