Skip to content

Commit

Permalink
versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
JHoelli committed Aug 25, 2023
1 parent 2985d2c commit a248f6b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion TSInterpret/InterpretabilityModels/FeatureAttribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def plot(
self,
item,
exp,
figsize= (6.4 ,4.8),
figsize=(6.4, 4.8),
heatmap=False,
normelize_saliency=True,
vmin=-1,
Expand Down
14 changes: 8 additions & 6 deletions TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
def explain(self):
raise NotImplementedError("Don't use the base CF class directly")

def plot(self, item, exp, figsize=(6.4 ,4.8), heatmap=False, save=None):
def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None):
"""
Plots explanation on the explained Sample.
Expand All @@ -62,11 +62,9 @@ def plot(self, item, exp, figsize=(6.4 ,4.8), heatmap=False, save=None):
exp = exp.reshape(exp.shape[-1], -1)
else:
print("NOT Time mode")



if heatmap:
fig,ax011 = plt.subplots(1, 1, figsize= figsize)
fig, ax011 = plt.subplots(1, 1, figsize=figsize)
sns.heatmap(
exp,
fmt="g",
Expand All @@ -79,7 +77,9 @@ def plot(self, item, exp, figsize=(6.4 ,4.8), heatmap=False, save=None):
)
elif len(item[0]) == 1:
# if only onedimensional input
fig, axn = plt.subplots(len(item[0]), 1, sharex=True, sharey=True, figsize= figsize)
fig, axn = plt.subplots(
len(item[0]), 1, sharex=True, sharey=True, figsize=figsize
)
# cbar_ax = fig.add_axes([.91, .3, .03, .4])
axn012 = axn.twinx()
sns.heatmap(
Expand All @@ -100,7 +100,9 @@ def plot(self, item, exp, figsize=(6.4 ,4.8), heatmap=False, save=None):
else:
ax011 = []

fig, axn = plt.subplots(len(item[0]), 1, sharex=True, sharey=True, figsize= figsize)
fig, axn = plt.subplots(
len(item[0]), 1, sharex=True, sharey=True, figsize=figsize
)
cbar_ax = fig.add_axes([0.91, 0.3, 0.03, 0.4])

for channel in item[0]:
Expand Down
6 changes: 3 additions & 3 deletions TSInterpret/InterpretabilityModels/counterfactual/CF.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def plot(
vis_change=True,
all_in_one=False,
save_fig=None,
figsize=(6.4 ,4.8)
figsize=(6.4, 4.8),
):
"""
Basic Plot Function for visualizing Coutnerfactuals.
Expand Down Expand Up @@ -165,7 +165,7 @@ def plot(
plt.savefig(save_fig)

def plot_in_one(
self, item, org_label, exp, cf_label, save_fig=None, figsize=(6.4 ,4.8)
self, item, org_label, exp, cf_label, save_fig=None, figsize=(6.4, 4.8)
):
"""
Plot Function for Counterfactuals in uni-and multivariate setting. In the multivariate setting only the changed features are visualized.
Expand Down Expand Up @@ -245,7 +245,7 @@ def plot_in_one(
plt.savefig(save_fig)

def plot_multi(
self, item, org_label, exp, cf_label, figsize=(6.4 ,4.8), save_fig=None
self, item, org_label, exp, cf_label, figsize=(6.4, 4.8), save_fig=None
):
"""Plot Function for Ates et al., used if multiple features are changed in a Multivariate Setting.
Also called via plot_in_one. Preferably, do not use directly.
Expand Down
2 changes: 1 addition & 1 deletion TSInterpret/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION = (0, 3, 3)
VERSION = (0, 3, 4)
__version__ = ".".join(map(str, VERSION)) # noqa: F401

0 comments on commit a248f6b

Please sign in to comment.