Skip to content

Commit

Permalink
Unify plotting behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilbhavikatti committed Aug 15, 2024
1 parent d830fa1 commit 8382b6d
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 34 deletions.
10 changes: 9 additions & 1 deletion uadapy/plotting/plots1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def plot_1d_distribution(distributions, num_samples, plot_types:list, seed=55, f
This parameter determines the size of the dots used in the 'stripplot' and 'swarmplot'.
If not provided, the size is calculated based on the number of samples and the type of plot.
- showmeans : bool, optional
If True, display means in plot. Only effective on violin plot =.
If True, display means in plot. Only effective on violin plot.
Default is False.
- show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
Expand Down Expand Up @@ -265,4 +268,9 @@ def plot_1d_distribution(distributions, num_samples, plot_types:list, seed=55, f
else:
ax.set_visible(False) # Hide unused subplots

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs
93 changes: 77 additions & 16 deletions uadapy/plotting/plots2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,59 @@
from numpy import ma
from matplotlib import ticker

def plot_samples(distributions, num_samples, **kwargs):
def plot_samples(distributions, num_samples, seed=55, **kwargs):
"""
Plot samples from the given distribution. If several distributions should be
plotted together, an array can be passed to this function
:param distributions: Distributions to plot
:param num_samples: Number of samples per distribution
:param kwargs: Optional other arguments to pass:
xlabel for label of x-axis
ylabel for label of y-axis
:return:
plotted together, an array can be passed to this function.
Parameters
----------
distributions : list
List of distributions to plot.
num_samples : int
Number of samples per distribution.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- xlabel : string, optional
label for x-axis.
- ylabel : string, optional
label for y-axis.
- show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
list
List of Axes objects used for plotting.
"""

if isinstance(distributions, distribution):
distributions = [distributions]
for d in distributions:
samples = d.sample(num_samples)
samples = d.sample(num_samples, seed)
plt.scatter(x=samples[:,0], y=samples[:,1])
if 'xlabel' in kwargs:
plt.xlabel(kwargs['xlabel'])
if 'ylabel' in kwargs:
plt.ylabel(kwargs['ylabel'])
plt.show()
if 'title' in kwargs:
plt.title(kwargs['title'])

# Get the current figure and axes
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None, seed=55, **kwargs):
"""
Expand All @@ -44,11 +76,16 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
None
This function does not return a value. It displays a plot using plt.show().
matplotlib.figure.Figure
The figure object containing the plot.
list
List of Axes objects used for plotting.
Raises
------
Expand Down Expand Up @@ -102,7 +139,17 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None
isovalues.append(densities[int((1 - quantile/100) * num_samples)])

plt.contour(xv, yv, pdf, levels=isovalues, colors = [color])
plt.show()

# Get the current figure and axes
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour_bands(distributions, num_samples, resolution=128, ranges=None, quantiles:list=None, seed=55, **kwargs):
"""
Expand All @@ -124,11 +171,16 @@ def plot_contour_bands(distributions, num_samples, resolution=128, ranges=None,
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
None
This function does not return a value. It displays a plot using plt.show().
matplotlib.figure.Figure
The figure object containing the plot.
list
List of Axes objects used for plotting.
Raises
------
Expand Down Expand Up @@ -190,7 +242,16 @@ def plot_contour_bands(distributions, num_samples, resolution=128, ranges=None,
# Generate logarithmic levels and create the contour plot with different colormap for each distribution
plt.contourf(xv, yv, pdf, levels=isovalues, locator=ticker.LogLocator(), cmap=colormaps[i % len(colormaps)])

plt.show()
# Get the current figure and axes
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

# HELPER FUNCTIONS
def generate_random_colors(length):
Expand Down
89 changes: 72 additions & 17 deletions uadapy/plotting/plotsND.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,32 @@
from uadapy import distribution
import uadapy.plotting.utils as utils

def plot_samples(distributions, num_samples, **kwargs):
def plot_samples(distributions, num_samples, seed=55, **kwargs):
"""
Plot samples from the multivariate distribution as a SLOM
:param distribution: The multivariate distributions
:param num_samples: Number of samples to draw
:param kwargs: Optional other arguments to pass:
:return:
Plot samples from the multivariate distribution as a SLOM.
Parameters
----------
distributions : list
List of distributions to plot.
num_samples : int
Number of samples per distribution.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
list
List of Axes objects used for plotting.
"""

if isinstance(distributions, distribution):
distributions = [distributions]
# Create matrix
Expand All @@ -26,7 +44,7 @@ def plot_samples(distributions, num_samples, **kwargs):
for k, d in enumerate(distributions):
if d.dim < 2:
raise Exception('Wrong dimension of distribution')
samples = d.sample(num_samples)
samples = d.sample(num_samples, seed)
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i, j), (j, i)]:
axes[x,y].scatter(samples[:,y], y=samples[:,x], color=contour_colors[k])
Expand All @@ -41,8 +59,17 @@ def plot_samples(distributions, num_samples, **kwargs):
axes[-1,i].xaxis.set_visible(True)
axes[i,0].yaxis.set_visible(True)
axes[0,1].yaxis.set_visible(True)
fig.tight_layout()
plt.show()

# Get the current figure and axes
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour(distributions, num_samples, resolution=128, ranges=None, quantiles:list=None, seed=55, **kwargs):
"""
Expand All @@ -64,11 +91,16 @@ def plot_contour(distributions, num_samples, resolution=128, ranges=None, quanti
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
None
This function does not return a value. It displays a plot using plt.show().
matplotlib.figure.Figure
The figure object containing the plot.
list
List of Axes objects used for plotting.
Raises
------
Expand Down Expand Up @@ -155,8 +187,17 @@ def plot_contour(distributions, num_samples, resolution=128, ranges=None, quanti
axes[-1,i].xaxis.set_visible(True)
axes[i,0].yaxis.set_visible(True)
axes[0,1].yaxis.set_visible(True)
fig.tight_layout()
plt.show()

# Get the current figure and axes
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour_samples(distributions, num_samples, resolution=128, ranges=None, quantiles:list=None, seed=55, **kwargs):
"""
Expand All @@ -179,11 +220,16 @@ def plot_contour_samples(distributions, num_samples, resolution=128, ranges=None
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
None
This function does not return a value. It displays a plot using plt.show().
matplotlib.figure.Figure
The figure object containing the plot.
list
List of Axes objects used for plotting.
Raises
------
Expand Down Expand Up @@ -272,5 +318,14 @@ def plot_contour_samples(distributions, num_samples, resolution=128, ranges=None
axes[-1,i].xaxis.set_visible(True)
axes[i,0].yaxis.set_visible(True)
axes[0,1].yaxis.set_visible(True)
fig.tight_layout()
plt.show()

# Get the current figure and axes
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

0 comments on commit 8382b6d

Please sign in to comment.