Skip to content

Commit

Permalink
ignore pyright errors, fix proposed in #1102
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthijs committed Mar 22, 2024
1 parent 915e77f commit 899c063
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

import collections
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from warnings import warn
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar

import matplotlib as mpl
import numpy as np
Expand All @@ -14,14 +14,16 @@
from matplotlib.figure import Figure, FigureBase
from scipy.stats import binom, gaussian_kde, iqr
from torch import Tensor

from sbi.analysis import eval_conditional_density

try:
collectionsAbc = collections.abc # type: ignore
except AttributeError:
collectionsAbc = collections

T=TypeVar('T')
T = TypeVar('T')


def hex2rgb(hex):
"""Pass 16 to the integer function for change of base"""
Expand All @@ -36,12 +38,13 @@ def rgb2hex(RGB):
])


def to_list(x:Optional[Union[T,List[Optional[T]]]], len:int) -> List[Optional[T]]:
def to_list(x: Optional[Union[T, List[Optional[T]]]], len: int) -> List[Optional[T]]:
"""If x is not a list, make it a list of length `len`."""
if not isinstance(x, list):
return [x for _ in range(len)]
return x


def _update(d, u):
if u is not None:
"""update dictionary with user input, see: https://stackoverflow.com/a/3233356"""
Expand Down Expand Up @@ -524,16 +527,16 @@ def pairplot(
upper: Optional[Union[List[str], str]] = "hist",
lower: Optional[Union[List[str], str]] = None,
diag: Optional[Union[List[str], str]] = "hist",
figsize:Tuple = (10, 10),
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
offdiag: Optional[Union[List[str], str]] = None,
diag_kwargs: Optional[Union[List[Dict],Dict]] =None,
upper_kwargs: Optional[Union[List[Dict],Dict]] =None,
lower_kwargs: Optional[Union[List[Dict],Dict]] =None,
fig_kwargs: Optional[Union[List[Dict],Dict]] =None,
fig: Optional[FigureBase]=None,
axes: Optional[Axes]=None,
diag_kwargs: Optional[Union[List[Dict], Dict]] = None,
upper_kwargs: Optional[Union[List[Dict], Dict]] = None,
lower_kwargs: Optional[Union[List[Dict], Dict]] = None,
fig_kwargs: Optional[Union[List[Dict], Dict]] = None,
fig: Optional[FigureBase] = None,
axes: Optional[Axes] = None,
**kwargs: Optional[Any],
):
"""
Expand Down Expand Up @@ -616,7 +619,7 @@ def pairplot(
upper = offdiag

# Prepare diag
diag_list= to_list(diag, len(samples))
diag_list = to_list(diag, len(samples))
diag_kwargs_list = to_list(diag_kwargs, len(samples))
diag_func = get_diag_funcs(diag_list)
diag_kwargs_filled = []
Expand Down Expand Up @@ -679,10 +682,10 @@ def marginal_plot(
figsize: Optional[Tuple] = (10, 2),
labels: Optional[List[str]] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
diag_kwargs: Optional[Union[List[Dict],Dict]] =None,
fig_kwargs: Optional[Union[List[Dict],Dict]] =None,
fig: Optional[FigureBase]=None,
axes: Optional[Axes]=None,
diag_kwargs: Optional[Union[List[Dict], Dict]] = None,
fig_kwargs: Optional[Union[List[Dict], Dict]] = None,
fig: Optional[FigureBase] = None,
axes: Optional[Axes] = None,
**kwargs: Optional[Any],
):
"""
Expand Down Expand Up @@ -792,7 +795,7 @@ def _get_default_offdiag_kwargs(offdiag, i=0):
elif offdiag == "scatter":
offdiag_kwargs = {
"mpl_kwargs": {
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2],
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2], # pyright: ignore[reportOptionalMemberAccess]
"edgecolor": "white",
"alpha": 0.5,
"rasterized": False,
Expand All @@ -805,13 +808,13 @@ def _get_default_offdiag_kwargs(offdiag, i=0):
"levels": [0.68, 0.95, 0.99],
"percentile": True,
"mpl_kwargs": {
"colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2]
"colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2] # pyright: ignore[reportOptionalMemberAccess]
},
}
elif offdiag == "plot":
offdiag_kwargs = {
"mpl_kwargs": {
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2]
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2] # pyright: ignore[reportOptionalMemberAccess]
}
}
else:
Expand All @@ -826,23 +829,23 @@ def _get_default_diag_kwargs(diag, i=0):
"bw_method": "scott",
"bins": 50,
"mpl_kwargs": {
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2]
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2] # pyright: ignore[reportOptionalMemberAccess]
},
}

elif diag == "hist":
diag_kwargs = {
"bin_heuristic": "Freedman-Diaconis",
"mpl_kwargs": {
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2],
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2], # pyright: ignore[reportOptionalMemberAccess]
"density": False,
"histtype": "step",
},
}
elif diag == "scatter":
diag_kwargs = {
"mpl_kwargs": {
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2]
"color": plt.rcParams["axes.prop_cycle"].by_key()["color"][i * 2] # pyright: ignore[reportOptionalMemberAccess]
}
}
else:
Expand All @@ -859,8 +862,8 @@ def _get_default_fig_kwargs():
"points_labels": [f"points_{idx}" for idx in range(10)], # for points
"samples_labels": [f"samples_{idx}" for idx in range(10)], # for samples
# colors: take even colors for samples, odd colors for points
"samples_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][0::2],
"points_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][1::2],
"samples_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][0::2], # pyright: ignore[reportOptionalMemberAccess]
"points_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][1::2], # pyright: ignore[reportOptionalMemberAccess]
# ticks
"tickformatter": mpl.ticker.FormatStrFormatter("%g"), # type: ignore
"tick_labels": None,
Expand Down Expand Up @@ -1718,8 +1721,8 @@ def pairplot_dep(
labels: Optional[List[str]] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
upper: Optional[Union[List[str], str]] = None,
fig: Optional[FigureBase]=None,
axes: Optional[Axes]=None,
fig: Optional[FigureBase] = None,
axes: Optional[Axes] = None,
**kwargs: Optional[Any],
):
"""
Expand Down Expand Up @@ -1897,8 +1900,8 @@ def marginal_plot_dep(
figsize: Optional[Tuple] = (10, 10),
labels: Optional[List[str]] = None,
ticks: Optional[Union[List, torch.Tensor]] = None,
fig: Optional[FigureBase]=None,
axes: Optional[Axes]=None,
fig: Optional[FigureBase] = None,
axes: Optional[Axes] = None,
**kwargs: Optional[Any],
):
"""
Expand Down Expand Up @@ -2240,8 +2243,8 @@ def _get_default_opts():
"points_labels": [f"points_{idx}" for idx in range(10)], # for points
"samples_labels": [f"samples_{idx}" for idx in range(10)], # for samples
# colors: take even colors for samples, odd colors for points
"samples_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][0::2],
"points_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][1::2],
"samples_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][0::2], # pyright: ignore[reportOptionalMemberAccess]
"points_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][1::2], # pyright: ignore[reportOptionalMemberAccess]
# ticks
"ticks": [],
"tickformatter": mpl.ticker.FormatStrFormatter("%g"), # type: ignore
Expand Down

0 comments on commit 899c063

Please sign in to comment.