Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swpaxes #45

Merged
merged 6 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified ClassificationModels/models/BasicMotions/ResNet
Binary file not shown.
Binary file modified ClassificationModels/models/Epilepsy/ResNet
Binary file not shown.
45 changes: 30 additions & 15 deletions TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_PTY.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,12 @@ def _getTwoStepRescaling(
newGrad = np.zeros((input_size, sequence_length))
# print("has Sliding Window", hasSliding_window_shapes)
if self.mode == "time":
input = input.reshape(-1, sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
# print(input.shape)
# print('mode timw')
# print('inüut1',input)
# input = np.swapaxes(input,-1,-2)#.reshape(-1, sequence_length, input_size)
# print('inüut1',input)

if hasBaseline is None:
ActualGrad = (
Expand Down Expand Up @@ -283,9 +288,11 @@ def _getTwoStepRescaling(
)
# if self.mode == "time":
# ActualGrad = ActualGrad.reshape(-1, input_size, sequence_length)
if self.mode == "time":
input = np.swapaxes(
input, -1, -2
) # input.reshape(-1, input_size, sequence_length)
for t in range(sequence_length):
if self.mode == "time":
input = input.reshape(-1, input_size, sequence_length)
newInput = input.clone()
# if newInput.shape[-1] == self.NumTimeSteps:
# print('A')
Expand All @@ -294,8 +301,9 @@ def _getTwoStepRescaling(
# print('B')
# newInput[:, t,:] = assignment
if self.mode == "time":
newInput = newInput.reshape(-1, sequence_length, input_size)

newInput = np.swapaxes(
newInput, -1, -2
) # .reshape(-1, sequence_length, input_size)
if hasBaseline is None:
timeGrad_perTime = (
self.Grad.attribute(newInput, target=TestingLabel)
Expand Down Expand Up @@ -339,9 +347,9 @@ def _getTwoStepRescaling(

timeGrad_perTime = np.absolute(ActualGrad - timeGrad_perTime)
if self.mode == "time":
timeGrad_perTime = timeGrad_perTime.reshape(
-1, input_size, sequence_length
)
timeGrad_perTime = np.swapaxes(timeGrad_perTime, -1, -2) # .reshape(
# -1, input_size, sequence_length
# )
timeGrad[:, t] = np.sum(timeGrad_perTime)

timeContribution = preprocessing.minmax_scale(timeGrad, axis=1)
Expand All @@ -354,7 +362,9 @@ def _getTwoStepRescaling(
newInput = input.clone()
newInput[:, c, t] = assignment
if self.mode == "time":
newInput = newInput.reshape(-1, sequence_length, input_size)
newInput = np.swapaxes(
newInput, -1, -2
) # .reshape(-1, sequence_length, input_size)

if hasBaseline is None:
inputGrad_perInput = (
Expand Down Expand Up @@ -395,21 +405,26 @@ def _getTwoStepRescaling(
)

inputGrad_perInput = np.absolute(ActualGrad - inputGrad_perInput)
inputGrad_perInput = inputGrad_perInput.reshape(
-1, input_size, sequence_length
)
inputGrad_perInput = np.swapaxes(
inputGrad_perInput, -1, -2
) # .reshape(
# -1, input_size, sequence_length
# )
inputGrad[c, :] = np.sum(inputGrad_perInput)
featureContribution = preprocessing.minmax_scale(inputGrad, axis=0)

else:
featureContribution = np.ones((input_size, 1)) * 0.1
# print('FC',featureContribution)
newGrad = newGrad.reshape(input_size, sequence_length)
# newGrad = newGrad#.reshape(input_size, sequence_length)
if self.mode == "time":
# newGrad = newGrad.reshape(sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
for c in range(input_size):
newGrad[c, t] = timeContribution[0, t] * featureContribution[c, 0]
if self.mode == "time":
newGrad = newGrad.reshape(sequence_length, input_size)
# print('NewGrad',newGrad.shape)
# newGrad = newGrad.reshape(sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
return newGrad

def _givenAttGetRescaledSaliency(self, attributions, isTensor=True):
Expand Down
27 changes: 19 additions & 8 deletions TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_TF.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def _getTwoStepRescaling(
):
sequence_length = self.NumTimeSteps
input_size = self.NumFeatures
# print(sequence_length)
# print(input_size)
# print('inputshape',input.shape)
# print('Saliency Rescaling',input)
assignment = input[0, 0, 0]
timeGrad = np.zeros((1, sequence_length))
inputGrad = np.zeros((input_size, 1))
Expand All @@ -162,11 +166,16 @@ def _getTwoStepRescaling(
ActualGrad = self.Grad.explain(
(input, None), self.model, class_index=TestingLabel
) # .data.cpu().numpy()

# print('Actual GRad', ActualGrad)
for t in range(sequence_length):
newInput = input.copy().reshape(1, input_size, sequence_length)
newInput = np.swapaxes(input.copy(), 2, 1).reshape(
1, input_size, sequence_length
)
# print('NEW INPUT',newInput)
newInput[:, :, t] = assignment
newInput = newInput.reshape(1, sequence_length, input_size)
newInput = np.swapaxes(newInput, 2, 1).reshape(
1, sequence_length, input_size
)
if self.method == "FO":
timeGrad_perTime = self.Grad.explain(
(newInput, None),
Expand All @@ -187,13 +196,16 @@ def _getTwoStepRescaling(

timeContibution = preprocessing.minmax_scale(timeGrad, axis=1)
meanTime = np.quantile(timeContibution, 0.55)

for t in range(sequence_length):
if timeContibution[0, t] > meanTime:
for c in range(input_size):
newInput = input.copy().reshape(1, input_size, sequence_length)
newInput = np.swapaxes(input.copy(), 2, 1).reshape(
1, input_size, sequence_length
)
newInput[:, c, t] = assignment
newInput = newInput.reshape(1, sequence_length, input_size)
newInput = np.swapaxes(newInput, 2, 1).reshape(
1, sequence_length, input_size
)
if self.method == "FO":
inputGrad_perInput = self.Grad.explain(
(newInput, None),
Expand All @@ -204,7 +216,6 @@ def _getTwoStepRescaling(
elif self.method == "DLS" or self.method == "GS":
inputGrad_perInput = self.Grad.shap_values(newInput)
inputGrad_perInput = np.array(inputGrad_perInput)
# print(inputGrad_perInput.shape)
else:
newInput = newInput.reshape(1, sequence_length, input_size, 1)
inputGrad_perInput = self.Grad.explain(
Expand All @@ -220,4 +231,4 @@ def _getTwoStepRescaling(

for c in range(input_size):
newGrad[c, t] = timeContibution[0, t] * featureContibution[c, 0]
return newGrad.reshape(sequence_length, input_size)
return np.swapaxes(newGrad, 0, 1)
7 changes: 5 additions & 2 deletions TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import seaborn as sns

from TSInterpret.InterpretabilityModels.FeatureAttribution import FeatureAttribution
import numpy as np


class Saliency(FeatureAttribution):
Expand Down Expand Up @@ -58,8 +59,10 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None):
i = 0
if self.mode == "time":
print("time mode")
item = item.reshape(1, item.shape[2], item.shape[1])
exp = exp.reshape(exp.shape[-1], -1)
item = np.swapaxes(
item, -1, -2
) # item.reshape(1, item.shape[2], item.shape[1])
exp = np.swapaxes(exp, -1, -2) # exp.reshape(exp.shape[-1], -1)
else:
print("NOT Time mode")

Expand Down
8 changes: 5 additions & 3 deletions TSInterpret/InterpretabilityModels/counterfactual/CF.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ def plot_in_one(
save_fig str: Path to Save the figure.
"""
if self.mode == "time":
item = item.reshape(item.shape[-1], item.shape[-2])
exp = exp.reshape(exp.shape[-1], exp.shape[-2])
org = item
item = np.swapaxes(item, -1, -2).reshape(org.shape[-1], org.shape[-2])
exp = np.swapaxes(exp, -1, -2).reshape(org.shape[-1], org.shape[-2])
else:
item = item.reshape(item.shape[-2], item.shape[-1])
exp = exp.reshape(item.shape[-2], item.shape[-1])

# item = np.swapaxes(item, -2, -1) # .reshape(item.shape[-1], item.shape[-2])
# exp = np.swapaxes(exp, -2, -1) # exp.reshape(exp.shape[-1], exp.shape[-2])
# TODO This is new and needs to be testes
ind = ""
# print("Item Shape", item.shape[-2])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def construct_per_class_trees(self):
return
self.per_class_trees = {}
self.per_class_node_indices = {c: [] for c in np.unique(self.labels)}

input_ = self.timeseries.reshape(-1, self.channels, self.window_size)
input_ = self.timeseries

preds = np.argmax(self.clf(input_), axis=1)
true_positive_node_ids = {c: [] for c in np.unique(self.labels)}
Expand Down Expand Up @@ -246,7 +245,7 @@ def _find_best(self, x_test, distractor, label_idx):
CLASSIFIER = self.clf
X_TEST = x_test
DISTRACTOR = distractor
input_ = x_test.reshape(1, -1, self.window_size)
input_ = x_test
best_case = self.clf(input_)[0][label_idx]
best_column = None
tuples = []
Expand All @@ -271,7 +270,7 @@ def _find_best(self, x_test, distractor, label_idx):
return best_column, best_case

def explain(self, x_test, to_maximize=None, num_features=10):
input_ = x_test.reshape(1, -1, self.window_size)
input_ = x_test
orig_preds = self.clf(input_)
if to_maximize is None:
to_maximize = np.argsort(orig_preds)[0][-2:-1][0]
Expand All @@ -292,10 +291,8 @@ def explain(self, x_test, to_maximize=None, num_features=10):
prev_best = 0
# best_dist = dist
while True:
input_ = modified.reshape(1, -1, self.window_size)
input_ = modified
probas = self.clf(input_)
# print('Current may',np.argmax(probas))
# print(to_maximize)
if np.argmax(probas) == to_maximize:
current_best = np.max(probas)
if current_best > best_explanation_score:
Expand Down Expand Up @@ -376,16 +373,15 @@ def _prune_explanation(
modified = x_test.copy()
for c in short_explanation:
modified[0][c] = dist[0][c]
input_ = modified.reshape(1, -1, self.window_size)
input_ = modified
prev_proba = self.clf(input_)[0][to_maximize]
best_col = None
best_diff = 0
for c in explanation:
tmp = modified.copy()

tmp[0][c] = dist[0][c]

input_ = tmp.reshape(1, self.channels, self.window_size)
input_ = tmp
cur_proba = self.clf(input_)[0][to_maximize]
if cur_proba - prev_proba > best_diff:
best_col = c
Expand All @@ -399,7 +395,7 @@ def _prune_explanation(
def explain(
self, x_test, num_features=None, to_maximize=None
) -> Tuple[np.array, int]:
input_ = x_test.reshape(1, -1, self.window_size)
input_ = x_test
orig_preds = self.clf(input_)

orig_label = np.argmax(orig_preds)
Expand All @@ -419,8 +415,6 @@ def explain(
x_test, num_features=num_features, to_maximize=to_maximize
)
best, other = explanation
# print('Other',np.array(other).shape)
# print('Best',np.array(best).shape)
target = np.argmax(self.clf(best), axis=1)

return best, target
Expand All @@ -429,8 +423,6 @@ def _get_explanation(self, x_test, to_maximize, num_features):
distractors = self._get_distractors(
x_test, to_maximize, n_distractors=self.num_distractors
)
# print('distracotr shape',np.array(distractors).shape)
# print('distracotr classification',np.argmax(self.clf(np.array(distractors).reshape(2,6,100)), axis=1))

# Avoid constructing KDtrees twice
self.backup.per_class_trees = self.per_class_trees
Expand Down Expand Up @@ -469,7 +461,7 @@ def _get_explanation(self, x_test, to_maximize, num_features):
for c in columns:
if c in explanation:
modified[0][c] = dist[0][c]
input_ = modified.reshape(1, -1, self.window_size)
input_ = modified # .reshape(1, -1, self.window_size)
probas = self.clf(input_)

if not self.silent:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,32 +85,17 @@ def evaluate(self, feature_matrix):

for col_replace, a in zip(self.cols_swap, feature_matrix):
if a == 1:
# print(self.distractor.shape)
new_case[0][col_replace] = self.distractor[0][col_replace]

replaced_feature_count = np.sum(feature_matrix)
# print('replaced_Feature', replaced_feature_count)

# print('NEW CASE', new_case)
# print('self xtest', self.x_test)
# print('NEW CASE', new_case.shape)
# print('self xtest', self.x_test.shape)
# print('DIFF', np.where((self.x_test.reshape(-1)-new_case.reshape(-1)) != 0) )

input_ = new_case.reshape(1, self.channels, self.window_size)
input_ = new_case
result_org = self.clf(input_)
result = result_org[0][self.target]
# print('RESULT',result)
feature_loss = self.reg * np.maximum(
0, replaced_feature_count - self.max_features
)

# print('FEATURE LOSS',feature_loss)
loss_pred = np.square(np.maximum(0, 0.95 - result))
# print('losspred ',loss_pred)
# if np.argmax(result_org[0]) != self.target:
# loss_pred=np.inf

loss_pred = loss_pred + feature_loss

return loss_pred
8 changes: 6 additions & 2 deletions TSInterpret/InterpretabilityModels/counterfactual/COMTECF.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(
# Parse test data into (1, feat, time):
change = True
self.ts_length = shape[-2]
test_x = test_x.reshape(test_x.shape[0], test_x.shape[2], test_x.shape[1])
test_x = np.swapaxes(
test_x, 2, 1
) # test_x.reshape(test_x.shape[0], test_x.shape[2], test_x.shape[1])
elif mode == "feat":
change = False
self.ts_length = shape[-1]
Expand Down Expand Up @@ -87,7 +89,7 @@ def explain(
"""
org_shape = x.shape
if self.mode != "feat":
x = x.reshape(-1, x.shape[-1], x.shape[-2])
x = np.swapaxes(x, -1, -2) # x.reshape(-1, x.shape[-1], x.shape[-2])
train_x, train_y = self.referenceset
if len(train_y.shape) > 1:
train_y = np.argmax(train_y, axis=1)
Expand All @@ -106,4 +108,6 @@ def explain(
elif self.method == "brute":
opt = BruteForceSearch(self.predict, train_x, train_y, threads=1)
exp, label = opt.explain(x, to_maximize=target)
if self.mode != "feat":
exp = np.swapaxes(exp, -1, -2)
return exp.reshape(org_shape), label
2 changes: 1 addition & 1 deletion TSInterpret/Models/PyTorchModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def predict(self, item) -> List:
"""
item = np.array(item.tolist()) # , dtype=np.float64)
if self.change:
item = torch.from_numpy(item.reshape(-1, item.shape[-1], item.shape[-2]))
item = torch.from_numpy(np.swapaxes(item, -1, -2))

else:
item = torch.from_numpy(item)
Expand Down
5 changes: 4 additions & 1 deletion TSInterpret/Models/TensorflowModel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf

from TSInterpret.Models.base_model import BaseModel
import numpy as np


class TensorFlowModel(BaseModel):
Expand All @@ -20,7 +21,9 @@ def predict(self, item):
an array of output scores for a classifier.
"""
if self.change:
item = item.reshape(item.shape[0], item.shape[2], item.shape[1])
item = np.swapaxes(
item, 2, 1
) # item.reshape(item.shape[0], item.shape[2], item.shape[1])
out = self.model.predict(item)
return out

Expand Down
2 changes: 1 addition & 1 deletion TSInterpret/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION = (0, 3, 4)
VERSION = (0, 4, 0)
__version__ = ".".join(map(str, VERSION)) # noqa: F401
Binary file modified docs/Notebooks/Ates.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading