diff --git a/README.md b/README.md index c273c06..621288a 100644 --- a/README.md +++ b/README.md @@ -8,34 +8,42 @@ # S2SCAT: Scattering covariance transform on the sphere - - -`S2SCAT` is a Python package for computing third generation scattering covariances on the sphere [(Mousset et al 2024)](https://arxiv.org/abs/xxxx.xxxxx) using JAX. It leverages autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs). Scattering covariances are useful both for field-level emulation of complex non-Gaussian textures and for statistical compression of high dimensional field-level data, a key step of e.g. simulation based inference [(Cranmer et al 2020)](https://www.pnas.org/doi/abs/10.1073/pnas.1912789117). +`S2SCAT` is a Python package for computing scattering covariances on the sphere ([Mousset et al. 2024](https://arxiv.org/abs/xxxx.xxxxx)) using JAX. It exploits autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs), leveraging the differentiable and accelerated spherical harmonic and wavelet transforms implemented in [s2fft](https://github.com/astro-informatics/s2fft) and [s2wav](https://github.com/astro-informatics/s2wav), respectively. > [!IMPORTANT] > It is worth highlighting that the input to `S2SCAT` are spherical harmonic coefficients, which can be generated with whichever software package you prefer, e.g. [`S2FFT`](https://github.com/astro-informatics/s2fft) or [`healpy`](https://healpy.readthedocs.io/en/latest/). Just ensure your harmonic coefficients are indexed using our convention; helper functions for this reindexing can be found in [`S2FFT`](https://github.com/astro-informatics/s2fft). > [!TIP] -> At launch `S2SCAT` provides two core transform modes: recursive, which performs underlying spherical harmonic and Wigner transforms through the [Price & McEwen](https://arxiv.org/abs/2311.14670) recursion; and precompute, which a priori computes and caches all Wigner elements required. The precompute approach will be faster but can only be run up to $L \sim 512$, whereas the recursive approach will run up to $L \sim 2048$, depending on GPU hardware. +> At launch `S2SCAT` provides two core transform modes: on-the-fly, which performs underlying spherical harmonic and Wigner transforms through the [Price & McEwen](https://arxiv.org/abs/2311.14670) recursion; and precompute, which a priori computes and caches all Wigner elements required. The precompute approach will be faster but can only be run up to $L \sim 512$, whereas the on-the-fly approach will run up to $L \sim 2048$ and potentially beyond, depending on GPU hardware. + +Ballpark compute times (when running on an 40GB A100 GPU) and compression levels are given in the table below. -| Ballpark Numbers [A100 40GB] | Max resolution | Forward pass | Gradient pass | JIT compilation | Input params | Anisotropic (compression) | Isotropic (compression) | +| Method | Resolution | Forward pass | Gradient pass | JIT compilation | Input params | Anisotropic (compression) | Isotropic (compression) | |:----------------------------:|:--------------:|:------------:|:-------------:|:---------------:|:------------:|:--------------------------:|:------------------------:| -| Recursive | L=512, N=3 | ~90ms | ~190ms | ~20s | 2,618,880 | ~ 63,000 (97.594%) | ~504 (99.981%) | -| Precompute | L=2048, N=3 | ~18s | ~40s | ~5m | 41,932,800 | ~ 123,750 (99.705%) | ~ 990 (99.998%) | +| Precompute | L=512, N=3 | ~90ms | ~190ms | ~20s | 2,618,880 | ~ 63,000 (97.594%) | ~504 (99.981%) | +| On-the-fly | L=2048, N=3 | ~18s | ~40s | ~5m | 41,932,800 | ~ 123,750 (99.705%) | ~ 990 (99.998%) | + +## Scattering covariances :dna: + +

+ +

