Skip to content

Commit

Permalink
Merge pull request #10 from astro-informatics/revisions/readme
Browse files Browse the repository at this point in the history
Revisions/readme
  • Loading branch information
jasonmcewen authored May 8, 2024
2 parents 667e9df + 2a07fbe commit a8e0042
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
49 changes: 25 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,42 @@

# S2SCAT: Scattering covariance transform on the sphere

<img align="center" src="./docs/assets/synthesis_zoom.gif">

`S2SCAT` is a Python package for computing third generation scattering covariances on the sphere [(Mousset et al 2024)](https://arxiv.org/abs/xxxx.xxxxx) using JAX. It leverages autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs). Scattering covariances are useful both for field-level emulation of complex non-Gaussian textures and for statistical compression of high dimensional field-level data, a key step of e.g. simulation based inference [(Cranmer et al 2020)](https://www.pnas.org/doi/abs/10.1073/pnas.1912789117).
`S2SCAT` is a Python package for computing scattering covariances on the sphere ([Mousset et al. 2024](https://arxiv.org/abs/xxxx.xxxxx)) using JAX. It exploits autodiff to provide differentiable transforms, which are also deployable on hardware accelerators (e.g. GPUs and TPUs), leveraging the differentiable and accelerated spherical harmonic and wavelet transforms implemented in [s2fft](https://github.com/astro-informatics/s2fft) and [s2wav](https://github.com/astro-informatics/s2wav), respectively.

> [!IMPORTANT]
> It is worth highlighting that the input to `S2SCAT` are spherical harmonic coefficients, which can be generated with whichever software package you prefer, e.g. [`S2FFT`](https://github.com/astro-informatics/s2fft) or [`healpy`](https://healpy.readthedocs.io/en/latest/). Just ensure your harmonic coefficients are indexed using our convention; helper functions for this reindexing can be found in [`S2FFT`](https://github.com/astro-informatics/s2fft).
> [!TIP]
> At launch `S2SCAT` provides two core transform modes: recursive, which performs underlying spherical harmonic and Wigner transforms through the [Price & McEwen](https://arxiv.org/abs/2311.14670) recursion; and precompute, which a priori computes and caches all Wigner elements required. The precompute approach will be faster but can only be run up to $L \sim 512$, whereas the recursive approach will run up to $L \sim 2048$, depending on GPU hardware.
> At launch `S2SCAT` provides two core transform modes: on-the-fly, which performs underlying spherical harmonic and Wigner transforms through the [Price & McEwen](https://arxiv.org/abs/2311.14670) recursion; and precompute, which a priori computes and caches all Wigner elements required. The precompute approach will be faster but can only be run up to $L \sim 512$, whereas the on-the-fly approach will run up to $L \sim 2048$ and potentially beyond, depending on GPU hardware.
Ballpark compute times (when running on an 40GB A100 GPU) and compression levels are given in the table below.

| Ballpark Numbers [A100 40GB] | Max resolution | Forward pass | Gradient pass | JIT compilation | Input params | Anisotropic (compression) | Isotropic (compression) |
| Method | Resolution | Forward pass | Gradient pass | JIT compilation | Input params | Anisotropic (compression) | Isotropic (compression) |
|:----------------------------:|:--------------:|:------------:|:-------------:|:---------------:|:------------:|:--------------------------:|:------------------------:|
| Recursive | L=512, N=3 | ~90ms | ~190ms | ~20s | 2,618,880 | ~ 63,000 (97.594%) | ~504 (99.981%) |
| Precompute | L=2048, N=3 | ~18s | ~40s | ~5m | 41,932,800 | ~ 123,750 (99.705%) | ~ 990 (99.998%) |
| Precompute | L=512, N=3 | ~90ms | ~190ms | ~20s | 2,618,880 | ~ 63,000 (97.594%) | ~504 (99.981%) |
| On-the-fly | L=2048, N=3 | ~18s | ~40s | ~5m | 41,932,800 | ~ 123,750 (99.705%) | ~ 990 (99.998%) |

## Scattering covariances :dna:

<p align="center">
<img width="300" height="300" src="./docs/assets/synthesis.gif">
</p>

## Third Generation Scattering Statistics :dna:
We introduce scattering covariances on the sphere in [Mousset et al. (2024)](https://arxiv.org/abs/xxxx.xxxxx), which extend to spherical settings similar scattering transforms introduced for 1D signals by [Morel et al. (2023)](https://arxiv.org/abs/2204.10177) and for planar 2D signals by [Cheng et al. (2023)](https://arxiv.org/abs/2306.17210).

<img align="right" width="300" height="300" src="./docs/assets/synthesis.gif">
Scattering covariances $S$ are computed by

Scattering covariances, or scattering spectra, were previously introduced for 1D signals by [Morel et al (2023)](https://arxiv.org/abs/2204.10177) and for planar 2D signals by [Cheng et al (2023)](https://arxiv.org/abs/2306.17210). The scattering transform is defined by repeated application of directional wavelet transforms followed by a machine learning inspired non-linearity, typically the modulus operator. The wavelet transform $W^{\lambda}$ within each layer has an associated scale $j$ and direction $n$, which we group into a single label $\lambda$. Scattering covariances $S$ are computed from the coefficients of a two-layer scattering transform and are defined as
$$S_1^{\lambda_1} = \langle |W^{\lambda_1} I| \rangle,$$

$$S_1^{\lambda_1} = \langle |W^{\lambda_1} I| \rangle \quad S_2^{\lambda_1} = \langle|W^{\lambda_1} I|^2 \rangle$$
$$S_2^{\lambda_1} = \langle|W^{\lambda_1} I|^2 \rangle,$$

$$S_3^{\lambda_1, \lambda_2} = \text{Cov} \left[ W^{\lambda_1}I, W^{\lambda_1}|W^{\lambda_2} I| \right]$$
$$S_3^{\lambda_1, \lambda_2} = \text{Cov} \left[ W^{\lambda_1}I, W^{\lambda_1}|W^{\lambda_2} I| \right],$$

$$S_4^{\lambda_1, \lambda_2, \lambda_3} = \text{Cov} \left[W^{\lambda_1}|W^{\lambda_3}I|, W^{\lambda_1}|W^{\lambda_2}I|\right].$$
$$S_4^{\lambda_1, \lambda_2, \lambda_3} = \text{Cov} \left[W^{\lambda_1}|W^{\lambda_3}I|, W^{\lambda_1}|W^{\lambda_2}I|\right]$$

Given that the highest order coefficients are computed from products between $\lambda_1, \lambda_2$ and $\lambda_3$ they encode $6^{\text{th}}$-order statistical information. This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions; which can adequetly capture complex non-Gaussian structural information, e.g. filamentary structure. Using recently release JAX spherical harmonic [(Price & McEwen 2023)](https://arxiv.org/abs/2311.14670) and wavelet transforms [(Price et al 2024)](https://arxiv.org/abs/2402.01282) this work extends scattering covariances to the sphere, which is necessary for their application to e.g. wide-field cosmological surveys [(Mousset et al 2024)](https://arxiv.org/abs/xxxx.xxxxx).
where $W^{\lambda} I$ denotes the wavelet transform of field $I$ at scale $j$ and direction $\gamma$, which we group into a single label $\lambda=(j,\gamma)$.

This statistical representation characterises the power and sparsity at given scales, as well as covariant features between different wavelet scale and directions, which can effectively capture complex non-Gaussian structural information, e.g. filamentary structure.

## Package Directory Structure :art:

Expand Down Expand Up @@ -92,19 +100,12 @@ To import and use `S2SCAT` is as simple follows:

``` python
import s2scat
L = _ # Harmonic bandlimit
N = _ # Azimuthal bandlimit
flm = _ # Harmonic coefficients of the input signal

# Core GPU transforms
# Given harmonic bandlimit L, azimuthal bandlimit N and spherical harmonic coefficients flm

config = s2scat.configure(L, N)
covariances = s2scat.scatter(flm, L, N, config=config)

# C backend CPU transforms
config = s2scat.configure(L, N, c_backend=True)
covariances = s2scat.scatter_c(flm, L, N, config=config)
```
`S2SCAT` also provides JAX support for existing C backend libraries which are memory efficient but CPU bound; at launch we support [`SSHT`](https://github.com/astro-informatics/ssht), however this could be extended straightforwardly. This works by wrapping python bindings with custom JAX frontends.

For further details on usage see the [documentation](https://astro-informatics.github.io/s2scat/) and associated [notebooks](https://astro-informatics.github.io/s2scat/notebooks/).

Expand Down Expand Up @@ -138,7 +139,7 @@ referenced. A BibTeX entry for this reference may look like:
@article{mousset:s2scat,
author = "Louise Mousset et al",
title = "TBD",
journal = "Astronomy & Astrophysics, submitted",
journal = "TBD, submitted",
year = "2024",
eprint = "TBD"
}
Expand Down
2 changes: 1 addition & 1 deletion notebooks/compression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"\n",
"### So what's the difference? \n",
"\n",
"At highest order, anisotropic coefficients encode covariant information between 3 different wavelet scales $j$ and directions $n$. On the other hand, isotropic coefficients average over directions $n$, sampling the mean covariance structure across scales. Note that isotropic coefficients will still capture directional filamentary structure but will be somewhat less expressive.\n",
"At highest order, anisotropic coefficients encode covariant information between 3 different wavelet scales $j$ and directions $\\gamma$. On the other hand, isotropic coefficients average over directions $\\gamma$, sampling the mean covariance structure across scales. Note that isotropic coefficients will still capture directional filamentary structure but will be somewhat less expressive.\n",
"\n",
"### What's the latent dimensionality? \n",
"\n",
Expand Down
3 changes: 1 addition & 2 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,4 @@ def func(flm):
loss += jnp.mean(jnp.abs(coeffs[i]))
return loss

rtol = 5e-3 if isotropic else 1e-3
check_grads(func, (flm,), order=1, modes=("rev"), rtol=rtol)
check_grads(func, (flm,), order=1, modes=("rev"), rtol=5e-3)

0 comments on commit a8e0042

Please sign in to comment.