From 26eb6279766e03d46333f2ae8046640ff5ffb49e Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 9 Oct 2024 10:23:43 +0200 Subject: [PATCH 1/7] Tweak typing in docstrings --- .../contributing/implementing_distribution.md | 5 +- pymc/_version.py | 58 +++++++++---------- pymc/backends/__init__.py | 2 +- pymc/backends/arviz.py | 8 +-- pymc/backends/base.py | 4 +- pymc/distributions/custom.py | 24 ++++---- pymc/distributions/distribution.py | 4 +- pymc/gp/hsgp_approx.py | 2 +- pymc/model/fgraph.py | 2 +- pymc/sampling/forward.py | 16 ++--- pymc/sampling/jax.py | 8 +-- 11 files changed, 65 insertions(+), 68 deletions(-) diff --git a/docs/source/contributing/implementing_distribution.md b/docs/source/contributing/implementing_distribution.md index 8d0c1750ad..33e2a72892 100644 --- a/docs/source/contributing/implementing_distribution.md +++ b/docs/source/contributing/implementing_distribution.md @@ -37,7 +37,6 @@ The following snippet illustrates how to create a new `RandomVariable`: from pytensor.tensor.var import TensorVariable from pytensor.tensor.random.op import RandomVariable -from typing import List, Tuple # Create your own `RandomVariable`... class BlahRV(RandomVariable): @@ -53,7 +52,7 @@ class BlahRV(RandomVariable): dtype: str = "floatX" # A pretty text and LaTeX representation for the RV - _print_name: Tuple[str, str] = ("blah", "\\operatorname{blah}") + _print_name: tuple[str, str] = ("blah", "\\operatorname{blah}") # If you want to add a custom signature and default values for the # parameters, do it like this. Otherwise this can be left out. @@ -70,7 +69,7 @@ class BlahRV(RandomVariable): rng: np.random.RandomState, loc: np.ndarray, scale: np.ndarray, - size: Tuple[int, ...], + size: tuple[int, ...], ) -> np.ndarray: return scipy.stats.blah.rvs(loc, scale, random_state=rng, size=size) diff --git a/pymc/_version.py b/pymc/_version.py index 2f7f80bfad..a7b02b0360 100644 --- a/pymc/_version.py +++ b/pymc/_version.py @@ -29,11 +29,11 @@ import re import subprocess import sys -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable import functools -def get_keywords() -> Dict[str, str]: +def get_keywords() -> dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -75,8 +75,8 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} +LONG_VERSION_PY: dict[str, str] = {} +HANDLERS: dict[str, dict[str, Callable]] = {} def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator @@ -91,18 +91,18 @@ def decorate(f: Callable) -> Callable: def run_command( - commands: List[str], - args: List[str], - cwd: Optional[str] = None, + commands: list[str], + args: list[str], + cwd: str | None = None, verbose: bool = False, hide_stderr: bool = False, - env: Optional[Dict[str, str]] = None, -) -> Tuple[Optional[str], Optional[int]]: + env: dict[str, str] | None = None, +) -> tuple[str | None, int | None]: """Call the given command(s).""" assert isinstance(commands, list) process = None - popen_kwargs: Dict[str, Any] = {} + popen_kwargs: dict[str, Any] = {} if sys.platform == "win32": # This hides the console window if pythonw.exe is used startupinfo = subprocess.STARTUPINFO() @@ -142,7 +142,7 @@ def versions_from_parentdir( parentdir_prefix: str, root: str, verbose: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -167,13 +167,13 @@ def versions_from_parentdir( @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: +def git_get_keywords(versionfile_abs: str) -> dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords: Dict[str, str] = {} + keywords: dict[str, str] = {} try: with open(versionfile_abs, "r") as fobj: for line in fobj: @@ -196,10 +196,10 @@ def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: @register_vcs_handler("git", "keywords") def git_versions_from_keywords( - keywords: Dict[str, str], + keywords: dict[str, str], tag_prefix: str, verbose: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get version information from git keywords.""" if "refnames" not in keywords: raise NotThisMethod("Short version file found") @@ -268,7 +268,7 @@ def git_pieces_from_vcs( root: str, verbose: bool, runner: Callable = run_command -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -308,7 +308,7 @@ def git_pieces_from_vcs( raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces: Dict[str, Any] = {} + pieces: dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None @@ -400,14 +400,14 @@ def git_pieces_from_vcs( return pieces -def plus_or_dot(pieces: Dict[str, Any]) -> str: +def plus_or_dot(pieces: dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces: Dict[str, Any]) -> str: +def render_pep440(pieces: dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -432,7 +432,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str: return rendered -def render_pep440_branch(pieces: Dict[str, Any]) -> str: +def render_pep440_branch(pieces: dict[str, Any]) -> str: """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . The ".dev0" means not master branch. Note that .dev0 sorts backwards @@ -462,7 +462,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str: return rendered -def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: +def pep440_split_post(ver: str) -> tuple[str, int | None]: """Split pep440 version string at the post-release segment. Returns the release segments before the post-release and the @@ -472,7 +472,7 @@ def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: return vc[0], int(vc[1] or 0) if len(vc) == 2 else None -def render_pep440_pre(pieces: Dict[str, Any]) -> str: +def render_pep440_pre(pieces: dict[str, Any]) -> str: """TAG[.postN.devDISTANCE] -- No -dirty. Exceptions: @@ -496,7 +496,7 @@ def render_pep440_pre(pieces: Dict[str, Any]) -> str: return rendered -def render_pep440_post(pieces: Dict[str, Any]) -> str: +def render_pep440_post(pieces: dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -523,7 +523,7 @@ def render_pep440_post(pieces: Dict[str, Any]) -> str: return rendered -def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: +def render_pep440_post_branch(pieces: dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . The ".dev0" means not master branch. @@ -552,7 +552,7 @@ def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: return rendered -def render_pep440_old(pieces: Dict[str, Any]) -> str: +def render_pep440_old(pieces: dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. @@ -574,7 +574,7 @@ def render_pep440_old(pieces: Dict[str, Any]) -> str: return rendered -def render_git_describe(pieces: Dict[str, Any]) -> str: +def render_git_describe(pieces: dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -594,7 +594,7 @@ def render_git_describe(pieces: Dict[str, Any]) -> str: return rendered -def render_git_describe_long(pieces: Dict[str, Any]) -> str: +def render_git_describe_long(pieces: dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -614,7 +614,7 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str: return rendered -def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: +def render(pieces: dict[str, Any], style: str) -> dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -650,7 +650,7 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: "date": pieces.get("date")} -def get_versions() -> Dict[str, Any]: +def get_versions() -> dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 986a34f4ba..aea1c57b8c 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -63,7 +63,7 @@ from collections.abc import Mapping, Sequence from copy import copy -from typing import Optional, TypeAlias, Union +from typing import TypeAlias import numpy as np diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index d1c27b787b..808801ba56 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -118,7 +118,7 @@ class _DefaultTrace: Attributes ---------- - trace_dict : Dict[str, np.ndarray] + trace_dict : dict[str, np.ndarray] A dictionary constituting a trace. Should be extracted after a procedure has filled the `_DefaultTrace` using the `insert()` method @@ -548,7 +548,7 @@ def predictions_to_inference_data( Parameters ---------- - predictions: Dict[str, np.ndarray] + predictions: dict[str, np.ndarray] The predictions are the return value of :func:`~pymc.sample_posterior_predictive`, a dictionary of strings (variable names) to numpy ndarrays (draws). Requires the arrays to follow the convention ``chain, draw, *shape``. @@ -559,9 +559,9 @@ def predictions_to_inference_data( variables must be *removed* from this trace. model: Model The pymc model. It can be omitted if within a model context. - coords: Dict[str, array-like[Any]] + coords: dict[str, array-like[Any]] Coordinates for the variables. Map from coordinate names to coordinate values. - dims: Dict[str, array-like[str]] + dims: dict[str, array-like[str]] Map from variable name to ordered set of coordinate names. idata_orig: InferenceData, optional If supplied, then modify this inference data in place, adding ``predictions`` and diff --git a/pymc/backends/base.py b/pymc/backends/base.py index c0239f8dec..544450e1c9 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -329,11 +329,11 @@ class MultiTrace: ---------- nchains: int Number of chains in the `MultiTrace`. - chains: `List[int]` + chains: list[int] List of chain indices report: str Report on the sampling process. - varnames: `List[str]` + varnames: list[str] List of variable names in the trace(s) """ diff --git a/pymc/distributions/custom.py b/pymc/distributions/custom.py index 3238680bb3..aad14c687e 100644 --- a/pymc/distributions/custom.py +++ b/pymc/distributions/custom.py @@ -481,10 +481,10 @@ class CustomDist: Parameters ---------- name : str - dist_params : Tuple + dist_params : tuple A sequence of the distribution's parameter. These will be converted into Pytensor tensor variables internally. - dist: Optional[Callable] + dist: Callable | None A callable that returns a PyTensor graph built from simpler PyMC distributions which represents the distribution. This can be used by PyMC to take random draws as well as to infer the logp of the distribution in some cases. In that case @@ -494,7 +494,7 @@ class CustomDist: The symbolic tensor distribution parameters are passed as positional arguments in the same order as they are supplied when the ``CustomDist`` is constructed. - random : Optional[Callable] + random : Callable | None A callable that can be used to generate random draws from the distribution It must have the following signature: ``random(*dist_params, rng=None, size=None)``. @@ -506,7 +506,7 @@ class CustomDist: error will be raised when trying to draw random samples from the distribution's prior or posterior predictive. - logp : Optional[Callable] + logp : Callable | None A callable that calculates the log probability of some given ``value`` conditioned on certain distribution parameter values. It must have the following signature: ``logp(value, *dist_params)``, where ``value`` is @@ -519,7 +519,7 @@ class CustomDist: Otherwise, a ``NotImplementedError`` will be raised when trying to compute the distribution's logp. - logcdf : Optional[Callable] + logcdf : Callable | None A callable that calculates the log cumulative log probability of some given ``value`` conditioned on certain distribution parameter values. It must have the following signature: ``logcdf(value, *dist_params)``, where ``value`` is @@ -527,7 +527,7 @@ class CustomDist: are the tensors that hold the values of the distribution parameters. This function must return a PyTensor tensor. If ``None``, a ``NotImplementedError`` will be raised when trying to compute the distribution's logcdf. - support_point : Optional[Callable] + support_point : Callable | None A callable that can be used to compute the finete logp point of the distribution. It must have the following signature: ``support_point(rv, size, *rv_inputs)``. The distribution's variable is passed as the first argument ``rv``. ``size`` @@ -536,15 +536,15 @@ class CustomDist: distribution parameters, in the same order as they were supplied when the CustomDist was created. If ``None``, a default ``support_point`` function will be assigned that will always return 0, or an array of zeros. - ndim_supp : Optional[int] + ndim_supp : int | None The number of dimensions in the support of the distribution. Inferred from signature, if provided. Defaults to assuming a scalar distribution, i.e. ``ndim_supp = 0`` - ndims_params : Optional[Sequence[int]] + ndims_params : Sequence[int] | None The list of number of dimensions in the support of each of the distribution's parameters. Inferred from signature, if provided. Defaults to assuming all parameters are scalars, i.e. ``ndims_params=[0, ...]``. - signature : Optional[str] + signature : str | None A numpy vectorize-like signature that indicates the number and core dimensionality of the input parameters and sample outputs of the CustomDist. When specified, `ndim_supp` and `ndims_params` are not needed. See examples below. @@ -591,8 +591,6 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: .. code-block:: python - from typing import Optional, Tuple - import numpy as np import pymc as pm from pytensor.tensor import TensorVariable @@ -604,8 +602,8 @@ def logp(value: TensorVariable, mu: TensorVariable) -> TensorVariable: def random( mu: np.ndarray | float, - rng: Optional[np.random.Generator] = None, - size: Optional[Tuple[int]] = None, + rng: np.random.Generator | None = None, + size: tuple[int, ...] | None = None, ) -> np.ndarray | float: return rng.normal(loc=mu, scale=1, size=size) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 8e55f649d4..178eeeb094 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -734,9 +734,9 @@ def create_partial_observed_rv( Returns ------- - observed_rv and mask : Tuple of TensorVariable + observed_rv and mask : tuple of TensorVariable The observed component of the RV and respective indexing mask - unobserved_rv and mask : Tuple of TensorVariable + unobserved_rv and mask : tuple of TensorVariable The unobserved component of the RV and respective indexing mask joined_rv : TensorVariable The symbolic join of the observed and unobserved components. diff --git a/pymc/gp/hsgp_approx.py b/pymc/gp/hsgp_approx.py index cf434ebe3f..5699be6878 100644 --- a/pymc/gp/hsgp_approx.py +++ b/pymc/gp/hsgp_approx.py @@ -631,7 +631,7 @@ def prior_linearized(self, X: TensorLike): Returns ------- - (phi_cos, phi_sin): Tuple[array-like] + (phi_cos, phi_sin): tuple[array-like, ...] List of either Numpy or PyTensor 2D array of the cosine and sine fixed basis vectors. There are n rows, one per row of `Xs` and `m` columns, one for each basis vector. psd: array-like diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 78ad61306e..02f20fdd41 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -146,7 +146,7 @@ def fgraph_from_model( FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops. It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`. - memo: Dict + memo: dict A dictionary mapping original model variables to the equivalent nodes in the fgraph. """ if any(v is not None for v in model.rvs_to_initial_values.values()): diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index db706f2101..0212c68c4c 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -151,21 +151,21 @@ def compile_forward_sampling_function( Parameters ---------- - outputs : List[pytensor.graph.basic.Variable] + outputs : list[pytensor.graph.basic.Variable] The list of variables that will be returned by the compiled function - vars_in_trace : List[pytensor.graph.basic.Variable] + vars_in_trace : list[pytensor.graph.basic.Variable] The list of variables that are assumed to have values stored in the trace - basic_rvs : Optional[List[pytensor.graph.basic.Variable]] + basic_rvs : list[pytensor.graph.basic.Variable] | None A list of random variables that are defined in the model. This list (which could be the output of ``model.basic_RVs``) should have a reference to the variables that should be considered as random variable instances. This includes variables that have a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or Censored distributions. - givens_dict : Optional[Dict[pytensor.graph.basic.Variable, Any]] + givens_dict : dict[pytensor.graph.basic.Variable, Any] | None A dictionary that maps tensor variables to the values that should be used to replace them in the compiled function. The types of the key and value should match or an error will be raised during compilation. - constant_data : Optional[Dict[str, numpy.ndarray]] + constant_data : dict[str, numpy.ndarray] | None A dictionary that maps the names of ``Data`` instances to their corresponding values at inference time. If a model was created with ``Data``, these are stored as ``SharedVariable`` with the name of the data variable and a value equal to @@ -176,7 +176,7 @@ def compile_forward_sampling_function( the ``SharedVariable`` is assumed to not be volatile. If a ``SharedVariable`` is not found in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile. Setting ``constant_data`` to ``None`` is equivalent to passing an empty dictionary. - constant_coords : Optional[Set[str]] + constant_coords : Set[str] | None A set with the names of the mutable coordinates that have not changed their shape after inference. If a model was created with mutable coordinates, these are stored as ``SharedVariable`` with the name of the coordinate and a value equal to the length of said @@ -392,7 +392,7 @@ def sample_prior_predictive( Returns ------- - arviz.InferenceData or Dict + arviz.InferenceData or dict An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default), or a dictionary with variable names as keys and samples as numpy arrays. """ @@ -530,7 +530,7 @@ def sample_posterior_predictive( Returns ------- - arviz.InferenceData or Dict + arviz.InferenceData or dict An ArviZ ``InferenceData`` object containing the posterior predictive samples (default), or a dictionary with variable names as keys, and samples as numpy arrays. diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 43e1baa87f..20ec393c46 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -314,7 +314,7 @@ def _sample_blackjax_nuts( values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, RandomState or Generator, optional Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional + initvals: StartDict or Sequence[StartDict | None], optional Initial values for random variables provided as a dictionary (or sequence of dictionaries) mapping the random variable (by name or reference) to desired starting values. @@ -332,7 +332,7 @@ def _sample_blackjax_nuts( chain_method : str, default "parallel" Specify how samples should be drawn. The choices include "parallel", and "vectorized". - postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, + postprocessing_backend: Literal["cpu", "gpu"] | None, default None, Specify how postprocessing should be computed. gpu or cpu postprocessing_vectorize: Literal["vmap", "scan"], default "scan" How to vectorize the postprocessing: vmap or sequential scan @@ -507,7 +507,7 @@ def sample_jax_nuts( values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, RandomState or Generator, optional Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional + initvals: StartDict or Sequence[StartDict | None], optional Initial values for random variables provided as a dictionary (or sequence of dictionaries) mapping the random variable (by name or reference) to desired starting values. @@ -529,7 +529,7 @@ def sample_jax_nuts( chain_method : str, default "parallel" Specify how samples should be drawn. The choices include "parallel", and "vectorized". - postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None, + postprocessing_backend : Literal["cpu", "gpu"] | None, default None, Specify how postprocessing should be computed. gpu or cpu postprocessing_vectorize : Literal["vmap", "scan"], default "scan" How to vectorize the postprocessing: vmap or sequential scan From 8f9ca1c7a0a114b65a62e16944d7f376dcab3ca5 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 10 Oct 2024 07:45:30 +0200 Subject: [PATCH 2/7] Improve typing in initial_point.py --- pymc/initial_point.py | 80 ++++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 15f4f887c0..2c72946108 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -15,6 +15,7 @@ import warnings from collections.abc import Callable, Sequence +from typing import Literal import numpy as np import pytensor @@ -30,6 +31,8 @@ StartDict = dict[Variable | str, np.ndarray | Variable | str] PointType = dict[str, np.ndarray] +SeedSequenceSeed = None | int | Sequence[int] | np.ndarray | np.random.SeedSequence +SeededInitialPointFn = Callable[[SeedSequenceSeed], dict[str, np.ndarray]] def convert_str_to_rv_dict( @@ -61,7 +64,7 @@ def make_initial_point_fns_per_chain( overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, chains: int, -) -> list[Callable]: +) -> list[SeededInitialPointFn]: """Create an initial point function for each chain, as defined by initvals. If a single initval dictionary is passed, the function is replicated for each @@ -76,12 +79,23 @@ def make_initial_point_fns_per_chain( Random variable tensors for which U(-1, 1) jitter shall be applied. (To the transformed space if applicable.) + Returns + ------- + initial_point_fns : list[SeededInitialPointFn] + A list, one element per chain. Each element is a function that takes + a seed and returns a dictionary of variable names to initial points + (numpy arrays). + Raises ------ ValueError If the number of entries in initvals is different than the number of chains """ + if isinstance(overrides, Sequence) and len(overrides) != chains: + msg = f"Number of initval dicts ({len(overrides)}) must match the number of chains ({chains})." + raise ValueError(msg) + if isinstance(overrides, dict) or overrides is None: # One strategy for all chains # Only one function compilation is needed. @@ -93,21 +107,18 @@ def make_initial_point_fns_per_chain( return_transformed=True, ) ] * chains - elif len(overrides) == chains: - ipfns = [ - make_initial_point_fn( - model=model, - jitter_rvs=jitter_rvs, - overrides=chain_overrides, - return_transformed=True, - ) - for chain_overrides in overrides - ] - else: - raise ValueError( - f"Number of initval dicts ({len(overrides)}) does not match the number of chains ({chains})." + return ipfns + + assert isinstance(overrides, Sequence) and len(overrides) == chains + ipfns = [ + make_initial_point_fn( + model=model, + jitter_rvs=jitter_rvs, + overrides=chain_overrides, + return_transformed=True, ) - + for chain_overrides in overrides + ] return ipfns @@ -116,22 +127,27 @@ def make_initial_point_fn( model, overrides: StartDict | None = None, jitter_rvs: set[TensorVariable] | None = None, - default_strategy: str = "support_point", + default_strategy: Literal["support_point", "prior"] = "support_point", return_transformed: bool = True, -) -> Callable: +) -> SeededInitialPointFn: """Create seeded function that computes initial values for all free model variables. Parameters ---------- - jitter_rvs : set + overrides : StartDict or None (default: None) + Initial value (strategies) to use instead of what's specified in `Model.initial_values`. + jitter_rvs : set or None (default: None) The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. - default_strategy : str + default_strategy : either "support_point" or "prior" (default: "support_point") Which of { "support_point", "prior" } to prefer if the initval setting for an RV is None. - overrides : dict - Initial value (strategies) to use instead of what's specified in `Model.initial_values`. - return_transformed : bool + return_transformed : bool (default: True) If `True` the returned variables will correspond to transformed initial values. + + Returns + ------- + initial_point_fn : SeededInitialPointFn + A function that takes a seed and returns a dictionary of variable names to initial points (numpy arrays). """ sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) initval_strats = { @@ -162,11 +178,13 @@ def make_initial_point_fn( name = var.name varnames.append(name) - def make_seeded_function(func): + def make_seeded_function( + func: pytensor.compile.Function, + ) -> SeededInitialPointFn: rngs = find_rng_nodes(func.maker.fgraph.outputs) @functools.wraps(func) - def inner(seed, *args, **kwargs): + def inner(seed: SeedSequenceSeed, *args, **kwargs) -> dict[str, np.ndarray]: reseed_rngs(rngs, seed) values = func(*args, **kwargs) return dict(zip(varnames, values)) @@ -182,26 +200,26 @@ def make_initial_point_expression( rvs_to_transforms: dict[TensorVariable, Transform], initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None], jitter_rvs: set[TensorVariable] | None = None, - default_strategy: str = "support_point", + default_strategy: Literal["support_point", "prior"] = "support_point", return_transformed: bool = False, ) -> list[TensorVariable]: """Create the tensor variables that need to be evaluated to obtain an initial point. Parameters ---------- - free_rvs : list + free_rvs : list of `TensorVariable`s Tensors of free random variables in the model. - rvs_to_values : dict + rvs_to_transforms : dict[TensorVariable, Transform] Mapping of free random variable tensors to value variable tensors. - initval_strategies : dict + initval_strategies : dict[TensorVariable, np.ndarray | Variable | str | None] Mapping of free random variable tensors to initial value strategies. For example the `Model.initial_values` dictionary. - jitter_rvs : set + jitter_rvs : set[TensorVariable] | None (default: None) The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. - default_strategy : str + default_strategy : either "support_point" or "prior" (default: "support_point") Which of { "support_point", "prior" } to prefer if the initval strategy setting for an RV is None. - return_transformed : bool + return_transformed : bool (default: False) Switches between returning the tensors for untransformed or transformed initial points. Returns From bfffd66978cfcbde0d8a9a25ca77b683f65f66dd Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 10 Oct 2024 08:10:32 +0200 Subject: [PATCH 3/7] Fix type of tuples with only one element --- pymc/distributions/distribution.py | 6 ++++-- pymc/logprob/mixture.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 178eeeb094..80c0d85f12 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -112,7 +112,7 @@ def __new__(cls, name, bases, clsdict): # Create dispatch functions size_idx: int | None = None - params_idxs: tuple[int] | None = None + params_idxs: tuple[int, ...] | None = None if issubclass(rv_type, SymbolicRandomVariable): extended_signature = getattr(rv_type, "extended_signature", None) if extended_signature is not None: @@ -308,7 +308,9 @@ def default_output(cls_or_self) -> int | None: @staticmethod def get_input_output_type_idxs( extended_signature: str | None, - ) -> tuple[tuple[tuple[int], int | None, tuple[int]], tuple[tuple[int], tuple[int]]]: + ) -> tuple[ + tuple[tuple[int, ...], int | None, tuple[int, ...]], tuple[tuple[int, ...], tuple[int, ...]] + ]: """Parse extended_signature and return indexes for *[rng], [size] and parameters as well as outputs.""" if extended_signature is None: raise ValueError("extended_signature must be provided") diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 55e506ad99..8595f7aa6f 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -92,8 +92,8 @@ def is_newaxis(x): def expand_indices( - indices: tuple[Variable | slice | None, ...], shape: tuple[TensorVariable] -) -> tuple[TensorVariable]: + indices: tuple[Variable | slice | None, ...], shape: tuple[TensorVariable, ...] +) -> tuple[TensorVariable, ...]: """Convert basic and/or advanced indices into a single, broadcasted advanced indexing operation. Parameters @@ -206,7 +206,7 @@ def expand_indices( adv_indices.append(expanded_idx) - return cast(tuple[TensorVariable], tuple(pt.broadcast_arrays(*adv_indices))) + return tuple(pt.broadcast_arrays(*adv_indices)) def rv_pull_down(x: TensorVariable) -> TensorVariable: From 5145b0c68c97fb99a2b2dbc77995c5ee7776ece2 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 10 Oct 2024 08:30:25 +0200 Subject: [PATCH 4/7] Fix mypy errors in run_mypy.py --- scripts/run_mypy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 842fb0a132..1f8d9cff7e 100755 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -18,7 +18,7 @@ import subprocess import sys -from collections.abc import Iterator +from collections.abc import Iterable import pandas @@ -60,13 +60,13 @@ def enforce_pep561(module_name): return -def mypy_to_pandas(input_lines: Iterator[str]) -> pandas.DataFrame: +def mypy_to_pandas(input_lines: Iterable[str]) -> pandas.DataFrame: """Reformats mypy output with error codes to a DataFrame. Adapted from: https://gist.github.com/michaelosthege/24d0703e5f37850c9e5679f69598930a """ - current_section = None - data = { + current_section = "" + data: dict[str, list[str]] = { "file": [], "line": [], "type": [], @@ -97,7 +97,7 @@ def mypy_to_pandas(input_lines: Iterator[str]) -> pandas.DataFrame: return pandas.DataFrame(data=data).set_index(["file", "line"]) -def check_no_unexpected_results(mypy_lines: Iterator[str]): +def check_no_unexpected_results(mypy_lines: Iterable[str]): """Compare mypy results with list of known FAILING files. Exits the process with non-zero exit code upon unexpected results. From b19744d363b745a9f26f2082cd1c4c4e80fa78ff Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 10 Oct 2024 08:36:18 +0200 Subject: [PATCH 5/7] Upgrade mypy version --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-jax.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- pymc/backends/__init__.py | 2 +- requirements-dev.txt | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 85e6694a95..d0e634c2ee 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -37,7 +37,7 @@ dependencies: - watermark - polyagamma - sphinx-remove-toctrees -- mypy=1.5.1 +- mypy=1.11.2 - types-cachetools - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 97d25dd5b8..639f641eeb 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -33,7 +33,7 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=1.5.1 +- mypy=1.11.2 - types-cachetools - pip: - numdifftools>=0.9.40 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 58cde0d327..c319bab0d9 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -28,7 +28,7 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=1.5.1 +- mypy=1.11.2 - types-cachetools - pip: - numdifftools>=0.9.40 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 6d785e2cac..bdc01369cd 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -34,7 +34,7 @@ dependencies: - sphinx>=1.5 - watermark - sphinx-remove-toctrees -- mypy=1.5.1 +- mypy=1.11.2 - types-cachetools - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index fd17c31711..85d86214bf 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -28,7 +28,7 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=1.5.1 +- mypy=1.11.2 - types-cachetools - pip: - numdifftools>=0.9.40 diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index aea1c57b8c..29a76ec27c 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -85,7 +85,7 @@ RunType: TypeAlias = Run HAS_MCB = True except ImportError: - TraceOrBackend = BaseTrace # type: ignore[misc] + TraceOrBackend = BaseTrace # type: ignore[assignment, misc] RunType = type(None) # type: ignore[assignment, misc] diff --git a/requirements-dev.txt b/requirements-dev.txt index 082eab73ce..26d8ec8fee 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,7 +9,7 @@ h5py>=2.7 ipython>=7.16 jupyter-sphinx mcbackend>=0.4.0 -mypy==1.5.1 +mypy==1.11.2 myst-nb<=1.0.0 numdifftools>=0.9.40 numpy>=1.15.0 From 166c35cbda8b1bdc9c2e99450e157b36eb60fa5b Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 10 Oct 2024 09:25:52 +0200 Subject: [PATCH 6/7] Remove biwrap from util.py --- pymc/util.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/pymc/util.py b/pymc/util.py index 8ec8aa84de..6d4b1bf4de 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import warnings from collections.abc import Sequence @@ -248,24 +247,6 @@ def get_transformed(z): return z -def biwrap(wrapper): - @functools.wraps(wrapper) - def enhanced(*args, **kwargs): - is_bound_method = hasattr(args[0], wrapper.__name__) if args else False - if is_bound_method: - count = 1 - else: - count = 0 - if len(args) > count: - newfn = wrapper(*args, **kwargs) - return newfn - else: - newwrapper = functools.partial(wrapper, *args, **kwargs) - return newwrapper - - return enhanced - - def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData: """Return a new ``InferenceData`` object with the "warning" stat removed from sample stats groups. From b98c47ee66e1d9a45f93d6ecbf5459bc091bab6d Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Thu, 10 Oct 2024 09:40:01 +0200 Subject: [PATCH 7/7] Type some functions in util.py --- pymc/sampling/mcmc.py | 2 +- pymc/util.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4b26bb51c8..fec0d7c145 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -246,7 +246,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: else: varnames = ", ".join( [ - get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name + get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name # type: ignore[arg-type, misc] for v in s.vars ] ) diff --git a/pymc/util.py b/pymc/util.py index 6d4b1bf4de..42835915f6 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -14,9 +14,9 @@ import warnings -from collections.abc import Sequence +from collections.abc import Callable, Iterable, Sequence from copy import deepcopy -from typing import NewType, cast +from typing import Any, NewType, TypeVar, cast import arviz import cloudpickle @@ -32,8 +32,10 @@ from pymc.exceptions import BlockModelAccessError +T = TypeVar("T") -def __getattr__(name): + +def __getattr__(name: str) -> Callable: if name == "dataset_to_point_list": warnings.warn( f"{name} has been moved to backends.arviz. Importing from util will fail in a future release.", @@ -160,8 +162,8 @@ def tree_contains(self, item): return dict.__contains__(self, item) -def get_transformed_name(name, transform): - r""" +def get_transformed_name(name: str, transform) -> str: + """ Consistent way of transforming names. Parameters @@ -179,8 +181,8 @@ def get_transformed_name(name, transform): return f"{name}_{transform.name}__" -def is_transformed_name(name): - r""" +def is_transformed_name(name: str) -> bool: + """ Quickly check if a name was transformed with `get_transformed_name`. Parameters @@ -196,8 +198,8 @@ def is_transformed_name(name): return name.endswith("__") and name.count("_") >= 3 -def get_untransformed_name(name): - r""" +def get_untransformed_name(name: str) -> str: + """ Undo transformation in `get_transformed_name`. Throws ValueError if name wasn't transformed. Parameters @@ -215,8 +217,13 @@ def get_untransformed_name(name): return "_".join(name.split("_")[:-3]) -def get_default_varnames(var_iterator, include_transformed): - r"""Extract default varnames from a trace. +VarOrVarName = TypeVar("VarOrVarName", Variable, str) + + +def get_default_varnames( + var_iterator: Iterable[VarOrVarName], include_transformed: bool +) -> list[VarOrVarName]: + """Extract default varnames from a trace. Parameters ---------- @@ -236,7 +243,7 @@ def get_default_varnames(var_iterator, include_transformed): return [var for var in var_iterator if not is_transformed_name(get_var_name(var))] -def get_var_name(var) -> VarName: +def get_var_name(var: VarOrVarName) -> VarName: """Get an appropriate, plain variable name for a variable.""" return VarName(str(getattr(var, "name", var))) @@ -280,7 +287,7 @@ def chains_and_samples(data: xarray.Dataset | arviz.InferenceData) -> tuple[int, return nchains, nsamples -def hashable(a=None) -> int: +def hashable(a: Any = None) -> int: """ Hash many kinds of objects, including some that are unhashable through the builtin `hash` function. @@ -516,8 +523,8 @@ def _add_future_warning_tag(var) -> None: var.tag = new_tag -def makeiter(a): - if isinstance(a, tuple | list): +def makeiter(a: Sequence[T] | T) -> Sequence[T]: + if isinstance(a, Sequence): return a else: return [a]