-## Third Generation Scattering Statistics :dna: +We introduce scattering covariances on the sphere in [Mousset et al. (2024)](https://arxiv.org/abs/xxxx.xxxxx), which extend to spherical settings similar scattering transforms introduced for 1D signals by [Morel et al. (2023)](https://arxiv.org/abs/2204.10177) and for planar 2D signals by [Cheng et al. (2023)](https://arxiv.org/abs/2306.17210). - +Scattering covariances $S$ are computed by -Scattering covariances, or scattering spectra, were previously introduced for 1D signals by [Morel et al (2023)](https://arxiv.org/abs/2204.10177) and for planar 2D signals by [Cheng et al (2023)](https://arxiv.org/abs/2306.17210). The scattering transform is defined by repeated application of directional wavelet transforms followed by a machine learning inspired non-linearity, typically the modulus operator. The wavelet transform $W^{\lambda}$ within each layer has an associated scale $j$ and direction $n$, which we group into a single label $\lambda$. Scattering covariances $S$ are computed from the coefficients of a two-layer scattering transform and are defined as +$$S_1^{\lambda_1} = \langle |W^{\lambda_1} I| \rangle,$$ -$$S_1^{\lambda_1} = \langle |W^{\lambda_1} I| \rangle \quad S_2^{\lambda_1} = \langle|W^{\lambda_1} I|^2 \rangle$$ +$$S_2^{\lambda_1} = \langle|W^{\lambda_1} I|^2 \rangle,$$ -$$S_3^{\lambda_1, \lambda_2} = \text{Cov} \left[ W^{\lambda_1}I, W^{\lambda_1}|W^{\lambda_2} I| \right]$$ +$$S_3^{\lambda_1, \lambda_2} = \text{Cov} \left[ W^{\lambda_1}I, W^{\lambda_1}|W^{\lambda_2} I| \right],$$ -$$S_4^{\lambda_1, \lambda_2, \lambda_3} = \text{Cov} \left[W^{\lambda_1}|W^{\lambda_3}I|, W^{\lambda_1}|W^{\lambda_2}I|\right].$$ +$$S_4^{\lambda_1, \lambda_2, \lambda_3} = \text{Cov} \left[W^{\lambda_1}|W^{\lambda_3}I|, W^{\lambda_1}|W^{\lambda_2}I|\right]$$ -Given that the highest order coefficients are computed from products between $\lambda_1, \lambda_2$ and $\lambda_3$ they encode $6^{\text{th}}$-order statistical information. This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions; which can adequetly capture complex non-Gaussian structural information, e.g. filamentary structure. Using recently release JAX spherical harmonic [(Price & McEwen 2023)](https://arxiv.org/abs/2311.14670) and wavelet transforms [(Price et al 2024)](https://arxiv.org/abs/2402.01282) this work extends scattering covariances to the sphere, which is necessary for their application to e.g. wide-field cosmological surveys [(Mousset et al 2024)](https://arxiv.org/abs/xxxx.xxxxx). +where $W^{\lambda} I$ denotes the wavelet transform of field $I$ at scale $j$ and direction $\gamma$, which we group into a single label $\lambda=(j,\gamma)$. + +This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions, which can effectively capture complex non-Gaussian structural information, e.g. filamentary structure. ## Package Directory Structure :art: @@ -92,19 +100,12 @@ To import and use `S2SCAT` is as simple follows: ``` python import s2scat -L = _ # Harmonic bandlimit -N = _ # Azimuthal bandlimit -flm = _ # Harmonic coefficients of the input signal -# Core GPU transforms +# Given harmonic bandlimit L, azimuthal bandlimit N and spherical harmonic coefficients flm + config = s2scat.configure(L, N) covariances = s2scat.scatter(flm, L, N, config=config) - -# C backend CPU transforms -config = s2scat.configure(L, N, c_backend=True) -covariances = s2scat.scatter_c(flm, L, N, config=config) ``` -`S2SCAT` also provides JAX support for existing C backend libraries which are memory efficient but CPU bound; at launch we support [`SSHT`](https://github.com/astro-informatics/ssht), however this could be extended straightforwardly. This works by wrapping python bindings with custom JAX frontends. For further details on usage see the [documentation](https://astro-informatics.github.io/s2scat/) and associated [notebooks](https://astro-informatics.github.io/s2scat/notebooks/). @@ -138,7 +139,7 @@ referenced. A BibTeX entry for this reference may look like: @article{mousset:s2scat, author = "Louise Mousset et al", title = "TBD", - journal = "Astronomy & Astrophysics, submitted", + journal = "TBD, submitted", year = "2024", eprint = "TBD" } diff --git a/notebooks/compression.ipynb b/notebooks/compression.ipynb index 7b6fb12..f491a1b 100644 --- a/notebooks/compression.ipynb +++ b/notebooks/compression.ipynb @@ -15,7 +15,7 @@ "\n", "### So what's the difference? \n", "\n", - "At highest order, anisotropic coefficients encode covariant information between 3 different wavelet scales $j$ and directions $n$. On the other hand, isotropic coefficients average over directions $n$, sampling the mean covariance structure across scales. Note that isotropic coefficients will still capture directional filamentary structure but will be somewhat less expressive.\n", + "At highest order, anisotropic coefficients encode covariant information between 3 different wavelet scales $j$ and directions $\\gamma$. On the other hand, isotropic coefficients average over directions $\\gamma$, sampling the mean covariance structure across scales. Note that isotropic coefficients will still capture directional filamentary structure but will be somewhat less expressive.\n", "\n", "### What's the latent dimensionality? \n", "\n", diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 23bed6a..6db1d3c 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -57,5 +57,4 @@ def func(flm): loss += jnp.mean(jnp.abs(coeffs[i])) return loss - rtol = 5e-3 if isotropic else 1e-3 - check_grads(func, (flm,), order=1, modes=("rev"), rtol=rtol) + check_grads(func, (flm,), order=1, modes=("rev"), rtol=5e-3)