diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 3c1dcfed1..db4f9dd5b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,15 +4,14 @@ $ py.test -v -s africanus ``` - If the pep8 tests fail, the quickest way to correct - this is to run `autopep8` and then `flake8` and - `pycodestyle` to fix the remaining issues. + If the pre-commit tests fail, install and + run the pre-commit hooks in your development + virtuale environment: ``` - $ pip install -U autopep8 flake8 pycodestyle - $ autopep8 -r -i africanus - $ flake8 africanus - $ pycodestyle africanus + $ pip install pre-commit + $ pre-commit install + $ pre-commit run -a ``` - [ ] Fully documented, including `HISTORY.rst` for all changes diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 000000000..00b45094a --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,15 @@ +name: pre-commit + +on: [push] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 1 + - uses: actions/setup-python@v5.0.0 + with: + python-version: 3.11 + - uses: pre-commit/action@v3.0.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..28fb60bf1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.3 + hooks: + - id: ruff-format + name: ruff format diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 0448ee6bc..48c685d3f 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -66,34 +66,35 @@ Ready to contribute? Here's how to set up `codex-africanus` for local developmen 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: - $ mkvirtualenv codex-africanus - $ cd codex-africanus/ + $ python -m venv ./africanus + $ source ./africanus/bin/activate $ pip install -e . -4. Create a branch for local development:: +4. Install the pre-commit hooks + + $ pip install pre-commit + $ pre-commit install + +5. Create a branch for local development:: $ git checkout -b name-of-your-bugfix-or-feature Now you can make your changes locally. -5. When you're done making changes, check that your changes +6. When you're done making changes, check that your changes pass the test cases, fixup your PEP8 compliance, and check for any code style issues: - $ py.test -v africanus - $ autopep8 -r -i africanus - $ flake8 africanus - $ pycodestyle africanus - - To get autopep8 and pycodestyle, just pip install them into your virtualenv. + $ py.test -vvv africanus + $ pre-commit run -a -6. Commit your changes and push your branch to GitHub:: +7. Commit your changes and push your branch to GitHub:: $ git add . $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature -7. Submit a pull request through the GitHub website. +8. Submit a pull request through the GitHub website. Pull Request Guidelines ----------------------- @@ -104,9 +105,7 @@ Before you submit a pull request, check that it meets these guidelines: 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in HISTORY.rst. -3. The pull request should work for Python 2.7, 3.5 and 3.6. Check - https://travis-ci.org/ska-sa/codex-africanus/pull_requests - and make sure that the tests pass for all supported Python versions. +3. The pull request should work for Python 3.9 and above. Tips ---- diff --git a/README.rst b/README.rst index 86d6199aa..28d0076fe 100644 --- a/README.rst +++ b/README.rst @@ -27,4 +27,3 @@ Documentation ------------- https://codex-africanus.readthedocs.io. - diff --git a/africanus/averaging/bda_avg.py b/africanus/averaging/bda_avg.py index 12574a161..1589c2042 100644 --- a/africanus/averaging/bda_avg.py +++ b/africanus/averaging/bda_avg.py @@ -4,43 +4,81 @@ import numpy as np -from africanus.averaging.bda_mapping import (bda_mapper, - RowMapOutput) -from africanus.averaging.shared import (chan_corrs, - merge_flags, - vis_output_arrays) +from africanus.averaging.bda_mapping import bda_mapper, RowMapOutput +from africanus.averaging.shared import chan_corrs, merge_flags, vis_output_arrays from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (njit, - overload, - JIT_OPTIONS, - intrinsic, - is_numba_type_none) - - -_row_output_fields = ["antenna1", "antenna2", "time_centroid", "exposure", - "uvw", "weight", "sigma"] +from africanus.util.numba import ( + njit, + overload, + JIT_OPTIONS, + intrinsic, + is_numba_type_none, +) + + +_row_output_fields = [ + "antenna1", + "antenna2", + "time_centroid", + "exposure", + "uvw", + "weight", + "sigma", +] RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields) @njit(**JIT_OPTIONS) -def row_average(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): - return row_average_impl(meta, ant1, ant2, flag_row=flag_row, - time_centroid=time_centroid, exposure=exposure, - uvw=uvw, weight=weight, sigma=sigma) - - -def row_average_impl(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): +def row_average( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, +): + return row_average_impl( + meta, + ant1, + ant2, + flag_row=flag_row, + time_centroid=time_centroid, + exposure=exposure, + uvw=uvw, + weight=weight, + sigma=sigma, + ) + + +def row_average_impl( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, +): return NotImplementedError @overload(row_average_impl, jit_options=JIT_OPTIONS) -def nb_row_average_impl(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): +def nb_row_average_impl( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, +): have_flag_row = not is_numba_type_none(flag_row) have_time_centroid = not is_numba_type_none(time_centroid) have_exposure = not is_numba_type_none(exposure) @@ -48,10 +86,17 @@ def nb_row_average_impl(meta, ant1, ant2, flag_row=None, have_weight = not is_numba_type_none(weight) have_sigma = not is_numba_type_none(sigma) - def impl(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): - + def impl( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, + ): out_rows = meta.time.shape[0] counts = np.zeros(out_rows, dtype=np.uint32) @@ -61,34 +106,42 @@ def impl(meta, ant1, ant2, flag_row=None, # Possibly present outputs for possibly present inputs uvw_avg = ( - None if not have_uvw else - np.zeros((out_rows,) + uvw.shape[1:], - dtype=uvw.dtype)) + None + if not have_uvw + else np.zeros((out_rows,) + uvw.shape[1:], dtype=uvw.dtype) + ) time_centroid_avg = ( - None if not have_time_centroid else - np.zeros((out_rows,) + time_centroid.shape[1:], - dtype=time_centroid.dtype)) + None + if not have_time_centroid + else np.zeros( + (out_rows,) + time_centroid.shape[1:], dtype=time_centroid.dtype + ) + ) exposure_avg = ( - None if not have_exposure else - np.zeros((out_rows,) + exposure.shape[1:], - dtype=exposure.dtype)) + None + if not have_exposure + else np.zeros((out_rows,) + exposure.shape[1:], dtype=exposure.dtype) + ) weight_avg = ( - None if not have_weight else - np.zeros((out_rows,) + weight.shape[1:], - dtype=weight.dtype)) + None + if not have_weight + else np.zeros((out_rows,) + weight.shape[1:], dtype=weight.dtype) + ) sigma_avg = ( - None if not have_sigma else - np.zeros((out_rows,) + sigma.shape[1:], - dtype=sigma.dtype)) + None + if not have_sigma + else np.zeros((out_rows,) + sigma.shape[1:], dtype=sigma.dtype) + ) sigma_weight_sum = ( - None if not have_sigma else - np.zeros((out_rows,) + sigma.shape[1:], - dtype=sigma.dtype)) + None + if not have_sigma + else np.zeros((out_rows,) + sigma.shape[1:], dtype=sigma.dtype) + ) # Average each array, if present # The output is a flattened row-channel array @@ -131,7 +184,7 @@ def impl(meta, ant1, ant2, flag_row=None, wt = weight[ri, co] if have_weight else 1.0 # Aggregate - sigma_avg[ro, co] += sigma[ri, co]**2 * wt**2 + sigma_avg[ro, co] += sigma[ri, co] ** 2 * wt**2 sigma_weight_sum[ro, co] += wt # Normalise and copy @@ -191,20 +244,21 @@ def impl(meta, ant1, ant2, flag_row=None, for co in range(sigma.shape[1]): sigma_avg[bro + c, co] = sigma_avg[bro, co] - return RowAverageOutput(ant1_avg, ant2_avg, - time_centroid_avg, - exposure_avg, uvw_avg, - weight_avg, sigma_avg) + return RowAverageOutput( + ant1_avg, + ant2_avg, + time_centroid_avg, + exposure_avg, + uvw_avg, + weight_avg, + sigma_avg, + ) return impl -_rowchan_output_fields = ["visibilities", - "flag", - "weight_spectrum", - "sigma_spectrum"] -RowChanAverageOutput = namedtuple("RowChanAverageOutput", - _rowchan_output_fields) +_rowchan_output_fields = ["visibilities", "flag", "weight_spectrum", "sigma_spectrum"] +RowChanAverageOutput = namedtuple("RowChanAverageOutput", _rowchan_output_fields) class RowChannelAverageException(Exception): @@ -212,9 +266,9 @@ class RowChannelAverageException(Exception): @intrinsic -def average_visibilities(typingctx, vis, vis_avg, vis_weight_sum, - weight, ri, fi, ro, co): - +def average_visibilities( + typingctx, vis, vis_avg, vis_weight_sum, weight, ri, fi, ro, co +): import numba.core.types as nbtypes have_array = isinstance(vis, nbtypes.Array) @@ -226,8 +280,7 @@ def avg_fn(vis, vis_avg, vis_ws, wt, ri, fi, ro, co): return_type = nbtypes.NoneType("none") - sig = return_type(vis, vis_avg, vis_weight_sum, - weight, ri, fi, ro, co) + sig = return_type(vis, vis_avg, vis_weight_sum, weight, ri, fi, ro, co) def codegen(context, builder, signature, args): vis, vis_type = args[0], signature.args[0] @@ -241,34 +294,45 @@ def codegen(context, builder, signature, args): return_type = signature.return_type if have_array: - avg_sig = return_type(vis_type, - vis_avg_type, - vis_weight_sum_type, - weight_type, - ri_type, fi_type, - ro_type, co_type) - avg_args = [vis, vis_avg, vis_weight_sum, - weight, ri, fi, ro, co] + avg_sig = return_type( + vis_type, + vis_avg_type, + vis_weight_sum_type, + weight_type, + ri_type, + fi_type, + ro_type, + co_type, + ) + avg_args = [vis, vis_avg, vis_weight_sum, weight, ri, fi, ro, co] # Compile function and get handle to output - context.compile_internal(builder, avg_fn, - avg_sig, avg_args) + context.compile_internal(builder, avg_fn, avg_sig, avg_args) elif have_tuple: for i in range(len(vis_type)): - avg_sig = return_type(vis_type.types[i], - vis_avg_type.types[i], - vis_weight_sum_type.types[i], - weight_type, - ri_type, fi_type, - ro_type, co_type) - avg_args = [builder.extract_value(vis, i), - builder.extract_value(vis_avg, i), - builder.extract_value(vis_weight_sum, i), - weight, ri, fi, ro, co] + avg_sig = return_type( + vis_type.types[i], + vis_avg_type.types[i], + vis_weight_sum_type.types[i], + weight_type, + ri_type, + fi_type, + ro_type, + co_type, + ) + avg_args = [ + builder.extract_value(vis, i), + builder.extract_value(vis_avg, i), + builder.extract_value(vis_weight_sum, i), + weight, + ri, + fi, + ro, + co, + ] # Compile function and get handle to output - context.compile_internal(builder, avg_fn, - avg_sig, avg_args) + context.compile_internal(builder, avg_fn, avg_sig, avg_args) else: raise TypeError("Unhandled visibility array type") @@ -300,26 +364,28 @@ def codegen(context, builder, signature, args): if have_array: # Normalise single array - norm_sig = return_type(vis_avg_type, - vis_weight_sum_type, - ro_type, co_type) + norm_sig = return_type(vis_avg_type, vis_weight_sum_type, ro_type, co_type) norm_args = [vis_avg, vis_weight_sum, ro, co] - context.compile_internal(builder, normalise_fn, - norm_sig, norm_args) + context.compile_internal(builder, normalise_fn, norm_sig, norm_args) elif have_tuple: # Normalise each array in the tuple for i in range(len(vis_avg_type)): - norm_sig = return_type(vis_avg_type.types[i], - vis_weight_sum_type.types[i], - ro_type, co_type) - norm_args = [builder.extract_value(vis_avg, i), - builder.extract_value(vis_weight_sum, i), - ro, co] + norm_sig = return_type( + vis_avg_type.types[i], + vis_weight_sum_type.types[i], + ro_type, + co_type, + ) + norm_args = [ + builder.extract_value(vis_avg, i), + builder.extract_value(vis_weight_sum, i), + ro, + co, + ] # Compile function and get handle to output - context.compile_internal(builder, normalise_fn, - norm_sig, norm_args) + context.compile_internal(builder, normalise_fn, norm_sig, norm_args) else: raise TypeError("Unhandled visibility array type") @@ -327,34 +393,48 @@ def codegen(context, builder, signature, args): @njit(**JIT_OPTIONS) -def row_chan_average(meta, flag_row=None, weight=None, - visibilities=None, - flag=None, - weight_spectrum=None, - sigma_spectrum=None): - - return row_chan_average_impl(meta, flag_row=flag_row, weight=weight, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum) - - -def row_chan_average_impl(meta, flag_row=None, weight=None, - visibilities=None, - flag=None, - weight_spectrum=None, - sigma_spectrum=None): - +def row_chan_average( + meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, +): + return row_chan_average_impl( + meta, + flag_row=flag_row, + weight=weight, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + ) + + +def row_chan_average_impl( + meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, +): return NotImplementedError @overload(row_chan_average_impl, jit_options=JIT_OPTIONS) -def nb_row_chan_average(meta, flag_row=None, weight=None, - visibilities=None, - flag=None, - weight_spectrum=None, - sigma_spectrum=None): - +def nb_row_chan_average( + meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, +): have_vis = not is_numba_type_none(visibilities) have_flag = not is_numba_type_none(flag) have_flag_row = not is_numba_type_none(flag_row) @@ -364,17 +444,19 @@ def nb_row_chan_average(meta, flag_row=None, weight=None, have_weight_spectrum = not is_numba_type_none(weight_spectrum) have_sigma_spectrum = not is_numba_type_none(sigma_spectrum) - def impl(meta, flag_row=None, weight=None, - visibilities=None, - flag=None, - weight_spectrum=None, - sigma_spectrum=None): - + def impl( + meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + ): out_rows = meta.time.shape[0] - nchan, ncorrs = chan_corrs(visibilities, flag, - weight_spectrum, sigma_spectrum, - None, None, - None, None) + nchan, ncorrs = chan_corrs( + visibilities, flag, weight_spectrum, sigma_spectrum, None, None, None, None + ) out_shape = (out_rows, ncorrs) @@ -401,8 +483,7 @@ def impl(meta, flag_row=None, weight=None, ro = meta.map[ri, fi] for co in range(ncorrs): - flagged = (row_flagged or - (have_flag and flag[ri, fi, co] != 0)) + flagged = row_flagged or (have_flag and flag[ri, fi, co] != 0) if have_flags and flagged: flag_counts[ro, co] += 1 @@ -454,8 +535,7 @@ def impl(meta, flag_row=None, weight=None, # unflagged samples never contribute to a # completely flagged bin if have_flags: - in_flag = (row_flagged or - (have_flag and flag[ri, fi, co] != 0)) + in_flag = row_flagged or (have_flag and flag[ri, fi, co] != 0) flags_match[ri, fi, co] = in_flag == out_flag # ------------- @@ -464,8 +544,7 @@ def impl(meta, flag_row=None, weight=None, if not have_vis: vis_avg = None else: - vis_avg, vis_weight_sum = vis_output_arrays( - visibilities, out_shape) + vis_avg, vis_weight_sum = vis_output_arrays(visibilities, out_shape) # Aggregate for ri in range(meta.map.shape[0]): @@ -476,13 +555,17 @@ def impl(meta, flag_row=None, weight=None, if have_flags and not flags_match[ri, fi, co]: continue - wt = (weight_spectrum[ri, fi, co] - if have_weight_spectrum else - weight[ri, co] if have_weight else 1.0) + wt = ( + weight_spectrum[ri, fi, co] + if have_weight_spectrum + else weight[ri, co] + if have_weight + else 1.0 + ) - average_visibilities(visibilities, - vis_avg, vis_weight_sum, - wt, ri, fi, ro, co) + average_visibilities( + visibilities, vis_avg, vis_weight_sum, wt, ri, fi, ro, co + ) # Normalise for ro in range(out_rows): @@ -506,8 +589,7 @@ def impl(meta, flag_row=None, weight=None, if have_flags and not flags_match[ri, fi, co]: continue - weight_spectrum_avg[ro, co] += ( - weight_spectrum[ri, fi, co]) + weight_spectrum_avg[ro, co] += weight_spectrum[ri, fi, co] # --------------- # Sigma Spectrum @@ -527,11 +609,15 @@ def impl(meta, flag_row=None, weight=None, if have_flags and not flags_match[ri, fi, co]: continue - wt = (weight_spectrum[ri, fi, co] - if have_weight_spectrum else - weight[ri, co] if have_weight else 1.0) + wt = ( + weight_spectrum[ri, fi, co] + if have_weight_spectrum + else weight[ri, co] + if have_weight + else 1.0 + ) - ssv = sigma_spectrum[ri, fi, co]**2 * wt**2 + ssv = sigma_spectrum[ri, fi, co] ** 2 * wt**2 sigma_spectrum_avg[ro, co] += ssv sigma_spectrum_weight_sum[ro, co] += wt @@ -543,9 +629,9 @@ def impl(meta, flag_row=None, weight=None, sswsum = sigma_spectrum_weight_sum[ro, co] sigma_spectrum_avg[ro, co] = np.sqrt(ssv / sswsum**2) - return RowChanAverageOutput(vis_avg, flag_avg, - weight_spectrum_avg, - sigma_spectrum_avg) + return RowChanAverageOutput( + vis_avg, flag_avg, weight_spectrum_avg, sigma_spectrum_avg + ) return impl @@ -554,114 +640,193 @@ def impl(meta, flag_row=None, weight=None, ChannelAverageOutput = namedtuple("ChannelAverageOutput", _chan_output_fields) -AverageOutput = namedtuple("AverageOutput", - list(RowMapOutput._fields) + - _row_output_fields + - # _chan_output_fields + - _rowchan_output_fields) +AverageOutput = namedtuple( + "AverageOutput", + list(RowMapOutput._fields) + + _row_output_fields + + + # _chan_output_fields + + _rowchan_output_fields, +) @njit(**JIT_OPTIONS) -def bda(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - max_uvw_dist=None, max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): - - return bda_impl(time, interval, antenna1, antenna2, - time_centroid=time_centroid, exposure=exposure, - flag_row=flag_row, uvw=uvw, weight=weight, sigma=sigma, - chan_freq=chan_freq, chan_width=chan_width, - effective_bw=effective_bw, resolution=resolution, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum, - max_uvw_dist=max_uvw_dist, max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, min_nchan=min_nchan) - - -def bda_impl(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - max_uvw_dist=None, max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): +def bda( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + max_uvw_dist=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, +): + return bda_impl( + time, + interval, + antenna1, + antenna2, + time_centroid=time_centroid, + exposure=exposure, + flag_row=flag_row, + uvw=uvw, + weight=weight, + sigma=sigma, + chan_freq=chan_freq, + chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + max_uvw_dist=max_uvw_dist, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan, + ) + + +def bda_impl( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + max_uvw_dist=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, +): return NotImplementedError @overload(bda_impl, jit_options=JIT_OPTIONS) -def nb_bda_impl(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - max_uvw_dist=None, max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): +def nb_bda_impl( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + max_uvw_dist=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, +): # Merge flag_row and flag arrays flag_row = merge_flags(flag_row, flag) - meta = bda_mapper(time, interval, antenna1, antenna2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan) - - row_avg = row_average(meta, antenna1, antenna2, flag_row, # noqa: F841 - time_centroid, exposure, uvw, - weight=weight, sigma=sigma) - - row_chan_avg = row_chan_average(meta, # noqa: F841 - flag_row=flag_row, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum) + meta = bda_mapper( + time, + interval, + antenna1, + antenna2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan, + ) + + row_avg = row_average( + meta, + antenna1, + antenna2, + flag_row, # noqa: F841 + time_centroid, + exposure, + uvw, + weight=weight, + sigma=sigma, + ) + + row_chan_avg = row_chan_average( + meta, # noqa: F841 + flag_row=flag_row, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + ) # Have to explicitly write it out because numba tuples # are highly constrained types - return AverageOutput(meta.map, - meta.offsets, - meta.decorr_chan_width, - meta.time, - meta.interval, - meta.chan_width, - meta.flag_row, - row_avg.antenna1, - row_avg.antenna2, - row_avg.time_centroid, - row_avg.exposure, - row_avg.uvw, - row_avg.weight, - row_avg.sigma, - # None, # chan_data.chan_freq, - # None, # chan_data.chan_width, - # None, # chan_data.effective_bw, - # None, # chan_data.resolution, - row_chan_avg.visibilities, - row_chan_avg.flag, - row_chan_avg.weight_spectrum, - row_chan_avg.sigma_spectrum) - - -BDA_DOCS = DocstringTemplate(""" + return AverageOutput( + meta.map, + meta.offsets, + meta.decorr_chan_width, + meta.time, + meta.interval, + meta.chan_width, + meta.flag_row, + row_avg.antenna1, + row_avg.antenna2, + row_avg.time_centroid, + row_avg.exposure, + row_avg.uvw, + row_avg.weight, + row_avg.sigma, + # None, # chan_data.chan_freq, + # None, # chan_data.chan_width, + # None, # chan_data.effective_bw, + # None, # chan_data.resolution, + row_chan_avg.visibilities, + row_chan_avg.flag, + row_chan_avg.weight_spectrum, + row_chan_avg.sigma_spectrum, + ) + + +BDA_DOCS = DocstringTemplate( + """ Averages in time and channel, dependent on baseline length. Parameters @@ -756,7 +921,8 @@ def nb_bda_impl(time, interval, antenna1, antenna2, A namedtuple whose entries correspond to the input arrays. Output arrays will be ``None`` if the inputs were ``None``. See the Notes for an explanation of the output formats. -""") +""" +) try: bda.__doc__ = BDA_DOCS.substitute(array_type=":class:`numpy.ndarray`") diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index a2abdbf7f..0b1fbe786 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -8,11 +8,7 @@ from numba import types from africanus.constants import c as lightspeed -from africanus.util.numba import ( - JIT_OPTIONS, - overload, - njit, - is_numba_type_none) +from africanus.util.numba import JIT_OPTIONS, overload, njit, is_numba_type_none from africanus.averaging.support import unique_time, unique_baselines @@ -20,13 +16,13 @@ class RowMapperError(Exception): pass -@njit(nogil=True, cache=True, inline='always') +@njit(nogil=True, cache=True, inline="always") def factors(n): assert n >= 1 result = [] i = 1 - while i*i <= n: + while i * i <= n: quot, rem = divmod(n, i) if rem == 0: @@ -40,7 +36,7 @@ def factors(n): return np.unique(np.array(result)) -@njit(nogil=True, cache=True, inline='always') +@njit(nogil=True, cache=True, inline="always") def max_chan_width(ref_freq, fractional_bandwidth): """ Derive max_𝞓𝝼, the maximum change in bandwidth @@ -58,15 +54,15 @@ def max_chan_width(ref_freq, fractional_bandwidth): return 2 * ref_freq * fractional_bandwidth -FinaliseOutput = namedtuple("FinaliseOutput", - ["tbin", "time", "interval", - "nchan", "flag"]) +FinaliseOutput = namedtuple( + "FinaliseOutput", ["tbin", "time", "interval", "nchan", "flag"] +) class Binner: - def __init__(self, row_start, row_end, - max_lm, decorrelation, time_bin_secs, - max_chan_freq): + def __init__( + self, row_start, row_end, max_lm, decorrelation, time_bin_secs, max_chan_freq + ): # Index of the time bin to which all rows in the bin will contribute self.tbin = 0 # Number of rows in the bin @@ -94,10 +90,14 @@ def __init__(self, row_start, row_end, self.time_bin_secs = time_bin_secs def reset(self): - self.__init__(0, 0, self.max_lm, - self.decorrelation, - self.time_bin_secs, - self.max_chan_freq) + self.__init__( + 0, + 0, + self.max_lm, + self.decorrelation, + self.time_bin_secs, + self.max_chan_freq, + ) def start_bin(self, row, time, interval, flag_row): """ @@ -106,8 +106,7 @@ def start_bin(self, row, time, interval, flag_row): self.rs = row self.re = row self.bin_count = 1 - self.bin_flag_count = (1 if flag_row is not None and flag_row[row] != 0 - else 0) + self.bin_flag_count = 1 if flag_row is not None and flag_row[row] != 0 else 0 def add_row(self, row, auto_corr, time, interval, uvw, flag_row): """ @@ -123,8 +122,9 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row): re = self.re if re == row: - raise ValueError("start_bin should be called to start a bin " - "before add_row is called.") + raise ValueError( + "start_bin should be called to start a bin " "before add_row is called." + ) if auto_corr: # Fast path for auto-correlated baseline. @@ -147,10 +147,13 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row): dv = uvw[row, 1] - uvw[rs, 1] dw = uvw[row, 2] - uvw[rs, 2] dt = time_end - time_start - half_𝞓𝞇 = (np.sqrt(du**2 + dv**2 + dw**2) * - self.max_chan_freq * - np.sin(np.abs(self.max_lm)) * - np.pi / lightspeed) + 1.0e-8 + half_𝞓𝞇 = ( + np.sqrt(du**2 + dv**2 + dw**2) + * self.max_chan_freq + * np.sin(np.abs(self.max_lm)) + * np.pi + / lightspeed + ) + 1.0e-8 bldecorr = np.sin(half_𝞓𝞇) / half_𝞓𝞇 # fringe rate at the equator @@ -168,8 +171,7 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row): # Do not add the row to the bin as it # would exceed the decorrelation tolerance # or the required number of seconds in the bin - if (bldecorr < np.sinc(self.decorrelation) or - dt > self.time_bin_secs): + if bldecorr < np.sinc(self.decorrelation) or dt > self.time_bin_secs: return False # Add the row by making it the end of the bin @@ -187,18 +189,21 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row): def empty(self): return self.bin_count == 0 - def finalise_bin(self, auto_corr, uvw, time, interval, - nchan_factors, chan_width, chan_freq): - """ Finalise the contents of this bin """ + def finalise_bin( + self, auto_corr, uvw, time, interval, nchan_factors, chan_width, chan_freq + ): + """Finalise the contents of this bin""" if self.bin_count == 0: raise ValueError("Attempted to finalise empty bin") elif self.bin_count == 1: # Single entry in the bin, no averaging occurs - out = FinaliseOutput(self.tbin, - time[self.rs], - interval[self.rs], - chan_width.size, - self.bin_count == self.bin_flag_count) + out = FinaliseOutput( + self.tbin, + time[self.rs], + interval[self.rs], + chan_width.size, + self.bin_count == self.bin_flag_count, + ) self.tbin += 1 @@ -221,8 +226,9 @@ def finalise_bin(self, auto_corr, uvw, time, interval, cuv = np.sqrt(cu**2 + cv**2) - max_abs_dist = np.sqrt(np.abs(cuv)*np.abs(self.max_lm) + - np.abs(cw)*np.abs(self.n_max)) + max_abs_dist = np.sqrt( + np.abs(cuv) * np.abs(self.max_lm) + np.abs(cw) * np.abs(self.n_max) + ) if max_abs_dist == 0.0: raise ValueError("max_abs_dist == 0.0") @@ -240,18 +246,17 @@ def finalise_bin(self, auto_corr, uvw, time, interval, # The following is copied from DDFacet. Variables names could # be changed but wanted to keep the correspondence clear. # BH: I strongly suspect this is wrong: see eq. 18-19 in SI II - delta_nu = (lightspeed / (2*np.pi)) * \ - (self.decorrelation / max_abs_dist) + delta_nu = (lightspeed / (2 * np.pi)) * (self.decorrelation / max_abs_dist) fracsizeChanBlock = delta_nu / chan_width fracsizeChanBlockMin = max(fracsizeChanBlock.min(), 1) assert fracsizeChanBlockMin >= 1 - nchan = np.ceil(chan_width.size/fracsizeChanBlockMin) + nchan = np.ceil(chan_width.size / fracsizeChanBlockMin) # Now find the next highest integer factorisation # of the input number of channels - s = np.searchsorted(nchan_factors, nchan, side='left') + s = np.searchsorted(nchan_factors, nchan, side="left") nchan = nchan_factors[min(nchan_factors.shape[0] - 1, s)] time_start = time[rs] - (interval[rs] / 2.0) @@ -259,72 +264,111 @@ def finalise_bin(self, auto_corr, uvw, time, interval, # Finalise bin values for return assert self.bin_count >= 1 - out = FinaliseOutput(self.tbin, - (time_start + time_end) / 2.0, - time_end - time_start, - nchan, - self.bin_count == self.bin_flag_count) + out = FinaliseOutput( + self.tbin, + (time_start + time_end) / 2.0, + time_end - time_start, + nchan, + self.bin_count == self.bin_flag_count, + ) self.tbin += 1 return out -RowMapOutput = namedtuple("RowMapOutput", - ["map", "offsets", "decorr_chan_width", - "time", "interval", "chan_width", "flag_row"]) +RowMapOutput = namedtuple( + "RowMapOutput", + [ + "map", + "offsets", + "decorr_chan_width", + "time", + "interval", + "chan_width", + "flag_row", + ], +) @njit(**JIT_OPTIONS) -def bda_mapper(time, interval, ant1, ant2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=None, - max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): - return bda_mapper_impl(time, interval, ant1, ant2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan) - - -def bda_mapper_impl(time, interval, ant1, ant2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=None, - max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): +def bda_mapper( + time, + interval, + ant1, + ant2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, +): + return bda_mapper_impl( + time, + interval, + ant1, + ant2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan, + ) + + +def bda_mapper_impl( + time, + interval, + ant1, + ant2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, +): return NotImplementedError @overload(bda_mapper_impl, jit_options={"nogil": True}) -def nb_bda_mapper(time, interval, ant1, ant2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=None, - max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): +def nb_bda_mapper( + time, + interval, + ant1, + ant2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, +): have_time_bin_secs = not is_numba_type_none(time_bin_secs) Omitted = types.misc.Omitted - decorr_type = (numba.typeof(decorrelation.value) - if isinstance(decorrelation, Omitted) - else decorrelation) + decorr_type = ( + numba.typeof(decorrelation.value) + if isinstance(decorrelation, Omitted) + else decorrelation + ) - fov_type = (numba.typeof(max_fov.value) - if isinstance(max_fov, Omitted) - else max_fov) + fov_type = numba.typeof(max_fov.value) if isinstance(max_fov, Omitted) else max_fov # If time_bin_secs is None, # then we set it to the max of the time dtype @@ -332,31 +376,39 @@ def nb_bda_mapper(time, interval, ant1, ant2, uvw, time_bin_secs_type = time_bin_secs if have_time_bin_secs else time.dtype spec = [ - ('tbin', numba.uintp), - ('bin_count', numba.uintp), - ('bin_flag_count', numba.uintp), - ('time_sum', time.dtype), - ('interval_sum', interval.dtype), - ('rs', numba.uintp), - ('re', numba.uintp), - ('bin_half_Δψ', uvw.dtype), - ('max_lm', fov_type), - ('n_max', fov_type), - ('decorrelation', decorr_type), - ('time_bin_secs', time_bin_secs_type), - ('max_chan_freq', chan_freq.dtype), - ('max_uvw_dist', max_uvw_dist)] + ("tbin", numba.uintp), + ("bin_count", numba.uintp), + ("bin_flag_count", numba.uintp), + ("time_sum", time.dtype), + ("interval_sum", interval.dtype), + ("rs", numba.uintp), + ("re", numba.uintp), + ("bin_half_Δψ", uvw.dtype), + ("max_lm", fov_type), + ("n_max", fov_type), + ("decorrelation", decorr_type), + ("time_bin_secs", time_bin_secs_type), + ("max_chan_freq", chan_freq.dtype), + ("max_uvw_dist", max_uvw_dist), + ] JitBinner = jitclass(spec)(Binner) - def impl(time, interval, ant1, ant2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=None, - max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1): + def impl( + time, + interval, + ant1, + ant2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, + ): # 𝞓 𝝿 𝞇 𝞍 𝝼 if decorrelation < 0.0 or decorrelation > 1.0: @@ -375,8 +427,9 @@ def impl(time, interval, ant1, ant2, uvw, nbl = ubl.shape[0] nchan = chan_width.shape[0] if nchan == 0: - raise ValueError("Number of channels passed into " - "averager must be at least size 1") + raise ValueError( + "Number of channels passed into " "averager must be at least size 1" + ) nchan_factors = factors(nchan) bandwidth = chan_width.sum() @@ -384,7 +437,7 @@ def impl(time, interval, ant1, ant2, uvw, min_nchan = 1 else: min_nchan = min(min_nchan, nchan) - s = np.searchsorted(nchan_factors, min_nchan, side='left') + s = np.searchsorted(nchan_factors, min_nchan, side="left") min_nchan = max(min_nchan, nchan_factors[s]) if nchan == 0: @@ -453,12 +506,9 @@ def update_lookups(finalised, bl): # dphi = np.sqrt(6. / np.pi**2 * (1. - decorrelation)) # better approximation - dphi = np.arccos(decorrelation)*np.sqrt(3)/np.pi + dphi = np.arccos(decorrelation) * np.sqrt(3) / np.pi - binner = JitBinner(0, 0, max_lm, - dphi, - time_bin_secs, - chan_freq.max()) + binner = JitBinner(0, 0, max_lm, dphi, time_bin_secs, chan_freq.max()) for bl in range(nbl): # Reset the binner for this baseline @@ -480,12 +530,16 @@ def update_lookups(finalised, bl): # Try add the row to the bin # If this fails, finalise the current bin and start a new one - elif not binner.add_row(r, auto_corr, - time, interval, - uvw, flag_row): - f = binner.finalise_bin(auto_corr, uvw, time, interval, - nchan_factors, - chan_width, chan_freq) + elif not binner.add_row(r, auto_corr, time, interval, uvw, flag_row): + f = binner.finalise_bin( + auto_corr, + uvw, + time, + interval, + nchan_factors, + chan_width, + chan_freq, + ) update_lookups(f, bl) # Post-finalisation, the bin is empty, start a new bin binner.start_bin(r, time, interval, flag_row) @@ -495,8 +549,9 @@ def update_lookups(finalised, bl): # Finalise any remaining data in the bin if not binner.empty: - f = binner.finalise_bin(auto_corr, uvw, time, interval, - nchan_factors, chan_width, chan_freq) + f = binner.finalise_bin( + auto_corr, uvw, time, interval, nchan_factors, chan_width, chan_freq + ) update_lookups(f, bl) nr_of_time_bins += binner.tbin @@ -510,7 +565,7 @@ def update_lookups(finalised, bl): # Flatten the time lookup and argsort it flat_time = time_lookup.ravel() - argsort = np.argsort(flat_time, kind='mergesort') + argsort = np.argsort(flat_time, kind="mergesort") inv_argsort = np.empty_like(argsort) # Generate lookup from flattened (bl, time) to output row @@ -539,8 +594,9 @@ def update_lookups(finalised, bl): chan_width_ret = np.full(out_row_chans, 0, dtype=chan_width.dtype) # Construct output flag row, if necessary - out_flag_row = (None if flag_row is None else - np.empty(out_row_chans, dtype=flag_row.dtype)) + out_flag_row = ( + None if flag_row is None else np.empty(out_row_chans, dtype=flag_row.dtype) + ) # foreach input row for in_row in range(time.shape[0]): @@ -553,7 +609,7 @@ def update_lookups(finalised, bl): bin_time = time_lookup[bl, tbin] bin_interval = interval_lookup[bl, tbin] flagged = bin_flagged[bl, tbin] - out_row = inv_argsort[bl*ntime + tbin] + out_row = inv_argsort[bl * ntime + tbin] decorr_chan_width[out_row] = bin_chan_width[bl, tbin] @@ -563,10 +619,12 @@ def update_lookups(finalised, bl): # Handle output row flagging if flag_row is not None and flag_row[in_row] == 0 and flagged: - raise RowMapperError("Unflagged input row " - "contributing to " - "flagged output row. " - "This should never happen!") + raise RowMapperError( + "Unflagged input row " + "contributing to " + "flagged output row. " + "This should never happen!" + ) # Set up the row channel map, populate # time, interval and chan_width @@ -590,9 +648,14 @@ def update_lookups(finalised, bl): if flag_row is not None: out_flag_row[out_offset] = 1 if flagged else 0 - return RowMapOutput(row_chan_map, offsets, - decorr_chan_width, - time_ret, int_ret, - chan_width_ret, out_flag_row) + return RowMapOutput( + row_chan_map, + offsets, + decorr_chan_width, + time_ret, + int_ret, + chan_width_ret, + out_flag_row, + ) return impl diff --git a/africanus/averaging/dask.py b/africanus/averaging/dask.py index 6f8937b0a..c2283444d 100644 --- a/africanus/averaging/dask.py +++ b/africanus/averaging/dask.py @@ -3,28 +3,30 @@ from operator import getitem -from africanus.averaging.bda_mapping import ( - bda_mapper as np_bda_mapper) +from africanus.averaging.bda_mapping import bda_mapper as np_bda_mapper from africanus.averaging.bda_avg import ( - BDA_DOCS, - row_average as np_bda_row_avg, - row_chan_average as np_bda_row_chan_avg, - AverageOutput as BDAAverageOutput, - RowAverageOutput as BDARowAverageOutput, - RowChanAverageOutput as BDARowChanAverageOutput) + BDA_DOCS, + row_average as np_bda_row_avg, + row_chan_average as np_bda_row_chan_avg, + AverageOutput as BDAAverageOutput, + RowAverageOutput as BDARowAverageOutput, + RowChanAverageOutput as BDARowChanAverageOutput, +) from africanus.averaging.time_and_channel_mapping import ( - row_mapper as np_tc_row_mapper, - channel_mapper as np_tc_channel_mapper) + row_mapper as np_tc_row_mapper, + channel_mapper as np_tc_channel_mapper, +) from africanus.averaging.time_and_channel_avg import ( - row_average as np_tc_row_average, - row_chan_average as np_tc_row_chan_average, - chan_average as np_tc_chan_average, - merge_flags as np_merge_flags, - AVERAGING_DOCS as TC_AVERAGING_DOCS, - AverageOutput as TcAverageOutput, - ChannelAverageOutput as TcChannelAverageOutput, - RowAverageOutput as TcRowAverageOutput, - RowChanAverageOutput as TcRowChanAverageOutput) + row_average as np_tc_row_average, + row_chan_average as np_tc_row_chan_average, + chan_average as np_tc_chan_average, + merge_flags as np_merge_flags, + AVERAGING_DOCS as TC_AVERAGING_DOCS, + AverageOutput as TcAverageOutput, + ChannelAverageOutput as TcChannelAverageOutput, + RowAverageOutput as TcRowAverageOutput, + RowChanAverageOutput as TcRowChanAverageOutput, +) from africanus.util.requirements import requires_optional @@ -42,7 +44,7 @@ def tc_chan_metadata(row_chan_arrays, chan_arrays, chan_bin_size): - """ Create dask array with channel metadata for each chunk channel """ + """Create dask array with channel metadata for each chunk channel""" chan_chunks = None for array in row_chan_arrays: @@ -67,8 +69,10 @@ def tc_chan_metadata(row_chan_arrays, chan_arrays, chan_bin_size): # Create a dask channel mapping structure name = "channel-mapper-" + tokenize(chan_chunks, chan_bin_size) - layers = {(name, i): (np_tc_channel_mapper, c, chan_bin_size) - for i, c in enumerate(chan_chunks)} + layers = { + (name, i): (np_tc_channel_mapper, c, chan_bin_size) + for i, c in enumerate(chan_chunks) + } graph = HighLevelGraph.from_collections(name, layers, ()) chunks = (chan_chunks,) chan_mapper = da.Array(graph, name, chunks, dtype=object) @@ -76,108 +80,163 @@ def tc_chan_metadata(row_chan_arrays, chan_arrays, chan_bin_size): return chan_mapper -def tc_row_mapper(time, interval, antenna1, antenna2, - flag_row=None, time_bin_secs=1.0): - """ Create a dask row mapping structure for each row chunk """ - return da.blockwise(np_tc_row_mapper, ("row",), - time, ("row",), - interval, ("row",), - antenna1, ("row",), - antenna2, ("row",), - flag_row, None if flag_row is None else ("row",), - adjust_chunks={"row": lambda x: np.nan}, - time_bin_secs=time_bin_secs, - meta=np.empty((0,), dtype=object), - dtype=object) +def tc_row_mapper(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1.0): + """Create a dask row mapping structure for each row chunk""" + return da.blockwise( + np_tc_row_mapper, + ("row",), + time, + ("row",), + interval, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + flag_row, + None if flag_row is None else ("row",), + adjust_chunks={"row": lambda x: np.nan}, + time_bin_secs=time_bin_secs, + meta=np.empty((0,), dtype=object), + dtype=object, + ) def _getitem_row(avg, idx, array, dims): - """ Extract row-like arrays from a dask array of tuples """ + """Extract row-like arrays from a dask array of tuples""" assert dims[0] == "row" name = ("row-average-getitem-%d-" % idx) + tokenize(avg, idx) - layers = db.blockwise(getitem, name, dims, - avg.name, ("row",), - idx, None, - new_axes=dict(zip(dims[1:], array.shape[1:])), - numblocks={avg.name: avg.numblocks}) + layers = db.blockwise( + getitem, + name, + dims, + avg.name, + ("row",), + idx, + None, + new_axes=dict(zip(dims[1:], array.shape[1:])), + numblocks={avg.name: avg.numblocks}, + ) graph = HighLevelGraph.from_collections(name, layers, (avg,)) chunks = avg.chunks + tuple((s,) for s in array.shape[1:]) - return da.Array(graph, name, chunks, - meta=np.empty((0,)*len(dims), dtype=array.dtype), - dtype=array.dtype) - - -def _tc_row_average_wrapper(row_meta, ant1, ant2, flag_row, - time_centroid, exposure, uvw, - weight, sigma): - return np_tc_row_average(row_meta, ant1, ant2, flag_row, - time_centroid, exposure, - uvw[0] if uvw is not None else None, - weight[0] if weight is not None else None, - sigma[0] if sigma is not None else None) - - -def tc_row_average(row_meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): - """ Average row-based dask arrays """ + return da.Array( + graph, + name, + chunks, + meta=np.empty((0,) * len(dims), dtype=array.dtype), + dtype=array.dtype, + ) + + +def _tc_row_average_wrapper( + row_meta, ant1, ant2, flag_row, time_centroid, exposure, uvw, weight, sigma +): + return np_tc_row_average( + row_meta, + ant1, + ant2, + flag_row, + time_centroid, + exposure, + uvw[0] if uvw is not None else None, + weight[0] if weight is not None else None, + sigma[0] if sigma is not None else None, + ) + + +def tc_row_average( + row_meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, +): + """Average row-based dask arrays""" rd = ("row",) rcd = ("row", "corr") # (output, array, dims) - args = [(False, row_meta, rd), - (True, ant1, rd), - (True, ant2, rd), - (False, flag_row, None if flag_row is None else rd), - (True, time_centroid, None if time_centroid is None else rd), - (True, exposure, None if exposure is None else rd), - (True, uvw, None if uvw is None else ("row", "uvw")), - (True, weight, None if weight is None else rcd), - (True, sigma, None if sigma is None else rcd)] - - avg = da.blockwise(_tc_row_average_wrapper, rd, - *(v for pair in args for v in pair[1:]), - align_arrays=False, - adjust_chunks={"row": lambda x: np.nan}, - meta=np.empty((0,)*len(rd), dtype=object), - dtype=object) + args = [ + (False, row_meta, rd), + (True, ant1, rd), + (True, ant2, rd), + (False, flag_row, None if flag_row is None else rd), + (True, time_centroid, None if time_centroid is None else rd), + (True, exposure, None if exposure is None else rd), + (True, uvw, None if uvw is None else ("row", "uvw")), + (True, weight, None if weight is None else rcd), + (True, sigma, None if sigma is None else rcd), + ] + + avg = da.blockwise( + _tc_row_average_wrapper, + rd, + *(v for pair in args for v in pair[1:]), + align_arrays=False, + adjust_chunks={"row": lambda x: np.nan}, + meta=np.empty((0,) * len(rd), dtype=object), + dtype=object, + ) # ant1, ant2, time_centroid, exposure, uvw, weight, sigma out_args = [(a, dims) for out, a, dims in args if out is True] - tuple_gets = [None if a is None else _getitem_row(avg, i, a, dims) - for i, (a, dims) in enumerate(out_args)] + tuple_gets = [ + None if a is None else _getitem_row(avg, i, a, dims) + for i, (a, dims) in enumerate(out_args) + ] return TcRowAverageOutput(*tuple_gets) def _getitem_row_chan(avg, idx, dtype): - """ Extract (row,chan,corr) arrays from dask array of tuples """ + """Extract (row,chan,corr) arrays from dask array of tuples""" name = ("row-chan-average-getitem-%d-" % idx) + tokenize(avg, idx) dim = ("row", "chan", "corr") - layers = db.blockwise(getitem, name, dim, - avg.name, dim, - idx, None, - numblocks={avg.name: avg.numblocks}) + layers = db.blockwise( + getitem, + name, + dim, + avg.name, + dim, + idx, + None, + numblocks={avg.name: avg.numblocks}, + ) graph = HighLevelGraph.from_collections(name, layers, (avg,)) - return da.Array(graph, name, avg.chunks, - meta=np.empty((0,)*len(dim), dtype=object), - dtype=dtype) + return da.Array( + graph, + name, + avg.chunks, + meta=np.empty((0,) * len(dim), dtype=object), + dtype=dtype, + ) _row_chan_avg_dims = ("row", "chan", "corr") -def tc_row_chan_average(row_meta, chan_meta, flag_row=None, weight=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - chan_bin_size=1): - """ Average (row,chan,corr)-based dask arrays """ +def tc_row_chan_average( + row_meta, + chan_meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + chan_bin_size=1, +): + """Average (row,chan,corr)-based dask arrays""" if chan_meta is None: return TcRowChanAverageOutput(None, None, None, None) @@ -186,7 +245,7 @@ def tc_row_chan_average(row_meta, chan_meta, flag_row=None, weight=None, # but we can simply divide each channel chunk size by the bin size adjust_chunks = { "row": lambda r: np.nan, - "chan": lambda c: (c + chan_bin_size - 1) // chan_bin_size + "chan": lambda c: (c + chan_bin_size - 1) // chan_bin_size, } flag_row_dims = None if flag_row is None else ("row",) @@ -203,37 +262,48 @@ def tc_row_chan_average(row_meta, chan_meta, flag_row=None, weight=None, # convert them into an array of tuples of visibilities if isinstance(visibilities, (tuple, list)): if not all(isinstance(a, da.Array) for a in visibilities): - raise ValueError("Visibility tuple must exclusively " - "contain dask arrays") + raise ValueError("Visibility tuple must exclusively " "contain dask arrays") have_vis_tuple = True nvis_elements = len(visibilities) meta = np.empty((0, 0, 0), visibilities[0].dtype) - visibilities = da.blockwise(lambda *a: a, _row_chan_avg_dims, - *[elem for a in visibilities - for elem in (a, _row_chan_avg_dims)], - meta=meta) - - avg = da.blockwise(np_tc_row_chan_average, _row_chan_avg_dims, - row_meta, ("row",), - chan_meta, ("chan",), - flag_row, flag_row_dims, - weight, weight_dims, - visibilities, vis_dims, - flag, flag_dims, - weight_spectrum, ws_dims, - sigma_spectrum, ss_dims, - align_arrays=False, - adjust_chunks=adjust_chunks, - meta=np.empty((0,)*len(_row_chan_avg_dims), - dtype=object), - dtype=object) - - tuple_gets = (None if a is None else _getitem_row_chan(avg, i, a.dtype) - for i, a in enumerate([visibilities, flag, - weight_spectrum, - sigma_spectrum])) + visibilities = da.blockwise( + lambda *a: a, + _row_chan_avg_dims, + *[elem for a in visibilities for elem in (a, _row_chan_avg_dims)], + meta=meta, + ) + + avg = da.blockwise( + np_tc_row_chan_average, + _row_chan_avg_dims, + row_meta, + ("row",), + chan_meta, + ("chan",), + flag_row, + flag_row_dims, + weight, + weight_dims, + visibilities, + vis_dims, + flag, + flag_dims, + weight_spectrum, + ws_dims, + sigma_spectrum, + ss_dims, + align_arrays=False, + adjust_chunks=adjust_chunks, + meta=np.empty((0,) * len(_row_chan_avg_dims), dtype=object), + dtype=object, + ) + + tuple_gets = ( + None if a is None else _getitem_row_chan(avg, i, a.dtype) + for i, a in enumerate([visibilities, flag, weight_spectrum, sigma_spectrum]) + ) # If we received an array of tuples of visibilities # convert them into a tuple of visibility arrays @@ -243,10 +313,15 @@ def tc_row_chan_average(row_meta, chan_meta, flag_row=None, weight=None, tuple_vis = [] for v in range(nvis_elements): - v = da.blockwise(getitem, _row_chan_avg_dims, - vis_tuple, _row_chan_avg_dims, - v, None, - dtype=vis_tuple.dtype) + v = da.blockwise( + getitem, + _row_chan_avg_dims, + vis_tuple, + _row_chan_avg_dims, + v, + None, + dtype=vis_tuple.dtype, + ) tuple_vis.append(v) tuple_gets = (tuple(tuple_vis),) + tuple_gets[1:] @@ -255,80 +330,120 @@ def tc_row_chan_average(row_meta, chan_meta, flag_row=None, weight=None, def _getitem_chan(avg, idx, dtype): - """ Extract row-like arrays from a dask array of tuples """ + """Extract row-like arrays from a dask array of tuples""" name = ("chan-average-getitem-%d-" % idx) + tokenize(avg, idx) - layers = db.blockwise(getitem, name, ("chan",), - avg.name, ("chan",), - idx, None, - numblocks={avg.name: avg.numblocks}) + layers = db.blockwise( + getitem, + name, + ("chan",), + avg.name, + ("chan",), + idx, + None, + numblocks={avg.name: avg.numblocks}, + ) graph = HighLevelGraph.from_collections(name, layers, (avg,)) - return da.Array(graph, name, avg.chunks, - meta=np.empty((0,), dtype=dtype), - dtype=dtype) - - -def tc_chan_average(chan_meta, chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, chan_bin_size=1): - + return da.Array( + graph, name, avg.chunks, meta=np.empty((0,), dtype=dtype), dtype=dtype + ) + + +def tc_chan_average( + chan_meta, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + chan_bin_size=1, +): if chan_meta is None: return TcChannelAverageOutput(None, None) - adjust_chunks = { - "chan": lambda c: (c + chan_bin_size - 1) // chan_bin_size - } + adjust_chunks = {"chan": lambda c: (c + chan_bin_size - 1) // chan_bin_size} cdim = ("chan",) - avg = da.blockwise(np_tc_chan_average, cdim, - chan_meta, cdim, - chan_freq, None if chan_freq is None else cdim, - chan_width, None if chan_width is None else cdim, - effective_bw, None if effective_bw is None else cdim, - resolution, None if resolution is None else cdim, - adjust_chunks=adjust_chunks, - meta=np.empty((0,), dtype=object), - dtype=object) - - tuple_gets = (None if a is None else _getitem_chan(avg, i, a.dtype) - for i, a in enumerate([chan_freq, chan_width, - effective_bw, resolution])) + avg = da.blockwise( + np_tc_chan_average, + cdim, + chan_meta, + cdim, + chan_freq, + None if chan_freq is None else cdim, + chan_width, + None if chan_width is None else cdim, + effective_bw, + None if effective_bw is None else cdim, + resolution, + None if resolution is None else cdim, + adjust_chunks=adjust_chunks, + meta=np.empty((0,), dtype=object), + dtype=object, + ) + + tuple_gets = ( + None if a is None else _getitem_chan(avg, i, a.dtype) + for i, a in enumerate([chan_freq, chan_width, effective_bw, resolution]) + ) return TcChannelAverageOutput(*tuple_gets) def merge_flags(flag_row, flag): - """ Perform flag merging on dask arrays """ + """Perform flag merging on dask arrays""" if flag_row is None and flag is not None: - return da.blockwise(np_merge_flags, "r", - flag_row, None, - flag, "rfc", - concatenate=True, - dtype=flag.dtype) + return da.blockwise( + np_merge_flags, + "r", + flag_row, + None, + flag, + "rfc", + concatenate=True, + dtype=flag.dtype, + ) elif flag_row is not None and flag is None: - return da.blockwise(np_merge_flags, "r", - flag_row, "r", - None, None, - dtype=flag_row.dtype) + return da.blockwise( + np_merge_flags, "r", flag_row, "r", None, None, dtype=flag_row.dtype + ) elif flag_row is not None and flag is not None: - return da.blockwise(np_merge_flags, "r", - flag_row, "r", - flag, "rfc", - concatenate=True, - dtype=flag_row.dtype) + return da.blockwise( + np_merge_flags, + "r", + flag_row, + "r", + flag, + "rfc", + concatenate=True, + dtype=flag_row.dtype, + ) else: return None @requires_optional("dask.array", dask_import_error) -def time_and_channel(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - time_bin_secs=1.0, chan_bin_size=1): - +def time_and_channel( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + time_bin_secs=1.0, + chan_bin_size=1, +): row_chan_arrays = (visibilities, flag, weight_spectrum, sigma_spectrum) chan_arrays = (chan_freq, chan_width, effective_bw, resolution) @@ -339,109 +454,171 @@ def time_and_channel(time, interval, antenna1, antenna2, flag_row = merge_flags(flag_row, flag) # Generate row mapping metadata - row_meta = tc_row_mapper(time, interval, - antenna1, antenna2, - flag_row=flag_row, - time_bin_secs=time_bin_secs) + row_meta = tc_row_mapper( + time, + interval, + antenna1, + antenna2, + flag_row=flag_row, + time_bin_secs=time_bin_secs, + ) # Generate channel mapping metadata chan_meta = tc_chan_metadata(row_chan_arrays, chan_arrays, chan_bin_size) # Average row data - row_data = tc_row_average(row_meta, antenna1, antenna2, - flag_row=flag_row, - time_centroid=time_centroid, - exposure=exposure, uvw=uvw, - weight=weight, sigma=sigma) + row_data = tc_row_average( + row_meta, + antenna1, + antenna2, + flag_row=flag_row, + time_centroid=time_centroid, + exposure=exposure, + uvw=uvw, + weight=weight, + sigma=sigma, + ) # Average channel data - row_chan_data = tc_row_chan_average(row_meta, chan_meta, - flag_row=flag_row, weight=weight, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum, - chan_bin_size=chan_bin_size) - - chan_data = tc_chan_average(chan_meta, - chan_freq=chan_freq, - chan_width=chan_width, - effective_bw=effective_bw, - resolution=resolution) + row_chan_data = tc_row_chan_average( + row_meta, + chan_meta, + flag_row=flag_row, + weight=weight, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + chan_bin_size=chan_bin_size, + ) + + chan_data = tc_chan_average( + chan_meta, + chan_freq=chan_freq, + chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution, + ) # Merge output tuples - return TcAverageOutput(_getitem_row(row_meta, 1, time, ("row",)), - _getitem_row(row_meta, 2, interval, ("row",)), - (_getitem_row(row_meta, 3, flag_row, ("row",)) - if flag_row is not None else None), - row_data.antenna1, - row_data.antenna2, - row_data.time_centroid, - row_data.exposure, - row_data.uvw, - row_data.weight, - row_data.sigma, - chan_data.chan_freq, - chan_data.chan_width, - chan_data.effective_bw, - chan_data.resolution, - row_chan_data.visibilities, - row_chan_data.flag, - row_chan_data.weight_spectrum, - row_chan_data.sigma_spectrum) - - -def _bda_mapper_wrapper(time, interval, ant1, ant2, - uvw, chan_width, chan_freq, - max_uvw_dist, flag_row, - max_fov=None, - decorrelation=None, - time_bin_secs=None, - min_nchan=None): - return np_bda_mapper(time, interval, ant1, ant2, - None if uvw is None else uvw[0], - chan_width[0], chan_freq[0], - max_uvw_dist=max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan) - - -def bda_mapper(time, interval, antenna1, antenna2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=None, max_fov=None, - decorrelation=None, - time_bin_secs=None, - min_nchan=None): - """ Createask row mapping structure for each row chunk """ - return da.blockwise(_bda_mapper_wrapper, ("row",), - time, ("row",), - interval, ("row",), - antenna1, ("row",), - antenna2, ("row",), - uvw, ("row", "uvw"), - chan_width, ("chan",), - chan_freq, ("chan",), - max_uvw_dist, None if max_uvw_dist is None else (), - flag_row, None if flag_row is None else ("row",), - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan, - adjust_chunks={"row": lambda x: np.nan}, - meta=np.empty((0, 0), dtype=object)) - - -def _bda_row_average_wrapper(meta, ant1, ant2, flag_row, - time_centroid, exposure, uvw, - weight, sigma): - return np_bda_row_avg(meta, ant1, ant2, flag_row, - time_centroid, exposure, - None if uvw is None else uvw[0], - None if weight is None else weight[0], - None if sigma is None else sigma[0]) + return TcAverageOutput( + _getitem_row(row_meta, 1, time, ("row",)), + _getitem_row(row_meta, 2, interval, ("row",)), + ( + _getitem_row(row_meta, 3, flag_row, ("row",)) + if flag_row is not None + else None + ), + row_data.antenna1, + row_data.antenna2, + row_data.time_centroid, + row_data.exposure, + row_data.uvw, + row_data.weight, + row_data.sigma, + chan_data.chan_freq, + chan_data.chan_width, + chan_data.effective_bw, + chan_data.resolution, + row_chan_data.visibilities, + row_chan_data.flag, + row_chan_data.weight_spectrum, + row_chan_data.sigma_spectrum, + ) + + +def _bda_mapper_wrapper( + time, + interval, + ant1, + ant2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row, + max_fov=None, + decorrelation=None, + time_bin_secs=None, + min_nchan=None, +): + return np_bda_mapper( + time, + interval, + ant1, + ant2, + None if uvw is None else uvw[0], + chan_width[0], + chan_freq[0], + max_uvw_dist=max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan, + ) + + +def bda_mapper( + time, + interval, + antenna1, + antenna2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=None, + max_fov=None, + decorrelation=None, + time_bin_secs=None, + min_nchan=None, +): + """Createask row mapping structure for each row chunk""" + return da.blockwise( + _bda_mapper_wrapper, + ("row",), + time, + ("row",), + interval, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + uvw, + ("row", "uvw"), + chan_width, + ("chan",), + chan_freq, + ("chan",), + max_uvw_dist, + None if max_uvw_dist is None else (), + flag_row, + None if flag_row is None else ("row",), + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan, + adjust_chunks={"row": lambda x: np.nan}, + meta=np.empty((0, 0), dtype=object), + ) + + +def _bda_row_average_wrapper( + meta, ant1, ant2, flag_row, time_centroid, exposure, uvw, weight, sigma +): + return np_bda_row_avg( + meta, + ant1, + ant2, + flag_row, + time_centroid, + exposure, + None if uvw is None else uvw[0], + None if weight is None else weight[0], + None if sigma is None else sigma[0], + ) def _ragged_row_getitem(avg, idx, meta): @@ -449,7 +626,7 @@ def _ragged_row_getitem(avg, idx, meta): def _bda_getitem_row(avg, idx, array, dims, meta, format="flat"): - """ Extract row-like arrays from a dask array of tuples """ + """Extract row-like arrays from a dask array of tuples""" assert dims[0] == "row" name = "row-average-getitem-%s-" % idx @@ -458,26 +635,39 @@ def _bda_getitem_row(avg, idx, array, dims, meta, format="flat"): numblocks = {avg.name: avg.numblocks} if format == "flat": - layers = db.blockwise(getitem, name, dims, - avg.name, ("row",), - idx, None, - new_axes=new_axes, - numblocks=numblocks) + layers = db.blockwise( + getitem, + name, + dims, + avg.name, + ("row",), + idx, + None, + new_axes=new_axes, + numblocks=numblocks, + ) elif format == "ragged": numblocks[meta.name] = meta.numblocks - layers = db.blockwise(_ragged_row_getitem, name, dims, - avg.name, ("row",), - idx, None, - meta.name, ("row",), - new_axes=new_axes, - numblocks=numblocks) + layers = db.blockwise( + _ragged_row_getitem, + name, + dims, + avg.name, + ("row",), + idx, + None, + meta.name, + ("row",), + new_axes=new_axes, + numblocks=numblocks, + ) else: raise ValueError("Invalid format %s" % format) graph = HighLevelGraph.from_collections(name, layers, (avg,)) chunks = avg.chunks + tuple((s,) for s in array.shape[1:]) - meta = np.empty((0,)*len(dims), dtype=array.dtype) + meta = np.empty((0,) * len(dims), dtype=array.dtype) return da.Array(graph, name, chunks, meta=meta) @@ -486,18 +676,22 @@ def _ragged_row_chan_getitem(avg, idx, meta): data = avg[idx] if isinstance(data, tuple): - return tuple({"r%d" % (r+1): d[None, s:e, ...] - for r, (s, e) - in enumerate(zip(meta.offsets[:-1], meta.offsets[1:]))} - for d in data) - - return {"r%d" % (r+1): data[None, s:e, ...] - for r, (s, e) - in enumerate(zip(meta.offsets[:-1], meta.offsets[1:]))} + return tuple( + { + "r%d" % (r + 1): d[None, s:e, ...] + for r, (s, e) in enumerate(zip(meta.offsets[:-1], meta.offsets[1:])) + } + for d in data + ) + + return { + "r%d" % (r + 1): data[None, s:e, ...] + for r, (s, e) in enumerate(zip(meta.offsets[:-1], meta.offsets[1:])) + } def _bda_getitem_row_chan(avg, idx, dtype, format, avg_meta, nchan): - """ Extract (row, corr) arrays from dask array of tuples """ + """Extract (row, corr) arrays from dask array of tuples""" f = BDARowChanAverageOutput._fields[idx] name = "row-chan-average-getitem-%s-%s-" % (f, format) name += tokenize(avg, idx) @@ -506,10 +700,16 @@ def _bda_getitem_row_chan(avg, idx, dtype, format, avg_meta, nchan): dims = ("row", "corr") new_axes = None - layers = db.blockwise(getitem, name, dims, - avg.name, ("row", "corr"), - idx, None, - numblocks={avg.name: avg.numblocks}) + layers = db.blockwise( + getitem, + name, + dims, + avg.name, + ("row", "corr"), + idx, + None, + numblocks={avg.name: avg.numblocks}, + ) chunks = avg.chunks meta = np.empty((0, 0), dtype=object) @@ -517,14 +717,19 @@ def _bda_getitem_row_chan(avg, idx, dtype, format, avg_meta, nchan): dims = ("row", "chan", "corr") new_axes = {"chan": nchan} - layers = db.blockwise(_ragged_row_chan_getitem, name, dims, - avg.name, ("row", "corr"), - idx, None, - avg_meta.name, ("row",), - new_axes=new_axes, - numblocks={ - avg.name: avg.numblocks, - avg_meta.name: avg_meta.numblocks}) + layers = db.blockwise( + _ragged_row_chan_getitem, + name, + dims, + avg.name, + ("row", "corr"), + idx, + None, + avg_meta.name, + ("row",), + new_axes=new_axes, + numblocks={avg.name: avg.numblocks, avg_meta.name: avg_meta.numblocks}, + ) chunks = (avg.chunks[0], (nchan,), avg.chunks[1]) meta = np.empty((0, 0, 0), dtype=object) @@ -535,65 +740,83 @@ def _bda_getitem_row_chan(avg, idx, dtype, format, avg_meta, nchan): return da.Array(graph, name, chunks, meta=meta) -def bda_row_average(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None, - format="flat"): - """ Average row-based dask arrays """ +def bda_row_average( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, + format="flat", +): + """Average row-based dask arrays""" rd = ("row",) rcd = ("row", "corr") # (output, array, dims) - args = [(False, meta, ("row",)), - (True, ant1, rd), - (True, ant2, rd), - (False, flag_row, None if flag_row is None else rd), - (True, time_centroid, None if time_centroid is None else rd), - (True, exposure, None if exposure is None else rd), - (True, uvw, None if uvw is None else ("row", "uvw")), - (True, weight, None if weight is None else rcd), - (True, sigma, None if sigma is None else rcd)] - - avg = da.blockwise(_bda_row_average_wrapper, rd, - *(v for pair in args for v in pair[1:]), - align_arrays=False, - adjust_chunks={"row": lambda x: np.nan}, - meta=np.empty((0,)*len(rd), dtype=object), - dtype=object) + args = [ + (False, meta, ("row",)), + (True, ant1, rd), + (True, ant2, rd), + (False, flag_row, None if flag_row is None else rd), + (True, time_centroid, None if time_centroid is None else rd), + (True, exposure, None if exposure is None else rd), + (True, uvw, None if uvw is None else ("row", "uvw")), + (True, weight, None if weight is None else rcd), + (True, sigma, None if sigma is None else rcd), + ] + + avg = da.blockwise( + _bda_row_average_wrapper, + rd, + *(v for pair in args for v in pair[1:]), + align_arrays=False, + adjust_chunks={"row": lambda x: np.nan}, + meta=np.empty((0,) * len(rd), dtype=object), + dtype=object, + ) # ant1, ant2, time_centroid, exposure, uvw, weight, sigma out_args = [(a, dims) for out, a, dims in args if out is True] - tuple_gets = [None if a is None - else _bda_getitem_row(avg, i, a, dims, meta, format=format) - for i, (a, dims) in enumerate(out_args)] + tuple_gets = [ + None if a is None else _bda_getitem_row(avg, i, a, dims, meta, format=format) + for i, (a, dims) in enumerate(out_args) + ] return BDARowAverageOutput(*tuple_gets) -def _bda_row_chan_average_wrapper(avg_meta, flag_row, weight, - vis, flag, - weight_spectrum, - sigma_spectrum): +def _bda_row_chan_average_wrapper( + avg_meta, flag_row, weight, vis, flag, weight_spectrum, sigma_spectrum +): return np_bda_row_chan_avg( - avg_meta, flag_row, weight, - None if vis is None else vis[0], - None if flag is None else flag[0], - None if weight_spectrum is None else weight_spectrum[0], - None if sigma_spectrum is None else sigma_spectrum[0]) - - -def bda_row_chan_average(avg_meta, flag_row=None, weight=None, - visibilities=None, flag=None, - weight_spectrum=None, - sigma_spectrum=None, - format="flat"): - """ Average (row,chan,corr)-based dask arrays """ - if all(v is None for v in (visibilities, - flag, - weight_spectrum, - sigma_spectrum)): + avg_meta, + flag_row, + weight, + None if vis is None else vis[0], + None if flag is None else flag[0], + None if weight_spectrum is None else weight_spectrum[0], + None if sigma_spectrum is None else sigma_spectrum[0], + ) + + +def bda_row_chan_average( + avg_meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + format="flat", +): + """Average (row,chan,corr)-based dask arrays""" + if all(v is None for v in (visibilities, flag, weight_spectrum, sigma_spectrum)): return BDARowChanAverageOutput(None, None, None, None) # We don't know how many rows are in each row chunk, @@ -622,19 +845,20 @@ def bda_row_chan_average(avg_meta, flag_row=None, weight=None, nchan = visibilities[0].shape[1] if not all(isinstance(a, da.Array) for a in visibilities): - raise ValueError("Visibility tuple must exclusively " - "contain dask arrays") + raise ValueError("Visibility tuple must exclusively " "contain dask arrays") # If we received a tuple of visibility arrays # convert them into an array of tuples of visibilities have_vis_tuple = True nvis_elements = len(visibilities) - meta = np.empty((0,)*len(bda_dims), dtype=visibilities[0].dtype) - - visibilities = da.blockwise(lambda *a: a, _row_chan_avg_dims, - *[elem for a in visibilities - for elem in (a, vis_dims)], - meta=meta) + meta = np.empty((0,) * len(bda_dims), dtype=visibilities[0].dtype) + + visibilities = da.blockwise( + lambda *a: a, + _row_chan_avg_dims, + *[elem for a in visibilities for elem in (a, vis_dims)], + meta=meta, + ) elif isinstance(flag, da.Array): nchan = flag.shape[1] elif isinstance(weight_spectrum, da.Array): @@ -644,25 +868,35 @@ def bda_row_chan_average(avg_meta, flag_row=None, weight=None, else: raise ValueError("Couldn't infer nchan") - avg = da.blockwise(_bda_row_chan_average_wrapper, ("row", "corr"), - avg_meta, ("row",), - flag_row, flag_row_dims, - weight, weight_dims, - visibilities, vis_dims, - flag, flag_dims, - weight_spectrum, ws_dims, - sigma_spectrum, ss_dims, - align_arrays=False, - adjust_chunks=adjust_chunks, - meta=np.empty((0, 0), dtype=object), - dtype=object) - - tuple_gets = (None if a is None else - _bda_getitem_row_chan(avg, i, a.dtype, - format, avg_meta, nchan) - for i, a in enumerate([visibilities, flag, - weight_spectrum, - sigma_spectrum])) + avg = da.blockwise( + _bda_row_chan_average_wrapper, + ("row", "corr"), + avg_meta, + ("row",), + flag_row, + flag_row_dims, + weight, + weight_dims, + visibilities, + vis_dims, + flag, + flag_dims, + weight_spectrum, + ws_dims, + sigma_spectrum, + ss_dims, + align_arrays=False, + adjust_chunks=adjust_chunks, + meta=np.empty((0, 0), dtype=object), + dtype=object, + ) + + tuple_gets = ( + None + if a is None + else _bda_getitem_row_chan(avg, i, a.dtype, format, avg_meta, nchan) + for i, a in enumerate([visibilities, flag, weight_spectrum, sigma_spectrum]) + ) # If we received an array of tuples of visibilities # convert them into a tuple of visibility arrays @@ -672,10 +906,9 @@ def bda_row_chan_average(avg_meta, flag_row=None, weight=None, tuple_vis = [] for v in range(nvis_elements): - v = da.blockwise(getitem, bda_dims, - vis_array, bda_dims, - v, None, - dtype=vis_array.dtype) + v = da.blockwise( + getitem, bda_dims, vis_array, bda_dims, v, None, dtype=vis_array.dtype + ) tuple_vis.append(v) tuple_gets = (tuple(tuple_vis),) + tuple_gets[1:] @@ -684,21 +917,32 @@ def bda_row_chan_average(avg_meta, flag_row=None, weight=None, @requires_optional("dask.array", dask_import_error) -def bda(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, - sigma_spectrum=None, - max_uvw_dist=None, - max_fov=3.0, - decorrelation=0.98, - time_bin_secs=None, - min_nchan=1, - format="flat"): - +def bda( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + max_uvw_dist=None, + max_fov=3.0, + decorrelation=0.98, + time_bin_secs=None, + min_nchan=1, + format="flat", +): if uvw is None: raise ValueError("uvw must be supplied") @@ -709,8 +953,7 @@ def bda(time, interval, antenna1, antenna2, raise ValueError("chan_freq must be supplied") if not len(chan_width.chunks[0]) == 1: - raise ValueError("Chunking in channel is not " - "currently supported.") + raise ValueError("Chunking in channel is not " "currently supported.") if max_uvw_dist is None: max_uvw_dist = da.sqrt((uvw**2).sum(axis=1)).max() @@ -725,31 +968,47 @@ def bda(time, interval, antenna1, antenna2, flag_row = merge_flags(flag_row, flag) # Generate row mapping metadata - meta = bda_mapper(time, interval, antenna1, antenna2, uvw, - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=max_fov, - decorrelation=decorrelation, - time_bin_secs=time_bin_secs, - min_nchan=min_nchan) + meta = bda_mapper( + time, + interval, + antenna1, + antenna2, + uvw, + chan_width, + chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=max_fov, + decorrelation=decorrelation, + time_bin_secs=time_bin_secs, + min_nchan=min_nchan, + ) # Average row data - row_data = bda_row_average(meta, antenna1, antenna2, - flag_row=flag_row, - time_centroid=time_centroid, - exposure=exposure, - uvw=uvw, - weight=weight, sigma=sigma, - format=format) + row_data = bda_row_average( + meta, + antenna1, + antenna2, + flag_row=flag_row, + time_centroid=time_centroid, + exposure=exposure, + uvw=uvw, + weight=weight, + sigma=sigma, + format=format, + ) # Average channel data - row_chan_data = bda_row_chan_average(meta, - flag_row=flag_row, weight=weight, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum, - format=format) + row_chan_data = bda_row_chan_average( + meta, + flag_row=flag_row, + weight=weight, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + format=format, + ) # chan_data = chan_average(chan_meta, # chan_freq=chan_freq, @@ -757,9 +1016,11 @@ def bda(time, interval, antenna1, antenna2, # effective_bw=effective_bw, # resolution=resolution) - fake_map = da.zeros((time.shape[0], chan_width.shape[0]), - chunks=time.chunks + chan_width.chunks, - dtype=np.uint32) + fake_map = da.zeros( + (time.shape[0], chan_width.shape[0]), + chunks=time.chunks + chan_width.chunks, + dtype=np.uint32, + ) fake_ints = da.zeros_like(time, dtype=np.uint32) fake_floats = da.zeros_like(chan_width) @@ -767,43 +1028,46 @@ def bda(time, interval, antenna1, antenna2, meta_map = _bda_getitem_row(meta, 0, fake_map, ("row", "chan"), meta) meta_offsets = _bda_getitem_row(meta, 1, fake_ints, ("row",), meta) meta_decorr_cw = _bda_getitem_row(meta, 2, fake_floats, ("row",), meta) - meta_time = _bda_getitem_row(meta, 3, time, ("row",), - meta, format=format) - meta_interval = _bda_getitem_row(meta, 4, interval, ("row",), - meta, format=format) + meta_time = _bda_getitem_row(meta, 3, time, ("row",), meta, format=format) + meta_interval = _bda_getitem_row(meta, 4, interval, ("row",), meta, format=format) meta_chan_width = _bda_getitem_row(meta, 5, chan_width, ("row",), meta) - meta_flag_row = (_bda_getitem_row(meta, 6, flag_row, ("row",), - meta, format=format) - if flag_row is not None else None) + meta_flag_row = ( + _bda_getitem_row(meta, 6, flag_row, ("row",), meta, format=format) + if flag_row is not None + else None + ) # Merge output tuples - return BDAAverageOutput(meta_map, - meta_offsets, - meta_decorr_cw, - meta_time, - meta_interval, - meta_chan_width, - meta_flag_row, - row_data.antenna1, - row_data.antenna2, - row_data.time_centroid, - row_data.exposure, - row_data.uvw, - row_data.weight, - row_data.sigma, - # None, # chan_data.chan_freq, - # None, # chan_data.chan_width, - # None, # chan_data.effective_bw, - # None, # chan_data.resolution, - row_chan_data.visibilities, - row_chan_data.flag, - row_chan_data.weight_spectrum, - row_chan_data.sigma_spectrum) + return BDAAverageOutput( + meta_map, + meta_offsets, + meta_decorr_cw, + meta_time, + meta_interval, + meta_chan_width, + meta_flag_row, + row_data.antenna1, + row_data.antenna2, + row_data.time_centroid, + row_data.exposure, + row_data.uvw, + row_data.weight, + row_data.sigma, + # None, # chan_data.chan_freq, + # None, # chan_data.chan_width, + # None, # chan_data.effective_bw, + # None, # chan_data.resolution, + row_chan_data.visibilities, + row_chan_data.flag, + row_chan_data.weight_spectrum, + row_chan_data.sigma_spectrum, + ) try: time_and_channel.__doc__ = TC_AVERAGING_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" + ) bda.__doc__ = BDA_DOCS.substitute(array_type=":class:`dask.array.Array`") except AttributeError: pass diff --git a/africanus/averaging/shared.py b/africanus/averaging/shared.py index 089bd9da9..d3e83f4e6 100644 --- a/africanus/averaging/shared.py +++ b/africanus/averaging/shared.py @@ -2,11 +2,13 @@ import numpy as np -from africanus.util.numba import (is_numba_type_none, - intrinsic, - JIT_OPTIONS, - njit, - overload) +from africanus.util.numba import ( + is_numba_type_none, + intrinsic, + JIT_OPTIONS, + njit, + overload, +) def shape_or_invalid_shape(array, ndim): @@ -28,8 +30,9 @@ def nb_merge_flags(flag_row, flag): have_flag = not is_numba_type_none(flag) if have_flag_row and have_flag: + def impl(flag_row, flag): - """ Check flag_row and flag agree """ + """Check flag_row and flag agree""" for r in range(flag.shape[0]): all_flagged = True @@ -48,13 +51,15 @@ def impl(flag_row, flag): return flag_row elif have_flag_row and not have_flag: + def impl(flag_row, flag): - """ Return flag_row """ + """Return flag_row""" return flag_row elif not have_flag_row and have_flag: + def impl(flag_row, flag): - """ Construct flag_row from flag """ + """Construct flag_row from flag""" new_flag_row = np.empty(flag.shape[0], dtype=flag.dtype) for r in range(flag.shape[0]): @@ -69,26 +74,26 @@ def impl(flag_row, flag): if not all_flagged: break - new_flag_row[r] = (1 if all_flagged else 0) + new_flag_row[r] = 1 if all_flagged else 0 return new_flag_row else: + def impl(flag_row, flag): return None return impl -@overload(shape_or_invalid_shape, inline='always') +@overload(shape_or_invalid_shape, inline="always") def _shape_or_invalid_shape(array, ndim): - """ Return array shape tuple or (-1,)*ndim if the array is None """ + """Return array shape tuple or (-1,)*ndim if the array is None""" import numba.core.types as nbtypes from numba.extending import SentryLiteralArgs - SentryLiteralArgs(['ndim']).for_function( - _shape_or_invalid_shape).bind(array, ndim) + SentryLiteralArgs(["ndim"]).for_function(_shape_or_invalid_shape).bind(array, ndim) try: ndim_lit = getattr(ndim, "literal_value") @@ -96,24 +101,25 @@ def _shape_or_invalid_shape(array, ndim): raise ValueError("ndim must be a integer literal") if is_numba_type_none(array): - tup = (-1,)*ndim_lit + tup = (-1,) * ndim_lit def impl(array, ndim): return tup return impl elif isinstance(array, nbtypes.Array): + def impl(array, ndim): return array.shape return impl - elif (isinstance(array, nbtypes.UniTuple) and - isinstance(array.dtype, nbtypes.Array)): - + elif isinstance(array, nbtypes.UniTuple) and isinstance(array.dtype, nbtypes.Array): if len(array) == 1: + def impl(array, ndim): return array[0].shape else: + def impl(array, ndim): shape = array[0].shape @@ -132,9 +138,11 @@ def impl(array, ndim): raise ValueError("Array ndims in Tuple don't match") if len(array) == 1: + def impl(array, ndim): return array[0].shape else: + def impl(array, ndim): shape = array[0].shape @@ -187,8 +195,7 @@ def find_chan_corr(chan, corr, shape, chan_idx, corr_idx): chan = array_chan # Check consistency elif chan != array_chan: - raise ValueError("Inconsistent Channel Dimension " - "in Input Arrays") + raise ValueError("Inconsistent Channel Dimension " "in Input Arrays") if corr_idx != -1: array_corr = shape[corr_idx] @@ -201,8 +208,7 @@ def find_chan_corr(chan, corr, shape, chan_idx, corr_idx): corr = array_corr # Check consistency elif corr != array_corr: - raise ValueError("Inconsistent Correlation Dimension " - "in Input Arrays") + raise ValueError("Inconsistent Correlation Dimension " "in Input Arrays") return chan, corr @@ -211,10 +217,16 @@ def find_chan_corr(chan, corr, shape, chan_idx, corr_idx): # maybe inline='always' if # https://github.com/numba/numba/issues/4693 is resolved @njit(nogil=True, cache=True) -def chan_corrs(vis, flag, - weight_spectrum, sigma_spectrum, - chan_freq, chan_width, - effective_bw, resolution): +def chan_corrs( + vis, + flag, + weight_spectrum, + sigma_spectrum, + chan_freq, + chan_width, + effective_bw, + resolution, +): """ Infer channel and correlation size from input dimensions @@ -253,12 +265,14 @@ def flags_match(flag_row, ri, out_flag_row, ro): pass -@overload(flags_match, inline='always') +@overload(flags_match, inline="always") def _flags_match(flag_row, ri, out_flag_row, ro): if is_numba_type_none(flag_row): + def impl(flag_row, ri, out_flag_row, ro): return True else: + def impl(flag_row, ri, out_flag_row, ro): return flag_row[ri] == out_flag_row[ro] @@ -271,7 +285,7 @@ def vis_output_arrays(typingctx, vis, out_shape): from numba.np import numpy_support def vis_weight_types(vis): - """ Determine output visibility and weight types """ + """Determine output visibility and weight types""" if isinstance(vis.dtype, types.Complex): # Use the float representation as dtype if vis is complex @@ -345,10 +359,9 @@ def gen_array_factory(numba_dtype): factory_args = [out_shape] # Compile function and get handle to output array - inner_value = context.compile_internal(builder, - array_factory, - factory_sig, - factory_args) + inner_value = context.compile_internal( + builder, array_factory, factory_sig, factory_args + ) # Insert inner tuple into outer tuple elif have_vis_tuple: @@ -363,8 +376,9 @@ def gen_array_factory(numba_dtype): factory_args = [out_shape] # Compile function and get handle to output - data = context.compile_internal(builder, array_factory, - factory_sig, factory_args) + data = context.compile_internal( + builder, array_factory, factory_sig, factory_args + ) # Insert data into inner_value inner_value = builder.insert_value(inner_value, data, j) diff --git a/africanus/averaging/splines.py b/africanus/averaging/splines.py index 983c0cf77..5a889c1e8 100644 --- a/africanus/averaging/splines.py +++ b/africanus/averaging/splines.py @@ -11,8 +11,7 @@ @njit(nogil=True, cache=True) -def solve_trid_system(x, y, left_type=2, right_type=2, - left_value=0.0, right_value=0.0): +def solve_trid_system(x, y, left_type=2, right_type=2, left_value=0.0, right_value=0.0): """ Solves a tridiagonal matrix @@ -23,12 +22,13 @@ def solve_trid_system(x, y, left_type=2, right_type=2, n = x.shape[0] # Construct tridiagonal matrix - for i in range(1, n-1): - diag[i, A] = (1.0 / 3.0) * (x[i] - x[i-1]) - diag[i, B] = (2.0 / 3.0) * (x[i+1] - x[i-1]) - diag[i, C] = (1.0 / 3.0) * (x[i+1] - x[i]) - v[i] = ((y[i+1] - y[i])/(x[i+1] - x[i]) - - (y[i] - y[i-1])/(x[i] - x[i-1])) + for i in range(1, n - 1): + diag[i, A] = (1.0 / 3.0) * (x[i] - x[i - 1]) + diag[i, B] = (2.0 / 3.0) * (x[i + 1] - x[i - 1]) + diag[i, C] = (1.0 / 3.0) * (x[i + 1] - x[i]) + v[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / ( + x[i] - x[i - 1] + ) # Configure left end point if left_type == 2: @@ -44,35 +44,34 @@ def solve_trid_system(x, y, left_type=2, right_type=2, # Configure right endpoint if right_type == 2: - diag[n-1, B] = 2.0 - diag[n-1, C] = 0.0 - v[n-1] = right_value + diag[n - 1, B] = 2.0 + diag[n - 1, C] = 0.0 + v[n - 1] = right_value elif left_type == 1: - diag[n-1, B] = 2.0 * (x[n-1] - x[n-2]) - diag[n-1, C] = 1.0 * (x[n-1] - x[n-2]) - v[n-1] = 3.0 * (right_value - (y[n-1] - y[n-2]) / (x[n-1] - x[n-2])) + diag[n - 1, B] = 2.0 * (x[n - 1] - x[n - 2]) + diag[n - 1, C] = 1.0 * (x[n - 1] - x[n - 2]) + v[n - 1] = 3.0 * (right_value - (y[n - 1] - y[n - 2]) / (x[n - 1] - x[n - 2])) else: raise ValueError("right_type not in (1, 2)") # Solve tridiagonal system in place for i in range(1, n): - w = diag[i, A] - diag[i-1, B] - diag[i, B] -= w*diag[i-1, C] - v[i] -= w*v[i-1] + w = diag[i, A] - diag[i - 1, B] + diag[i, B] -= w * diag[i - 1, C] + v[i] -= w * v[i - 1] # Compute solution z = np.zeros_like(v) - z[n-1] = v[n-1] / diag[n-1, B] + z[n - 1] = v[n - 1] / diag[n - 1, B] for i in range(n - 1, -1, -1): - z[i] = (v[i] - diag[i, C]*z[i+1])/diag[i, B] + z[i] = (v[i] - diag[i, C] * z[i + 1]) / diag[i, B] return z @njit(nogil=True, cache=True) -def fit_cubic_spline(x, y, left_type=2, right_type=2, - left_value=0.0, right_value=0.0): +def fit_cubic_spline(x, y, left_type=2, right_type=2, left_value=0.0, right_value=0.0): b = solve_trid_system(x, y, left_type, right_type, left_value, right_value) a = np.empty_like(b) c = np.empty_like(b) @@ -80,13 +79,14 @@ def fit_cubic_spline(x, y, left_type=2, right_type=2, n = x.shape[0] for i in range(n - 1): - a[i] = (b[i+1] - b[i]) / (3*(x[i+1] - x[i])) - c[i] = ((y[i+1] - y[i]) / (x[i+1] - x[i]) - - (2.0*b[i] + b[i+1]) * (x[i+1] - x[i]) / 3.0) + a[i] = (b[i + 1] - b[i]) / (3 * (x[i + 1] - x[i])) + c[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (2.0 * b[i] + b[i + 1]) * ( + x[i + 1] - x[i] + ) / 3.0 - h = x[n-2] - x[n-1] - a[n-1] = 0 - c[n-1] = 3.0*a[n-2]*h*h + 2.0*b[n-2]*h + c[n-2] + h = x[n - 2] - x[n - 1] + a[n - 1] = 0 + c[n - 1] = 3.0 * a[n - 2] * h * h + 2.0 * b[n - 2] * h + c[n - 2] return Spline(a, b, c, x, y) @@ -104,38 +104,38 @@ def evaluate_spline(spline, x, order=0): if order == 0: for i, p in enumerate(x): - j = max(np.searchsorted(mx, p, side='right') - 1, 0) + j = max(np.searchsorted(mx, p, side="right") - 1, 0) h = p - mx[j] if p < x[0]: - values[i] = (mb0*h + mc0)*h + my[0] - elif p > x[n-1]: - values[i] = (mb[n-1]*h + mc[n-1])*h + my[n-1] + values[i] = (mb0 * h + mc0) * h + my[0] + elif p > x[n - 1]: + values[i] = (mb[n - 1] * h + mc[n - 1]) * h + my[n - 1] else: - values[i] = ((ma[j]*h + mb[j])*h + mc[j])*h + my[j] + values[i] = ((ma[j] * h + mb[j]) * h + mc[j]) * h + my[j] elif order == 1: for i, p in enumerate(x): - j = max(np.searchsorted(mx, p, side='right') - 1, 0) + j = max(np.searchsorted(mx, p, side="right") - 1, 0) h = p - mx[j] if p < x[0]: - values[i] = 2.0*mb0*h + mc0 - elif p > x[n-1]: - values[i] = 2.0*mb[n-1]*h + mc[n-1] + values[i] = 2.0 * mb0 * h + mc0 + elif p > x[n - 1]: + values[i] = 2.0 * mb[n - 1] * h + mc[n - 1] else: - values[i] = (3.0*ma[j]*h + 2.0*mb[j])*h + mc[j] + values[i] = (3.0 * ma[j] * h + 2.0 * mb[j]) * h + mc[j] elif order == 2: for i, p in enumerate(x): - j = max(np.searchsorted(mx, p, side='right') - 1, 0) + j = max(np.searchsorted(mx, p, side="right") - 1, 0) h = p - mx[j] if p < x[0]: - values[i] = 2.0*mb0*h - elif p > x[n-1]: - values[i] = 2.0*mb[n-1] + values[i] = 2.0 * mb0 * h + elif p > x[n - 1]: + values[i] = 2.0 * mb[n - 1] else: - values[i] = 6.0*ma[j]*h + 2.0*mb[j] + values[i] = 6.0 * ma[j] * h + 2.0 * mb[j] else: raise ValueError("order not in (0, 1, 2)") diff --git a/africanus/averaging/support.py b/africanus/averaging/support.py index a4ca769ef..fa23de7b7 100644 --- a/africanus/averaging/support.py +++ b/africanus/averaging/support.py @@ -10,18 +10,19 @@ @njit(nogil=True, cache=True) def _unique_internal(data): if len(data.shape) != 1: - raise ValueError("_unique_internal currently " - "only supports 1D arrays") + raise ValueError("_unique_internal currently " "only supports 1D arrays") # Handle the empty array case if data.shape[0] == 0: - return (data, - np.empty((0,), dtype=np.intp), - np.empty((0,), dtype=np.intp), - np.empty((0,), dtype=np.intp)) + return ( + data, + np.empty((0,), dtype=np.intp), + np.empty((0,), dtype=np.intp), + np.empty((0,), dtype=np.intp), + ) # See numpy's unique1d - perm = np.argsort(data, kind='mergesort') + perm = np.argsort(data, kind="mergesort") # Combine these arrays to save on allocations? aux = np.empty_like(data) @@ -64,7 +65,7 @@ def unique_time_impl(time): @overload(unique_time_impl, jit_options=JIT_OPTIONS) def nb_unique_time(time): - """ Return unique time, inverse index and counts """ + """Return unique time, inverse index and counts""" if time.dtype not in (numba.float32, numba.float64): raise ValueError("time must be floating point but is %s" % time.dtype) @@ -85,12 +86,13 @@ def unique_baselines_impl(ant1, ant2): @overload(unique_baselines_impl, jit_options=JIT_OPTIONS) def nb_unique_baselines(ant1, ant2): - """ Return unique baselines, inverse index and counts """ + """Return unique baselines, inverse index and counts""" if not ant1.dtype == numba.int32 or not ant2.dtype == numba.int32: # Need these to be int32 for the bl_32bit.view(np.int64) trick - raise ValueError("ant1 and ant2 must be np.int32 " - "but received %s and %s" % - (ant1.dtype, ant2.dtype)) + raise ValueError( + "ant1 and ant2 must be np.int32 " + "but received %s and %s" % (ant1.dtype, ant2.dtype) + ) def impl(ant1, ant2): # Trickery, stack the two int32 antenna pairs in an array diff --git a/africanus/averaging/tests/test_bda_averaging.py b/africanus/averaging/tests/test_bda_averaging.py index 7ea2688c1..4ee723521 100644 --- a/africanus/averaging/tests/test_bda_averaging.py +++ b/africanus/averaging/tests/test_bda_averaging.py @@ -11,84 +11,52 @@ from africanus.averaging.dask import bda as dask_bda -@pytest.fixture(params=[ - # 5 rows, 4 channels => 3 rows - # - # row 1 contains 2 channels - # row 2 contains 3 channels - # row 3 contains 1 channel - [[0, 0, 1, 1], - [0, 0, 1, 1], - [2, 3, 3, 4], - [2, 3, 3, 4], - [5, 5, 5, 5]] -]) +@pytest.fixture( + params=[ + # 5 rows, 4 channels => 3 rows + # + # row 1 contains 2 channels + # row 2 contains 3 channels + # row 3 contains 1 channel + [[0, 0, 1, 1], [0, 0, 1, 1], [2, 3, 3, 4], [2, 3, 3, 4], [5, 5, 5, 5]] + ] +) def bda_test_map(request): return np.asarray(request.param) -@pytest.fixture(params=[ - # No flags - [[0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], - - # Row 0 and 1 flagged - [[1, 1, 1, 1], - [1, 1, 1, 1], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]], - - # Row 2 flagged - [[0, 0, 0, 0], - [0, 0, 0, 0], - [1, 1, 1, 1], - [0, 0, 0, 0], - [0, 0, 0, 0]], - - # Row 0, 2, 4 flagged - [[1, 1, 1, 1], - [0, 0, 0, 0], - [1, 1, 1, 1], - [0, 0, 0, 0], - [1, 1, 1, 1]], - - - # All flagged - [[1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1]], - - - # Partially flagged - [[0, 1, 0, 1], - [0, 1, 0, 0], - [0, 0, 0, 0], - [1, 1, 1, 1], - [1, 0, 0, 0]], -]) +@pytest.fixture( + params=[ + # No flags + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + # Row 0 and 1 flagged + [[1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + # Row 2 flagged + [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]], + # Row 0, 2, 4 flagged + [[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]], + # All flagged + [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]], + # Partially flagged + [[0, 1, 0, 1], [0, 1, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 0, 0, 0]], + ] +) def flags(request): return np.asarray(request.param) @pytest.fixture def inv_bda_test_map(bda_test_map): - """ Generates a :code:`{out_row: [in_row, in_chan]}` mapping""" + """Generates a :code:`{out_row: [in_row, in_chan]}` mapping""" inv = defaultdict(list) for idx in np.ndindex(*bda_test_map.shape): inv[bda_test_map[idx]].append(idx) - return {ro: tuple(list(i) for i in zip(*v)) - for ro, v in inv.items()} + return {ro: tuple(list(i) for i in zip(*v)) for ro, v in inv.items()} def _effective_row_map(flag_row, inv_row_map): - """ Build an effective row map """ + """Build an effective row map""" emap = [] for _, (rows, counts) in sorted(inv_row_map.items()): @@ -120,7 +88,7 @@ def _calc_sigma(weight, sigma, rows): weight = weight[rows] sigma = sigma[rows] numerator = (sigma**2 * weight**2).sum(axis=0) - denominator = weight.sum(axis=0)**2 + denominator = weight.sum(axis=0) ** 2 denominator[denominator == 0.0] = 1.0 return np.sqrt(numerator / denominator) @@ -132,8 +100,9 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): # Derive flag_row from flags flag_row = flags.all(axis=1) - out_chan = np.array([np.unique(rows).size for rows - in np.unique(bda_test_map, axis=0)]) + out_chan = np.array( + [np.unique(rows).size for rows in np.unique(bda_test_map, axis=0)] + ) # Number of output rows is sum of row channels out_row = out_chan.sum() @@ -146,10 +115,10 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): time = np.linspace(1.0, float(in_row), in_row, dtype=np.float64) # noqa interval = np.full(in_row, 1.0, dtype=np.float64) # noqa - uvw = np.arange(in_row*3).reshape(in_row, 3).astype(np.float64) + uvw = np.arange(in_row * 3).reshape(in_row, 3).astype(np.float64) weight = rs.normal(size=(in_row, in_corr)) sigma = rs.normal(size=(in_row, in_corr)) - chan_width = np.repeat(.856e9 / out_chan, out_chan) + chan_width = np.repeat(0.856e9 / out_chan, out_chan) # Aggregate time and interval, in_row => out_row # first channel in the map. We're only averaging over @@ -170,32 +139,43 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): out_interval[:] = out_interval[copy_idx] out_time /= out_counts - inv_row_map = {ro: np.unique(rows, return_counts=True) - for ro, (rows, _) in inv_bda_test_map.items()} - out_time2 = [time[rows].sum() / len(counts) for _, (rows, counts) - in sorted(inv_row_map.items())] + inv_row_map = { + ro: np.unique(rows, return_counts=True) + for ro, (rows, _) in inv_bda_test_map.items() + } + out_time2 = [ + time[rows].sum() / len(counts) + for _, (rows, counts) in sorted(inv_row_map.items()) + ] assert_array_equal(out_time, out_time2) - out_interval2 = [interval[rows].sum() for _, (rows, _) - in sorted(inv_row_map.items())] + out_interval2 = [ + interval[rows].sum() for _, (rows, _) in sorted(inv_row_map.items()) + ] assert_array_equal(out_interval, out_interval2) - out_flag_row = [flag_row[rows].all() for _, (rows, _) - in sorted(inv_row_map.items())] + out_flag_row = [ + flag_row[rows].all() for _, (rows, _) in sorted(inv_row_map.items()) + ] - meta = RowMapOutput(bda_test_map, offsets, - chan_width, out_time, out_interval, - None, out_flag_row) + meta = RowMapOutput( + bda_test_map, offsets, chan_width, out_time, out_interval, None, out_flag_row + ) ant1 = np.full(in_row, 0, dtype=np.int32) ant2 = np.full(in_row, 1, dtype=np.int32) - row_avg = row_average(meta, ant1, ant2, - time_centroid=time, - exposure=interval, - uvw=uvw, - weight=weight, sigma=sigma, - flag_row=flag_row) + row_avg = row_average( + meta, + ant1, + ant2, + time_centroid=time, + exposure=interval, + uvw=uvw, + weight=weight, + sigma=sigma, + flag_row=flag_row, + ) assert_array_equal(row_avg.antenna1, 0) assert_array_equal(row_avg.antenna2, 1) @@ -215,20 +195,25 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): assert_array_almost_equal(row_avg.sigma, out_sigma) vshape = (in_row, in_chan, in_corr) - vis = rs.normal(size=vshape) + rs.normal(size=vshape)*1j + vis = rs.normal(size=vshape) + rs.normal(size=vshape) * 1j weight_spectrum = rs.normal(size=vshape) sigma_spectrum = rs.normal(size=vshape) flag = np.broadcast_to(flags[:, :, None], vshape) effective_map = _effective_rowchan_map(flags, inv_bda_test_map) - out_ws = np.stack([ - weight_spectrum[r, c, :].sum(axis=0) for r, c in effective_map]) - out_ss = np.stack([ - (sigma_spectrum[r, c, :]**2 * weight_spectrum[r, c, :]**2).sum(axis=0) - for r, c in effective_map]) - out_vis = np.stack([ - (vis[r, c, :]*weight_spectrum[r, c, :]).sum(axis=0) - for r, c in effective_map]) + out_ws = np.stack([weight_spectrum[r, c, :].sum(axis=0) for r, c in effective_map]) + out_ss = np.stack( + [ + (sigma_spectrum[r, c, :] ** 2 * weight_spectrum[r, c, :] ** 2).sum(axis=0) + for r, c in effective_map + ] + ) + out_vis = np.stack( + [ + (vis[r, c, :] * weight_spectrum[r, c, :]).sum(axis=0) + for r, c in effective_map + ] + ) out_flag = np.stack([flag[r, c, :].all(axis=0) for r, c in effective_map]) weight_div = out_ws.copy() @@ -238,10 +223,13 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): # Broadcast flag data up to correlation dimension row_chan_avg = row_chan_average( - meta, flag_row=flag_row, visibilities=vis, + meta, + flag_row=flag_row, + visibilities=vis, weight_spectrum=weight_spectrum, sigma_spectrum=sigma_spectrum, - flag=flag) + flag=flag, + ) assert_array_almost_equal(row_chan_avg.visibilities, out_vis) assert_array_almost_equal(row_chan_avg.flag, out_flag) @@ -251,14 +239,9 @@ def test_bda_avg(bda_test_map, inv_bda_test_map, flags): @pytest.mark.parametrize("vis_format", ["ragged", "flat"]) def test_dask_bda_avg(vis_format): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") - dim_chunks = { - "chan": (4,), - "time": (5, 4, 5), - "ant": (7,), - "corr": (4,) - } + dim_chunks = {"chan": (4,), "time": (5, 4, 5), "ant": (7,), "corr": (4,)} ant1, ant2 = np.triu_indices(sum(dim_chunks["ant"]), 1) ant1 = ant1.astype(np.int32) @@ -271,7 +254,7 @@ def test_dask_bda_avg(vis_format): ant2 = np.tile(ant2, ntime) interval = np.full(time.shape, 1.0) - row_chunks = tuple(t*nbl for t in dim_chunks["time"]) + row_chunks = tuple(t * nbl for t in dim_chunks["time"]) nrow = sum(row_chunks) assert nrow == time.shape[0] @@ -283,13 +266,13 @@ def test_dask_bda_avg(vis_format): rs = np.random.RandomState(42) uvw = rs.normal(size=(nrow, 3)) - vis = rs.normal(size=vshape) + rs.normal(size=vshape)*1j + vis = rs.normal(size=vshape) + rs.normal(size=vshape) * 1j flag = rs.randint(0, 2, size=vshape) flag_row = flag.all(axis=(1, 2)) assert flag_row.shape == (nrow,) - chan_freq = np.linspace(.856e9, 2*.856e9, nchan) - chan_width = np.full(nchan, .856e9 / nchan) + chan_freq = np.linspace(0.856e9, 2 * 0.856e9, nchan) + chan_width = np.full(nchan, 0.856e9 / nchan) chan_chunks = dim_chunks["chan"] decorrelation = 0.999 @@ -309,36 +292,57 @@ def test_dask_bda_avg(vis_format): da_vis = da.from_array(vis, chunks=vis_chunks) da_flag = da.from_array(flag, chunks=vis_chunks) - avg = dask_bda(da_time, da_interval, da_ant1, da_ant2, - time_centroid=da_time_centroid, exposure=da_exposure, - flag_row=da_flag_row, uvw=da_uvw, - chan_freq=da_chan_freq, chan_width=da_chan_width, - visibilities=da_vis, flag=da_flag, - decorrelation=decorrelation, - format=vis_format) - - avg = {f: getattr(avg, f) for f in ("time", "interval", - "time_centroid", "exposure", - "visibilities")} - - avg2 = dask_bda(da_time, da_interval, da_ant1, da_ant2, - time_centroid=da_time_centroid, exposure=da_exposure, - flag_row=da_flag_row, uvw=da_uvw, - chan_freq=da_chan_freq, chan_width=da_chan_width, - visibilities=(da_vis, da_vis), flag=da_flag, - decorrelation=decorrelation, - format=vis_format) - - avg2 = {f: getattr(avg2, f) for f in ("time", "interval", - "time_centroid", "exposure", - "visibilities")} + avg = dask_bda( + da_time, + da_interval, + da_ant1, + da_ant2, + time_centroid=da_time_centroid, + exposure=da_exposure, + flag_row=da_flag_row, + uvw=da_uvw, + chan_freq=da_chan_freq, + chan_width=da_chan_width, + visibilities=da_vis, + flag=da_flag, + decorrelation=decorrelation, + format=vis_format, + ) + + avg = { + f: getattr(avg, f) + for f in ("time", "interval", "time_centroid", "exposure", "visibilities") + } + + avg2 = dask_bda( + da_time, + da_interval, + da_ant1, + da_ant2, + time_centroid=da_time_centroid, + exposure=da_exposure, + flag_row=da_flag_row, + uvw=da_uvw, + chan_freq=da_chan_freq, + chan_width=da_chan_width, + visibilities=(da_vis, da_vis), + flag=da_flag, + decorrelation=decorrelation, + format=vis_format, + ) + + avg2 = { + f: getattr(avg2, f) + for f in ("time", "interval", "time_centroid", "exposure", "visibilities") + } import dask - result = dask.persist(avg, scheduler='single-threaded')[0] - result2 = dask.persist(avg2, scheduler='single-threaded')[0] - assert_array_almost_equal(result['interval'], result['exposure']) - assert_array_almost_equal(result['time'], result['time_centroid']) + result = dask.persist(avg, scheduler="single-threaded")[0] + result2 = dask.persist(avg2, scheduler="single-threaded")[0] + + assert_array_almost_equal(result["interval"], result["exposure"]) + assert_array_almost_equal(result["time"], result["time_centroid"]) # Flatten all three visibility graphs dsk1 = dict(result["visibilities"].__dask_graph__()) diff --git a/africanus/averaging/tests/test_bda_mapping.py b/africanus/averaging/tests/test_bda_mapping.py index fc0508695..6683730c3 100644 --- a/africanus/averaging/tests/test_bda_mapping.py +++ b/africanus/averaging/tests/test_bda_mapping.py @@ -14,73 +14,134 @@ def nchan(request): @pytest.fixture(scope="session") def time(): - return np.array([ - 5.03373334e+09, 5.03373334e+09, 5.03373335e+09, - 5.03373336e+09, 5.03373337e+09, 5.03373338e+09, - 5.03373338e+09, 5.03373339e+09, 5.03373340e+09, - 5.03373341e+09, 5.03373342e+09, 5.03373342e+09, - 5.03373343e+09, 5.03373344e+09, 5.03373345e+09, - 5.03373346e+09, 5.03373346e+09, 5.03373347e+09, - 5.03373348e+09, 5.03373349e+09, 5.03373350e+09, - 5.03373350e+09, 5.03373351e+09, 5.03373352e+09, - 5.03373353e+09, 5.03373354e+09, 5.03373354e+09, - 5.03373355e+09, 5.03373356e+09, 5.03373357e+09, - 5.03373358e+09, 5.03373358e+09, 5.03373359e+09, - 5.03373360e+09, 5.03373361e+09, 5.03373362e+09]) + return np.array( + [ + 5.03373334e09, + 5.03373334e09, + 5.03373335e09, + 5.03373336e09, + 5.03373337e09, + 5.03373338e09, + 5.03373338e09, + 5.03373339e09, + 5.03373340e09, + 5.03373341e09, + 5.03373342e09, + 5.03373342e09, + 5.03373343e09, + 5.03373344e09, + 5.03373345e09, + 5.03373346e09, + 5.03373346e09, + 5.03373347e09, + 5.03373348e09, + 5.03373349e09, + 5.03373350e09, + 5.03373350e09, + 5.03373351e09, + 5.03373352e09, + 5.03373353e09, + 5.03373354e09, + 5.03373354e09, + 5.03373355e09, + 5.03373356e09, + 5.03373357e09, + 5.03373358e09, + 5.03373358e09, + 5.03373359e09, + 5.03373360e09, + 5.03373361e09, + 5.03373362e09, + ] + ) @pytest.fixture(scope="session") def interval(): - return np.array([ - 7.99661697, 7.99661697, 7.99661697, 7.99661697, 7.99661697, - 7.99661697, 7.99661697, 7.99661697, 7.99661697, 7.99661697, - 7.99661697, 7.99661697, 7.99661697, 7.99661697, 7.99661697, - 7.99661697, 7.99661697, 7.99661697, 7.99661697, 7.99661697, - 7.99661697, 7.99661697, 7.99661697, 7.99661697, 7.99661697, - 7.99661697, 7.99661697, 7.99661697, 7.99661697, 7.99661697, - 7.99661697, 7.99661697, 7.99661697, 7.99661697, 7.99661697, - 7.99661697]) + return np.array( + [ + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + 7.99661697, + ] + ) @pytest.fixture(scope="session") def ants(): - return np.array([ - [5109224.29038545, 2006790.35753831, -3239100.60907827], - [5109247.7157809, 2006736.96831224, -3239096.13639116], - [5109222.76106102, 2006688.94849795, -3239165.94167899], - [5109101.13948279, 2006650.38001812, -3239383.31891167], - [5109132.81491624, 2006798.06346825, -3239242.1849703], - [5109046.33257705, 2006823.98423929, -3239363.78875328], - [5109095.03238529, 2006898.89823927, -3239239.95261248], - [5109082.8918671, 2007045.24176653, -3239169.09131402], - [5109139.53289849, 2006992.25575245, -3239111.37956843], - [5109368.62360157, 2006509.64851116, -3239043.72292735], - [5109490.75061883, 2006708.38364351, -3238726.60664016], - [5109310.2977957, 2007017.0371345, -3238823.74833534], - [5109273.32322089, 2007083.40054198, -3238841.20407917], - [5109233.60247272, 2007298.47483172, -3238770.86653967], - [5109514.1862076, 2007536.98018719, -3238177.03761655], - [5109175.83425585, 2007164.6225741, -3238946.9157957], - [5109093.99046283, 2007162.9306937, -3239078.77530747], - [5108965.29396408, 2007106.07798817, -3239319.10626408], - [5108993.64502175, 2006679.78785901, -3239536.3704696], - [5109111.46526165, 2006445.98820889, -3239491.95574845], - [5109486.39986795, 2006225.48918911, -3239031.01140517], - [5109925.48993011, 2006111.83927162, -3238401.39137192], - [5110109.89167353, 2005177.90721032, -3238688.71487862], - [5110676.49309192, 2005793.15912039, -3237408.15958056], - [5109284.52911273, 2006201.59095546, -3239366.63085706], - [5111608.06713389, 2004721.2262196, -3236602.97648213], - [5110840.88031587, 2003560.05835788, -3238544.12229424], - [5109666.45350777, 2004767.93425934, -3239646.10724868], - [5108767.23563213, 2007556.54497446, -3239354.53798391], - [5108927.44284297, 2007973.80069955, -3238840.15661171], - [5110746.29394702, 2007713.62376395, -3236109.83563026], - [5109561.42891041, 2009946.10154943, -3236606.07622565], - [5108335.37384839, 2010410.68719286, -3238271.56790951], - [5107206.7556267, 2009680.79691055, -3240512.45932645], - [5108231.34344288, 2006391.59690538, -3240926.75417832], - [5108666.77102205, 2005032.4814725, -3241081.69797118]]) + return np.array( + [ + [5109224.29038545, 2006790.35753831, -3239100.60907827], + [5109247.7157809, 2006736.96831224, -3239096.13639116], + [5109222.76106102, 2006688.94849795, -3239165.94167899], + [5109101.13948279, 2006650.38001812, -3239383.31891167], + [5109132.81491624, 2006798.06346825, -3239242.1849703], + [5109046.33257705, 2006823.98423929, -3239363.78875328], + [5109095.03238529, 2006898.89823927, -3239239.95261248], + [5109082.8918671, 2007045.24176653, -3239169.09131402], + [5109139.53289849, 2006992.25575245, -3239111.37956843], + [5109368.62360157, 2006509.64851116, -3239043.72292735], + [5109490.75061883, 2006708.38364351, -3238726.60664016], + [5109310.2977957, 2007017.0371345, -3238823.74833534], + [5109273.32322089, 2007083.40054198, -3238841.20407917], + [5109233.60247272, 2007298.47483172, -3238770.86653967], + [5109514.1862076, 2007536.98018719, -3238177.03761655], + [5109175.83425585, 2007164.6225741, -3238946.9157957], + [5109093.99046283, 2007162.9306937, -3239078.77530747], + [5108965.29396408, 2007106.07798817, -3239319.10626408], + [5108993.64502175, 2006679.78785901, -3239536.3704696], + [5109111.46526165, 2006445.98820889, -3239491.95574845], + [5109486.39986795, 2006225.48918911, -3239031.01140517], + [5109925.48993011, 2006111.83927162, -3238401.39137192], + [5110109.89167353, 2005177.90721032, -3238688.71487862], + [5110676.49309192, 2005793.15912039, -3237408.15958056], + [5109284.52911273, 2006201.59095546, -3239366.63085706], + [5111608.06713389, 2004721.2262196, -3236602.97648213], + [5110840.88031587, 2003560.05835788, -3238544.12229424], + [5109666.45350777, 2004767.93425934, -3239646.10724868], + [5108767.23563213, 2007556.54497446, -3239354.53798391], + [5108927.44284297, 2007973.80069955, -3238840.15661171], + [5110746.29394702, 2007713.62376395, -3236109.83563026], + [5109561.42891041, 2009946.10154943, -3236606.07622565], + [5108335.37384839, 2010410.68719286, -3238271.56790951], + [5107206.7556267, 2009680.79691055, -3240512.45932645], + [5108231.34344288, 2006391.59690538, -3240926.75417832], + [5108666.77102205, 2005032.4814725, -3241081.69797118], + ] + ) @pytest.fixture(scope="session") @@ -90,12 +151,12 @@ def phase_dir(): @pytest.fixture(scope="session") def chan_width(nchan): - return np.full(nchan, (2*.856e9 - .856e9) / nchan) + return np.full(nchan, (2 * 0.856e9 - 0.856e9) / nchan) @pytest.fixture(scope="session") def chan_freq(chan_width): - return .856e9 + np.cumsum(np.concatenate([[0], chan_width[1:]])) + return 0.856e9 + np.cumsum(np.concatenate([[0], chan_width[1:]])) @pytest.fixture(scope="session") @@ -119,16 +180,14 @@ def synthesized_uvw(ants, time, phase_dir, auto_correlations): of these new coordinates may be wrong, depending on whether data timesteps were heavily flagged. """ - pytest.importorskip('pyrap') + pytest.importorskip("pyrap") from pyrap.measures import measures from pyrap.quanta import quantity as q dm = measures() epoch = dm.epoch("UT1", q(time[0], "s")) - ref_dir = dm.direction("j2000", - q(phase_dir[0], "rad"), - q(phase_dir[1], "rad")) + ref_dir = dm.direction("j2000", q(phase_dir[0], "rad"), q(phase_dir[1], "rad")) ox, oy, oz = ants[0] obs = dm.position("ITRF", q(ox, "m"), q(oy, "m"), q(oz, "m")) @@ -137,8 +196,7 @@ def synthesized_uvw(ants, time, phase_dir, auto_correlations): dm.do_frame(ref_dir) dm.do_frame(epoch) - ant1, ant2 = np.triu_indices(ants.shape[0], - 0 if auto_correlations else 1) + ant1, ant2 = np.triu_indices(ants.shape[0], 0 if auto_correlations else 1) ant1 = ant1.astype(np.int32) ant2 = ant2.astype(np.int32) @@ -157,26 +215,23 @@ def synthesized_uvw(ants, time, phase_dir, auto_correlations): # Calculate antenna UVW positions for ai, (x, y, z) in enumerate(ants): - bl = dm.baseline("ITRF", - q([x, ox], "m"), - q([y, oy], "m"), - q([z, oz], "m")) + bl = dm.baseline("ITRF", q([x, ox], "m"), q([y, oy], "m"), q([z, oz], "m")) ant_uvw[ai] = dm.to_uvw(bl)["xyz"].get_value()[0:3] # Now calculate baseline UVW positions # noting that ant1 - ant2 is the CASA convention - base = ti*nbl - uvw[base:base + nbl, :] = ant_uvw[ant1] - ant_uvw[ant2] + base = ti * nbl + uvw[base : base + nbl, :] = ant_uvw[ant1] - ant_uvw[ant2] return ant1, ant2, uvw @pytest.mark.parametrize("decorrelation", [0.95]) @pytest.mark.parametrize("min_nchan", [1]) -def test_bda_mapper(time, synthesized_uvw, interval, - chan_freq, chan_width, - decorrelation, min_nchan): +def test_bda_mapper( + time, synthesized_uvw, interval, chan_freq, chan_width, decorrelation, min_nchan +): time = np.unique(time) ant1, ant2, uvw = synthesized_uvw @@ -191,13 +246,20 @@ def test_bda_mapper(time, synthesized_uvw, interval, max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max() - row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841 - chan_width, chan_freq, - max_uvw_dist, - flag_row=flag_row, - max_fov=3.0, - decorrelation=decorrelation, - min_nchan=min_nchan) + row_meta = bda_mapper( + time, + interval, + ant1, + ant2, + uvw, # noqa :F841 + chan_width, + chan_freq, + max_uvw_dist, + flag_row=flag_row, + max_fov=3.0, + decorrelation=decorrelation, + min_nchan=min_nchan, + ) offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0]) assert_array_equal(offsets, row_meta.offsets[:-1]) @@ -211,8 +273,7 @@ def test_bda_mapper(time, synthesized_uvw, interval, assert_array_equal(decorr_cw, row_meta.decorr_chan_width) -def test_bda_binner(time, interval, synthesized_uvw, - ref_freq, chan_freq, chan_width): +def test_bda_binner(time, interval, synthesized_uvw, ref_freq, chan_freq, chan_width): time = np.unique(time) ant1, ant2, uvw = synthesized_uvw diff --git a/africanus/averaging/tests/test_mapping.py b/africanus/averaging/tests/test_mapping.py index 5a8273e31..0c8dd3d88 100644 --- a/africanus/averaging/tests/test_mapping.py +++ b/africanus/averaging/tests/test_mapping.py @@ -6,8 +6,7 @@ import pytest from africanus.averaging.support import unique_time, unique_baselines -from africanus.averaging.time_and_channel_mapping import (row_mapper, - channel_mapper) +from africanus.averaging.time_and_channel_mapping import row_mapper, channel_mapper @pytest.fixture @@ -18,19 +17,23 @@ def time(): @pytest.fixture def interval(): data = np.asarray([1.9, 2.0, 2.1, 1.85, 1.95, 2.0, 2.05, 2.1, 2.05, 1.9]) - return data*0.1 + return data * 0.1 @pytest.fixture def ant1(): - return np.asarray([0, 0, 1, 0, 0, 1, 2, 0, 0, 1], # noqa - dtype=np.int32) + return np.asarray( + [0, 0, 1, 0, 0, 1, 2, 0, 0, 1], # noqa + dtype=np.int32, + ) @pytest.fixture def ant2(): - return np.asarray([1, 2, 2, 0, 1, 2, 3, 0, 1, 2], # noqa - dtype=np.int32) + return np.asarray( + [1, 2, 2, 0, 1, 2, 3, 0, 1, 2], # noqa + dtype=np.int32, + ) def flag_row_factory(nrows, flagged_rows): @@ -44,8 +47,7 @@ def flag_row_factory(nrows, flagged_rows): @pytest.mark.parametrize("time_bin_secs", [0.1, 0.2, 1, 2, 4]) @pytest.mark.parametrize("flagged_rows", [None, [0, 1], [2, 4], range(10)]) -def test_row_mapper(time, interval, ant1, ant2, - flagged_rows, time_bin_secs): +def test_row_mapper(time, interval, ant1, ant2, flagged_rows, time_bin_secs): utime, _, time_inv, _ = unique_time(time) ubl, _, bl_inv, _ = unique_baselines(ant1, ant2) mask = np.full((ubl.shape[0], utime.shape[0]), -1, dtype=np.int32) @@ -54,9 +56,9 @@ def test_row_mapper(time, interval, ant1, ant2, flag_row = flag_row_factory(time.size, flagged_rows) - ret = row_mapper(time, interval, ant1, ant2, - flag_row=flag_row, - time_bin_secs=time_bin_secs) + ret = row_mapper( + time, interval, ant1, ant2, flag_row=flag_row, time_bin_secs=time_bin_secs + ) # For TIME AND INTERVAL, flagged inputs can # contribute to unflagged outputs diff --git a/africanus/averaging/tests/test_splines.py b/africanus/averaging/tests/test_splines.py index 0572bdf8b..43659f087 100644 --- a/africanus/averaging/tests/test_splines.py +++ b/africanus/averaging/tests/test_splines.py @@ -4,8 +4,7 @@ from numpy.testing import assert_almost_equal import pytest -from africanus.averaging.splines import (fit_cubic_spline, - evaluate_spline) +from africanus.averaging.splines import fit_cubic_spline, evaluate_spline # Generate y,z coords from given x coords diff --git a/africanus/averaging/tests/test_support.py b/africanus/averaging/tests/test_support.py index 582874f6a..4dc5b1b36 100644 --- a/africanus/averaging/tests/test_support.py +++ b/africanus/averaging/tests/test_support.py @@ -5,7 +5,7 @@ from numpy.testing import assert_array_equal import pytest -from africanus.averaging.support import (unique_baselines, unique_time) +from africanus.averaging.support import unique_baselines, unique_time @pytest.fixture @@ -15,21 +15,27 @@ def time(): @pytest.fixture def ant1(): - return np.asarray([0, 0, 1, 0, 0, 1, 2, 0, 0, 1], # noqa - dtype=np.int32) + return np.asarray( + [0, 0, 1, 0, 0, 1, 2, 0, 0, 1], # noqa + dtype=np.int32, + ) @pytest.fixture def ant2(): - return np.asarray([1, 2, 2, 0, 1, 2, 3, 0, 1, 2], # noqa - dtype=np.int32) + return np.asarray( + [1, 2, 2, 0, 1, 2, 3, 0, 1, 2], # noqa + dtype=np.int32, + ) @pytest.fixture def vis(): def _vis(row, chan, fcorrs): - return (np.arange(row*chan*fcorrs, dtype=np.float32) + - np.arange(1, row*chan*fcorrs+1, dtype=np.float32)*1j) + return ( + np.arange(row * chan * fcorrs, dtype=np.float32) + + np.arange(1, row * chan * fcorrs + 1, dtype=np.float32) * 1j + ) return _vis diff --git a/africanus/averaging/tests/test_time_and_channel_averaging.py b/africanus/averaging/tests/test_time_and_channel_averaging.py index 7ff744d30..c8ebf53cb 100644 --- a/africanus/averaging/tests/test_time_and_channel_averaging.py +++ b/africanus/averaging/tests/test_time_and_channel_averaging.py @@ -7,8 +7,7 @@ from africanus.averaging.support import unique_time, unique_baselines from africanus.averaging.time_and_channel_avg import time_and_channel -from africanus.averaging.time_and_channel_mapping import (row_mapper, - channel_mapper) +from africanus.averaging.time_and_channel_mapping import row_mapper, channel_mapper nchan = 16 ncorr = 4 @@ -21,28 +20,30 @@ def time(): @pytest.fixture def ant1(): - return np.asarray([0, 0, 1, 0, 0, 1, 2, 0, 0, 1], - dtype=np.int32) + return np.asarray([0, 0, 1, 0, 0, 1, 2, 0, 0, 1], dtype=np.int32) @pytest.fixture def ant2(): - return np.asarray([1, 2, 2, 0, 1, 2, 3, 0, 1, 2], - dtype=np.int32) + return np.asarray([1, 2, 2, 0, 1, 2, 3, 0, 1, 2], dtype=np.int32) @pytest.fixture def uvw(): - return np.asarray([[1.0, 1.0, 1.0], - [2.0, 2.0, 2.0], - [3.0, 3.0, 3.0], - [4.0, 4.0, 4.0], - [5.0, 5.0, 5.0], - [6.0, 6.0, 6.0], - [7.0, 7.0, 7.0], - [8.0, 8.0, 8.0], - [9.0, 9.0, 9.0], - [10.0, 10.0, 10.0]]) + return np.asarray( + [ + [1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0], + [4.0, 4.0, 4.0], + [5.0, 5.0, 5.0], + [6.0, 6.0, 6.0], + [7.0, 7.0, 7.0], + [8.0, 8.0, 8.0], + [9.0, 9.0, 9.0], + [10.0, 10.0, 10.0], + ] + ) @pytest.fixture @@ -77,19 +78,21 @@ def sigma_spectrum(time): @pytest.fixture def frequency(): - return np.linspace(.856, 2*.856e9, nchan) + return np.linspace(0.856, 2 * 0.856e9, nchan) @pytest.fixture def chan_width(): - return np.full(nchan, .856e9/nchan) + return np.full(nchan, 0.856e9 / nchan) @pytest.fixture def vis(): def _vis(row, chan, fcorrs): - flat_vis = (np.arange(row*chan*fcorrs, dtype=np.float64) + - np.arange(1, row*chan*fcorrs+1, dtype=np.float64)*1j) + flat_vis = ( + np.arange(row * chan * fcorrs, dtype=np.float64) + + np.arange(1, row * chan * fcorrs + 1, dtype=np.float64) * 1j + ) return flat_vis.reshape(row, chan, fcorrs) @@ -104,8 +107,7 @@ def _flag(row, chan, fcorrs): return _flag -def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, - row_meta): +def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, row_meta): """ Generates the same lookup as row_mapper, but different. @@ -118,8 +120,7 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, """ utime, _, time_inv, _ = unique_time(time) ubl, _, bl_inv, _ = unique_baselines(ant1, ant2) - bl_time_lookup = np.full((ubl.shape[0], utime.shape[0]), -1, - dtype=np.int32) + bl_time_lookup = np.full((ubl.shape[0], utime.shape[0]), -1, dtype=np.int32) # Create the row index row_idx = np.arange(time.size) @@ -191,9 +192,10 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, nominal_bin_map.append(nominal_map) # Produce a (avg_time, bl, effective_rows, nominal_rows) tuple - time_bl_row_map.extend((time[nrows].mean(), (a1, a2), erows, nrows) - for erows, nrows - in zip(effective_bin_map, nominal_bin_map)) + time_bl_row_map.extend( + (time[nrows].mean(), (a1, a2), erows, nrows) + for erows, nrows in zip(effective_bin_map, nominal_bin_map) + ) # Sort lookup sorted on averaged times return sorted(time_bl_row_map, key=lambda tup: tup[0]) @@ -204,24 +206,41 @@ def _calc_sigma(sigma, weight, idx): weight = weight[idx] numerator = (sigma**2 * weight**2).sum(axis=0) - denominator = weight.sum(axis=0)**2 + denominator = weight.sum(axis=0) ** 2 denominator[denominator == 0.0] = 1.0 return np.sqrt(numerator / denominator) -@pytest.mark.parametrize("flagged_rows", [ - [], [8, 9], [4], [0, 1], -]) +@pytest.mark.parametrize( + "flagged_rows", + [ + [], + [8, 9], + [4], + [0, 1], + ], +) @pytest.mark.parametrize("time_bin_secs", [1, 2, 3, 4]) @pytest.mark.parametrize("chan_bin_size", [1, 3, 5]) -def test_averager(time, ant1, ant2, flagged_rows, - uvw, interval, weight, sigma, - frequency, chan_width, - vis, flag, - weight_spectrum, sigma_spectrum, - time_bin_secs, chan_bin_size): - +def test_averager( + time, + ant1, + ant2, + flagged_rows, + uvw, + interval, + weight, + sigma, + frequency, + chan_width, + vis, + flag, + weight_spectrum, + sigma_spectrum, + time_bin_secs, + chan_bin_size, +): time_centroid = time exposure = interval @@ -239,30 +258,38 @@ def test_averager(time, ant1, ant2, flagged_rows, row_meta = row_mapper(time, interval, ant1, ant2, flag_row, time_bin_secs) chan_map, chan_bins = channel_mapper(nchan, chan_bin_size) - time_bl_row_map = _gen_testing_lookup(time_centroid, exposure, ant1, ant2, - flag_row, time_bin_secs, - row_meta) + time_bl_row_map = _gen_testing_lookup( + time_centroid, exposure, ant1, ant2, flag_row, time_bin_secs, row_meta + ) # Effective and Nominal rows associated with each output row - eff_idx, nom_idx = zip(*[(nrows, erows) for _, _, nrows, erows - in time_bl_row_map]) + eff_idx, nom_idx = zip(*[(nrows, erows) for _, _, nrows, erows in time_bl_row_map]) eff_idx = [ei for ei in eff_idx if len(ei) > 0] # Check that the averaged times from the test and accelerated lookup match - assert_array_equal([t for t, _, _, _ in time_bl_row_map], - row_meta.time) - - avg = time_and_channel(time, interval, ant1, ant2, - flag_row=flag_row, - time_centroid=time, exposure=exposure, uvw=uvw, - weight=weight, sigma=sigma, - chan_freq=frequency, chan_width=chan_width, - visibilities=vis, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum, - time_bin_secs=time_bin_secs, - chan_bin_size=chan_bin_size) + assert_array_equal([t for t, _, _, _ in time_bl_row_map], row_meta.time) + + avg = time_and_channel( + time, + interval, + ant1, + ant2, + flag_row=flag_row, + time_centroid=time, + exposure=exposure, + uvw=uvw, + weight=weight, + sigma=sigma, + chan_freq=frequency, + chan_width=chan_width, + visibilities=vis, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size, + ) # Take mean time, but first ant1 and ant2 expected_time_centroids = [time_centroid[i].mean(axis=0) for i in eff_idx] @@ -326,27 +353,41 @@ def test_averager(time, ant1, ant2, flagged_rows, exp_vis = exp_vis / exp_wts exp_sigma = np.sqrt(exp_sigma / (exp_wts**2)) - assert_array_almost_equal(exp_vis, - avg.visibilities[orow, ch, corr]) - assert_array_almost_equal(exp_wts, - avg.weight_spectrum[orow, ch, corr]) - assert_array_almost_equal(exp_sigma, - avg.sigma_spectrum[orow, ch, corr]) + assert_array_almost_equal(exp_vis, avg.visibilities[orow, ch, corr]) + assert_array_almost_equal(exp_wts, avg.weight_spectrum[orow, ch, corr]) + assert_array_almost_equal(exp_sigma, avg.sigma_spectrum[orow, ch, corr]) -@pytest.mark.parametrize("flagged_rows", [ - [], [8, 9], [4], [0, 1], -]) +@pytest.mark.parametrize( + "flagged_rows", + [ + [], + [8, 9], + [4], + [0, 1], + ], +) @pytest.mark.parametrize("time_bin_secs", [1, 2, 3, 4]) @pytest.mark.parametrize("chan_bin_size", [1, 3, 5]) -def test_dask_averager(time, ant1, ant2, flagged_rows, - uvw, interval, weight, sigma, - frequency, chan_width, - vis, flag, - weight_spectrum, sigma_spectrum, - time_bin_secs, chan_bin_size): - - da = pytest.importorskip('dask.array') +def test_dask_averager( + time, + ant1, + ant2, + flagged_rows, + uvw, + interval, + weight, + sigma, + frequency, + chan_width, + vis, + flag, + weight_spectrum, + sigma_spectrum, + time_bin_secs, + chan_bin_size, +): + da = pytest.importorskip("dask.array") from africanus.averaging.dask import time_and_channel as dask_avg @@ -368,15 +409,23 @@ def test_dask_averager(time, ant1, ant2, flagged_rows, flag[flag_row.astype(np.bool_), :, :] = 1 flag[~flag_row.astype(np.bool_), :, :] = 0 - np_avg = time_and_channel(time_centroid, exposure, ant1, ant2, - flag_row=flag_row, - chan_freq=frequency, chan_width=chan_width, - effective_bw=chan_width, resolution=chan_width, - visibilities=vis, flag=flag, - sigma_spectrum=sigma_spectrum, - weight_spectrum=weight_spectrum, - time_bin_secs=time_bin_secs, - chan_bin_size=chan_bin_size) + np_avg = time_and_channel( + time_centroid, + exposure, + ant1, + ant2, + flag_row=flag_row, + chan_freq=frequency, + chan_width=chan_width, + effective_bw=chan_width, + resolution=chan_width, + visibilities=vis, + flag=flag, + sigma_spectrum=sigma_spectrum, + weight_spectrum=weight_spectrum, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size, + ) # Using chunks == shape, the dask version should match the numpy version da_time_centroid = da.from_array(time_centroid, chunks=rows) @@ -388,35 +437,51 @@ def test_dask_averager(time, ant1, ant2, flagged_rows, da_ant2 = da.from_array(ant2, chunks=rows) da_chan_freq = da.from_array(frequency, chunks=chans) da_chan_width = da.from_array(chan_width, chunks=chans) - da_weight_spectrum = da.from_array(weight_spectrum, - chunks=(rows, chans, corrs)) - da_sigma_spectrum = da.from_array(sigma_spectrum, - chunks=(rows, chans, corrs)) + da_weight_spectrum = da.from_array(weight_spectrum, chunks=(rows, chans, corrs)) + da_sigma_spectrum = da.from_array(sigma_spectrum, chunks=(rows, chans, corrs)) da_vis = da.from_array(vis, chunks=(rows, chans, corrs)) da_flag = da.from_array(flag, chunks=(rows, chans, corrs)) - avg = dask_avg(da_time_centroid, da_exposure, da_ant1, da_ant2, - flag_row=da_flag_row, - chan_freq=da_chan_freq, chan_width=da_chan_width, - effective_bw=da_chan_width, resolution=da_chan_width, - weight=da_weight, sigma=da_sigma, - visibilities=da_vis, flag=da_flag, - weight_spectrum=da_weight_spectrum, - sigma_spectrum=da_sigma_spectrum, - time_bin_secs=time_bin_secs, - chan_bin_size=chan_bin_size) + avg = dask_avg( + da_time_centroid, + da_exposure, + da_ant1, + da_ant2, + flag_row=da_flag_row, + chan_freq=da_chan_freq, + chan_width=da_chan_width, + effective_bw=da_chan_width, + resolution=da_chan_width, + weight=da_weight, + sigma=da_sigma, + visibilities=da_vis, + flag=da_flag, + weight_spectrum=da_weight_spectrum, + sigma_spectrum=da_sigma_spectrum, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size, + ) # Compute all the averages in one go - (avg_time_centroid, avg_exposure, avg_flag_row, - avg_chan_freq, avg_chan_width, - avg_resolution, avg_vis, avg_flag) = da.compute( - avg.time_centroid, - avg.exposure, - avg.flag_row, - avg.chan_freq, - avg.chan_width, - avg.resolution, - avg.visibilities, avg.flag) + ( + avg_time_centroid, + avg_exposure, + avg_flag_row, + avg_chan_freq, + avg_chan_width, + avg_resolution, + avg_vis, + avg_flag, + ) = da.compute( + avg.time_centroid, + avg.exposure, + avg.flag_row, + avg.chan_freq, + avg.chan_width, + avg.resolution, + avg.visibilities, + avg.flag, + ) # Should match assert_array_equal(np_avg.time_centroid, avg_time_centroid) @@ -441,24 +506,38 @@ def test_dask_averager(time, ant1, ant2, flagged_rows, da_vis = da.from_array(vis, chunks=(rc, fc, cc)) da_flag = da.from_array(flag, chunks=(rc, fc, cc)) - avg = dask_avg(da_time_centroid, da_exposure, da_ant1, da_ant2, - flag_row=da_flag_row, - chan_freq=da_chan_freq, chan_width=da_chan_width, - visibilities=da_vis, flag=da_flag, - time_bin_secs=time_bin_secs, - chan_bin_size=chan_bin_size) + avg = dask_avg( + da_time_centroid, + da_exposure, + da_ant1, + da_ant2, + flag_row=da_flag_row, + chan_freq=da_chan_freq, + chan_width=da_chan_width, + visibilities=da_vis, + flag=da_flag, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size, + ) # Compute all the fields fields = [getattr(avg, f) for f in avg._fields] avg = type(avg)(*da.compute(fields)[0]) # Get same result with a visibility tuple - avg2 = dask_avg(da_time_centroid, da_exposure, da_ant1, da_ant2, - flag_row=da_flag_row, - chan_freq=da_chan_freq, chan_width=da_chan_width, - visibilities=(da_vis, da_vis), flag=da_flag, - time_bin_secs=time_bin_secs, - chan_bin_size=chan_bin_size) + avg2 = dask_avg( + da_time_centroid, + da_exposure, + da_ant1, + da_ant2, + flag_row=da_flag_row, + chan_freq=da_chan_freq, + chan_width=da_chan_width, + visibilities=(da_vis, da_vis), + flag=da_flag, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size, + ) assert_array_equal(avg.visibilities, avg2.visibilities[0]) assert_array_equal(avg.visibilities, avg2.visibilities[1]) diff --git a/africanus/averaging/time_and_channel_avg.py b/africanus/averaging/time_and_channel_avg.py index 208fd07e0..ea88a0179 100644 --- a/africanus/averaging/time_and_channel_avg.py +++ b/africanus/averaging/time_and_channel_avg.py @@ -6,15 +6,17 @@ from numba import types import numpy as np -from africanus.averaging.time_and_channel_mapping import (row_mapper, - channel_mapper) -from africanus.averaging.shared import (chan_corrs, - merge_flags, - vis_output_arrays) +from africanus.averaging.time_and_channel_mapping import row_mapper, channel_mapper +from africanus.averaging.shared import chan_corrs, merge_flags, vis_output_arrays from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (is_numba_type_none, JIT_OPTIONS, - njit, overload, intrinsic) +from africanus.util.numba import ( + is_numba_type_none, + JIT_OPTIONS, + njit, + overload, + intrinsic, +) TUPLE_TYPE = 0 ARRAY_TYPE = 1 @@ -23,69 +25,118 @@ def matching_flag_factory(present): if present: + def impl(flag_row, ri, out_flag_row, ro): return flag_row[ri] == out_flag_row[ro] else: + def impl(flag_row, ri, out_flag_row, ro): return True - return njit(nogil=True, cache=True, inline='always')(impl) + return njit(nogil=True, cache=True, inline="always")(impl) def is_chan_flagged(flag, r, f, c): pass -@overload(is_chan_flagged, inline='always') +@overload(is_chan_flagged, inline="always") def _is_chan_flagged(flag, r, f, c): if is_numba_type_none(flag): + def impl(flag, r, f, c): return True else: + def impl(flag, r, f, c): return flag[r, f, c] return impl -@njit(nogil=True, inline='always') +@njit(nogil=True, inline="always") def chan_add(output, input, orow, ochan, irow, ichan, corr): if input is not None: output[orow, ochan, corr] += input[irow, ichan, corr] -_row_output_fields = ["antenna1", "antenna2", "time_centroid", "exposure", - "uvw", "weight", "sigma"] +_row_output_fields = [ + "antenna1", + "antenna2", + "time_centroid", + "exposure", + "uvw", + "weight", + "sigma", +] RowAverageOutput = namedtuple("RowAverageOutput", _row_output_fields) @njit(**JIT_OPTIONS) -def row_average(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): - return row_average_impl(meta, ant1, ant2, flag_row=flag_row, - time_centroid=time_centroid, exposure=exposure, - uvw=uvw, weight=weight, sigma=sigma) - - -def row_average_impl(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): +def row_average( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, +): + return row_average_impl( + meta, + ant1, + ant2, + flag_row=flag_row, + time_centroid=time_centroid, + exposure=exposure, + uvw=uvw, + weight=weight, + sigma=sigma, + ) + + +def row_average_impl( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, +): return NotImplementedError @overload(row_average_impl, jit_options=JIT_OPTIONS) -def nb_row_average(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): - +def nb_row_average( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, +): have_flag_row = not is_numba_type_none(flag_row) flags_match = matching_flag_factory(have_flag_row) - def impl(meta, ant1, ant2, flag_row=None, - time_centroid=None, exposure=None, uvw=None, - weight=None, sigma=None): - + def impl( + meta, + ant1, + ant2, + flag_row=None, + time_centroid=None, + exposure=None, + uvw=None, + weight=None, + sigma=None, + ): out_rows = meta.time.shape[0] counts = np.zeros(out_rows, dtype=np.uint32) @@ -96,34 +147,42 @@ def impl(meta, ant1, ant2, flag_row=None, # Possibly present outputs for possibly present inputs uvw_avg = ( - None if uvw is None else - np.zeros((out_rows,) + uvw.shape[1:], - dtype=uvw.dtype)) + None + if uvw is None + else np.zeros((out_rows,) + uvw.shape[1:], dtype=uvw.dtype) + ) time_centroid_avg = ( - None if time_centroid is None else - np.zeros((out_rows,) + time_centroid.shape[1:], - dtype=time_centroid.dtype)) + None + if time_centroid is None + else np.zeros( + (out_rows,) + time_centroid.shape[1:], dtype=time_centroid.dtype + ) + ) exposure_avg = ( - None if exposure is None else - np.zeros((out_rows,) + exposure.shape[1:], - dtype=exposure.dtype)) + None + if exposure is None + else np.zeros((out_rows,) + exposure.shape[1:], dtype=exposure.dtype) + ) weight_avg = ( - None if weight is None else - np.zeros((out_rows,) + weight.shape[1:], - dtype=weight.dtype)) + None + if weight is None + else np.zeros((out_rows,) + weight.shape[1:], dtype=weight.dtype) + ) sigma_avg = ( - None if sigma is None else - np.zeros((out_rows,) + sigma.shape[1:], - dtype=sigma.dtype)) + None + if sigma is None + else np.zeros((out_rows,) + sigma.shape[1:], dtype=sigma.dtype) + ) sigma_weight_sum = ( - None if sigma is None else - np.zeros((out_rows,) + sigma.shape[1:], - dtype=sigma.dtype)) + None + if sigma is None + else np.zeros((out_rows,) + sigma.shape[1:], dtype=sigma.dtype) + ) # Iterate over input rows, accumulating into output rows for in_row, out_row in enumerate(meta.map): @@ -147,12 +206,12 @@ def impl(meta, ant1, ant2, flag_row=None, if sigma is not None: for co in range(sigma.shape[1]): - sva = sigma[in_row, co]**2 + sva = sigma[in_row, co] ** 2 # Use provided weights if weight is not None: wt = weight[in_row, co] - sva *= wt ** 2 + sva *= wt**2 sigma_weight_sum[out_row, co] += wt # Natural weights else: @@ -190,25 +249,25 @@ def impl(meta, ant1, ant2, flag_row=None, wt = sigma_weight_sum[out_row, co] if wt != 0.0: - ssva /= (wt**2) + ssva /= wt**2 sigma_avg[out_row, co] = np.sqrt(ssva) - return RowAverageOutput(ant1_avg, ant2_avg, - time_centroid_avg, - exposure_avg, uvw_avg, - weight_avg, sigma_avg) + return RowAverageOutput( + ant1_avg, + ant2_avg, + time_centroid_avg, + exposure_avg, + uvw_avg, + weight_avg, + sigma_avg, + ) return impl -_rowchan_output_fields = [ - "visibilities", - "flag", - "weight_spectrum", - "sigma_spectrum"] -RowChanAverageOutput = namedtuple("RowChanAverageOutput", - _rowchan_output_fields) +_rowchan_output_fields = ["visibilities", "flag", "weight_spectrum", "sigma_spectrum"] +RowChanAverageOutput = namedtuple("RowChanAverageOutput", _rowchan_output_fields) class RowChannelAverageException(Exception): @@ -216,9 +275,9 @@ class RowChannelAverageException(Exception): @intrinsic -def average_visibilities(typingctx, vis, vis_avg, vis_weight_sum, - weight, ri, fi, ro, fo, co): - +def average_visibilities( + typingctx, vis, vis_avg, vis_weight_sum, weight, ri, fi, ro, fo, co +): import numba.core.types as nbtypes have_array = isinstance(vis, nbtypes.Array) @@ -230,8 +289,7 @@ def avg_fn(vis, vis_avg, vis_ws, wt, ri, fi, ro, fo, co): return_type = nbtypes.NoneType("none") - sig = return_type(vis, vis_avg, vis_weight_sum, - weight, ri, fi, ro, fo, co) + sig = return_type(vis, vis_avg, vis_weight_sum, weight, ri, fi, ro, fo, co) def codegen(context, builder, signature, args): vis, vis_type = args[0], signature.args[0] @@ -246,34 +304,48 @@ def codegen(context, builder, signature, args): return_type = signature.return_type if have_array: - avg_sig = return_type(vis_type, - vis_avg_type, - vis_weight_sum_type, - weight_type, - ri_type, fi_type, - ro_type, fo_type, co_type) - avg_args = [vis, vis_avg, vis_weight_sum, - weight, ri, fi, ro, fo, co] + avg_sig = return_type( + vis_type, + vis_avg_type, + vis_weight_sum_type, + weight_type, + ri_type, + fi_type, + ro_type, + fo_type, + co_type, + ) + avg_args = [vis, vis_avg, vis_weight_sum, weight, ri, fi, ro, fo, co] # Compile function and get handle to output - context.compile_internal(builder, avg_fn, - avg_sig, avg_args) + context.compile_internal(builder, avg_fn, avg_sig, avg_args) elif have_tuple: for i in range(len(vis_type)): - avg_sig = return_type(vis_type.types[i], - vis_avg_type.types[i], - vis_weight_sum_type.types[i], - weight_type, - ri_type, fi_type, - ro_type, fo_type, co_type) - avg_args = [builder.extract_value(vis, i), - builder.extract_value(vis_avg, i), - builder.extract_value(vis_weight_sum, i), - weight, ri, fi, ro, fo, co] + avg_sig = return_type( + vis_type.types[i], + vis_avg_type.types[i], + vis_weight_sum_type.types[i], + weight_type, + ri_type, + fi_type, + ro_type, + fo_type, + co_type, + ) + avg_args = [ + builder.extract_value(vis, i), + builder.extract_value(vis_avg, i), + builder.extract_value(vis_weight_sum, i), + weight, + ri, + fi, + ro, + fo, + co, + ] # Compile function and get handle to output - context.compile_internal(builder, avg_fn, - avg_sig, avg_args) + context.compile_internal(builder, avg_fn, avg_sig, avg_args) else: raise TypeError("Unhandled visibility array type") @@ -306,26 +378,32 @@ def codegen(context, builder, signature, args): if have_array: # Normalise single array - norm_sig = return_type(vis_avg_type, - vis_weight_sum_type, - ro_type, fo_type, co_type) + norm_sig = return_type( + vis_avg_type, vis_weight_sum_type, ro_type, fo_type, co_type + ) norm_args = [vis_avg, vis_weight_sum, ro, fo, co] - context.compile_internal(builder, normalise_fn, - norm_sig, norm_args) + context.compile_internal(builder, normalise_fn, norm_sig, norm_args) elif have_tuple: # Normalise each array in the tuple for i in range(len(vis_avg_type)): - norm_sig = return_type(vis_avg_type.types[i], - vis_weight_sum_type.types[i], - ro_type, fo_type, co_type) - norm_args = [builder.extract_value(vis_avg, i), - builder.extract_value(vis_weight_sum, i), - ro, fo, co] + norm_sig = return_type( + vis_avg_type.types[i], + vis_weight_sum_type.types[i], + ro_type, + fo_type, + co_type, + ) + norm_args = [ + builder.extract_value(vis_avg, i), + builder.extract_value(vis_weight_sum, i), + ro, + fo, + co, + ] # Compile function and get handle to output - context.compile_internal(builder, normalise_fn, - norm_sig, norm_args) + context.compile_internal(builder, normalise_fn, norm_sig, norm_args) else: raise TypeError("Unhandled visibility array type") @@ -333,30 +411,52 @@ def codegen(context, builder, signature, args): @njit(**JIT_OPTIONS) -def row_chan_average(row_meta, chan_meta, - flag_row=None, weight=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None): - return row_chan_average_impl(row_meta, chan_meta, - flag_row=flag_row, weight=weight, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum) - - -def row_chan_average_impl(row_meta, chan_meta, - flag_row=None, weight=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None): +def row_chan_average( + row_meta, + chan_meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, +): + return row_chan_average_impl( + row_meta, + chan_meta, + flag_row=flag_row, + weight=weight, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + ) + + +def row_chan_average_impl( + row_meta, + chan_meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, +): return NotImplementedError @overload(row_chan_average_impl, jit_options=JIT_OPTIONS) -def nb_row_chan_average(row_meta, chan_meta, - flag_row=None, weight=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None): - +def nb_row_chan_average( + row_meta, + chan_meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, +): dummy_chan_freq = None dummy_chan_width = None @@ -369,15 +469,27 @@ def nb_row_chan_average(row_meta, chan_meta, have_weight_spectrum = not is_numba_type_none(weight_spectrum) have_sigma_spectrum = not is_numba_type_none(sigma_spectrum) - def impl(row_meta, chan_meta, flag_row=None, weight=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None): - + def impl( + row_meta, + chan_meta, + flag_row=None, + weight=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + ): out_rows = row_meta.time.shape[0] - nchan, ncorrs = chan_corrs(visibilities, flag, - weight_spectrum, sigma_spectrum, - dummy_chan_freq, dummy_chan_width, - dummy_chan_width, dummy_chan_width) + nchan, ncorrs = chan_corrs( + visibilities, + flag, + weight_spectrum, + sigma_spectrum, + dummy_chan_freq, + dummy_chan_width, + dummy_chan_width, + dummy_chan_width, + ) chan_map, out_chans = chan_meta @@ -405,8 +517,7 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, row_flagged = have_flag_row and flag_row[ri] != 0 for fi, fo in enumerate(chan_map): for co in range(ncorrs): - flagged = (row_flagged or - (have_flag and flag[ri, fi, co] != 0)) + flagged = row_flagged or (have_flag and flag[ri, fi, co] != 0) if have_flags and flagged: flag_counts[ro, fo, co] += 1 @@ -456,8 +567,7 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, # unflagged samples never contribute to a # completely flagged bin if have_flags: - in_flag = (row_flagged or - (have_flag and flag[ri, fi, co] != 0)) + in_flag = row_flagged or (have_flag and flag[ri, fi, co] != 0) flags_match[ri, fi, co] = in_flag == out_flag # ------------- @@ -466,8 +576,7 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, if not have_vis: vis_avg = None else: - vis_avg, vis_weight_sum = vis_output_arrays( - visibilities, out_shape) + vis_avg, vis_weight_sum = vis_output_arrays(visibilities, out_shape) # # Aggregate for ri, ro in enumerate(row_meta.map): @@ -476,21 +585,31 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, if have_flags and not flags_match[ri, fi, co]: continue - wt = (weight_spectrum[ri, fi, co] - if have_weight_spectrum else - weight[ri, co] if have_weight else 1.0) - - average_visibilities(visibilities, - vis_avg, - vis_weight_sum, - wt, ri, fi, ro, fo, co) + wt = ( + weight_spectrum[ri, fi, co] + if have_weight_spectrum + else weight[ri, co] + if have_weight + else 1.0 + ) + + average_visibilities( + visibilities, + vis_avg, + vis_weight_sum, + wt, + ri, + fi, + ro, + fo, + co, + ) # Normalise for ro in range(out_rows): for fo in range(out_chans): for co in range(ncorrs): - normalise_visibilities( - vis_avg, vis_weight_sum, ro, fo, co) + normalise_visibilities(vis_avg, vis_weight_sum, ro, fo, co) # ---------------- # Weight Spectrum @@ -508,8 +627,7 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, if have_flags and not flags_match[ri, fi, co]: continue - weight_spectrum_avg[ro, fo, co] += ( - weight_spectrum[ri, fi, co]) + weight_spectrum_avg[ro, fo, co] += weight_spectrum[ri, fi, co] # --------------- # Sigma Spectrum @@ -527,11 +645,15 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, if have_flags and not flags_match[ri, fi, co]: continue - wt = (weight_spectrum[ri, fi, co] - if have_weight_spectrum else - weight[ri, co] if have_weight else 1.0) + wt = ( + weight_spectrum[ri, fi, co] + if have_weight_spectrum + else weight[ri, co] + if have_weight + else 1.0 + ) - ssv = sigma_spectrum[ri, fi, co]**2 * wt**2 + ssv = sigma_spectrum[ri, fi, co] ** 2 * wt**2 sigma_spectrum_avg[ro, fo, co] += ssv sigma_spectrum_weight_sum[ro, fo, co] += wt @@ -542,12 +664,11 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, sswsum = sigma_spectrum_weight_sum[ro, fo, co] if sswsum != 0.0: ssv = sigma_spectrum_avg[ro, fo, co] - sigma_spectrum_avg[ro, fo, co] = np.sqrt( - ssv / sswsum**2) + sigma_spectrum_avg[ro, fo, co] = np.sqrt(ssv / sswsum**2) - return RowChanAverageOutput(vis_avg, flag_avg, - weight_spectrum_avg, - sigma_spectrum_avg) + return RowChanAverageOutput( + vis_avg, flag_avg, weight_spectrum_avg, sigma_spectrum_avg + ) return impl @@ -557,43 +678,50 @@ def impl(row_meta, chan_meta, flag_row=None, weight=None, @njit(**JIT_OPTIONS) -def chan_average(chan_meta, chan_freq=None, chan_width=None, - effective_bw=None, resolution=None): - return chan_average_impl(chan_meta, chan_freq=chan_freq, - chan_width=chan_width, - effective_bw=effective_bw, - resolution=resolution) - - -def chan_average_impl(chan_meta, chan_freq=None, chan_width=None, - effective_bw=None, resolution=None): - +def chan_average( + chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None +): + return chan_average_impl( + chan_meta, + chan_freq=chan_freq, + chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution, + ) + + +def chan_average_impl( + chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None +): return NotImplementedError @overload(chan_average_impl, jit_options=JIT_OPTIONS) -def nb_chan_average(chan_meta, chan_freq=None, chan_width=None, - effective_bw=None, resolution=None): - - def impl(chan_meta, chan_freq=None, chan_width=None, - effective_bw=None, resolution=None): +def nb_chan_average( + chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None +): + def impl( + chan_meta, chan_freq=None, chan_width=None, effective_bw=None, resolution=None + ): chan_map, out_chans = chan_meta chan_freq_avg = ( - None if chan_freq is None else - np.zeros(out_chans, dtype=chan_freq.dtype)) + None if chan_freq is None else np.zeros(out_chans, dtype=chan_freq.dtype) + ) chan_width_avg = ( - None if chan_width is None else - np.zeros(out_chans, dtype=chan_width.dtype)) + None if chan_width is None else np.zeros(out_chans, dtype=chan_width.dtype) + ) effective_bw_avg = ( - None if effective_bw is None else - np.zeros(out_chans, dtype=effective_bw.dtype)) + None + if effective_bw is None + else np.zeros(out_chans, dtype=effective_bw.dtype) + ) resolution_avg = ( - None if resolution is None else - np.zeros(out_chans, dtype=resolution.dtype)) + None if resolution is None else np.zeros(out_chans, dtype=resolution.dtype) + ) counts = np.zeros(out_chans, dtype=np.uint32) @@ -616,144 +744,238 @@ def impl(chan_meta, chan_freq=None, chan_width=None, if chan_freq is not None: chan_freq_avg[out_chan] /= counts[out_chan] - return ChannelAverageOutput(chan_freq_avg, chan_width_avg, - effective_bw_avg, resolution_avg) + return ChannelAverageOutput( + chan_freq_avg, chan_width_avg, effective_bw_avg, resolution_avg + ) return impl -AverageOutput = namedtuple("AverageOutput", - ["time", "interval", "flag_row"] + - _row_output_fields + - _chan_output_fields + - _rowchan_output_fields) +AverageOutput = namedtuple( + "AverageOutput", + ["time", "interval", "flag_row"] + + _row_output_fields + + _chan_output_fields + + _rowchan_output_fields, +) @njit(**JIT_OPTIONS) -def time_and_channel(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - time_bin_secs=1.0, chan_bin_size=1): - return time_and_channel_impl(time, interval, antenna1, antenna2, - time_centroid=time_centroid, - exposure=exposure, - flag_row=flag_row, - uvw=uvw, weight=weight, sigma=sigma, - chan_freq=chan_freq, chan_width=chan_width, - effective_bw=effective_bw, - resolution=resolution, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum, - time_bin_secs=time_bin_secs, - chan_bin_size=chan_bin_size) - - -def time_and_channel_impl(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - time_bin_secs=1.0, chan_bin_size=1): +def time_and_channel( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + time_bin_secs=1.0, + chan_bin_size=1, +): + return time_and_channel_impl( + time, + interval, + antenna1, + antenna2, + time_centroid=time_centroid, + exposure=exposure, + flag_row=flag_row, + uvw=uvw, + weight=weight, + sigma=sigma, + chan_freq=chan_freq, + chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + time_bin_secs=time_bin_secs, + chan_bin_size=chan_bin_size, + ) + + +def time_and_channel_impl( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + time_bin_secs=1.0, + chan_bin_size=1, +): return NotImplementedError @overload(time_and_channel_impl, jit_options=JIT_OPTIONS) -def nb_time_and_channel(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - time_bin_secs=1.0, chan_bin_size=1): - - valid_types = (types.misc.Omitted, types.scalars.Float, - types.scalars.Integer) +def nb_time_and_channel( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + time_bin_secs=1.0, + chan_bin_size=1, +): + valid_types = (types.misc.Omitted, types.scalars.Float, types.scalars.Integer) if not isinstance(time_bin_secs, valid_types): - raise TypeError( - f"time_bin_secs ({time_bin_secs}) must be a scalar float") + raise TypeError(f"time_bin_secs ({time_bin_secs}) must be a scalar float") valid_types = (types.misc.Omitted, types.scalars.Integer) if not isinstance(chan_bin_size, valid_types): - raise TypeError( - f"chan_bin_size ({chan_bin_size}) must be a scalar integer") - - def impl(time, interval, antenna1, antenna2, - time_centroid=None, exposure=None, flag_row=None, - uvw=None, weight=None, sigma=None, - chan_freq=None, chan_width=None, - effective_bw=None, resolution=None, - visibilities=None, flag=None, - weight_spectrum=None, sigma_spectrum=None, - time_bin_secs=1.0, chan_bin_size=1): - - nchan, ncorrs = chan_corrs(visibilities, flag, - weight_spectrum, sigma_spectrum, - chan_freq, chan_width, - effective_bw, resolution) + raise TypeError(f"chan_bin_size ({chan_bin_size}) must be a scalar integer") + + def impl( + time, + interval, + antenna1, + antenna2, + time_centroid=None, + exposure=None, + flag_row=None, + uvw=None, + weight=None, + sigma=None, + chan_freq=None, + chan_width=None, + effective_bw=None, + resolution=None, + visibilities=None, + flag=None, + weight_spectrum=None, + sigma_spectrum=None, + time_bin_secs=1.0, + chan_bin_size=1, + ): + nchan, ncorrs = chan_corrs( + visibilities, + flag, + weight_spectrum, + sigma_spectrum, + chan_freq, + chan_width, + effective_bw, + resolution, + ) # Merge flag_row and flag arrays flag_row = merge_flags(flag_row, flag) # Generate row mapping metadata - row_meta = row_mapper(time, interval, antenna1, antenna2, - flag_row=flag_row, time_bin_secs=time_bin_secs) + row_meta = row_mapper( + time, + interval, + antenna1, + antenna2, + flag_row=flag_row, + time_bin_secs=time_bin_secs, + ) # Generate channel mapping metadata chan_meta = channel_mapper(nchan, chan_bin_size) # Average row data - row_data = row_average(row_meta, antenna1, antenna2, flag_row=flag_row, - time_centroid=time_centroid, exposure=exposure, - uvw=uvw, weight=weight, sigma=sigma) + row_data = row_average( + row_meta, + antenna1, + antenna2, + flag_row=flag_row, + time_centroid=time_centroid, + exposure=exposure, + uvw=uvw, + weight=weight, + sigma=sigma, + ) # Average channel data - chan_data = chan_average(chan_meta, chan_freq=chan_freq, - chan_width=chan_width, - effective_bw=effective_bw, - resolution=resolution) + chan_data = chan_average( + chan_meta, + chan_freq=chan_freq, + chan_width=chan_width, + effective_bw=effective_bw, + resolution=resolution, + ) # Average row and channel data - row_chan_data = row_chan_average(row_meta, chan_meta, - flag_row=flag_row, weight=weight, - visibilities=visibilities, flag=flag, - weight_spectrum=weight_spectrum, - sigma_spectrum=sigma_spectrum) + row_chan_data = row_chan_average( + row_meta, + chan_meta, + flag_row=flag_row, + weight=weight, + visibilities=visibilities, + flag=flag, + weight_spectrum=weight_spectrum, + sigma_spectrum=sigma_spectrum, + ) # Have to explicitly write it out because numba tuples # are highly constrained types - return AverageOutput(row_meta.time, - row_meta.interval, - row_meta.flag_row, - row_data.antenna1, - row_data.antenna2, - row_data.time_centroid, - row_data.exposure, - row_data.uvw, - row_data.weight, - row_data.sigma, - chan_data.chan_freq, - chan_data.chan_width, - chan_data.effective_bw, - chan_data.resolution, - row_chan_data.visibilities, - row_chan_data.flag, - row_chan_data.weight_spectrum, - row_chan_data.sigma_spectrum) + return AverageOutput( + row_meta.time, + row_meta.interval, + row_meta.flag_row, + row_data.antenna1, + row_data.antenna2, + row_data.time_centroid, + row_data.exposure, + row_data.uvw, + row_data.weight, + row_data.sigma, + chan_data.chan_freq, + chan_data.chan_width, + chan_data.effective_bw, + chan_data.resolution, + row_chan_data.visibilities, + row_chan_data.flag, + row_chan_data.weight_spectrum, + row_chan_data.sigma_spectrum, + ) return impl -AVERAGING_DOCS = DocstringTemplate(""" +AVERAGING_DOCS = DocstringTemplate( + """ Averages in time and channel. Parameters @@ -816,11 +1038,13 @@ def impl(time, interval, antenna1, antenna2, namedtuple A namedtuple whose entries correspond to the input arrays. Output arrays will be ``None`` if the inputs were ``None``. -""") +""" +) try: time_and_channel.__doc__ = AVERAGING_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/averaging/time_and_channel_mapping.py b/africanus/averaging/time_and_channel_mapping.py index 31b287b86..dbdca4b9d 100644 --- a/africanus/averaging/time_and_channel_mapping.py +++ b/africanus/averaging/time_and_channel_mapping.py @@ -7,12 +7,7 @@ import numba from africanus.averaging.support import unique_time, unique_baselines -from africanus.util.numba import ( - is_numba_type_none, - njit, - jit, - JIT_OPTIONS, - overload) +from africanus.util.numba import is_numba_type_none, njit, jit, JIT_OPTIONS, overload class RowMapperError(Exception): @@ -21,9 +16,11 @@ class RowMapperError(Exception): def is_flagged_factory(have_flag_row): if have_flag_row: + def impl(flag_row, r): return flag_row[r] != 0 else: + def impl(flag_row, r): return False @@ -32,9 +29,11 @@ def impl(flag_row, r): def output_factory(have_flag_row): if have_flag_row: + def impl(rows, flag_row): return np.zeros(rows, dtype=flag_row.dtype) else: + def impl(rows, flag_row): return None @@ -43,27 +42,29 @@ def impl(rows, flag_row): def set_flag_row_factory(have_flag_row): if have_flag_row: + def impl(flag_row, in_row, out_flag_row, out_row, flagged): if flag_row[in_row] == 0 and flagged: - raise RowMapperError("Unflagged input row contributing " - "to flagged output row. " - "This should never happen!") + raise RowMapperError( + "Unflagged input row contributing " + "to flagged output row. " + "This should never happen!" + ) - out_flag_row[out_row] = (1 if flagged else 0) + out_flag_row[out_row] = 1 if flagged else 0 else: + def impl(flag_row, in_row, out_flag_row, out_row, flagged): pass return njit(nogil=True, cache=True)(impl) -RowMapOutput = namedtuple("RowMapOutput", - ["map", "time", "interval", "flag_row"]) +RowMapOutput = namedtuple("RowMapOutput", ["map", "time", "interval", "flag_row"]) @njit(**JIT_OPTIONS) -def row_mapper(time, interval, antenna1, antenna2, - flag_row=None, time_bin_secs=1): +def row_mapper(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): """ Generates a mapping from a high resolution row index to a low resolution row index in support of time and channel @@ -183,26 +184,29 @@ def row_mapper(time, interval, antenna1, antenna2, Raised if an illegal condition occurs """ - return row_mapper_impl(time, interval, antenna1, antenna2, - flag_row=flag_row, time_bin_secs=time_bin_secs) + return row_mapper_impl( + time, + interval, + antenna1, + antenna2, + flag_row=flag_row, + time_bin_secs=time_bin_secs, + ) -def row_mapper_impl(time, interval, antenna1, antenna2, - flag_row=None, time_bin_secs=1): +def row_mapper_impl(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): return NotImplementedError @overload(row_mapper_impl, jit_options=JIT_OPTIONS) -def nb_row_mapper(time, interval, antenna1, antenna2, - flag_row=None, time_bin_secs=1): +def nb_row_mapper(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): have_flag_row = not is_numba_type_none(flag_row) is_flagged_fn = is_flagged_factory(have_flag_row) output_flag_row = output_factory(have_flag_row) set_flag_row = set_flag_row_factory(have_flag_row) - def impl(time, interval, antenna1, antenna2, - flag_row=None, time_bin_secs=1): + def impl(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): ubl, _, bl_inv, _ = unique_baselines(antenna1, antenna2) utime, _, time_inv, _ = unique_time(time) @@ -212,10 +216,10 @@ def impl(time, interval, antenna1, antenna2, sentinel = np.finfo(time.dtype).max out_rows = numba.uint32(0) - scratch = np.full(3*nbl*ntime, -1, dtype=np.int32) - row_lookup = scratch[:nbl*ntime].reshape(nbl, ntime) - bin_lookup = scratch[nbl*ntime:2*nbl*ntime].reshape(nbl, ntime) - inv_argsort = scratch[2*nbl*ntime:] + scratch = np.full(3 * nbl * ntime, -1, dtype=np.int32) + row_lookup = scratch[: nbl * ntime].reshape(nbl, ntime) + bin_lookup = scratch[nbl * ntime : 2 * nbl * ntime].reshape(nbl, ntime) + inv_argsort = scratch[2 * nbl * ntime :] time_lookup = np.zeros((nbl, ntime), dtype=time.dtype) interval_lookup = np.zeros((nbl, ntime), dtype=interval.dtype) @@ -231,12 +235,14 @@ def impl(time, interval, antenna1, antenna2, if row_lookup[bl, t] == -1: row_lookup[bl, t] = r else: - raise ValueError("Duplicate (TIME, ANTENNA1, ANTENNA2) " - "combinations were discovered in the input " - "data. This is usually caused by not " - "partitioning your data sufficiently " - "by indexing columns, DATA_DESC_ID " - "and SCAN_NUMBER in particular.") + raise ValueError( + "Duplicate (TIME, ANTENNA1, ANTENNA2) " + "combinations were discovered in the input " + "data. This is usually caused by not " + "partitioning your data sufficiently " + "by indexing columns, DATA_DESC_ID " + "and SCAN_NUMBER in particular." + ) # Average times over each baseline and construct the # bin_lookup and time_lookup arrays @@ -312,7 +318,7 @@ def impl(time, interval, antenna1, antenna2, # Flatten the time lookup and argsort it flat_time = time_lookup.ravel() flat_int = interval_lookup.ravel() - argsort = np.argsort(flat_time, kind='mergesort') + argsort = np.argsort(flat_time, kind="mergesort") # Generate lookup from flattened (bl, time) to output row for i, a in enumerate(argsort): @@ -333,15 +339,13 @@ def impl(time, interval, antenna1, antenna2, # lookup time bin and output row tbin = bin_lookup[bl, t] # lookup output row in inv_argsort - out_row = inv_argsort[bl*ntime + tbin] + out_row = inv_argsort[bl * ntime + tbin] if out_row >= out_rows: raise RowMapperError("out_row >= out_rows") # Handle output row flagging - set_flag_row(flag_row, in_row, - out_flag_row, out_row, - bin_flagged[bl, tbin]) + set_flag_row(flag_row, in_row, out_flag_row, out_row, bin_flagged[bl, tbin]) row_map[in_row] = out_row diff --git a/africanus/calibration/phase_only/dask.py b/africanus/calibration/phase_only/dask.py index c1281be7c..ff89ffc45 100644 --- a/africanus/calibration/phase_only/dask.py +++ b/africanus/calibration/phase_only/dask.py @@ -7,6 +7,7 @@ from africanus.calibration.phase_only import compute_jhr as np_compute_jhr from africanus.util.requirements import requires_optional from africanus.calibration.utils.utils import DIAG_DIAG + try: from dask.array.core import blockwise except ImportError as e: @@ -15,61 +16,84 @@ dask_import_error = None -@requires_optional('dask.array', dask_import_error) -def compute_jhj(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, flag): - - mode = check_type(jones, model, vis_type='model') +@requires_optional("dask.array", dask_import_error) +def compute_jhj( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag +): + mode = check_type(jones, model, vis_type="model") if mode != DIAG_DIAG: raise NotImplementedError("Only DIAG-DIAG case has been implemented") - jones_shape = ('row', 'ant', 'chan', 'dir', 'corr') - vis_shape = ('row', 'chan', 'corr') - model_shape = ('row', 'chan', 'dir', 'corr') - return blockwise(np_compute_jhj, jones_shape, - time_bin_indices, ('row',), - time_bin_counts, ('row',), - antenna1, ('row',), - antenna2, ('row',), - jones, jones_shape, - model, model_shape, - flag, vis_shape, - adjust_chunks={"row": antenna1.chunks[0]}, - new_axes={"corr2": 2}, # why? - dtype=model.dtype, - align_arrays=False) - + jones_shape = ("row", "ant", "chan", "dir", "corr") + vis_shape = ("row", "chan", "corr") + model_shape = ("row", "chan", "dir", "corr") + return blockwise( + np_compute_jhj, + jones_shape, + time_bin_indices, + ("row",), + time_bin_counts, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + jones, + jones_shape, + model, + model_shape, + flag, + vis_shape, + adjust_chunks={"row": antenna1.chunks[0]}, + new_axes={"corr2": 2}, # why? + dtype=model.dtype, + align_arrays=False, + ) -@requires_optional('dask.array', dask_import_error) -def compute_jhr(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): +@requires_optional("dask.array", dask_import_error) +def compute_jhr( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag +): mode = check_type(jones, residual) if mode != DIAG_DIAG: raise NotImplementedError("Only DIAG-DIAG case has been implemented") - jones_shape = ('row', 'ant', 'chan', 'dir', 'corr') - vis_shape = ('row', 'chan', 'corr') - model_shape = ('row', 'chan', 'dir', 'corr') - return blockwise(np_compute_jhr, jones_shape, - time_bin_indices, ('row',), - time_bin_counts, ('row',), - antenna1, ('row',), - antenna2, ('row',), - jones, jones_shape, - residual, vis_shape, - model, model_shape, - flag, vis_shape, - adjust_chunks={"row": antenna1.chunks[0]}, - new_axes={"corr2": 2}, # why? - dtype=model.dtype, - align_arrays=False) + jones_shape = ("row", "ant", "chan", "dir", "corr") + vis_shape = ("row", "chan", "corr") + model_shape = ("row", "chan", "dir", "corr") + return blockwise( + np_compute_jhr, + jones_shape, + time_bin_indices, + ("row",), + time_bin_counts, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + jones, + jones_shape, + residual, + vis_shape, + model, + model_shape, + flag, + vis_shape, + adjust_chunks={"row": antenna1.chunks[0]}, + new_axes={"corr2": 2}, # why? + dtype=model.dtype, + align_arrays=False, + ) compute_jhj.__doc__ = COMPUTE_JHJ_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) compute_jhr.__doc__ = COMPUTE_JHR_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) diff --git a/africanus/calibration/phase_only/phase_only.py b/africanus/calibration/phase_only/phase_only.py index f0c6d329a..0d2b549c8 100755 --- a/africanus/calibration/phase_only/phase_only.py +++ b/africanus/calibration/phase_only/phase_only.py @@ -9,44 +9,65 @@ def jacobian_factory(mode): if mode == DIAG_DIAG: + def jacobian(a1j, blj, a2j, sign, out): out[...] = sign * a1j * blj * a2j.conjugate() # for c in range(out.shape[-1]): # out[c] = sign * a1j[c] * blj[c] * a2j[c].conjugate() elif mode == DIAG: + def jacobian(a1j, blj, a2j, sign, out): out[...] = 0 elif mode == FULL: + def jacobian(a1j, blj, a2j, sign, out): out[...] = 0 - return njit(nogil=True, inline='always')(jacobian) - -@njit(**JIT_OPTIONS) -def compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): - return compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, residual, - model, flag) + return njit(nogil=True, inline="always")(jacobian) -def compute_jhj_and_jhr_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): +@njit(**JIT_OPTIONS) +def compute_jhj_and_jhr( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag +): + return compute_jhj_and_jhr_impl( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + residual, + model, + flag, + ) + + +def compute_jhj_and_jhr_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag +): return NotImplementedError @overload(compute_jhj_and_jhr_impl, jit_options=JIT_OPTIONS) -def nb_compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): - +def nb_compute_jhj_and_jhr( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag +): mode = check_type(jones, residual) if mode != DIAG_DIAG: raise NotImplementedError("Only DIAG-DIAG case has been implemented") jacobian = jacobian_factory(mode) - def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): + def _jhj_and_jhr_fn( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + residual, + model, + flag, + ): # for chunked dask arrays we need to adjust the chunks to # start counting from zero (see also map_blocks) time_bin_indices -= time_bin_indices.min() @@ -61,8 +82,9 @@ def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1, # tmp array the shape of jones_corr jac = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype) for t in range(n_tim): - for row in range(time_bin_indices[t], - time_bin_indices[t] + time_bin_counts[t]): + for row in range( + time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t] + ): p = antenna1[row] q = antenna2[row] for nu in range(n_chan): @@ -74,37 +96,42 @@ def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1, # for the derivative w.r.t. antenna p jacobian(gp[s], model[row, nu, s], gq[s], 1.0j, jac) jhj[t, p, nu, s] += (np.conj(jac) * jac).real - jhr[t, p, nu, s] += (np.conj(jac) * residual[row, nu]) + jhr[t, p, nu, s] += np.conj(jac) * residual[row, nu] # for the derivative w.r.t. antenna q jacobian(gp[s], model[row, nu, s], gq[s], -1.0j, jac) jhj[t, q, nu, s] += (np.conj(jac) * jac).real - jhr[t, q, nu, s] += (np.conj(jac) * residual[row, nu]) + jhr[t, q, nu, s] += np.conj(jac) * residual[row, nu] return jhj, jhr + return _jhj_and_jhr_fn @njit(**JIT_OPTIONS) -def compute_jhj(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, flag): - return compute_jhj_impl(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model, flag) +def compute_jhj( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag +): + return compute_jhj_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag + ) -def compute_jhj_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, flag): +def compute_jhj_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag +): return NotImplementedError @overload(compute_jhj_impl, jit_options=JIT_OPTIONS) -def nb_compute_jhj(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, flag): - - mode = check_type(jones, model, vis_type='model') +def nb_compute_jhj( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag +): + mode = check_type(jones, model, vis_type="model") jacobian = jacobian_factory(mode) - def _compute_jhj_fn(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, flag): + def _compute_jhj_fn( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag + ): # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() @@ -117,8 +144,9 @@ def _compute_jhj_fn(time_bin_indices, time_bin_counts, antenna1, # tmp array the shape of jones_corr jac = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype) for t in range(n_tim): - for row in range(time_bin_indices[t], - time_bin_indices[t] + time_bin_counts[t]): + for row in range( + time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t] + ): p = antenna1[row] q = antenna2[row] for nu in range(n_chan): @@ -132,32 +160,50 @@ def _compute_jhj_fn(time_bin_indices, time_bin_counts, antenna1, jacobian(gp[s], model[row, nu, s], gq[s], -1.0j, jac) jhj[t, q, nu, s] += (jac.conjugate() * jac).real return jhj + return _compute_jhj_fn @njit(**JIT_OPTIONS) -def compute_jhr(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): - return compute_jhr_impl(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, residual, - model, flag) - - -def compute_jhr_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): +def compute_jhr( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag +): + return compute_jhr_impl( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + residual, + model, + flag, + ) + + +def compute_jhr_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag +): return NotImplementedError @overload(compute_jhr_impl, jit_options=JIT_OPTIONS) -def nb_compute_jhr(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): - - mode = check_type(jones, model, vis_type='model') +def nb_compute_jhr( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag +): + mode = check_type(jones, model, vis_type="model") jacobian = jacobian_factory(mode) - def _compute_jhr_fn(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, residual, model, flag): + def _compute_jhr_fn( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + residual, + model, + flag, + ): # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() @@ -170,8 +216,9 @@ def _compute_jhr_fn(time_bin_indices, time_bin_counts, antenna1, # tmp array the shape of jones_corr jac = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype) for t in range(n_tim): - for row in range(time_bin_indices[t], - time_bin_indices[t] + time_bin_counts[t]): + for row in range( + time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t] + ): p = antenna1[row] q = antenna2[row] for nu in range(n_chan): @@ -185,16 +232,27 @@ def _compute_jhr_fn(time_bin_indices, time_bin_counts, antenna1, jacobian(gp[s], model[row, nu, s], gq[s], -1.0j, jac) jhr[t, q, nu, s] += jac.conjugate() * residual[row, nu] return jhr + return _compute_jhr_fn + # LB - TODO somehow this generated_jit causes tests to fail # @generated_jit(nopython=True, nogil=True, cache=True, fastmath=True) -def gauss_newton(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model, - weight, tol=1e-4, maxiter=100): - +def gauss_newton( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + vis, + flag, + model, + weight, + tol=1e-4, + maxiter=100, +): # whiten data sqrtweights = np.sqrt(weight) vis *= sqrtweights @@ -204,8 +262,9 @@ def gauss_newton(time_bin_indices, time_bin_counts, antenna1, # can avoid recomputing JHJ in DIAG_DIAG mode if mode == DIAG_DIAG: - jhj = compute_jhj(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model, flag) + jhj = compute_jhj( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag + ) else: raise NotImplementedError("Only DIAG_DIAG mode implemented") @@ -216,15 +275,30 @@ def gauss_newton(time_bin_indices, time_bin_counts, antenna1, phases = np.angle(jones) # get residual TODO - we can avoid this in DIE case - residual = residual_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model) - - jhr = compute_jhr(time_bin_indices, time_bin_counts, - antenna1, antenna2, - jones, residual, model, flag) + residual = residual_vis( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + vis, + flag, + model, + ) + + jhr = compute_jhr( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + residual, + model, + flag, + ) # implement update - phases_new = phases + 0.5 * (jhr/jhj).real + phases_new = phases + 0.5 * (jhr / jhj).real jones = np.exp(1.0j * phases_new) # check convergence/iteration control @@ -234,7 +308,8 @@ def gauss_newton(time_bin_indices, time_bin_counts, antenna1, return jones, jhj, jhr, k -GAUSS_NEWTON_DOCS = DocstringTemplate(""" +GAUSS_NEWTON_DOCS = DocstringTemplate( + """ Performs phase-only maximum likelihood calibration using a Gauss-Newton optimisation algorithm. Currently only DIAG mode is supported. @@ -288,16 +363,19 @@ def gauss_newton(time_bin_indices, time_bin_counts, antenna1, k: int Number of iterations (will equal maxiter if not converged) -""") +""" +) try: gauss_newton.__doc__ = GAUSS_NEWTON_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass -JHJ_AND_JHR_DOCS = DocstringTemplate(""" +JHJ_AND_JHR_DOCS = DocstringTemplate( + """ Computes the diagonal of the Hessian and the residual locally projected in to gain space. @@ -336,16 +414,19 @@ def gauss_newton(time_bin_indices, time_bin_counts, antenna1, Residuals projected into signal space of shape :code:`(time, ant, chan, dir, corr)` or :code:`(time, ant, chan, dir, corr, corr)`. -""") +""" +) try: compute_jhj_and_jhr.__doc__ = JHJ_AND_JHR_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass -COMPUTE_JHJ_DOCS = DocstringTemplate(""" +COMPUTE_JHJ_DOCS = DocstringTemplate( + """ Computes the diagonal of the Hessian required to perform phase-only maximum likelihood calibration. Currently assumes @@ -379,15 +460,18 @@ def gauss_newton(time_bin_indices, time_bin_counts, antenna1, The diagonal of the Hessian of shape :code:`(time, ant, chan, dir, corr)` or :code:`(time, ant, chan, dir, corr, corr)`. -""") +""" +) try: compute_jhj.__doc__ = COMPUTE_JHJ_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass -COMPUTE_JHR_DOCS = DocstringTemplate(""" +COMPUTE_JHR_DOCS = DocstringTemplate( + """ Computes the residual projected in to gain space. Parameters @@ -421,10 +505,12 @@ def gauss_newton(time_bin_indices, time_bin_counts, antenna1, The residual projected into gain space shape :code:`(time, ant, chan, dir, corr)` or :code:`(time, ant, chan, dir, corr, corr)`. -""") +""" +) try: compute_jhr.__doc__ = COMPUTE_JHR_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/calibration/phase_only/tests/conftest.py b/africanus/calibration/phase_only/tests/conftest.py index d068208be..bc2f15441 100644 --- a/africanus/calibration/phase_only/tests/conftest.py +++ b/africanus/calibration/phase_only/tests/conftest.py @@ -8,29 +8,38 @@ def lm_factory(n_dir, rs): - ls = 0.1*rs.randn(n_dir) - ms = 0.1*rs.randn(n_dir) + ls = 0.1 * rs.randn(n_dir) + ms = 0.1 * rs.randn(n_dir) lm = np.vstack((ls, ms)).T return lm def flux_factory(n_dir, n_chan, corr_shape, alpha, freq, freq0, rs): - w = freq/freq0 + w = freq / freq0 flux = np.zeros((n_dir, n_chan) + corr_shape, dtype=np.float64) for d in range(n_dir): tmp_flux = np.abs(rs.normal(size=corr_shape)) for v in range(n_chan): - flux[d, v] = tmp_flux * w[v]**alpha + flux[d, v] = tmp_flux * w[v] ** alpha return flux @pytest.fixture def data_factory(): - def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, - n_dir, corr_shape, jones_shape, phase_only_gains=False): + def impl( + sigma_n, + sigma_f, + n_time, + n_chan, + n_ant, + n_dir, + corr_shape, + jones_shape, + phase_only_gains=False, + ): rs = np.random.RandomState(42) - n_bl = n_ant*(n_ant-1)//2 - n_row = n_bl*n_time + n_bl = n_ant * (n_ant - 1) // 2 + n_row = n_bl * n_time # make aux data antenna1 = np.zeros(n_row, dtype=np.int16) antenna2 = np.zeros(n_row, dtype=np.int16) @@ -42,19 +51,18 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, row = 0 for p in range(n_ant): for q in range(p): - time[i*n_bl + row] = time_values[i] - antenna1[i*n_bl + row] = p - antenna2[i*n_bl + row] = q - uvw[i*n_bl + row] = np.random.randn(3) + time[i * n_bl + row] = time_values[i] + antenna1[i * n_bl + row] = p + antenna2[i * n_bl + row] = q + uvw[i * n_bl + row] = np.random.randn(3) row += 1 assert time.size == n_row # simulate visibilities - model_data = np.zeros((n_row, n_chan, n_dir) + - corr_shape, dtype=np.complex128) + model_data = np.zeros((n_row, n_chan, n_dir) + corr_shape, dtype=np.complex128) # make up some sources lm = lm_factory(n_dir, rs) alpha = -0.7 - freq0 = freq[n_chan//2] + freq0 = freq[n_chan // 2] flux = flux_factory(n_dir, n_chan, corr_shape, alpha, freq, freq0, rs) # simulate model data for dir in range(n_dir): @@ -65,28 +73,31 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, model_data[:, :, dir] = tmp.reshape((n_row, n_chan) + corr_shape) assert not np.isnan(model_data).any() # simulate gains (just randomly scattered around 1 for now) - jones = np.ones((n_time, n_ant, n_chan, n_dir) + - jones_shape, dtype=np.complex128) + jones = np.ones( + (n_time, n_ant, n_chan, n_dir) + jones_shape, dtype=np.complex128 + ) if sigma_f: if phase_only_gains: - jones = np.exp(1.0j*rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape)) + jones = np.exp( + 1.0j * rs.normal(loc=0.0, scale=sigma_f, size=jones.shape) + ) else: - jones += (rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape) + - 1.0j*rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape)) + jones += rs.normal( + loc=0.0, scale=sigma_f, size=jones.shape + ) + 1.0j * rs.normal(loc=0.0, scale=sigma_f, size=jones.shape) assert (np.abs(jones) > 1e-5).all() assert not np.isnan(jones).any() # get vis _, time_bin_indices, time_bin_counts = chunkify_rows(time, n_time) - vis = corrupt_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model_data) + vis = corrupt_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model_data + ) assert not np.isnan(vis).any() # add noise if sigma_n: - vis += (rs.normal(loc=0.0, scale=sigma_n, size=vis.shape) + - 1.0j*rs.normal(loc=0.0, scale=sigma_n, size=vis.shape)) + vis += rs.normal(loc=0.0, scale=sigma_n, size=vis.shape) + 1.0j * rs.normal( + loc=0.0, scale=sigma_n, size=vis.shape + ) weights = np.ones(vis.shape, dtype=np.float64) if sigma_n: weights /= sigma_n**2 @@ -99,6 +110,7 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, data_dict["ANTENNA1"] = antenna1 data_dict["ANTENNA2"] = antenna2 data_dict["FLAG"] = flag - data_dict['JONES'] = jones + data_dict["JONES"] = jones return data_dict + return impl diff --git a/africanus/calibration/phase_only/tests/test_phase_only.py b/africanus/calibration/phase_only/tests/test_phase_only.py index 6cb23a729..1e0b4d4aa 100644 --- a/africanus/calibration/phase_only/tests/test_phase_only.py +++ b/africanus/calibration/phase_only/tests/test_phase_only.py @@ -20,26 +20,31 @@ def test_compute_jhj_and_jhr(data_factory): sigma_f = 0.05 corr_shape = (2,) jones_shape = (2,) - data_dict = data_factory(sigma_n, sigma_f, n_time, n_chan, - n_ant, n_dir, corr_shape, jones_shape) - time = data_dict['TIME'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + time = data_dict["TIME"] _, time_bin_indices, time_bin_counts = chunkify_rows(time, n_time) - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - vis = data_dict['DATA'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - flag = data_dict['FLAG'] + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + vis = data_dict["DATA"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + flag = data_dict["FLAG"] from africanus.calibration.phase_only.phase_only import compute_jhj from africanus.calibration.phase_only.phase_only import compute_jhr from africanus.calibration.phase_only.phase_only import compute_jhj_and_jhr - jhj1, jhr1 = compute_jhj_and_jhr(time_bin_indices, time_bin_counts, - ant1, ant2, jones, vis, model, flag) - jhj2 = compute_jhj(time_bin_indices, time_bin_counts, - ant1, ant2, jones, model, flag) - jhr2 = compute_jhr(time_bin_indices, time_bin_counts, - ant1, ant2, jones, vis, model, flag) + + jhj1, jhr1 = compute_jhj_and_jhr( + time_bin_indices, time_bin_counts, ant1, ant2, jones, vis, model, flag + ) + jhj2 = compute_jhj( + time_bin_indices, time_bin_counts, ant1, ant2, jones, model, flag + ) + jhr2 = compute_jhr( + time_bin_indices, time_bin_counts, ant1, ant2, jones, vis, model, flag + ) assert_array_almost_equal(jhj1, jhj2, decimal=10) assert_array_almost_equal(jhr1, jhr2, decimal=10) @@ -55,39 +60,45 @@ def test_compute_jhj_dask(data_factory): sigma_f = 0.05 corr_shape = (2,) jones_shape = (2,) - data_dict = data_factory(sigma_n, sigma_f, n_time, n_chan, - n_ant, n_dir, corr_shape, jones_shape) - time = data_dict['TIME'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + time = data_dict["TIME"] ncpu = 8 - utimes_per_chunk = n_time//ncpu - row_chunks, time_bin_idx, time_bin_counts = chunkify_rows( - time, utimes_per_chunk) - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - flag = data_dict['FLAG'] + utimes_per_chunk = n_time // ncpu + row_chunks, time_bin_idx, time_bin_counts = chunkify_rows(time, utimes_per_chunk) + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + flag = data_dict["FLAG"] # get the numpy result - jhj = np_compute_jhj(time_bin_idx, time_bin_counts, ant1, ant2, - jones, model, flag) + jhj = np_compute_jhj(time_bin_idx, time_bin_counts, ant1, ant2, jones, model, flag) - da_time_bin_idx = da.from_array(time_bin_idx, - chunks=(utimes_per_chunk)) - da_time_bin_counts = da.from_array(time_bin_counts, - chunks=(utimes_per_chunk)) + da_time_bin_idx = da.from_array(time_bin_idx, chunks=(utimes_per_chunk)) + da_time_bin_counts = da.from_array(time_bin_counts, chunks=(utimes_per_chunk)) da_ant1 = da.from_array(ant1, chunks=row_chunks) da_ant2 = da.from_array(ant2, chunks=row_chunks) - da_model = da.from_array(model, chunks=( - row_chunks, (n_chan,), (n_dir,)) + (corr_shape)) - da_jones = da.from_array(jones, chunks=( - utimes_per_chunk, n_ant, n_chan, n_dir)+jones_shape) + da_model = da.from_array( + model, chunks=(row_chunks, (n_chan,), (n_dir,)) + (corr_shape) + ) + da_jones = da.from_array( + jones, chunks=(utimes_per_chunk, n_ant, n_chan, n_dir) + jones_shape + ) da_flag = da.from_array(flag, chunks=(row_chunks, (n_chan,)) + corr_shape) from africanus.calibration.phase_only.dask import compute_jhj - da_jhj = compute_jhj(da_time_bin_idx, da_time_bin_counts, - da_ant1, da_ant2, da_jones, da_model, da_flag) + da_jhj = compute_jhj( + da_time_bin_idx, + da_time_bin_counts, + da_ant1, + da_ant2, + da_jones, + da_model, + da_flag, + ) jhj2 = da_jhj.compute() @@ -104,41 +115,50 @@ def test_compute_jhr_dask(data_factory): sigma_f = 0.05 corr_shape = (2,) jones_shape = (2,) - data_dict = data_factory(sigma_n, sigma_f, n_time, n_chan, - n_ant, n_dir, corr_shape, jones_shape) - time = data_dict['TIME'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + time = data_dict["TIME"] ncpu = 8 - utimes_per_chunk = n_time//ncpu - row_chunks, time_bin_idx, time_bin_counts = chunkify_rows( - time, utimes_per_chunk) - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - vis = data_dict['DATA'] - flag = data_dict['FLAG'] + utimes_per_chunk = n_time // ncpu + row_chunks, time_bin_idx, time_bin_counts = chunkify_rows(time, utimes_per_chunk) + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + vis = data_dict["DATA"] + flag = data_dict["FLAG"] # get the numpy result - jhr = np_compute_jhr(time_bin_idx, time_bin_counts, ant1, ant2, - jones, vis, model, flag) + jhr = np_compute_jhr( + time_bin_idx, time_bin_counts, ant1, ant2, jones, vis, model, flag + ) - da_time_bin_idx = da.from_array(time_bin_idx, - chunks=(utimes_per_chunk)) - da_time_bin_counts = da.from_array(time_bin_counts, - chunks=(utimes_per_chunk)) + da_time_bin_idx = da.from_array(time_bin_idx, chunks=(utimes_per_chunk)) + da_time_bin_counts = da.from_array(time_bin_counts, chunks=(utimes_per_chunk)) da_ant1 = da.from_array(ant1, chunks=row_chunks) da_ant2 = da.from_array(ant2, chunks=row_chunks) - da_model = da.from_array(model, chunks=( - row_chunks, (n_chan,), (n_dir,)) + (corr_shape)) - da_jones = da.from_array(jones, chunks=( - utimes_per_chunk, n_ant, n_chan, n_dir)+jones_shape) + da_model = da.from_array( + model, chunks=(row_chunks, (n_chan,), (n_dir,)) + (corr_shape) + ) + da_jones = da.from_array( + jones, chunks=(utimes_per_chunk, n_ant, n_chan, n_dir) + jones_shape + ) da_flag = da.from_array(flag, chunks=(row_chunks, (n_chan,)) + corr_shape) da_vis = da.from_array(vis, chunks=(row_chunks, (n_chan,)) + corr_shape) from africanus.calibration.phase_only.dask import compute_jhr - da_jhr = compute_jhr(da_time_bin_idx, da_time_bin_counts, - da_ant1, da_ant2, da_jones, da_vis, da_model, da_flag) + da_jhr = compute_jhr( + da_time_bin_idx, + da_time_bin_counts, + da_ant1, + da_ant2, + da_jones, + da_vis, + da_model, + da_flag, + ) jhr2 = da_jhr.compute() @@ -161,31 +181,47 @@ def test_phase_only_diag_diag(data_factory): sigma_f = 0.1 corr_shape = (2,) jones_shape = (2,) - data_dict = data_factory(sigma_n, sigma_f, n_time, n_chan, - n_ant, n_dir, corr_shape, jones_shape, - phase_only_gains=True) - time = data_dict['TIME'] + data_dict = data_factory( + sigma_n, + sigma_f, + n_time, + n_chan, + n_ant, + n_dir, + corr_shape, + jones_shape, + phase_only_gains=True, + ) + time = data_dict["TIME"] _, time_bin_indices, time_bin_counts = chunkify_rows(time, n_time) - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - vis = data_dict['DATA'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - flag = data_dict['FLAG'] - weight = data_dict['WEIGHT_SPECTRUM'] + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + vis = data_dict["DATA"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + flag = data_dict["FLAG"] + weight = data_dict["WEIGHT_SPECTRUM"] # calibrate the data - jones0 = np.ones((n_time, n_ant, n_chan, n_dir) + jones_shape, - dtype=np.complex128) + jones0 = np.ones((n_time, n_ant, n_chan, n_dir) + jones_shape, dtype=np.complex128) precision = 5 gains, jhj, jhr, k = gauss_newton( - time_bin_indices, time_bin_counts, - ant1, ant2, jones0, vis, - flag, model, weight, - tol=10**(-precision), maxiter=250) + time_bin_indices, + time_bin_counts, + ant1, + ant2, + jones0, + vis, + flag, + model, + weight, + tol=10 ** (-precision), + maxiter=250, + ) # check that phase differences are correct for p in range(n_ant): for q in range(p): phase_diff_true = np.angle(jones[:, p]) - np.angle(jones[:, q]) phase_diff = np.angle(gains[:, p]) - np.angle(gains[:, q]) assert_array_almost_equal( - phase_diff_true, phase_diff, decimal=precision-3) + phase_diff_true, phase_diff, decimal=precision - 3 + ) diff --git a/africanus/calibration/tests/conftest.py b/africanus/calibration/tests/conftest.py index 3bf67a115..8d10bfe19 100644 --- a/africanus/calibration/tests/conftest.py +++ b/africanus/calibration/tests/conftest.py @@ -12,29 +12,38 @@ def lm_factory(n_dir, rs): - ls = 0.1*rs.randn(n_dir) - ms = 0.1*rs.randn(n_dir) + ls = 0.1 * rs.randn(n_dir) + ms = 0.1 * rs.randn(n_dir) lm = np.vstack((ls, ms)).T return lm def flux_factory(n_dir, n_chan, corr_shape, alpha, freq, freq0, rs): - w = freq/freq0 + w = freq / freq0 flux = np.zeros((n_dir, n_chan) + corr_shape, dtype=np.float64) for d in range(n_dir): tmp_flux = np.abs(rs.normal(size=corr_shape)) for v in range(n_chan): - flux[d, v] = tmp_flux * w[v]**alpha + flux[d, v] = tmp_flux * w[v] ** alpha return flux @pytest.fixture def data_factory(): - def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, - n_dir, corr_shape, jones_shape, phase_only_gains=False): + def impl( + sigma_n, + sigma_f, + n_time, + n_chan, + n_ant, + n_dir, + corr_shape, + jones_shape, + phase_only_gains=False, + ): rs = np.random.RandomState(42) - n_bl = n_ant*(n_ant-1)//2 - n_row = n_bl*n_time + n_bl = n_ant * (n_ant - 1) // 2 + n_row = n_bl * n_time # make aux data antenna1 = np.zeros(n_row, dtype=np.int16) antenna2 = np.zeros(n_row, dtype=np.int16) @@ -46,19 +55,18 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, row = 0 for p in range(n_ant): for q in range(p): - time[i*n_bl + row] = time_values[i] - antenna1[i*n_bl + row] = p - antenna2[i*n_bl + row] = q - uvw[i*n_bl + row] = np.random.randn(3) + time[i * n_bl + row] = time_values[i] + antenna1[i * n_bl + row] = p + antenna2[i * n_bl + row] = q + uvw[i * n_bl + row] = np.random.randn(3) row += 1 assert time.size == n_row # simulate visibilities - model_data = np.zeros((n_row, n_chan, n_dir) + - corr_shape, dtype=np.complex128) + model_data = np.zeros((n_row, n_chan, n_dir) + corr_shape, dtype=np.complex128) # make up some sources lm = lm_factory(n_dir, rs) alpha = -0.7 - freq0 = freq[n_chan//2] + freq0 = freq[n_chan // 2] flux = flux_factory(n_dir, n_chan, corr_shape, alpha, freq, freq0, rs) # simulate model data for dir in range(n_dir): @@ -69,28 +77,31 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, model_data[:, :, dir] = tmp.reshape((n_row, n_chan) + corr_shape) assert not np.isnan(model_data).any() # simulate gains (just randomly scattered around 1 for now) - jones = np.ones((n_time, n_ant, n_chan, n_dir) + - jones_shape, dtype=np.complex128) + jones = np.ones( + (n_time, n_ant, n_chan, n_dir) + jones_shape, dtype=np.complex128 + ) if sigma_f: if phase_only_gains: - jones = np.exp(1.0j*rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape)) + jones = np.exp( + 1.0j * rs.normal(loc=0.0, scale=sigma_f, size=jones.shape) + ) else: - jones += (rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape) + - 1.0j*rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape)) + jones += rs.normal( + loc=0.0, scale=sigma_f, size=jones.shape + ) + 1.0j * rs.normal(loc=0.0, scale=sigma_f, size=jones.shape) assert (np.abs(jones) > 1e-5).all() assert not np.isnan(jones).any() # get vis _, time_bin_indices, _, time_bin_counts = unique_time(time) - vis = corrupt_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model_data) + vis = corrupt_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model_data + ) assert not np.isnan(vis).any() # add noise if sigma_n: - vis += (rs.normal(loc=0.0, scale=sigma_n, size=vis.shape) + - 1.0j*rs.normal(loc=0.0, scale=sigma_n, size=vis.shape)) + vis += rs.normal(loc=0.0, scale=sigma_n, size=vis.shape) + 1.0j * rs.normal( + loc=0.0, scale=sigma_n, size=vis.shape + ) weights = np.ones(vis.shape, dtype=np.float64) if sigma_n: weights /= sigma_n**2 @@ -103,6 +114,7 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, data_dict["ANTENNA1"] = antenna1 data_dict["ANTENNA2"] = antenna2 data_dict["FLAG"] = flag - data_dict['JONES'] = jones + data_dict["JONES"] = jones return data_dict + return impl diff --git a/africanus/calibration/utils/compute_and_corrupt_vis.py b/africanus/calibration/utils/compute_and_corrupt_vis.py index 33ad85de0..f4d8ed6c7 100644 --- a/africanus/calibration/utils/compute_and_corrupt_vis.py +++ b/africanus/calibration/utils/compute_and_corrupt_vis.py @@ -10,96 +10,113 @@ def jones_mul_factory(mode): if mode == DIAG_DIAG: + def jones_mul(a1j, model, a2j, uvw, freq, lm, out): n_dir = np.shape(model)[0] u, v, w = uvw for s in range(n_dir): l, m = lm[s] n = np.sqrt(1 - l**2 - m**2) - real_phase = m2pioc * freq * (u*l + v*m + w*(n-1)) - source_vis = model[s] * np.exp(1.0j*real_phase)/n + real_phase = m2pioc * freq * (u * l + v * m + w * (n - 1)) + source_vis = model[s] * np.exp(1.0j * real_phase) / n for c in range(out.shape[-1]): - out[c] += a1j[s, c]*source_vis[c]*np.conj(a2j[s, c]) + out[c] += a1j[s, c] * source_vis[c] * np.conj(a2j[s, c]) elif mode == DIAG: + def jones_mul(a1j, model, a2j, uvw, freq, lm, out): n_dir = np.shape(model)[0] u, v, w = uvw for s in range(n_dir): l, m = lm[s] n = np.sqrt(1 - l**2 - m**2) - real_phase = m2pioc * freq * (u*l + v*m + w*(n-1)) - source_vis = model[s] * np.exp(1.0j*real_phase)/n - out[0, 0] += a1j[s, 0]*source_vis[0, 0] * np.conj(a2j[s, 0]) - out[0, 1] += a1j[s, 0]*source_vis[0, 1] * np.conj(a2j[s, 1]) - out[1, 0] += a1j[s, 1]*source_vis[1, 0] * np.conj(a2j[s, 0]) - out[1, 1] += a1j[s, 1]*source_vis[1, 1] * np.conj(a2j[s, 1]) + real_phase = m2pioc * freq * (u * l + v * m + w * (n - 1)) + source_vis = model[s] * np.exp(1.0j * real_phase) / n + out[0, 0] += a1j[s, 0] * source_vis[0, 0] * np.conj(a2j[s, 0]) + out[0, 1] += a1j[s, 0] * source_vis[0, 1] * np.conj(a2j[s, 1]) + out[1, 0] += a1j[s, 1] * source_vis[1, 0] * np.conj(a2j[s, 0]) + out[1, 1] += a1j[s, 1] * source_vis[1, 1] * np.conj(a2j[s, 1]) elif mode == FULL: + def jones_mul(a1j, model, a2j, uvw, freq, lm, out): n_dir = np.shape(model)[0] u, v, w = uvw for s in range(n_dir): l, m = lm[s] n = np.sqrt(1 - l**2 - m**2) - real_phase = m2pioc * freq * (u*l + v*m + w*(n-1)) - source_vis = model[s] * np.exp(1.0j*real_phase)/n + real_phase = m2pioc * freq * (u * l + v * m + w * (n - 1)) + source_vis = model[s] * np.exp(1.0j * real_phase) / n # precompute resuable terms - t1 = a1j[s, 0, 0]*source_vis[0, 0] - t2 = a1j[s, 0, 1]*source_vis[1, 0] - t3 = a1j[s, 0, 0]*source_vis[0, 1] - t4 = a1j[s, 0, 1]*source_vis[1, 1] + t1 = a1j[s, 0, 0] * source_vis[0, 0] + t2 = a1j[s, 0, 1] * source_vis[1, 0] + t3 = a1j[s, 0, 0] * source_vis[0, 1] + t4 = a1j[s, 0, 1] * source_vis[1, 1] tmp = np.conj(a2j[s].T) # overwrite with result - out[0, 0] += t1*tmp[0, 0] +\ - t2*tmp[0, 0] +\ - t3*tmp[1, 0] +\ - t4*tmp[1, 0] - out[0, 1] += t1*tmp[0, 1] +\ - t2*tmp[0, 1] +\ - t3*tmp[1, 1] +\ - t4*tmp[1, 1] - t1 = a1j[s, 1, 0]*source_vis[0, 0] - t2 = a1j[s, 1, 1]*source_vis[1, 0] - t3 = a1j[s, 1, 0]*source_vis[0, 1] - t4 = a1j[s, 1, 1]*source_vis[1, 1] - out[1, 0] += t1*tmp[0, 0] +\ - t2*tmp[0, 0] +\ - t3*tmp[1, 0] +\ - t4*tmp[1, 0] - out[1, 1] += t1*tmp[0, 1] +\ - t2*tmp[0, 1] +\ - t3*tmp[1, 1] +\ - t4*tmp[1, 1] - - return njit(nogil=True, inline='always')(jones_mul) + out[0, 0] += ( + t1 * tmp[0, 0] + t2 * tmp[0, 0] + t3 * tmp[1, 0] + t4 * tmp[1, 0] + ) + out[0, 1] += ( + t1 * tmp[0, 1] + t2 * tmp[0, 1] + t3 * tmp[1, 1] + t4 * tmp[1, 1] + ) + t1 = a1j[s, 1, 0] * source_vis[0, 0] + t2 = a1j[s, 1, 1] * source_vis[1, 0] + t3 = a1j[s, 1, 0] * source_vis[0, 1] + t4 = a1j[s, 1, 1] * source_vis[1, 1] + out[1, 0] += ( + t1 * tmp[0, 0] + t2 * tmp[0, 0] + t3 * tmp[1, 0] + t4 * tmp[1, 0] + ) + out[1, 1] += ( + t1 * tmp[0, 1] + t2 * tmp[0, 1] + t3 * tmp[1, 1] + t4 * tmp[1, 1] + ) + + return njit(nogil=True, inline="always")(jones_mul) @njit(**JIT_OPTIONS) -def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, uvw, freq, lm): - return compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model, - uvw, freq, lm) - - -def compute_and_corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, uvw, freq, lm): +def compute_and_corrupt_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm +): + return compute_and_corrupt_vis_impl( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + model, + uvw, + freq, + lm, + ) + + +def compute_and_corrupt_vis_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm +): return NotImplementedError @overload(compute_and_corrupt_vis_impl, jit_options=JIT_OPTIONS) -def mb_compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model, uvw, freq, lm): - - mode = check_type(jones, model, vis_type='model') +def mb_compute_and_corrupt_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm +): + mode = check_type(jones, model, vis_type="model") jones_mul = jones_mul_factory(mode) - def _compute_and_corrupt_vis_fn(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model, - uvw, freq, lm): + def _compute_and_corrupt_vis_fn( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones, + model, + uvw, + freq, + lm, + ): if model.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") if jones.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() @@ -109,21 +126,30 @@ def _compute_and_corrupt_vis_fn(time_bin_indices, time_bin_counts, vis = np.zeros(vis_shape, dtype=jones.dtype) n_chan = model_shape[1] for t in range(n_tim): - for row in range(time_bin_indices[t], - time_bin_indices[t] + time_bin_counts[t]): + for row in range( + time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t] + ): p = int(antenna1[row]) q = int(antenna2[row]) gp = jones[t, p] gq = jones[t, q] for nu in range(n_chan): - jones_mul(gp[nu], model[t, nu], gq[nu], uvw[row], - freq[nu], lm[t], vis[row, nu]) + jones_mul( + gp[nu], + model[t, nu], + gq[nu], + uvw[row], + freq[nu], + lm[t], + vis[row, nu], + ) return vis return _compute_and_corrupt_vis_fn -COMPUTE_AND_CORRUPT_VIS_DOCS = DocstringTemplate(""" +COMPUTE_AND_CORRUPT_VIS_DOCS = DocstringTemplate( + """ Corrupts time variable component model with arbitrary Jones terms. Currrently only time variable point source models are supported. @@ -159,11 +185,13 @@ def _compute_and_corrupt_vis_fn(time_bin_indices, time_bin_counts, visibilities of shape :code:`(row, chan, corr)` or :code:`(row, chan, corr, corr)`. -""") +""" +) try: compute_and_corrupt_vis.__doc__ = COMPUTE_AND_CORRUPT_VIS_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/calibration/utils/correct_vis.py b/africanus/calibration/utils/correct_vis.py index 6e99cdea6..db4f58110 100644 --- a/africanus/calibration/utils/correct_vis.py +++ b/africanus/calibration/utils/correct_vis.py @@ -9,116 +9,116 @@ def jones_inverse_mul_factory(mode): if mode == DIAG_DIAG: + def jones_inverse_mul(a1j, blj, a2j, out): for c in range(out.shape[-1]): - out[c] = blj[c]/(a1j[c]*np.conj(a2j[c])) + out[c] = blj[c] / (a1j[c] * np.conj(a2j[c])) elif mode == DIAG: + def jones_inverse_mul(a1j, blj, a2j, out): - out[0, 0] = blj[0, 0]/(a1j[0]*np.conj(a2j[0])) - out[0, 1] = blj[0, 1]/(a1j[0]*np.conj(a2j[1])) - out[1, 0] = blj[1, 0]/(a1j[1]*np.conj(a2j[0])) - out[1, 1] = blj[1, 1]/(a1j[1]*np.conj(a2j[1])) + out[0, 0] = blj[0, 0] / (a1j[0] * np.conj(a2j[0])) + out[0, 1] = blj[0, 1] / (a1j[0] * np.conj(a2j[1])) + out[1, 0] = blj[1, 0] / (a1j[1] * np.conj(a2j[0])) + out[1, 1] = blj[1, 1] / (a1j[1] * np.conj(a2j[1])) elif mode == FULL: + def jones_inverse_mul(a1j, blj, a2j, out): # get determinant - deta1j = a1j[0, 0]*a1j[1, 1]-a1j[0, 1]*a1j[1, 0] + deta1j = a1j[0, 0] * a1j[1, 1] - a1j[0, 1] * a1j[1, 0] # compute inverse - a00 = a1j[1, 1]/deta1j - a01 = -a1j[0, 1]/deta1j - a10 = -a1j[1, 0]/deta1j - a11 = a1j[0, 0]/deta1j + a00 = a1j[1, 1] / deta1j + a01 = -a1j[0, 1] / deta1j + a10 = -a1j[1, 0] / deta1j + a11 = a1j[0, 0] / deta1j # get determinant a2j = np.conj(a2j) - deta2j = a2j[0, 0]*a2j[1, 1]-a2j[0, 1]*a2j[1, 0] + deta2j = a2j[0, 0] * a2j[1, 1] - a2j[0, 1] * a2j[1, 0] # get conjugate transpose inverse - b00 = a2j[1, 1]/deta2j - b01 = -a2j[1, 0]/deta2j - b10 = -a2j[0, 1]/deta2j - b11 = a2j[0, 0]/deta2j + b00 = a2j[1, 1] / deta2j + b01 = -a2j[1, 0] / deta2j + b10 = -a2j[0, 1] / deta2j + b11 = a2j[0, 0] / deta2j # precompute resuable terms - t1 = a00*blj[0, 0] - t2 = a01*blj[1, 0] - t3 = a00*blj[0, 1] - t4 = a01*blj[1, 1] + t1 = a00 * blj[0, 0] + t2 = a01 * blj[1, 0] + t3 = a00 * blj[0, 1] + t4 = a01 * blj[1, 1] # overwrite with result - out[0, 0] = t1*b00 +\ - t2*b00 +\ - t3*b10 +\ - t4*b10 - out[0, 1] = t1*b01 +\ - t2*b01 +\ - t3*b11 +\ - t4*b11 - t1 = a10*blj[0, 0] - t2 = a11*blj[1, 0] - t3 = a10*blj[0, 1] - t4 = a11*blj[1, 1] - out[1, 0] = t1*b00 +\ - t2*b00 +\ - t3*b10 +\ - t4*b10 - out[1, 1] = t1*b01 +\ - t2*b01 +\ - t3*b11 +\ - t4*b11 - return njit(nogil=True, inline='always')(jones_inverse_mul) + out[0, 0] = t1 * b00 + t2 * b00 + t3 * b10 + t4 * b10 + out[0, 1] = t1 * b01 + t2 * b01 + t3 * b11 + t4 * b11 + t1 = a10 * blj[0, 0] + t2 = a11 * blj[1, 0] + t3 = a10 * blj[0, 1] + t4 = a11 * blj[1, 1] + out[1, 0] = t1 * b00 + t2 * b00 + t3 * b10 + t4 * b10 + out[1, 1] = t1 * b01 + t2 * b01 + t3 * b11 + t4 * b11 + + return njit(nogil=True, inline="always")(jones_inverse_mul) @njit(**JIT_OPTIONS) -def correct_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, vis, flag): - return correct_vis_impl(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, vis, flag) +def correct_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag +): + return correct_vis_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag + ) -def correct_vis_impl(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, vis, flag): +def correct_vis_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag +): return NotImplementedError @overload(correct_vis_impl, jit_options=JIT_OPTIONS) -def nb_correct_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, vis, flag): - +def nb_correct_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag +): mode = check_type(jones, vis) jones_inverse_mul = jones_inverse_mul_factory(mode) - def _correct_vis_fn(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, vis, flag): + def _correct_vis_fn( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag + ): # for dask arrays we need to adjust the chunks to # start counting from zero if vis.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") if jones.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") time_bin_indices -= time_bin_indices.min() jones_shape = np.shape(jones) n_tim = jones_shape[0] n_dir = jones_shape[3] if n_dir > 1: - raise ValueError("Jones has n_dir > 1. Cannot correct " - "for direction dependent gains") + raise ValueError( + "Jones has n_dir > 1. Cannot correct " "for direction dependent gains" + ) n_chan = jones_shape[2] corrected_vis = np.zeros_like(vis, dtype=vis.dtype) for t in range(n_tim): - for row in range(time_bin_indices[t], - time_bin_indices[t] + time_bin_counts[t]): + for row in range( + time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t] + ): p = int(antenna1[row]) q = int(antenna2[row]) gp = jones[t, p] gq = jones[t, q] for nu in range(n_chan): if not np.any(flag[row, nu]): - jones_inverse_mul(gp[nu, 0], vis[row, nu], gq[nu, 0], - corrected_vis[row, nu]) + jones_inverse_mul( + gp[nu, 0], vis[row, nu], gq[nu, 0], corrected_vis[row, nu] + ) return corrected_vis return _correct_vis_fn -CORRECT_VIS_DOCS = DocstringTemplate(""" +CORRECT_VIS_DOCS = DocstringTemplate( + """ Apply inverse of direction independent gains to visibilities to generate corrected visibilities. For a measurement model of the form @@ -168,10 +168,12 @@ def _correct_vis_fn(time_bin_indices, time_bin_counts, ------- corrected_vis : $(array_type) True visibilities of shape :code:`(row,chan,corr_1,corr_2)` -""") +""" +) try: correct_vis.__doc__ = CORRECT_VIS_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/calibration/utils/corrupt_vis.py b/africanus/calibration/utils/corrupt_vis.py index 191069f00..68ad83f13 100644 --- a/africanus/calibration/utils/corrupt_vis.py +++ b/africanus/calibration/utils/corrupt_vis.py @@ -9,78 +9,77 @@ def jones_mul_factory(mode): if mode == DIAG_DIAG: + def jones_mul(a1j, model, a2j, out): n_dir = np.shape(model)[0] for s in range(n_dir): - out += a1j[s]*model[s]*np.conj(a2j[s]) + out += a1j[s] * model[s] * np.conj(a2j[s]) elif mode == DIAG: + def jones_mul(a1j, model, a2j, out): n_dir = np.shape(model)[0] for s in range(n_dir): - out[0, 0] += a1j[s, 0]*model[s, 0, 0] * np.conj(a2j[s, 0]) - out[0, 1] += a1j[s, 0]*model[s, 0, 1] * np.conj(a2j[s, 1]) - out[1, 0] += a1j[s, 1]*model[s, 1, 0] * np.conj(a2j[s, 0]) - out[1, 1] += a1j[s, 1]*model[s, 1, 1] * np.conj(a2j[s, 1]) + out[0, 0] += a1j[s, 0] * model[s, 0, 0] * np.conj(a2j[s, 0]) + out[0, 1] += a1j[s, 0] * model[s, 0, 1] * np.conj(a2j[s, 1]) + out[1, 0] += a1j[s, 1] * model[s, 1, 0] * np.conj(a2j[s, 0]) + out[1, 1] += a1j[s, 1] * model[s, 1, 1] * np.conj(a2j[s, 1]) elif mode == FULL: + def jones_mul(a1j, model, a2j, out): n_dir = np.shape(model)[0] for s in range(n_dir): # precompute resuable terms - t1 = a1j[s, 0, 0]*model[s, 0, 0] - t2 = a1j[s, 0, 1]*model[s, 1, 0] - t3 = a1j[s, 0, 0]*model[s, 0, 1] - t4 = a1j[s, 0, 1]*model[s, 1, 1] + t1 = a1j[s, 0, 0] * model[s, 0, 0] + t2 = a1j[s, 0, 1] * model[s, 1, 0] + t3 = a1j[s, 0, 0] * model[s, 0, 1] + t4 = a1j[s, 0, 1] * model[s, 1, 1] tmp = np.conj(a2j[s].T) # overwrite with result - out[0, 0] += t1*tmp[0, 0] +\ - t2*tmp[0, 0] +\ - t3*tmp[1, 0] +\ - t4*tmp[1, 0] - out[0, 1] += t1*tmp[0, 1] +\ - t2*tmp[0, 1] +\ - t3*tmp[1, 1] +\ - t4*tmp[1, 1] - t1 = a1j[s, 1, 0]*model[s, 0, 0] - t2 = a1j[s, 1, 1]*model[s, 1, 0] - t3 = a1j[s, 1, 0]*model[s, 0, 1] - t4 = a1j[s, 1, 1]*model[s, 1, 1] - out[1, 0] += t1*tmp[0, 0] +\ - t2*tmp[0, 0] +\ - t3*tmp[1, 0] +\ - t4*tmp[1, 0] - out[1, 1] += t1*tmp[0, 1] +\ - t2*tmp[0, 1] +\ - t3*tmp[1, 1] +\ - t4*tmp[1, 1] - - return njit(nogil=True, inline='always')(jones_mul) + out[0, 0] += ( + t1 * tmp[0, 0] + t2 * tmp[0, 0] + t3 * tmp[1, 0] + t4 * tmp[1, 0] + ) + out[0, 1] += ( + t1 * tmp[0, 1] + t2 * tmp[0, 1] + t3 * tmp[1, 1] + t4 * tmp[1, 1] + ) + t1 = a1j[s, 1, 0] * model[s, 0, 0] + t2 = a1j[s, 1, 1] * model[s, 1, 0] + t3 = a1j[s, 1, 0] * model[s, 0, 1] + t4 = a1j[s, 1, 1] * model[s, 1, 1] + out[1, 0] += ( + t1 * tmp[0, 0] + t2 * tmp[0, 0] + t3 * tmp[1, 0] + t4 * tmp[1, 0] + ) + out[1, 1] += ( + t1 * tmp[0, 1] + t2 * tmp[0, 1] + t3 * tmp[1, 1] + t4 * tmp[1, 1] + ) + + return njit(nogil=True, inline="always")(jones_mul) @njit(**JIT_OPTIONS) -def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model): - return corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model) +def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): + return corrupt_vis_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model + ) -def corrupt_vis_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model): +def corrupt_vis_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model +): return NotImplementedError @overload(corrupt_vis_impl, jit_options=JIT_OPTIONS) -def nb_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model): - - mode = check_type(jones, model, vis_type='model') +def nb_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): + mode = check_type(jones, model, vis_type="model") jones_mul = jones_mul_factory(mode) - def _corrupt_vis_fn(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model): + def _corrupt_vis_fn( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model + ): if model.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") if jones.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() @@ -90,8 +89,9 @@ def _corrupt_vis_fn(time_bin_indices, time_bin_counts, antenna1, vis = np.zeros(vis_shape, dtype=model.dtype) n_chan = model_shape[1] for t in range(n_tim): - for row in range(time_bin_indices[t], - time_bin_indices[t] + time_bin_counts[t]): + for row in range( + time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t] + ): p = int(antenna1[row]) q = int(antenna2[row]) gp = jones[t, p] @@ -103,7 +103,8 @@ def _corrupt_vis_fn(time_bin_indices, time_bin_counts, antenna1, return _corrupt_vis_fn -CORRUPT_VIS_DOCS = DocstringTemplate(""" +CORRUPT_VIS_DOCS = DocstringTemplate( + """ Corrupts model visibilities with arbitrary Jones terms. @@ -132,11 +133,13 @@ def _corrupt_vis_fn(time_bin_indices, time_bin_counts, antenna1, visibilities of shape :code:`(time, ant, chan, dir, corr)` or :code:`(time, ant, chan, dir, corr, corr)`. -""") +""" +) try: corrupt_vis.__doc__ = CORRUPT_VIS_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/calibration/utils/dask.py b/africanus/calibration/utils/dask.py index 073d84ff7..962edd020 100644 --- a/africanus/calibration/utils/dask.py +++ b/africanus/calibration/utils/dask.py @@ -4,10 +4,12 @@ from africanus.calibration.utils.corrupt_vis import CORRUPT_VIS_DOCS from africanus.calibration.utils.residual_vis import RESIDUAL_VIS_DOCS from africanus.calibration.utils.compute_and_corrupt_vis import ( - COMPUTE_AND_CORRUPT_VIS_DOCS) + COMPUTE_AND_CORRUPT_VIS_DOCS, +) from africanus.calibration.utils import correct_vis as np_correct_vis -from africanus.calibration.utils import (compute_and_corrupt_vis as - np_compute_and_corrupt_vis) +from africanus.calibration.utils import ( + compute_and_corrupt_vis as np_compute_and_corrupt_vis, +) from africanus.calibration.utils import corrupt_vis as np_corrupt_vis from africanus.calibration.utils import residual_vis as np_residual_vis from africanus.calibration.utils import check_type @@ -22,17 +24,17 @@ dask_import_error = None -def _corrupt_vis_wrapper(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model): - return np_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones[0][0], model[0]) +def _corrupt_vis_wrapper( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model +): + return np_corrupt_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones[0][0], model[0] + ) -@requires_optional('dask.array', dask_import_error) -def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, model): - - mode = check_type(jones, model, vis_type='model') +@requires_optional("dask.array", dask_import_error) +def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): + mode = check_type(jones, model, vis_type="model") if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") @@ -56,31 +58,47 @@ def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, else: raise ValueError("Unknown mode argument of %s" % mode) - return blockwise(_corrupt_vis_wrapper, out_shape, - time_bin_indices, ("row",), - time_bin_counts, ("row",), - antenna1, ("row",), - antenna2, ("row",), - jones, jones_shape, - model, model_shape, - adjust_chunks={"row": antenna1.chunks[0]}, - dtype=model.dtype, - align_arrays=False) - - -def _compute_and_corrupt_vis_wrapper(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model, - uvw, freq, lm): - return np_compute_and_corrupt_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones[0][0], - model[0], uvw[0], freq, lm[0][0]) - - -@requires_optional('dask.array', dask_import_error) -def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model, - uvw, freq, lm): - + return blockwise( + _corrupt_vis_wrapper, + out_shape, + time_bin_indices, + ("row",), + time_bin_counts, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + jones, + jones_shape, + model, + model_shape, + adjust_chunks={"row": antenna1.chunks[0]}, + dtype=model.dtype, + align_arrays=False, + ) + + +def _compute_and_corrupt_vis_wrapper( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm +): + return np_compute_and_corrupt_vis( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones[0][0], + model[0], + uvw[0], + freq, + lm[0][0], + ) + + +@requires_optional("dask.array", dask_import_error) +def compute_and_corrupt_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm +): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: @@ -94,7 +112,7 @@ def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, if lm.chunks[2][0] != lm.shape[2]: raise ValueError("Cannot chunks lm over last axis") - mode = check_type(jones, model, vis_type='model') + mode = check_type(jones, model, vis_type="model") if mode == DIAG_DIAG: out_shape = ("row", "chan", "corr1") @@ -111,31 +129,45 @@ def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, else: raise ValueError("Unknown mode argument of %s" % mode) - return blockwise(_compute_and_corrupt_vis_wrapper, out_shape, - time_bin_indices, ("row",), - time_bin_counts, ("row",), - antenna1, ("row",), - antenna2, ("row",), - jones, jones_shape, - model, model_shape, - uvw, ("row", "three"), - freq, ("chan",), - lm, ("row", "dir", "two"), - adjust_chunks={"row": antenna1.chunks[0]}, - dtype=model.dtype, - align_arrays=False) - - -def _correct_vis_wrapper(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag): - return np_correct_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones[0][0], vis, flag) - - -@requires_optional('dask.array', dask_import_error) -def correct_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag): - + return blockwise( + _compute_and_corrupt_vis_wrapper, + out_shape, + time_bin_indices, + ("row",), + time_bin_counts, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + jones, + jones_shape, + model, + model_shape, + uvw, + ("row", "three"), + freq, + ("chan",), + lm, + ("row", "dir", "two"), + adjust_chunks={"row": antenna1.chunks[0]}, + dtype=model.dtype, + align_arrays=False, + ) + + +def _correct_vis_wrapper( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag +): + return np_correct_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones[0][0], vis, flag + ) + + +@requires_optional("dask.array", dask_import_error) +def correct_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag +): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: @@ -155,29 +187,48 @@ def correct_vis(time_bin_indices, time_bin_counts, antenna1, else: raise ValueError("Unknown mode argument of %s" % mode) - return blockwise(_correct_vis_wrapper, out_shape, - time_bin_indices, ("row",), - time_bin_counts, ("row",), - antenna1, ("row",), - antenna2, ("row",), - jones, jones_shape, - vis, out_shape, - flag, out_shape, - adjust_chunks={"row": antenna1.chunks[0]}, - dtype=vis.dtype, - align_arrays=False) - - -def _residual_vis_wrapper(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model): - return np_residual_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones[0][0], vis, flag, model[0]) - - -@requires_optional('dask.array', dask_import_error) -def residual_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model): - + return blockwise( + _correct_vis_wrapper, + out_shape, + time_bin_indices, + ("row",), + time_bin_counts, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + jones, + jones_shape, + vis, + out_shape, + flag, + out_shape, + adjust_chunks={"row": antenna1.chunks[0]}, + dtype=vis.dtype, + align_arrays=False, + ) + + +def _residual_vis_wrapper( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model +): + return np_residual_vis( + time_bin_indices, + time_bin_counts, + antenna1, + antenna2, + jones[0][0], + vis, + flag, + model[0], + ) + + +@requires_optional("dask.array", dask_import_error) +def residual_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model +): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: @@ -202,28 +253,43 @@ def residual_vis(time_bin_indices, time_bin_counts, antenna1, else: raise ValueError("Unknown mode argument of %s" % mode) - return blockwise(_residual_vis_wrapper, out_shape, - time_bin_indices, ("row",), - time_bin_counts, ("row",), - antenna1, ("row",), - antenna2, ("row",), - jones, jones_shape, - vis, out_shape, - flag, out_shape, - model, model_shape, - adjust_chunks={"row": antenna1.chunks[0]}, - dtype=vis.dtype, - align_arrays=False) + return blockwise( + _residual_vis_wrapper, + out_shape, + time_bin_indices, + ("row",), + time_bin_counts, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + jones, + jones_shape, + vis, + out_shape, + flag, + out_shape, + model, + model_shape, + adjust_chunks={"row": antenna1.chunks[0]}, + dtype=vis.dtype, + align_arrays=False, + ) compute_and_corrupt_vis.__doc__ = COMPUTE_AND_CORRUPT_VIS_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) corrupt_vis.__doc__ = CORRUPT_VIS_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) correct_vis.__doc__ = CORRECT_VIS_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) residual_vis.__doc__ = RESIDUAL_VIS_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) diff --git a/africanus/calibration/utils/examples/apply_gains_time_varying_sources.py b/africanus/calibration/utils/examples/apply_gains_time_varying_sources.py index 9e8be3d2c..d51c9cbe8 100755 --- a/africanus/calibration/utils/examples/apply_gains_time_varying_sources.py +++ b/africanus/calibration/utils/examples/apply_gains_time_varying_sources.py @@ -24,29 +24,47 @@ def create_parser(): p = argparse.ArgumentParser() p.add_argument("--ms", help="Name of measurement set", type=str) p.add_argument("--sky_model", type=str, help="Tigger lsm file") - p.add_argument("--data_col", help="Column where data lives. " - "Only used to get shape of data at this stage", - default='DATA', type=str) - p.add_argument("--out_col", help="Where to write the corrupted data to. " - "Must exist in MS before writing to it.", - default='DATA', type=str) - p.add_argument("--gain_file", help=".npy file containing gains in format " - "(time, antenna, freq, source, corr). " - "See corrupt_vis docs.", type=str) - p.add_argument("--utimes_per_chunk", default=32, type=int, - help="Number of unique times in each chunk.") - p.add_argument("--ncpu", help="The number of threads to use. " - "Default of zero means all", default=10, type=int) - p.add_argument('--field', default=0, type=int) + p.add_argument( + "--data_col", + help="Column where data lives. " "Only used to get shape of data at this stage", + default="DATA", + type=str, + ) + p.add_argument( + "--out_col", + help="Where to write the corrupted data to. " + "Must exist in MS before writing to it.", + default="DATA", + type=str, + ) + p.add_argument( + "--gain_file", + help=".npy file containing gains in format " + "(time, antenna, freq, source, corr). " + "See corrupt_vis docs.", + type=str, + ) + p.add_argument( + "--utimes_per_chunk", + default=32, + type=int, + help="Number of unique times in each chunk.", + ) + p.add_argument( + "--ncpu", + help="The number of threads to use. " "Default of zero means all", + default=10, + type=int, + ) + p.add_argument("--field", default=0, type=int) return p def main(args): # get full time column and compute row chunks ms = table(args.ms) - time = ms.getcol('TIME') - row_chunks, tbin_idx, tbin_counts = chunkify_rows( - time, args.utimes_per_chunk) + time = ms.getcol("TIME") + row_chunks, tbin_idx, tbin_counts = chunkify_rows(time, args.utimes_per_chunk) # convert to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) @@ -54,14 +72,15 @@ def main(args): ms.close() # get phase dir - fld = table(args.ms+'::FIELD') - radec0 = fld.getcol('PHASE_DIR').squeeze().reshape(1, 2) + fld = table(args.ms + "::FIELD") + radec0 = fld.getcol("PHASE_DIR").squeeze().reshape(1, 2) radec0 = np.tile(radec0, (n_time, 1)) fld.close() # get freqs - freqs = table( - args.ms+'::SPECTRAL_WINDOW').getcol('CHAN_FREQ')[0].astype(np.float64) + freqs = ( + table(args.ms + "::SPECTRAL_WINDOW").getcol("CHAN_FREQ")[0].astype(np.float64) + ) n_freq = freqs.size freqs = da.from_array(freqs, chunks=(n_freq)) @@ -94,24 +113,22 @@ def main(args): spi = np.asarray(spi) for t in range(n_time): for d in range(n_dir): - model[t, :, d, 0] = stokes[d] * (freqs/ref_freqs[d])**spi[d] + model[t, :, d, 0] = stokes[d] * (freqs / ref_freqs[d]) ** spi[d] # append antenna columns cols = [] - cols.append('ANTENNA1') - cols.append('ANTENNA2') - cols.append('UVW') + cols.append("ANTENNA1") + cols.append("ANTENNA2") + cols.append("UVW") # load in gains jones = np.load(args.gain_file) jones = jones.astype(np.complex128) jones_shape = jones.shape - jones = da.from_array(jones, chunks=(args.utimes_per_chunk,) - + jones_shape[1::]) + jones = da.from_array(jones, chunks=(args.utimes_per_chunk,) + jones_shape[1::]) # change model to dask array - model = da.from_array(model, chunks=(args.utimes_per_chunk,) - + model.shape[1::]) + model = da.from_array(model, chunks=(args.utimes_per_chunk,) + model.shape[1::]) # load data in in chunks and apply gains to each chunk xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0] @@ -120,8 +137,9 @@ def main(args): uvw = xds.UVW.data # apply gains - data = compute_and_corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, - jones, model, uvw, freqs, lm) + data = compute_and_corrupt_vis( + tbin_idx, tbin_counts, ant1, ant2, jones, model, uvw, freqs, lm + ) # Assign visibilities to args.out_col and write to ms xds = xds.assign(**{args.out_col: (("row", "chan", "corr"), data)}) @@ -139,9 +157,11 @@ def main(args): if args.ncpu: from multiprocessing.pool import ThreadPool import dask + dask.config.set(pool=ThreadPool(args.ncpu)) else: import multiprocessing + args.ncpu = multiprocessing.cpu_count() print("Using %i threads" % args.ncpu) diff --git a/africanus/calibration/utils/examples/apply_gains_to_ms.py b/africanus/calibration/utils/examples/apply_gains_to_ms.py index 79f1f4c3b..cf8c6a926 100755 --- a/africanus/calibration/utils/examples/apply_gains_to_ms.py +++ b/africanus/calibration/utils/examples/apply_gains_to_ms.py @@ -22,23 +22,47 @@ def create_parser(): p = argparse.ArgumentParser() p.add_argument("--ms", help="Name of measurement set", type=str) - p.add_argument("--model_cols", help="Comma separated string of " - "merasuturement set columns containing data " - "for each source", default='MODEL_DATA', type=str) - p.add_argument("--data_col", help="Column where data lives. " - "Only used to get shape of data at this stage", - default='DATA', type=str) - p.add_argument("--out_col", help="Where to write the corrupted data to. " - "Must exist in MS before writing to it.", - default='CORRECTED_DATA', type=str) - p.add_argument("--gain_file", help=".npy file containing gains in format " - "(time, antenna, freq, source, corr). " - "See corrupt_vis docs.", type=str) - p.add_argument("--utimes_per_chunk", default=32, type=int, - help="Number of unique times in each chunk.") - p.add_argument("--ncpu", help="The number of threads to use. " - "Default of zero means all", default=0, type=int) - p.add_argument('--field', default=0, type=int) + p.add_argument( + "--model_cols", + help="Comma separated string of " + "merasuturement set columns containing data " + "for each source", + default="MODEL_DATA", + type=str, + ) + p.add_argument( + "--data_col", + help="Column where data lives. " "Only used to get shape of data at this stage", + default="DATA", + type=str, + ) + p.add_argument( + "--out_col", + help="Where to write the corrupted data to. " + "Must exist in MS before writing to it.", + default="CORRECTED_DATA", + type=str, + ) + p.add_argument( + "--gain_file", + help=".npy file containing gains in format " + "(time, antenna, freq, source, corr). " + "See corrupt_vis docs.", + type=str, + ) + p.add_argument( + "--utimes_per_chunk", + default=32, + type=int, + help="Number of unique times in each chunk.", + ) + p.add_argument( + "--ncpu", + help="The number of threads to use. " "Default of zero means all", + default=0, + type=int, + ) + p.add_argument("--field", default=0, type=int) return p @@ -48,28 +72,30 @@ def create_parser(): ncpu = args.ncpu from multiprocessing.pool import ThreadPool import dask + dask.config.set(pool=ThreadPool(ncpu)) else: import multiprocessing + ncpu = multiprocessing.cpu_count() print("Using %i threads" % ncpu) # get full time column and compute row chunks -time = table(args.ms).getcol('TIME') +time = table(args.ms).getcol("TIME") row_chunks, tbin_idx, tbin_counts = chunkify_rows(time, args.utimes_per_chunk) # convert to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) # get model column names -model_cols = args.model_cols.split(',') +model_cols = args.model_cols.split(",") n_dir = len(model_cols) # append antenna columns cols = [] -cols.append('ANTENNA1') -cols.append('ANTENNA2') +cols.append("ANTENNA1") +cols.append("ANTENNA2") cols.append(args.data_col) for col in model_cols: cols.append(col) @@ -79,8 +105,7 @@ def create_parser(): jones = jones.astype(np.complex64) jones_shape = jones.shape ndims = len(jones_shape) -jones = da.from_array(jones, chunks=(args.utimes_per_chunk,) - + jones_shape[1:]) +jones = da.from_array(jones, chunks=(args.utimes_per_chunk,) + jones_shape[1:]) # load data in in chunks and apply gains to each chunk xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0] @@ -102,8 +127,7 @@ def create_parser(): reshape_vis = False # apply gains -corrupted_data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, - jones, model) +corrupted_data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, model) if reshape_vis: corrupted_data = corrupted_data.reshape(n_row, n_chan, n_corr) diff --git a/africanus/calibration/utils/examples/apply_phase_screen_to_ms.py b/africanus/calibration/utils/examples/apply_phase_screen_to_ms.py index b03efe30a..937349661 100755 --- a/africanus/calibration/utils/examples/apply_phase_screen_to_ms.py +++ b/africanus/calibration/utils/examples/apply_phase_screen_to_ms.py @@ -21,20 +21,33 @@ from africanus.calibration.utils.dask import compute_and_corrupt_vis import numpy as np import matplotlib as mpl -mpl.use('TkAgg') + +mpl.use("TkAgg") def create_parser(): p = argparse.ArgumentParser() p.add_argument("--ms", help="Name of measurement set", type=str) p.add_argument("--sky_model", type=str, help="Tigger lsm file") - p.add_argument("--out_col", help="Where to write the corrupted data to. " - "Must exist in MS before writing to it.", - default='DATA', type=str) - p.add_argument("--utimes_per_chunk", default=64, type=int, - help="Number of unique times in each chunk.") - p.add_argument("--ncpu", help="The number of threads to use. " - "Default of zero means all", default=10, type=int) + p.add_argument( + "--out_col", + help="Where to write the corrupted data to. " + "Must exist in MS before writing to it.", + default="DATA", + type=str, + ) + p.add_argument( + "--utimes_per_chunk", + default=64, + type=int, + help="Number of unique times in each chunk.", + ) + p.add_argument( + "--ncpu", + help="The number of threads to use. " "Default of zero means all", + default=10, + type=int, + ) return p @@ -45,37 +58,36 @@ def make_screen(lm, freq, n_time, n_ant, n_corr): n_coeff = 3 l_coord = lm[:, 0] m_coord = lm[:, 1] - basis = np.hstack((np.ones((n_dir, 1), dtype=np.float64), - l_coord[:, None], m_coord[:, None])) + basis = np.hstack( + (np.ones((n_dir, 1), dtype=np.float64), l_coord[:, None], m_coord[:, None]) + ) # get coeffs alphas = 0.05 * np.random.randn(n_time, n_ant, n_coeff, n_corr) # normalise freqs - freq_norm = freq/freq.max() + freq_norm = freq / freq.max() # simulate phases - phases = np.zeros((n_time, n_ant, n_freq, n_dir, n_corr), - dtype=np.float64) + phases = np.zeros((n_time, n_ant, n_freq, n_dir, n_corr), dtype=np.float64) for t in range(n_time): for p in range(n_ant): for c in range(n_corr): # get screen at source locations screen = basis.dot(alphas[t, p, :, c]) # apply frequency scaling - phases[t, p, :, :, c] = screen[None, :]/freq_norm[:, None] - return np.exp(1.0j*phases), alphas + phases[t, p, :, :, c] = screen[None, :] / freq_norm[:, None] + return np.exp(1.0j * phases), alphas def simulate(args): # get full time column and compute row chunks ms = table(args.ms) - time = ms.getcol('TIME') - row_chunks, tbin_idx, tbin_counts = chunkify_rows( - time, args.utimes_per_chunk) + time = ms.getcol("TIME") + row_chunks, tbin_idx, tbin_counts = chunkify_rows(time, args.utimes_per_chunk) # convert to dask arrays tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk)) tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk)) n_time = tbin_idx.size - ant1 = ms.getcol('ANTENNA1') - ant2 = ms.getcol('ANTENNA2') + ant1 = ms.getcol("ANTENNA1") + ant2 = ms.getcol("ANTENNA2") n_ant = np.maximum(ant1.max(), ant2.max()) + 1 flag = ms.getcol("FLAG") n_row, n_freq, n_corr = flag.shape @@ -93,11 +105,12 @@ def simulate(args): ms.close() # get phase dir - radec0 = table(args.ms+'::FIELD').getcol('PHASE_DIR').squeeze() + radec0 = table(args.ms + "::FIELD").getcol("PHASE_DIR").squeeze() # get freqs - freq = table( - args.ms+'::SPECTRAL_WINDOW').getcol('CHAN_FREQ')[0].astype(np.float64) + freq = ( + table(args.ms + "::SPECTRAL_WINDOW").getcol("CHAN_FREQ")[0].astype(np.float64) + ) assert freq.size == n_freq # get source coordinates from lsm @@ -124,7 +137,7 @@ def simulate(args): ref_freqs = np.asarray(ref_freqs) spi = np.asarray(spi) for d in range(n_dir): - Stokes_I = stokes[d] * (freq/ref_freqs[d])**spi[d] + Stokes_I = stokes[d] * (freq / ref_freqs[d]) ** spi[d] if n_corr == 4: model[:, d, 0, 0] = Stokes_I model[:, d, 1, 1] = Stokes_I @@ -136,26 +149,27 @@ def simulate(args): # append antenna columns cols = [] - cols.append('ANTENNA1') - cols.append('ANTENNA2') - cols.append('UVW') + cols.append("ANTENNA1") + cols.append("ANTENNA2") + cols.append("UVW") # load in gains jones, alphas = make_screen(lm, freq, n_time, n_ant, jones_corr[0]) jones = jones.astype(np.complex128) jones_shape = jones.shape - jones_da = da.from_array(jones, chunks=(args.utimes_per_chunk,) - + jones_shape[1::]) + jones_da = da.from_array(jones, chunks=(args.utimes_per_chunk,) + jones_shape[1::]) freqs = da.from_array(freq, chunks=(n_freq)) - lm = da.from_array(np.tile(lm[None], (n_time, 1, 1)), chunks=( - args.utimes_per_chunk, n_dir, 2)) + lm = da.from_array( + np.tile(lm[None], (n_time, 1, 1)), chunks=(args.utimes_per_chunk, n_dir, 2) + ) # change model to dask array tmp_shape = (n_time,) for i in range(len(model.shape)): tmp_shape += (1,) - model = da.from_array(np.tile(model[None], tmp_shape), - chunks=(args.utimes_per_chunk,) + model.shape) + model = da.from_array( + np.tile(model[None], tmp_shape), chunks=(args.utimes_per_chunk,) + model.shape + ) # load data in in chunks and apply gains to each chunk xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0] @@ -164,12 +178,14 @@ def simulate(args): uvw = xds.UVW.data # apply gains - data = compute_and_corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, - jones_da, model, uvw, freqs, lm) + data = compute_and_corrupt_vis( + tbin_idx, tbin_counts, ant1, ant2, jones_da, model, uvw, freqs, lm + ) # Assign visibilities to args.out_col and write to ms - xds = xds.assign(**{args.out_col: (("row", "chan", "corr"), - data.reshape(n_row, n_freq, n_corr))}) + xds = xds.assign( + **{args.out_col: (("row", "chan", "corr"), data.reshape(n_row, n_freq, n_corr))} + ) # Create a write to the table write = xds_to_table(xds, args.ms, [args.out_col]) @@ -186,27 +202,27 @@ def calibrate(args, jones, alphas): # load data ms = table(args.ms) - time = ms.getcol('TIME') + time = ms.getcol("TIME") _, tbin_idx, tbin_counts = chunkify_rows(time, args.utimes_per_chunk) n_time = tbin_idx.size - ant1 = ms.getcol('ANTENNA1') - ant2 = ms.getcol('ANTENNA2') + ant1 = ms.getcol("ANTENNA1") + ant2 = ms.getcol("ANTENNA2") n_ant = np.maximum(ant1.max(), ant2.max()) + 1 - uvw = ms.getcol('UVW').astype(np.float64) + uvw = ms.getcol("UVW").astype(np.float64) data = ms.getcol(args.out_col) # this is where we put the data # we know it is pure Stokes I so we can solve using diagonals only data = data[:, :, (0, 3)].astype(np.complex128) n_row, n_freq, n_corr = data.shape - flag = ms.getcol('FLAG') + flag = ms.getcol("FLAG") flag = flag[:, :, (0, 3)] # get phase dir - radec0 = table( - args.ms+'::FIELD').getcol('PHASE_DIR').squeeze().astype(np.float64) + radec0 = table(args.ms + "::FIELD").getcol("PHASE_DIR").squeeze().astype(np.float64) # get freqs - freq = table( - args.ms+'::SPECTRAL_WINDOW').getcol('CHAN_FREQ')[0].astype(np.float64) + freq = ( + table(args.ms + "::SPECTRAL_WINDOW").getcol("CHAN_FREQ")[0].astype(np.float64) + ) assert freq.size == n_freq # now get the model @@ -234,23 +250,33 @@ def calibrate(args, jones, alphas): ref_freqs = np.asarray(ref_freqs) spi = np.asarray(spi) for d in range(n_dir): - Stokes_I = stokes[d] * (freq/ref_freqs[d])**spi[d] + Stokes_I = stokes[d] * (freq / ref_freqs[d]) ** spi[d] model[:, :, d, 0:1] = im_to_vis( - Stokes_I[None, :, None], uvw, lm[d:d+1], freq) + Stokes_I[None, :, None], uvw, lm[d : d + 1], freq + ) model[:, :, d, 1] = model[:, :, d, 0] # set weights to unity weight = np.ones_like(data, dtype=np.float64) # initialise gains - jones0 = np.ones((n_time, n_ant, n_freq, n_dir, n_corr), - dtype=np.complex128) + jones0 = np.ones((n_time, n_ant, n_freq, n_dir, n_corr), dtype=np.complex128) # calibrate ti = timeit() jones_hat, jhj, jhr, k = gauss_newton( - tbin_idx, tbin_counts, ant1, ant2, jones0, data, flag, model, - weight, tol=1e-5, maxiter=100) + tbin_idx, + tbin_counts, + ant1, + ant2, + jones0, + data, + flag, + model, + weight, + tol=1e-5, + maxiter=100, + ) print("%i iterations took %fs" % (k, timeit() - ti)) # verify result @@ -270,9 +296,11 @@ def calibrate(args, jones, alphas): if args.ncpu: from multiprocessing.pool import ThreadPool import dask + dask.config.set(pool=ThreadPool(args.ncpu)) else: import multiprocessing + args.ncpu = multiprocessing.cpu_count() print("Using %i threads" % args.ncpu) diff --git a/africanus/calibration/utils/residual_vis.py b/africanus/calibration/utils/residual_vis.py index 6f2d0c961..907e1802f 100644 --- a/africanus/calibration/utils/residual_vis.py +++ b/africanus/calibration/utils/residual_vis.py @@ -10,83 +10,87 @@ def subtract_model_factory(mode): if mode == DIAG_DIAG: + def subtract_model(a1j, blj, a2j, model, out): n_dir = np.shape(model)[0] out[...] = blj for s in range(n_dir): - out -= a1j[s]*model[s]*np.conj(a2j[s]) + out -= a1j[s] * model[s] * np.conj(a2j[s]) elif mode == DIAG: + def subtract_model(a1j, blj, a2j, model, out): n_dir = np.shape(model)[0] out[...] = blj for s in range(n_dir): - out[0, 0] -= a1j[s, 0]*model[s, 0, 0] * np.conj(a2j[s, 0]) - out[0, 1] -= a1j[s, 0]*model[s, 0, 1] * np.conj(a2j[s, 1]) - out[1, 0] -= a1j[s, 1]*model[s, 1, 0] * np.conj(a2j[s, 0]) - out[1, 1] -= a1j[s, 1]*model[s, 1, 1] * np.conj(a2j[s, 1]) + out[0, 0] -= a1j[s, 0] * model[s, 0, 0] * np.conj(a2j[s, 0]) + out[0, 1] -= a1j[s, 0] * model[s, 0, 1] * np.conj(a2j[s, 1]) + out[1, 0] -= a1j[s, 1] * model[s, 1, 0] * np.conj(a2j[s, 0]) + out[1, 1] -= a1j[s, 1] * model[s, 1, 1] * np.conj(a2j[s, 1]) elif mode == FULL: + def subtract_model(a1j, blj, a2j, model, out): n_dir = np.shape(model)[0] out[...] = blj for s in range(n_dir): # precompute resuable terms - t1 = a1j[s, 0, 0]*model[s, 0, 0] - t2 = a1j[s, 0, 1]*model[s, 1, 0] - t3 = a1j[s, 0, 0]*model[s, 0, 1] - t4 = a1j[s, 0, 1]*model[s, 1, 1] + t1 = a1j[s, 0, 0] * model[s, 0, 0] + t2 = a1j[s, 0, 1] * model[s, 1, 0] + t3 = a1j[s, 0, 0] * model[s, 0, 1] + t4 = a1j[s, 0, 1] * model[s, 1, 1] tmp = np.conj(a2j[s].T) # overwrite with result - out[0, 0] -= t1*tmp[0, 0] +\ - t2*tmp[0, 0] +\ - t3*tmp[1, 0] +\ - t4*tmp[1, 0] - out[0, 1] -= t1*tmp[0, 1] +\ - t2*tmp[0, 1] +\ - t3*tmp[1, 1] +\ - t4*tmp[1, 1] - t1 = a1j[s, 1, 0]*model[s, 0, 0] - t2 = a1j[s, 1, 1]*model[s, 1, 0] - t3 = a1j[s, 1, 0]*model[s, 0, 1] - t4 = a1j[s, 1, 1]*model[s, 1, 1] - out[1, 0] -= t1*tmp[0, 0] +\ - t2*tmp[0, 0] +\ - t3*tmp[1, 0] +\ - t4*tmp[1, 0] - out[1, 1] -= t1*tmp[0, 1] +\ - t2*tmp[0, 1] +\ - t3*tmp[1, 1] +\ - t4*tmp[1, 1] - return njit(nogil=True, inline='always')(subtract_model) + out[0, 0] -= ( + t1 * tmp[0, 0] + t2 * tmp[0, 0] + t3 * tmp[1, 0] + t4 * tmp[1, 0] + ) + out[0, 1] -= ( + t1 * tmp[0, 1] + t2 * tmp[0, 1] + t3 * tmp[1, 1] + t4 * tmp[1, 1] + ) + t1 = a1j[s, 1, 0] * model[s, 0, 0] + t2 = a1j[s, 1, 1] * model[s, 1, 0] + t3 = a1j[s, 1, 0] * model[s, 0, 1] + t4 = a1j[s, 1, 1] * model[s, 1, 1] + out[1, 0] -= ( + t1 * tmp[0, 0] + t2 * tmp[0, 0] + t3 * tmp[1, 0] + t4 * tmp[1, 0] + ) + out[1, 1] -= ( + t1 * tmp[0, 1] + t2 * tmp[0, 1] + t3 * tmp[1, 1] + t4 * tmp[1, 1] + ) + + return njit(nogil=True, inline="always")(subtract_model) @njit(**JIT_OPTIONS) -def residual_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model): - return residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model) +def residual_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model +): + return residual_vis_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model + ) -def residual_vis_impl(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model): +def residual_vis_impl( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model +): return NotImplementedError @overload(residual_vis_impl, jit_options=JIT_OPTIONS) -def nb_residual_vis(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model): - +def nb_residual_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model +): mode = check_type(jones, vis) subtract_model = subtract_model_factory(mode) @wraps(residual_vis) - def _residual_vis_fn(time_bin_indices, time_bin_counts, antenna1, - antenna2, jones, vis, flag, model): + def _residual_vis_fn( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model + ): if vis.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") if jones.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") if model.shape[-1] > 2: - raise ValueError('ncorr cant be larger than 2') + raise ValueError("ncorr cant be larger than 2") # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() @@ -95,8 +99,9 @@ def _residual_vis_fn(time_bin_indices, time_bin_counts, antenna1, n_chan = vis_shape[1] residual = np.zeros(vis_shape, dtype=vis.dtype) for t in range(n_tim): - for row in range(time_bin_indices[t], - time_bin_indices[t] + time_bin_counts[t]): + for row in range( + time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t] + ): p = int(antenna1[row]) q = int(antenna2[row]) gp = jones[t, p] @@ -104,14 +109,19 @@ def _residual_vis_fn(time_bin_indices, time_bin_counts, antenna1, for nu in range(n_chan): if not np.any(flag[row, nu]): subtract_model( - gp[nu], vis[row, nu], gq[nu], - model[row, nu], residual[row, nu]) + gp[nu], + vis[row, nu], + gq[nu], + model[row, nu], + residual[row, nu], + ) return residual return _residual_vis_fn -RESIDUAL_VIS_DOCS = DocstringTemplate(""" +RESIDUAL_VIS_DOCS = DocstringTemplate( + """ Computes residual visibilities given model visibilities and gains solutions. @@ -146,11 +156,13 @@ def _residual_vis_fn(time_bin_indices, time_bin_counts, antenna1, Residual visibilities of shape :code:`(time, ant, chan, dir, corr)` or :code:`(time, ant, chan, dir, corr, corr)`. -""") +""" +) try: residual_vis.__doc__ = RESIDUAL_VIS_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/calibration/utils/tests/conftest.py b/africanus/calibration/utils/tests/conftest.py index ec239e665..db9eab4a2 100644 --- a/africanus/calibration/utils/tests/conftest.py +++ b/africanus/calibration/utils/tests/conftest.py @@ -8,29 +8,38 @@ def lm_factory(n_dir, rs): - ls = 0.1*rs.randn(n_dir) - ms = 0.1*rs.randn(n_dir) + ls = 0.1 * rs.randn(n_dir) + ms = 0.1 * rs.randn(n_dir) lm = np.vstack((ls, ms)).T return lm def flux_factory(n_dir, n_chan, corr_shape, alpha, freq, freq0, rs): - w = freq/freq0 + w = freq / freq0 flux = np.zeros((n_dir, n_chan) + corr_shape, dtype=np.float64) for d in range(n_dir): tmp_flux = np.abs(rs.normal(size=corr_shape)) for v in range(n_chan): - flux[d, v] = tmp_flux * w[v]**alpha + flux[d, v] = tmp_flux * w[v] ** alpha return flux @pytest.fixture def data_factory(): - def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, - n_dir, corr_shape, jones_shape, phase_only_gains=False): + def impl( + sigma_n, + sigma_f, + n_time, + n_chan, + n_ant, + n_dir, + corr_shape, + jones_shape, + phase_only_gains=False, + ): rs = np.random.RandomState(42) - n_bl = n_ant*(n_ant-1)//2 - n_row = n_bl*n_time + n_bl = n_ant * (n_ant - 1) // 2 + n_row = n_bl * n_time # make aux data antenna1 = np.zeros(n_row, dtype=np.int16) antenna2 = np.zeros(n_row, dtype=np.int16) @@ -42,19 +51,18 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, row = 0 for p in range(n_ant): for q in range(p): - time[i*n_bl + row] = time_values[i] - antenna1[i*n_bl + row] = p - antenna2[i*n_bl + row] = q - uvw[i*n_bl + row] = np.random.randn(3) + time[i * n_bl + row] = time_values[i] + antenna1[i * n_bl + row] = p + antenna2[i * n_bl + row] = q + uvw[i * n_bl + row] = np.random.randn(3) row += 1 assert time.size == n_row # simulate visibilities - model_data = np.zeros((n_row, n_chan, n_dir) + - corr_shape, dtype=np.complex128) + model_data = np.zeros((n_row, n_chan, n_dir) + corr_shape, dtype=np.complex128) # make up some sources lm = lm_factory(n_dir, rs) alpha = -0.7 - freq0 = freq[n_chan//2] + freq0 = freq[n_chan // 2] flux = flux_factory(n_dir, n_chan, corr_shape, alpha, freq, freq0, rs) # simulate model data for dir in range(n_dir): @@ -65,28 +73,31 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, model_data[:, :, dir] = tmp.reshape((n_row, n_chan) + corr_shape) assert not np.isnan(model_data).any() # simulate gains (just randomly scattered around 1 for now) - jones = np.ones((n_time, n_ant, n_chan, n_dir) + - jones_shape, dtype=np.complex128) + jones = np.ones( + (n_time, n_ant, n_chan, n_dir) + jones_shape, dtype=np.complex128 + ) if sigma_f: if phase_only_gains: - jones = np.exp(1.0j*rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape)) + jones = np.exp( + 1.0j * rs.normal(loc=0.0, scale=sigma_f, size=jones.shape) + ) else: - jones += (rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape) + - 1.0j*rs.normal(loc=0.0, scale=sigma_f, - size=jones.shape)) + jones += rs.normal( + loc=0.0, scale=sigma_f, size=jones.shape + ) + 1.0j * rs.normal(loc=0.0, scale=sigma_f, size=jones.shape) assert (np.abs(jones) > 1e-5).all() assert not np.isnan(jones).any() # get vis _, time_bin_indices, _, time_bin_counts = unique_time(time) - vis = corrupt_vis(time_bin_indices, time_bin_counts, - antenna1, antenna2, jones, model_data) + vis = corrupt_vis( + time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model_data + ) assert not np.isnan(vis).any() # add noise if sigma_n: - vis += (rs.normal(loc=0.0, scale=sigma_n, size=vis.shape) + - 1.0j*rs.normal(loc=0.0, scale=sigma_n, size=vis.shape)) + vis += rs.normal(loc=0.0, scale=sigma_n, size=vis.shape) + 1.0j * rs.normal( + loc=0.0, scale=sigma_n, size=vis.shape + ) weights = np.ones(vis.shape, dtype=np.float64) if sigma_n: weights /= sigma_n**2 @@ -99,6 +110,7 @@ def impl(sigma_n, sigma_f, n_time, n_chan, n_ant, data_dict["ANTENNA1"] = antenna1 data_dict["ANTENNA2"] = antenna2 data_dict["FLAG"] = flag - data_dict['JONES'] = jones + data_dict["JONES"] = jones return data_dict + return impl diff --git a/africanus/calibration/utils/tests/test_utils.py b/africanus/calibration/utils/tests/test_utils.py index 2484bb8da..1ff4e7741 100755 --- a/africanus/calibration/utils/tests/test_utils.py +++ b/africanus/calibration/utils/tests/test_utils.py @@ -8,12 +8,14 @@ from africanus.rime.predict import predict_vis corr_shape_parametrization = pytest.mark.parametrize( - 'corr_shape, jones_shape', - [((1,), (1,)), # DIAG_DIAG - ((2,), (2,)), # DIAG_DIAG - ((2, 2), (2,)), # DIAG - ((2, 2), (2, 2)), # FULL - ]) + "corr_shape, jones_shape", + [ + ((1,), (1,)), # DIAG_DIAG + ((2,), (2,)), # DIAG_DIAG + ((2, 2), (2,)), # DIAG + ((2, 2), (2, 2)), # FULL + ], +) @corr_shape_parametrization @@ -30,24 +32,24 @@ def test_corrupt_vis(data_factory, corr_shape, jones_shape): n_ant = 7 sigma_n = 0.0 sigma_f = 0.05 - data_dict = data_factory(sigma_n, sigma_f, n_time, n_chan, - n_ant, n_dir, corr_shape, jones_shape) + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) # make_data uses corrupt_vis to produce the data so we only need to test # that predict vis gives the same thing on the reshaped arrays - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - vis = data_dict['DATA'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - time = data_dict['TIME'] + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + vis = data_dict["DATA"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + time = data_dict["TIME"] # predict_vis expects (source, time, ant, chan, corr1, corr2) so # we need to transpose the axes while preserving corr_shape and jones_shape if jones_shape != corr_shape: # This only happens in DIAG mode and we need to broadcast jones_shape # to match corr_shape - tmp = np.zeros((n_time, n_ant, n_chan, n_dir) + - corr_shape, dtype=np.complex128) + tmp = np.zeros((n_time, n_ant, n_chan, n_dir) + corr_shape, dtype=np.complex128) tmp[:, :, :, :, 0, 0] = jones[:, :, :, :, 0] tmp[:, :, :, :, 1, 1] = jones[:, :, :, :, 1] jones = tmp @@ -63,10 +65,9 @@ def test_corrupt_vis(data_factory, corr_shape, jones_shape): # get vis time_index = np.unique(time, return_inverse=True)[1] - test_vis = predict_vis(time_index, ant1, ant2, - source_coh=model, - dde1_jones=jones, - dde2_jones=jones) + test_vis = predict_vis( + time_index, ant1, ant2, source_coh=model, dde1_jones=jones, dde2_jones=jones + ) assert_array_almost_equal(test_vis, vis, decimal=10) @@ -79,6 +80,7 @@ def test_residual_vis(data_factory, corr_shape, jones_shape): the output to the unsubtracted direction. """ from africanus.calibration.utils import residual_vis, corrupt_vis + # simulate noise free data with random DDE's n_dir = 3 n_time = 32 @@ -86,31 +88,42 @@ def test_residual_vis(data_factory, corr_shape, jones_shape): n_ant = 7 sigma_n = 0.0 sigma_f = 0.05 - data_dict = data_factory(sigma_n, sigma_f, n_time, n_chan, - n_ant, n_dir, corr_shape, jones_shape) - time = data_dict['TIME'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + time = data_dict["TIME"] _, time_bin_indices, _, time_bin_counts = unique_time(time) - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - vis = data_dict['DATA'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - flag = data_dict['FLAG'] + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + vis = data_dict["DATA"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + flag = data_dict["FLAG"] # split the model and jones terms model_unsubtracted = model[:, :, 0:1] model_subtract = model[:, :, 1::] jones_unsubtracted = jones[:, :, :, 0:1] jones_subtract = jones[:, :, :, 1::] # subtract all but one direction - residual = residual_vis(time_bin_indices, time_bin_counts, - ant1, ant2, jones_subtract, vis, - flag, model_subtract) + residual = residual_vis( + time_bin_indices, + time_bin_counts, + ant1, + ant2, + jones_subtract, + vis, + flag, + model_subtract, + ) # apply gains to the unsubtracted direction - vis_unsubtracted = corrupt_vis(time_bin_indices, - time_bin_counts, - ant1, ant2, - jones_unsubtracted, - model_unsubtracted) + vis_unsubtracted = corrupt_vis( + time_bin_indices, + time_bin_counts, + ant1, + ant2, + jones_unsubtracted, + model_unsubtracted, + ) # residual should now be equal to unsubtracted vis assert_array_almost_equal(residual, vis_unsubtracted, decimal=10) @@ -122,6 +135,7 @@ def test_correct_vis(data_factory, corr_shape, jones_shape): with random DIE gains """ from africanus.calibration.utils import correct_vis + # simulate noise free data with only DIE's n_dir = 1 n_time = 32 @@ -129,21 +143,21 @@ def test_correct_vis(data_factory, corr_shape, jones_shape): n_ant = 7 sigma_n = 0.0 sigma_f = 0.05 - data_dict = data_factory(sigma_n, sigma_f, n_time, - n_chan, n_ant, n_dir, corr_shape, - jones_shape) - time = data_dict['TIME'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + time = data_dict["TIME"] _, time_bin_indices, _, time_bin_counts = unique_time(time) - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - vis = data_dict['DATA'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - flag = data_dict['FLAG'] + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + vis = data_dict["DATA"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + flag = data_dict["FLAG"] # correct vis corrected_vis = correct_vis( - time_bin_indices, time_bin_counts, - ant1, ant2, jones, vis, flag) + time_bin_indices, time_bin_counts, ant1, ant2, jones, vis, flag + ) # squeeze out dir axis to get expected model data model = model.reshape(vis.shape) assert_array_almost_equal(corrected_vis, model, decimal=10) @@ -159,36 +173,38 @@ def test_corrupt_vis_dask(data_factory, corr_shape, jones_shape): n_ant = 4 sigma_n = 0.0 sigma_f = 0.05 - data_dict = data_factory(sigma_n, sigma_f, n_time, - n_chan, n_ant, n_dir, corr_shape, - jones_shape) - vis = data_dict['DATA'] # what we need to compare to - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - model = data_dict['MODEL_DATA'] - jones = data_dict['JONES'] - time = data_dict['TIME'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + vis = data_dict["DATA"] # what we need to compare to + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + model = data_dict["MODEL_DATA"] + jones = data_dict["JONES"] + time = data_dict["TIME"] # get chunking scheme ncpu = 8 - utimes_per_chunk = n_time//ncpu - row_chunks, time_bin_idx, time_bin_counts = chunkify_rows( - time, utimes_per_chunk) + utimes_per_chunk = n_time // ncpu + row_chunks, time_bin_idx, time_bin_counts = chunkify_rows(time, utimes_per_chunk) # set up dask arrays da_time_bin_idx = da.from_array(time_bin_idx, chunks=(utimes_per_chunk)) - da_time_bin_counts = da.from_array( - time_bin_counts, chunks=(utimes_per_chunk)) + da_time_bin_counts = da.from_array(time_bin_counts, chunks=(utimes_per_chunk)) da_ant1 = da.from_array(ant1, chunks=row_chunks) da_ant2 = da.from_array(ant2, chunks=row_chunks) - da_model = da.from_array(model, chunks=( - row_chunks, (n_chan,), (n_dir,)) + (corr_shape)) - da_jones = da.from_array(jones, chunks=( - utimes_per_chunk, n_ant, n_chan, n_dir)+jones_shape) + da_model = da.from_array( + model, chunks=(row_chunks, (n_chan,), (n_dir,)) + (corr_shape) + ) + da_jones = da.from_array( + jones, chunks=(utimes_per_chunk, n_ant, n_chan, n_dir) + jones_shape + ) from africanus.calibration.utils.dask import corrupt_vis - da_vis = corrupt_vis(da_time_bin_idx, da_time_bin_counts, - da_ant1, da_ant2, da_jones, da_model) + + da_vis = corrupt_vis( + da_time_bin_idx, da_time_bin_counts, da_ant1, da_ant2, da_jones, da_model + ) vis2 = da_vis.compute() assert_array_almost_equal(vis, vis2, decimal=10) @@ -203,38 +219,38 @@ def test_correct_vis_dask(data_factory, corr_shape, jones_shape): n_ant = 4 sigma_n = 0.0 sigma_f = 0.05 - data_dict = data_factory(sigma_n, sigma_f, n_time, - n_chan, n_ant, n_dir, corr_shape, - jones_shape) - vis = data_dict['DATA'] - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - model = data_dict['MODEL_DATA'] # what we need to compare to - jones = data_dict['JONES'] - time = data_dict['TIME'] - flag = data_dict['FLAG'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + vis = data_dict["DATA"] + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + model = data_dict["MODEL_DATA"] # what we need to compare to + jones = data_dict["JONES"] + time = data_dict["TIME"] + flag = data_dict["FLAG"] # get chunking scheme ncpu = 8 - utimes_per_chunk = n_time//ncpu - row_chunks, time_bin_idx, time_bin_counts = chunkify_rows( - time, utimes_per_chunk) + utimes_per_chunk = n_time // ncpu + row_chunks, time_bin_idx, time_bin_counts = chunkify_rows(time, utimes_per_chunk) # set up dask arrays da_time_bin_idx = da.from_array(time_bin_idx, chunks=(utimes_per_chunk)) - da_time_bin_counts = da.from_array( - time_bin_counts, chunks=(utimes_per_chunk)) + da_time_bin_counts = da.from_array(time_bin_counts, chunks=(utimes_per_chunk)) da_ant1 = da.from_array(ant1, chunks=row_chunks) da_ant2 = da.from_array(ant2, chunks=row_chunks) da_vis = da.from_array(vis, chunks=(row_chunks, (n_chan,)) + (corr_shape)) - da_jones = da.from_array(jones, chunks=( - utimes_per_chunk, n_ant, n_chan, n_dir)+jones_shape) - da_flag = da.from_array(flag, chunks=( - row_chunks, (n_chan,)) + (corr_shape)) + da_jones = da.from_array( + jones, chunks=(utimes_per_chunk, n_ant, n_chan, n_dir) + jones_shape + ) + da_flag = da.from_array(flag, chunks=(row_chunks, (n_chan,)) + (corr_shape)) from africanus.calibration.utils.dask import correct_vis - da_model = correct_vis(da_time_bin_idx, da_time_bin_counts, da_ant1, - da_ant2, da_jones, da_vis, da_flag) + + da_model = correct_vis( + da_time_bin_idx, da_time_bin_counts, da_ant1, da_ant2, da_jones, da_vis, da_flag + ) model2 = da_model.compute() assert_array_almost_equal(model.reshape(model2.shape), model2, decimal=10) @@ -249,44 +265,53 @@ def test_residual_vis_dask(data_factory, corr_shape, jones_shape): n_ant = 4 sigma_n = 0.0 sigma_f = 0.05 - data_dict = data_factory(sigma_n, sigma_f, n_time, - n_chan, n_ant, n_dir, corr_shape, - jones_shape) - vis = data_dict['DATA'] - ant1 = data_dict['ANTENNA1'] - ant2 = data_dict['ANTENNA2'] - model = data_dict['MODEL_DATA'] # what we need to compare to - jones = data_dict['JONES'] - time = data_dict['TIME'] - flag = data_dict['FLAG'] + data_dict = data_factory( + sigma_n, sigma_f, n_time, n_chan, n_ant, n_dir, corr_shape, jones_shape + ) + vis = data_dict["DATA"] + ant1 = data_dict["ANTENNA1"] + ant2 = data_dict["ANTENNA2"] + model = data_dict["MODEL_DATA"] # what we need to compare to + jones = data_dict["JONES"] + time = data_dict["TIME"] + flag = data_dict["FLAG"] # get chunking scheme ncpu = 8 - utimes_per_chunk = n_time//ncpu - row_chunks, time_bin_idx, time_bin_counts = chunkify_rows( - time, utimes_per_chunk) + utimes_per_chunk = n_time // ncpu + row_chunks, time_bin_idx, time_bin_counts = chunkify_rows(time, utimes_per_chunk) # set up dask arrays da_time_bin_idx = da.from_array(time_bin_idx, chunks=(utimes_per_chunk)) - da_time_bin_counts = da.from_array( - time_bin_counts, chunks=(utimes_per_chunk)) + da_time_bin_counts = da.from_array(time_bin_counts, chunks=(utimes_per_chunk)) da_ant1 = da.from_array(ant1, chunks=row_chunks) da_ant2 = da.from_array(ant2, chunks=row_chunks) da_vis = da.from_array(vis, chunks=(row_chunks, (n_chan,)) + (corr_shape)) - da_model = da.from_array(model, chunks=( - row_chunks, (n_chan,), (n_dir,)) + (corr_shape)) - da_jones = da.from_array(jones, chunks=( - utimes_per_chunk, n_ant, n_chan, n_dir)+jones_shape) - da_flag = da.from_array(flag, chunks=( - row_chunks, (n_chan,)) + (corr_shape)) + da_model = da.from_array( + model, chunks=(row_chunks, (n_chan,), (n_dir,)) + (corr_shape) + ) + da_jones = da.from_array( + jones, chunks=(utimes_per_chunk, n_ant, n_chan, n_dir) + jones_shape + ) + da_flag = da.from_array(flag, chunks=(row_chunks, (n_chan,)) + (corr_shape)) from africanus.calibration.utils import residual_vis as residual_vis_np - residual = residual_vis_np(time_bin_idx, time_bin_counts, ant1, ant2, - jones, vis, flag, model) + + residual = residual_vis_np( + time_bin_idx, time_bin_counts, ant1, ant2, jones, vis, flag, model + ) from africanus.calibration.utils.dask import residual_vis - da_residual = residual_vis(da_time_bin_idx, da_time_bin_counts, - da_ant1, da_ant2, da_jones, da_vis, - da_flag, da_model) + + da_residual = residual_vis( + da_time_bin_idx, + da_time_bin_counts, + da_ant1, + da_ant2, + da_jones, + da_vis, + da_flag, + da_model, + ) residual2 = da_residual.compute() assert_array_almost_equal(residual, residual2, decimal=10) diff --git a/africanus/calibration/utils/utils.py b/africanus/calibration/utils/utils.py index 38c296a53..fac7b971f 100644 --- a/africanus/calibration/utils/utils.py +++ b/africanus/calibration/utils/utils.py @@ -8,10 +8,10 @@ FULL = 2 -def check_type(jones, vis, vis_type='vis'): - if vis_type == 'vis': +def check_type(jones, vis, vis_type="vis"): + if vis_type == "vis": vis_ndim = (3, 4) - elif vis_type == 'model': + elif vis_type == "model": vis_ndim = (4, 5) else: raise ValueError("Unknown vis_type") @@ -21,9 +21,12 @@ def check_type(jones, vis, vis_type='vis'): if vis_axes_count == vis_ndim[0]: mode = DIAG_DIAG if jones_axes_count != 5: - raise RuntimeError("Jones axes not compatible with \ + raise RuntimeError( + "Jones axes not compatible with \ visibility axes. Expected length \ - 5 but got length %d" % jones_axes_count) + 5 but got length %d" + % jones_axes_count + ) elif vis_axes_count == vis_ndim[1]: if jones_axes_count == 5: @@ -44,15 +47,18 @@ def chunkify_rows(time, utimes_per_chunk): n_time = len(utimes) if utimes_per_chunk <= 0: utimes_per_chunk = n_time - row_chunks = [np.sum(time_bin_counts[i:i+utimes_per_chunk]) - for i in range(0, n_time, utimes_per_chunk)] + row_chunks = [ + np.sum(time_bin_counts[i : i + utimes_per_chunk]) + for i in range(0, n_time, utimes_per_chunk) + ] time_bin_indices = np.zeros(n_time, dtype=np.int32) time_bin_indices[1::] = np.cumsum(time_bin_counts)[0:-1] time_bin_counts = time_bin_counts.astype(np.int32) return tuple(row_chunks), time_bin_indices, time_bin_counts -CHECK_TYPE_DOCS = DocstringTemplate(""" +CHECK_TYPE_DOCS = DocstringTemplate( + """ Determines which calibration scenario to apply i.e. DIAG_DIAG, DIAG or COMPLEX2x2. @@ -74,15 +80,16 @@ def chunkify_rows(time, utimes_per_chunk): An integer representing the calibration mode. Options are 0 -> DIAG_DIAG, 1 -> DIAG, 2 -> FULL -""") +""" +) try: - check_type.__doc__ = CHECK_TYPE_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + check_type.__doc__ = CHECK_TYPE_DOCS.substitute(array_type=":class:`numpy.ndarray`") except AttributeError: pass -CHUNKIFY_ROWS_DOCS = DocstringTemplate(""" +CHUNKIFY_ROWS_DOCS = DocstringTemplate( + """ Divides rows into chunks containing integer numbers of times keeping track of the indices at which the unique time changes and the number @@ -106,10 +113,12 @@ def chunkify_rows(time, utimes_per_chunk): changes times_bin_counts : $(array_type) Array containing the number of times per unique time. -""") +""" +) try: chunkify_rows.__doc__ = CHUNKIFY_ROWS_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/conftest.py b/africanus/conftest.py index 6941c23fd..dd0de4a84 100644 --- a/africanus/conftest.py +++ b/africanus/conftest.py @@ -9,7 +9,7 @@ @pytest.fixture(scope="function", autouse=bool(numba.config.NRT_STATS)) def check_allocations(): - """ Check allocations match frees """ + """Check allocations match frees""" try: yield start = rtsys.get_allocation_stats() diff --git a/africanus/coordinates/coordinates.py b/africanus/coordinates/coordinates.py index 475691d36..4c0b27e01 100644 --- a/africanus/coordinates/coordinates.py +++ b/africanus/coordinates/coordinates.py @@ -4,9 +4,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (is_numba_type_none, - jit, JIT_OPTIONS, - njit, overload) +from africanus.util.numba import is_numba_type_none, jit, JIT_OPTIONS, njit, overload from africanus.util.requirements import requires_optional try: @@ -65,9 +63,8 @@ def _radec_to_lmn_impl(radec, phase_centre=None): sin_dec = np.sin(radec[s, 1]) cos_dec = np.cos(radec[s, 1]) - lmn[s, 0] = l = cos_dec*sin_ra_delta # noqa - lmn[s, 1] = m = (sin_dec*cos_pc_dec - - cos_dec*sin_pc_dec*cos_ra_delta) + lmn[s, 0] = l = cos_dec * sin_ra_delta # noqa + lmn[s, 1] = m = sin_dec * cos_pc_dec - cos_dec * sin_pc_dec * cos_ra_delta lmn[s, 2] = np.sqrt(1.0 - l**2 - m**2) return lmn @@ -113,8 +110,8 @@ def _radec_to_lm_impl(radec, phase_centre=None): sin_dec = np.sin(radec[s, 1]) cos_dec = np.cos(radec[s, 1]) - lm[s, 0] = cos_dec*sin_ra_delta - lm[s, 1] = sin_dec*cos_pc_dec - cos_dec*sin_pc_dec*cos_ra_delta + lm[s, 0] = cos_dec * sin_ra_delta + lm[s, 1] = sin_dec * cos_pc_dec - cos_dec * sin_pc_dec * cos_ra_delta return lm @@ -152,8 +149,8 @@ def _lmn_to_radec_impl(lmn, phase_centre=None): for s in range(radec.shape[0]): l, m, n = lmn[s] - radec[s, 1] = np.arcsin(m*cos_pc_dec + n*sin_pc_dec) - radec[s, 0] = pc_ra + np.arctan(l / (n*cos_pc_dec - m*sin_pc_dec)) + radec[s, 1] = np.arcsin(m * cos_pc_dec + n * sin_pc_dec) + radec[s, 0] = pc_ra + np.arctan(l / (n * cos_pc_dec - m * sin_pc_dec)) return radec @@ -192,8 +189,8 @@ def _lm_to_radec_impl(lm, phase_centre=None): l, m = lm[s] n = np.sqrt(1.0 - l**2 - m**2) - radec[s, 1] = np.arcsin(m*cos_pc_dec + n*sin_pc_dec) - radec[s, 0] = pc_ra + np.arctan(l / (n*cos_pc_dec - m*sin_pc_dec)) + radec[s, 1] = np.arcsin(m * cos_pc_dec + n * sin_pc_dec) + radec[s, 0] = pc_ra + np.arctan(l / (n * cos_pc_dec - m * sin_pc_dec)) return radec @@ -230,7 +227,8 @@ def astropy_radec_to_lmn(radec, phase_centre): return result -RADEC_TO_LMN_DOCS = DocstringTemplate(r""" +RADEC_TO_LMN_DOCS = DocstringTemplate( + r""" Converts Right-Ascension/Declination coordinates in radians to a Direction Cosine lm coordinates, relative to the Phase Centre. @@ -262,10 +260,12 @@ def astropy_radec_to_lmn(radec, phase_centre): ------- $(array_type) lm Direction Cosines of shape :code:`(coord, $(lm_components))` -""") +""" +) -LMN_TO_RADEC_DOCS = DocstringTemplate(r""" +LMN_TO_RADEC_DOCS = DocstringTemplate( + r""" Convert Direction Cosine lm coordinates to Right Ascension/Declination coordinates in radians, relative to the Phase Centre. @@ -298,21 +298,22 @@ def astropy_radec_to_lmn(radec, phase_centre): where Right-Ascension and Declination are in the last 2 components, respectively. -""") +""" +) try: radec_to_lmn.__doc__ = RADEC_TO_LMN_DOCS.substitute( - lm_components="3", - array_type=":class:`numpy.ndarray`") + lm_components="3", array_type=":class:`numpy.ndarray`" + ) radec_to_lm.__doc__ = RADEC_TO_LMN_DOCS.substitute( - lm_components="2", - array_type=":class:`numpy.ndarray`") + lm_components="2", array_type=":class:`numpy.ndarray`" + ) lmn_to_radec.__doc__ = LMN_TO_RADEC_DOCS.substitute( - lm_name="lmn", lm_components="3", - array_type=":class:`numpy.ndarray`") + lm_name="lmn", lm_components="3", array_type=":class:`numpy.ndarray`" + ) lm_to_radec.__doc__ = LMN_TO_RADEC_DOCS.substitute( - lm_name="lm", lm_components="2", - array_type=":class:`numpy.ndarray`") + lm_name="lm", lm_components="2", array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/coordinates/dask.py b/africanus/coordinates/dask.py index 841d8f60a..3124af875 100644 --- a/africanus/coordinates/dask.py +++ b/africanus/coordinates/dask.py @@ -10,87 +10,109 @@ from africanus.util.requirements import requires_optional -from africanus.coordinates.coordinates import (radec_to_lmn as np_radec_to_lmn, - radec_to_lm as np_radec_to_lm, - lmn_to_radec as np_lmn_to_radec, - lm_to_radec as np_lm_to_radec, - RADEC_TO_LMN_DOCS, - LMN_TO_RADEC_DOCS) +from africanus.coordinates.coordinates import ( + radec_to_lmn as np_radec_to_lmn, + radec_to_lm as np_radec_to_lm, + lmn_to_radec as np_lmn_to_radec, + lm_to_radec as np_lm_to_radec, + RADEC_TO_LMN_DOCS, + LMN_TO_RADEC_DOCS, +) def _radec_to_lmn(radec, phase_centre): return np_radec_to_lmn(radec[0], phase_centre[0] if phase_centre else None) -@requires_optional('dask.array', dask_import_error) +@requires_optional("dask.array", dask_import_error) def radec_to_lmn(radec, phase_centre=None): phase_centre_dims = ("radec",) if phase_centre is not None else None - return da.core.blockwise(_radec_to_lmn, ("source", "lmn"), - radec, ("source", "radec"), - phase_centre, phase_centre_dims, - new_axes={"lmn": 3}, - dtype=radec.dtype) + return da.core.blockwise( + _radec_to_lmn, + ("source", "lmn"), + radec, + ("source", "radec"), + phase_centre, + phase_centre_dims, + new_axes={"lmn": 3}, + dtype=radec.dtype, + ) def _lmn_to_radec(lmn, phase_centre): return np_lmn_to_radec(lmn[0], phase_centre) -@requires_optional('dask.array', dask_import_error) +@requires_optional("dask.array", dask_import_error) def lmn_to_radec(lmn, phase_centre=None): phase_centre_dims = ("radec",) if phase_centre is not None else None - return da.core.blockwise(_lmn_to_radec, ("source", "radec"), - lmn, ("source", "lmn"), - phase_centre, phase_centre_dims, - new_axes={"radec": 2}, - dtype=lmn.dtype) + return da.core.blockwise( + _lmn_to_radec, + ("source", "radec"), + lmn, + ("source", "lmn"), + phase_centre, + phase_centre_dims, + new_axes={"radec": 2}, + dtype=lmn.dtype, + ) def _radec_to_lm(radec, phase_centre): return np_radec_to_lm(radec[0], phase_centre[0] if phase_centre else None) -@requires_optional('dask.array', dask_import_error) +@requires_optional("dask.array", dask_import_error) def radec_to_lm(radec, phase_centre=None): phase_centre_dims = ("radec",) if phase_centre is not None else None - return da.core.blockwise(_radec_to_lm, ("source", "lm"), - radec, ("source", "radec"), - phase_centre, phase_centre_dims, - new_axes={"lm": 2}, - dtype=radec.dtype) + return da.core.blockwise( + _radec_to_lm, + ("source", "lm"), + radec, + ("source", "radec"), + phase_centre, + phase_centre_dims, + new_axes={"lm": 2}, + dtype=radec.dtype, + ) def _lm_to_radec(lm, phase_centre): return np_lm_to_radec(lm[0], phase_centre) -@requires_optional('dask.array', dask_import_error) +@requires_optional("dask.array", dask_import_error) def lm_to_radec(lm, phase_centre=None): phase_centre_dims = ("radec",) if phase_centre is not None else None - return da.core.blockwise(_lm_to_radec, ("source", "radec"), - lm, ("source", "lm"), - phase_centre, phase_centre_dims, - new_axes={"radec": 2}, - dtype=lm.dtype) + return da.core.blockwise( + _lm_to_radec, + ("source", "radec"), + lm, + ("source", "lm"), + phase_centre, + phase_centre_dims, + new_axes={"radec": 2}, + dtype=lm.dtype, + ) try: radec_to_lmn.__doc__ = RADEC_TO_LMN_DOCS.substitute( - lm_components="3", - array_type=":class:`dask.array.Array`") + lm_components="3", array_type=":class:`dask.array.Array`" + ) radec_to_lm.__doc__ = RADEC_TO_LMN_DOCS.substitute( - lm_components="2", - array_type=":class:`dask.array.Array`") + lm_components="2", array_type=":class:`dask.array.Array`" + ) lmn_to_radec.__doc__ = LMN_TO_RADEC_DOCS.substitute( - lm_name="lmn", lm_components="3", - array_type=":class:`dask.array.Array`") + lm_name="lmn", lm_components="3", array_type=":class:`dask.array.Array`" + ) lm_to_radec.__doc__ = LMN_TO_RADEC_DOCS.substitute( - lm_name="lm", lm_components="2", - array_type=":class:`dask.array.Array`") + lm_name="lm", lm_components="2", array_type=":class:`dask.array.Array`" + ) except AttributeError: diff --git a/africanus/coordinates/tests/test_coordinates.py b/africanus/coordinates/tests/test_coordinates.py index 74a056e6d..1cd229b83 100644 --- a/africanus/coordinates/tests/test_coordinates.py +++ b/africanus/coordinates/tests/test_coordinates.py @@ -1,18 +1,19 @@ - import numpy as np from numpy.testing import assert_array_equal, assert_array_almost_equal import pytest -from africanus.coordinates import (radec_to_lmn as np_radec_to_lmn, - radec_to_lm as np_radec_to_lm, - lmn_to_radec as np_lmn_to_radec, - lm_to_radec as np_lm_to_radec) +from africanus.coordinates import ( + radec_to_lmn as np_radec_to_lmn, + radec_to_lm as np_radec_to_lm, + lmn_to_radec as np_lmn_to_radec, + lm_to_radec as np_lm_to_radec, +) from africanus.coordinates.coordinates import astropy_radec_to_lmn def test_radec_to_lmn(): - """ Tests that basics run """ + """Tests that basics run""" np.random.seed(42) @@ -37,11 +38,11 @@ def test_radec_to_lmn(): def test_radec_to_lmn_astropy(): - """ Check that our code agrees with astropy """ + """Check that our code agrees with astropy""" np.random.seed(42) - astropy = pytest.importorskip('astropy') + astropy = pytest.importorskip("astropy") SkyCoord = astropy.coordinates.SkyCoord units = astropy.units @@ -51,15 +52,14 @@ def test_radec_to_lmn_astropy(): lmn = np_radec_to_lmn(radec, phase_centre) ast_radec = SkyCoord(radec[:, 0], radec[:, 1], unit=units.rad) - ast_phase_centre = SkyCoord(phase_centre[0], phase_centre[1], - unit=units.rad) + ast_phase_centre = SkyCoord(phase_centre[0], phase_centre[1], unit=units.rad) ast_lmn = astropy_radec_to_lmn(ast_radec, ast_phase_centre) assert_array_almost_equal(ast_lmn, lmn) def test_radec_to_lmn_wraps(): - """ Test that the radec can be recovered exactly """ + """Test that the radec can be recovered exactly""" np.random.seed(42) @@ -74,13 +74,15 @@ def test_radec_to_lmn_wraps(): def test_dask_radec_to_lmn(): - """ Test that dask version matches numpy version """ + """Test that dask version matches numpy version""" da = pytest.importorskip("dask.array") - from africanus.coordinates.dask import (radec_to_lmn as da_radec_to_lmn, - radec_to_lm as da_radec_to_lm, - lmn_to_radec as da_lmn_to_radec, - lm_to_radec as da_lm_to_radec) + from africanus.coordinates.dask import ( + radec_to_lmn as da_radec_to_lmn, + radec_to_lm as da_radec_to_lm, + lmn_to_radec as da_lmn_to_radec, + lm_to_radec as da_lm_to_radec, + ) np.random.seed(42) @@ -90,7 +92,7 @@ def test_dask_radec_to_lmn(): source = sum(source_chunks) coords = sum(coord_chunks) - radec = np.random.random((source, coords))*10 + radec = np.random.random((source, coords)) * 10 da_radec = da.from_array(radec, chunks=(source_chunks, coord_chunks)) phase_centre = np.random.random(coord_chunks) @@ -123,6 +125,6 @@ def test_dask_radec_to_lmn(): zpc = da.zeros((2,), dtype=radec.dtype, chunks=(2,)) assert_array_equal(da_radec_to_lmn(da_radec), da_radec_to_lmn(da_radec, zpc)) # noqa - assert_array_equal(da_radec_to_lm(da_radec), da_radec_to_lm(da_radec, zpc)) # noqa - assert_array_equal(da_lmn_to_radec(da_lmn), da_lmn_to_radec(da_lmn, zpc)) # noqa - assert_array_equal(da_lm_to_radec(da_lm), da_lm_to_radec(da_lm, zpc)) # noqa + assert_array_equal(da_radec_to_lm(da_radec), da_radec_to_lm(da_radec, zpc)) # noqa + assert_array_equal(da_lmn_to_radec(da_lmn), da_lmn_to_radec(da_lmn, zpc)) # noqa + assert_array_equal(da_lm_to_radec(da_lm), da_lm_to_radec(da_lm, zpc)) # noqa diff --git a/africanus/deconv/hogbom/clean.py b/africanus/deconv/hogbom/clean.py index eaf2ec82a..9ac20217e 100644 --- a/africanus/deconv/hogbom/clean.py +++ b/africanus/deconv/hogbom/clean.py @@ -23,17 +23,20 @@ def twod_gaussian(coords, amplitude, xo, yo, sigma_x, sigma_y, theta, offset): y = coords[1] xo = float(xo) yo = float(yo) - a = (np.cos(theta)**2)/(2*sigma_x**2) + (np.sin(theta)**2)/(2*sigma_y**2) - b = -(np.sin(2*theta))/(4*sigma_x**2) + (np.sin(2*theta))/(4*sigma_y**2) - c = (np.sin(theta)**2)/(2*sigma_x**2) + (np.cos(theta)**2)/(2*sigma_y**2) - g = (offset + amplitude * - np.exp(- (a*((x-xo)**2) + - 2*b*(x-xo)*(y-yo) + - c*((y-yo)**2)))) + a = (np.cos(theta) ** 2) / (2 * sigma_x**2) + (np.sin(theta) ** 2) / ( + 2 * sigma_y**2 + ) + b = -(np.sin(2 * theta)) / (4 * sigma_x**2) + (np.sin(2 * theta)) / (4 * sigma_y**2) + c = (np.sin(theta) ** 2) / (2 * sigma_x**2) + (np.cos(theta) ** 2) / ( + 2 * sigma_y**2 + ) + g = offset + amplitude * np.exp( + -(a * ((x - xo) ** 2) + 2 * b * (x - xo) * (y - yo) + c * ((y - yo) ** 2)) + ) return g.flatten() -@requires_optional('scipy', opt_import_err) +@requires_optional("scipy", opt_import_err) def fit_2d_gaussian(psf): """ Fit an elliptical Gaussian to the primary lobe of the psf @@ -45,17 +48,17 @@ def fit_2d_gaussian(psf): # implementation # I = np.stack((psf>=0.5*psf.max()).nonzero()).transpose() - loc = np.argwhere(psf >= 0.5*psf.max()) + loc = np.argwhere(psf >= 0.5 * psf.max()) # Create an array with these values at the same indices and zeros otherwise lk, mk = psf.shape psf_fit = np.zeros_like(psf) psf_fit[loc[:, 0], loc[:, 1]] = psf[loc[:, 0], loc[:, 1]] # Create x and y indices - x = np.linspace(0, psf.shape[0]-1, psf.shape[0]) - y = np.linspace(0, psf.shape[1]-1, psf.shape[1]) + x = np.linspace(0, psf.shape[0] - 1, psf.shape[0]) + y = np.linspace(0, psf.shape[1] - 1, psf.shape[1]) x, y = np.meshgrid(x, y) # Set starting point of optimiser - initial_guess = (0.5, lk/2, mk/2, 1.75, 1.4, -4.0, 0) + initial_guess = (0.5, lk / 2, mk / 2, 1.75, 1.4, -4.0, 0) # Flatten the data data = psf_fit.ravel() # Fit the function (Gaussian for now) @@ -63,7 +66,7 @@ def fit_2d_gaussian(psf): # Get function with fitted params data_fitted = twod_gaussian((x, y), *popt) # Normalise the psf to have a max value of one - data_fitted = data_fitted/data_fitted.max() + data_fitted = data_fitted / data_fitted.max() return data_fitted.reshape(lk, mk) @@ -103,20 +106,20 @@ def find_peak(residuals): @numba.jit(nopython=True, nogil=True, cache=True) def build_cleanmap(clean, intensity, gamma, p, q): - clean[p, q] += intensity*gamma + clean[p, q] += intensity * gamma @numba.jit(nopython=True, nogil=True, cache=True) def update_residual(residual, intensity, gamma, p, q, npix, psf): npix = residual.shape[0] # Assuming square image - residual -= gamma*intensity*psf[npix - 1 - p:2*npix - 1 - p, - npix - 1 - q:2*npix - 1 - q] + residual -= ( + gamma + * intensity + * psf[npix - 1 - p : 2 * npix - 1 - p, npix - 1 - q : 2 * npix - 1 - q] + ) -def hogbom_clean(dirty, psf, - gamma=0.1, - threshold="default", - niter="default"): +def hogbom_clean(dirty, psf, gamma=0.1, threshold="default", niter="default"): """ Performs Hogbom Clean on the ``dirty`` image given the ``psf``. @@ -145,8 +148,7 @@ def hogbom_clean(dirty, psf, residuals = dirty.copy() # Check that psf is twice the size of residuals - if (psf.shape[0] != 2*residuals.shape[0] or - psf.shape[1] != 2*residuals.shape[1]): + if psf.shape[0] != 2 * residuals.shape[0] or psf.shape[1] != 2 * residuals.shape[1]: raise ValueError("Warning psf not right size") # Initialise array to store cleaned image @@ -156,25 +158,27 @@ def hogbom_clean(dirty, psf, npix = clean.shape[0] if niter == "default": - niter = 3*npix + niter = 3 * npix p, q, pmin, qmin, intensity = find_peak(residuals) if threshold == "default": # Imin + 0.001*(intensity - Imin) - threshold = 0.2*np.abs(intensity) + threshold = 0.2 * np.abs(intensity) logging.info("Threshold set at %s", threshold) else: # Imin + 0.001*(intensity - Imin) - threshold = threshold*np.abs(intensity) + threshold = threshold * np.abs(intensity) logging.info("Assuming user set threshold at %s", threshold) # CLEAN the image i = 0 while np.abs(intensity) > threshold and i <= niter: - logging.info("min %f max %f peak %f threshold %f" % - (residuals.min(), residuals.max(), intensity, threshold)) + logging.info( + "min %f max %f peak %f threshold %f" + % (residuals.min(), residuals.max(), intensity, threshold) + ) # First we set the build_cleanmap(clean, intensity, gamma, p, q) @@ -222,7 +226,7 @@ def restore(clean, psf, residuals): logging.info("Convolving") # cval=0.0) #Fast using fft - iconv_model = scipy.signal.fftconvolve(clean, clean_beam, mode='same') + iconv_model = scipy.signal.fftconvolve(clean, clean_beam, mode="same") logging.info("Convolving done") diff --git a/africanus/dft/dask.py b/africanus/dft/dask.py index eff1415d5..5aec449ee 100644 --- a/africanus/dft/dask.py +++ b/africanus/dft/dask.py @@ -18,78 +18,82 @@ def _im_to_vis_wrapper(image, uvw, lm, frequency, convention, dtype_): - return np_im_to_vis(image[0], uvw[0], lm[0][0], frequency, - convention=convention, dtype=dtype_) + return np_im_to_vis( + image[0], uvw[0], lm[0][0], frequency, convention=convention, dtype=dtype_ + ) -@requires_optional('dask.array', dask_import_error) -def im_to_vis(image, uvw, lm, frequency, - convention='fourier', dtype=np.complex128): - """ Dask wrapper for im_to_vis function """ +@requires_optional("dask.array", dask_import_error) +def im_to_vis(image, uvw, lm, frequency, convention="fourier", dtype=np.complex128): + """Dask wrapper for im_to_vis function""" if lm.chunks[0][0] != lm.shape[0]: - raise ValueError("lm chunks must match lm shape " - "on first axis") + raise ValueError("lm chunks must match lm shape " "on first axis") if image.chunks[0][0] != image.shape[0]: - raise ValueError("Image chunks must match image " - "shape on first axis") + raise ValueError("Image chunks must match image " "shape on first axis") if image.chunks[0][0] != lm.chunks[0][0]: - raise ValueError("Image chunks and lm chunks must " - "match on first axis") + raise ValueError("Image chunks and lm chunks must " "match on first axis") if image.chunks[1] != frequency.chunks[0]: - raise ValueError("Image chunks must match frequency " - "chunks on second axis") - return da.core.blockwise(_im_to_vis_wrapper, ("row", "chan", "corr"), - image, ("source", "chan", "corr"), - uvw, ("row", "(u,v,w)"), - lm, ("source", "(l,m)"), - frequency, ("chan",), - convention=convention, - dtype=dtype, - dtype_=dtype) - - -def _vis_to_im_wrapper(vis, uvw, lm, frequency, flags, - convention, dtype_): - return np_vis_to_im(vis, uvw[0], lm[0], - frequency, flags, - convention=convention, - dtype=dtype_)[None, :] - - -@requires_optional('dask.array', dask_import_error) -def vis_to_im(vis, uvw, lm, frequency, flags, - convention='fourier', dtype=np.float64): - """ Dask wrapper for vis_to_im function """ + raise ValueError("Image chunks must match frequency " "chunks on second axis") + return da.core.blockwise( + _im_to_vis_wrapper, + ("row", "chan", "corr"), + image, + ("source", "chan", "corr"), + uvw, + ("row", "(u,v,w)"), + lm, + ("source", "(l,m)"), + frequency, + ("chan",), + convention=convention, + dtype=dtype, + dtype_=dtype, + ) + + +def _vis_to_im_wrapper(vis, uvw, lm, frequency, flags, convention, dtype_): + return np_vis_to_im( + vis, uvw[0], lm[0], frequency, flags, convention=convention, dtype=dtype_ + )[None, :] + + +@requires_optional("dask.array", dask_import_error) +def vis_to_im(vis, uvw, lm, frequency, flags, convention="fourier", dtype=np.float64): + """Dask wrapper for vis_to_im function""" if vis.chunks[0] != uvw.chunks[0]: - raise ValueError("Vis chunks and uvw chunks must " - "match on first axis") + raise ValueError("Vis chunks and uvw chunks must " "match on first axis") if vis.chunks[1] != frequency.chunks[0]: - raise ValueError("Vis chunks must match frequency " - "chunks on second axis") + raise ValueError("Vis chunks must match frequency " "chunks on second axis") if vis.chunks != flags.chunks: - raise ValueError("Vis chunks must match flags " - "chunks on all axes") - - ims = da.core.blockwise(_vis_to_im_wrapper, - ("row", "source", "chan", "corr"), - vis, ("row", "chan", "corr"), - uvw, ("row", "(u,v,w)"), - lm, ("source", "(l,m)"), - frequency, ("chan",), - flags, ("row", "chan", "corr"), - adjust_chunks={"row": 1}, - convention=convention, - dtype=dtype, - dtype_=dtype) + raise ValueError("Vis chunks must match flags " "chunks on all axes") + + ims = da.core.blockwise( + _vis_to_im_wrapper, + ("row", "source", "chan", "corr"), + vis, + ("row", "chan", "corr"), + uvw, + ("row", "(u,v,w)"), + lm, + ("source", "(l,m)"), + frequency, + ("chan",), + flags, + ("row", "chan", "corr"), + adjust_chunks={"row": 1}, + convention=convention, + dtype=dtype, + dtype_=dtype, + ) return ims.sum(axis=0) -im_to_vis.__doc__ = doc_tuple_to_str(im_to_vis_docs, - [(":class:`numpy.ndarray`", - ":class:`dask.array.Array`")]) +im_to_vis.__doc__ = doc_tuple_to_str( + im_to_vis_docs, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")] +) -vis_to_im.__doc__ = doc_tuple_to_str(vis_to_im_docs, - [(":class:`numpy.ndarray`", - ":class:`dask.array.Array`")]) +vis_to_im.__doc__ = doc_tuple_to_str( + vis_to_im_docs, [(":class:`numpy.ndarray`", ":class:`dask.array.Array`")] +) diff --git a/africanus/dft/examples/predict_from_fits.py b/africanus/dft/examples/predict_from_fits.py index a4353af10..646ec3e38 100755 --- a/africanus/dft/examples/predict_from_fits.py +++ b/africanus/dft/examples/predict_from_fits.py @@ -17,14 +17,16 @@ def create_parser(): p = argparse.ArgumentParser() p.add_argument("ms", help="Name of MS") p.add_argument("--fitsmodel", help="Fits file to predict from") - p.add_argument("--row_chunks", default=30000, type=int, - help="How to chunks up row dimension.") - p.add_argument("--ncpu", default=0, type=int, - help="Number of threads to use for predict") - p.add_argument("--colname", default="MODEL_DATA", - help="Name of column to write data to.") - p.add_argument('--field', default=0, type=int, - help="Field ID to predict to.") + p.add_argument( + "--row_chunks", default=30000, type=int, help="How to chunks up row dimension." + ) + p.add_argument( + "--ncpu", default=0, type=int, help="Number of threads to use for predict" + ) + p.add_argument( + "--colname", default="MODEL_DATA", help="Name of column to write data to." + ) + p.add_argument("--field", default=0, type=int, help="Field ID to predict to.") return p @@ -33,16 +35,19 @@ def create_parser(): if args.ncpu: ncpu = args.ncpu from multiprocessing.pool import ThreadPool + dask.config.set(pool=ThreadPool(ncpu)) else: import multiprocessing + ncpu = multiprocessing.cpu_count() print("Using %i threads" % ncpu) # Get MS frequencies -spw_ds = list(xds_from_table("::".join((args.ms, "SPECTRAL_WINDOW")), - group_cols="__row__"))[0] +spw_ds = list( + xds_from_table("::".join((args.ms, "SPECTRAL_WINDOW")), group_cols="__row__") +)[0] # Get frequencies in the measurement set # If these do not match those in the fits @@ -57,38 +62,38 @@ def create_parser(): # TODO - check that PHASE_DIR in MS matches that in fits # get image coordinates -if hdr['CUNIT1'] != "DEG" and hdr['CUNIT1'] != "deg": +if hdr["CUNIT1"] != "DEG" and hdr["CUNIT1"] != "deg": raise ValueError("Image units must be in degrees") -npix_l = hdr['NAXIS1'] -refpix_l = hdr['CRPIX1'] -delta_l = hdr['CDELT1'] * np.pi/180 # assumes untis are deg -l0 = hdr['CRVAL1'] * np.pi/180 -l_coord = np.sort(np.arange(1 - refpix_l, 1 + npix_l - refpix_l)*delta_l) +npix_l = hdr["NAXIS1"] +refpix_l = hdr["CRPIX1"] +delta_l = hdr["CDELT1"] * np.pi / 180 # assumes untis are deg +l0 = hdr["CRVAL1"] * np.pi / 180 +l_coord = np.sort(np.arange(1 - refpix_l, 1 + npix_l - refpix_l) * delta_l) -if hdr['CUNIT2'] != "DEG" and hdr['CUNIT2'] != "deg": +if hdr["CUNIT2"] != "DEG" and hdr["CUNIT2"] != "deg": raise ValueError("Image units must be in degrees") -npix_m = hdr['NAXIS2'] -refpix_m = hdr['CRPIX2'] -delta_m = hdr['CDELT2'] * np.pi/180 # assumes untis are deg -m0 = hdr['CRVAL2'] * np.pi/180 -m_coord = np.arange(1 - refpix_m, 1 + npix_m - refpix_m)*delta_m +npix_m = hdr["NAXIS2"] +refpix_m = hdr["CRPIX2"] +delta_m = hdr["CDELT2"] * np.pi / 180 # assumes untis are deg +m0 = hdr["CRVAL2"] * np.pi / 180 +m_coord = np.arange(1 - refpix_m, 1 + npix_m - refpix_m) * delta_m npix_tot = npix_l * npix_m # get frequencies -if hdr["CTYPE4"] == 'FREQ': - nband = hdr['NAXIS4'] - refpix_nu = hdr['CRPIX4'] - delta_nu = hdr['CDELT4'] # assumes units are Hz - ref_freq = hdr['CRVAL4'] - ncorr = hdr['NAXIS3'] +if hdr["CTYPE4"] == "FREQ": + nband = hdr["NAXIS4"] + refpix_nu = hdr["CRPIX4"] + delta_nu = hdr["CDELT4"] # assumes units are Hz + ref_freq = hdr["CRVAL4"] + ncorr = hdr["NAXIS3"] freq_axis = str(4) -elif hdr["CTYPE3"] == 'FREQ': - nband = hdr['NAXIS3'] - refpix_nu = hdr['CRPIX3'] - delta_nu = hdr['CDELT3'] # assumes units are Hz - ref_freq = hdr['CRVAL3'] - ncorr = hdr['NAXIS4'] +elif hdr["CTYPE3"] == "FREQ": + nband = hdr["NAXIS3"] + refpix_nu = hdr["CRPIX3"] + delta_nu = hdr["CDELT3"] # assumes units are Hz + ref_freq = hdr["CRVAL3"] + ncorr = hdr["NAXIS4"] freq_axis = str(3) else: raise ValueError("Freq axis must be 3rd or 4th") @@ -103,17 +108,17 @@ def create_parser(): # if frequencies do not match we need to reprojects fits cube if np.any(ms_freqs != freqs): - print("Warning - reprojecting fits cube to MS freqs. " - "This uses a lot of memory. ") + print( + "Warning - reprojecting fits cube to MS freqs. " "This uses a lot of memory. " + ) from scipy.interpolate import RegularGridInterpolator + # interpolate fits cube - fits_interp = RegularGridInterpolator((freqs, l_coord, m_coord), - model.squeeze(), - bounds_error=False, - fill_value=None) + fits_interp = RegularGridInterpolator( + (freqs, l_coord, m_coord), model.squeeze(), bounds_error=False, fill_value=None + ) # reevaluate at ms freqs - vv, ll, mm = np.meshgrid(ms_freqs, l_coord, m_coord, - indexing='ij') + vv, ll, mm = np.meshgrid(ms_freqs, l_coord, m_coord, indexing="ij") vlm = np.vstack((vv.flatten(), ll.flatten(), mm.flatten())).T model_cube = fits_interp(vlm).reshape(nchan, npix_l, npix_m) else: @@ -127,15 +132,15 @@ def create_parser(): model_cube = model_cube.reshape(nchan, npix_tot) model_max = np.amax(np.abs(model_cube), axis=0) idx_nz = np.argwhere(model_max > 0.0).squeeze() -model_predict = np.transpose(model_cube[:, None, idx_nz], - [2, 0, 1]) +model_predict = np.transpose(model_cube[:, None, idx_nz], [2, 0, 1]) ncomps = idx_nz.size model_predict = da.from_array(model_predict, chunks=(ncomps, nchan, ncorr)) lm = da.from_array(lm[idx_nz, :], chunks=(ncomps, 2)) ms_freqs = spw_ds.CHAN_FREQ.data -xds = xds_from_ms(args.ms, columns=["UVW", args.colname], - chunks={"row": args.row_chunks})[0] +xds = xds_from_ms( + args.ms, columns=["UVW", args.colname], chunks={"row": args.row_chunks} +)[0] uvw = xds.UVW.data vis = im_to_vis(model_predict, uvw, lm, ms_freqs) diff --git a/africanus/dft/kernels.py b/africanus/dft/kernels.py index 000cbba1d..cbe4d7576 100644 --- a/africanus/dft/kernels.py +++ b/africanus/dft/kernels.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- -from africanus.util.numba import (is_numba_type_none, njit, - overload, JIT_OPTIONS) +from africanus.util.numba import is_numba_type_none, njit, overload, JIT_OPTIONS from africanus.util.docs import doc_tuple_to_str from collections import namedtuple @@ -13,33 +12,28 @@ @njit(**JIT_OPTIONS) -def im_to_vis(image, uvw, lm, frequency, - convention='fourier', dtype=None): - return im_to_vis_impl(image, uvw, lm, frequency, - convention=convention, dtype=dtype) +def im_to_vis(image, uvw, lm, frequency, convention="fourier", dtype=None): + return im_to_vis_impl(image, uvw, lm, frequency, convention=convention, dtype=dtype) -def im_to_vis_impl(image, uvw, lm, frequency, - convention='fourier', dtype=None): +def im_to_vis_impl(image, uvw, lm, frequency, convention="fourier", dtype=None): raise NotImplementedError @overload(im_to_vis_impl, jit_options=JIT_OPTIONS) -def nb_im_to_vis(image, uvw, lm, frequency, - convention='fourier', dtype=None): +def nb_im_to_vis(image, uvw, lm, frequency, convention="fourier", dtype=None): # Infer complex output dtype if none provided if is_numba_type_none(dtype): - out_dtype = np.result_type(np.complex64, - *(np.dtype(a.dtype.name) for a in - (image, uvw, lm, frequency))) + out_dtype = np.result_type( + np.complex64, *(np.dtype(a.dtype.name) for a in (image, uvw, lm, frequency)) + ) else: out_dtype = dtype.dtype - def impl(image, uvw, lm, frequency, - convention='fourier', dtype=None): - if convention == 'fourier': + def impl(image, uvw, lm, frequency, convention="fourier", dtype=None): + if convention == "fourier": constant = minus_two_pi_over_c - elif convention == 'casa': + elif convention == "casa": constant = two_pi_over_c else: raise ValueError("convention not in ('fourier', 'casa')") @@ -68,7 +62,7 @@ def impl(image, uvw, lm, frequency, for c in range(ncorr): if image[s, nu, c]: - vis_of_im[r, nu, c] += np.exp(p)*image[s, nu, c] + vis_of_im[r, nu, c] += np.exp(p) * image[s, nu, c] return vis_of_im @@ -76,20 +70,18 @@ def impl(image, uvw, lm, frequency, @njit(**JIT_OPTIONS) -def vis_to_im(vis, uvw, lm, frequency, flags, - convention='fourier', dtype=None): - return vis_to_im_impl(vis, uvw, lm, frequency, flags, - convention=convention, dtype=dtype) +def vis_to_im(vis, uvw, lm, frequency, flags, convention="fourier", dtype=None): + return vis_to_im_impl( + vis, uvw, lm, frequency, flags, convention=convention, dtype=dtype + ) -def vis_to_im_impl(vis, uvw, lm, frequency, flags, - convention='fourier', dtype=None): +def vis_to_im_impl(vis, uvw, lm, frequency, flags, convention="fourier", dtype=None): raise NotImplementedError @overload(vis_to_im_impl, jit_options=JIT_OPTIONS) -def nb_vis_to_im(vis, uvw, lm, frequency, flags, - convention='fourier', dtype=None): +def nb_vis_to_im(vis, uvw, lm, frequency, flags, convention="fourier", dtype=None): # Infer output dtype if none provided if is_numba_type_none(dtype): # Support both real and complex visibilities... @@ -98,9 +90,9 @@ def nb_vis_to_im(vis, uvw, lm, frequency, flags, else: vis_comp_dtype = np.dtype(vis.dtype.name) - out_dtype = np.result_type(vis_comp_dtype, - *(np.dtype(a.dtype.name) for a in - (uvw, lm, frequency))) + out_dtype = np.result_type( + vis_comp_dtype, *(np.dtype(a.dtype.name) for a in (uvw, lm, frequency)) + ) else: if isinstance(dtype, numba.types.scalars.Complex): raise TypeError("dtype must be complex") @@ -109,16 +101,15 @@ def nb_vis_to_im(vis, uvw, lm, frequency, flags, assert np.shape(vis) == np.shape(flags) - def impl(vis, uvw, lm, frequency, flags, - convention='fourier', dtype=None): + def impl(vis, uvw, lm, frequency, flags, convention="fourier", dtype=None): nrows = uvw.shape[0] nsrc = lm.shape[0] nchan = frequency.shape[0] ncorr = vis.shape[-1] - if convention == 'fourier': + if convention == "fourier": constant = two_pi_over_c - elif convention == 'casa': + elif convention == "casa": constant = minus_two_pi_over_c else: raise ValueError("convention not in ('fourier', 'casa')") @@ -128,7 +119,7 @@ def impl(vis, uvw, lm, frequency, flags, # For each source for s in range(nsrc): l, m = lm[s] - n = np.sqrt(1.0 - l ** 2 - m ** 2) - 1.0 + n = np.sqrt(1.0 - l**2 - m**2) - 1.0 # For each uvw coordinate for r in range(nrows): u, v, w = uvw[r] @@ -147,18 +138,17 @@ def impl(vis, uvw, lm, frequency, flags, for c in range(ncorr): # elide the call to exp since result is real - im_of_vis[s, nu, c] += (np.cos(p) * - vis[r, nu, c].real - - np.sin(p) * - vis[r, nu, c].imag) + im_of_vis[s, nu, c] += ( + np.cos(p) * vis[r, nu, c].real + - np.sin(p) * vis[r, nu, c].imag + ) return im_of_vis return impl -_DFT_DOCSTRING = namedtuple( - "_DFTDOCSTRING", ["preamble", "parameters", "returns"]) +_DFT_DOCSTRING = namedtuple("_DFTDOCSTRING", ["preamble", "parameters", "returns"]) im_to_vis_docs = _DFT_DOCSTRING( preamble=""" @@ -170,7 +160,6 @@ def impl(vis, uvw, lm, frequency, flags, {\\Large \\sum_s e^{-2 \\pi i (u l_s + v m_s + w (n_s - 1))} \\cdot I_s } """, # noqa - parameters=r""" Parameters ---------- @@ -196,13 +185,12 @@ def impl(vis, uvw, lm, frequency, flags, If ``None``, :func:`numpy.result_type` is used to infer the data type from the inputs. """, - returns=""" Returns ------- visibilties : :class:`numpy.ndarray` complex of shape :code:`(row, chan, corr)` - """ + """, ) @@ -218,7 +206,6 @@ def impl(vis, uvw, lm, frequency, flags, {\\Large \\sum_k e^{ 2 \\pi i (u_k l + v_k m + w_k (n - 1))} \\cdot V_k} """, # noqa - parameters=r""" Parameters ---------- @@ -250,13 +237,12 @@ def impl(vis, uvw, lm, frequency, flags, If ``None``, :func:`numpy.result_type` is used to infer the data type from the inputs. """, - returns=""" Returns ------- image : :class:`numpy.ndarray` float of shape :code:`(source, chan, corr)` - """ + """, ) diff --git a/africanus/dft/tests/test_dft.py b/africanus/dft/tests/test_dft.py index ac445d82f..5cc73131c 100644 --- a/africanus/dft/tests/test_dft.py +++ b/africanus/dft/tests/test_dft.py @@ -27,10 +27,10 @@ def test_im_to_vis_phase_centre(): ncorr = 2 image = np.zeros((npix, npix, nchan, ncorr), dtype=np.float64) I0 = 1.0 - ref_freq = frequency[nchan//2] - Inu = I0*(frequency/ref_freq)**(-0.7) + ref_freq = frequency[nchan // 2] + Inu = I0 * (frequency / ref_freq) ** (-0.7) for corr in range(ncorr): - image[npix//2, npix//2, :, corr] = Inu + image[npix // 2, npix // 2, :, corr] = Inu image = image.reshape(npix**2, nchan, ncorr) vis = im_to_vis(image, uvw, lm, frequency) @@ -49,19 +49,20 @@ def test_im_to_vis_simple(): """ from africanus.dft.kernels import im_to_vis from africanus.constants import minus_two_pi_over_c + np.random.seed(123) nrow = 100 uvw = np.random.random(size=(nrow, 3)) nchan = 3 - frequency = np.linspace(1.e9, 2.e9, nchan, endpoint=True) + frequency = np.linspace(1.0e9, 2.0e9, nchan, endpoint=True) nsource = 5 I0 = np.random.randn(nsource) - ref_freq = frequency[nchan//2] - image = I0[:, None] * (frequency/ref_freq)**(-0.7) + ref_freq = frequency[nchan // 2] + image = I0[:, None] * (frequency / ref_freq) ** (-0.7) # add correlation axis image = image[:, :, None] - l = 0.001 + 0.1*np.random.random(nsource) # noqa - m = 0.001 + 0.1*np.random.random(nsource) + l = 0.001 + 0.1 * np.random.random(nsource) # noqa + m = 0.001 + 0.1 * np.random.random(nsource) lm = np.vstack((l, m)).T vis = im_to_vis(image, uvw, lm, frequency).squeeze() @@ -71,19 +72,24 @@ def test_im_to_vis_simple(): for source in range(nsource): l, m = lm[source] n = np.sqrt(1.0 - l**2 - m**2) - phase = minus_two_pi_over_c*frequency[ch] * 1.0j * \ - (uvw[:, 0]*l + uvw[:, 1]*m + uvw[:, 2]*(n-1)) - - vis_true[:, ch] += image[source, ch, 0]*np.exp(phase) + phase = ( + minus_two_pi_over_c + * frequency[ch] + * 1.0j + * (uvw[:, 0] * l + uvw[:, 1] * m + uvw[:, 2] * (n - 1)) + ) + + vis_true[:, ch] += image[source, ch, 0] * np.exp(phase) assert_array_almost_equal(vis, vis_true, decimal=14) -@pytest.mark.parametrize("convention", ['fourier', 'casa']) +@pytest.mark.parametrize("convention", ["fourier", "casa"]) def test_im_to_vis_fft(convention): """ Test against the fft when uv on regular and w is zero. """ from africanus.dft.kernels import im_to_vis + np.random.seed(123) Fs = np.fft.fftshift iFs = np.fft.ifftshift @@ -95,15 +101,15 @@ def test_im_to_vis_fft(convention): fft_image = np.zeros((npix, npix, ncorr), dtype=np.complex128) nsource = 25 for corr in range(ncorr): - Ix = np.random.randint(5, npix-5, nsource) - Iy = np.random.randint(5, npix-5, nsource) + Ix = np.random.randint(5, npix - 5, nsource) + Iy = np.random.randint(5, npix - 5, nsource) image[Ix, Iy, corr] = np.random.randn(nsource) fft_image[:, :, corr] = Fs(np.fft.fft2(iFs(image[:, :, corr]))) # image space coords deltal = 0.001 # this assumes npix is odd - l_coord = np.arange(-(npix//2), npix//2+1) * deltal + l_coord = np.arange(-(npix // 2), npix // 2 + 1) * deltal ll, mm = np.meshgrid(l_coord, l_coord) lm = np.vstack((ll.flatten(), mm.flatten())).T # uv-space coords @@ -116,12 +122,13 @@ def test_im_to_vis_fft(convention): image = image.reshape(npix**2, nchan, ncorr) frequency = np.ones(nchan, dtype=np.float64) from africanus.constants import c as lightspeed + frequency *= lightspeed # makes result independent of frequency # take DFT and compare vis = im_to_vis(image, uvw, lm, frequency, convention=convention) fft_image = fft_image.reshape(npix**2, nchan, ncorr) - fft_image = np.conj(fft_image) if convention == 'casa' else fft_image + fft_image = np.conj(fft_image) if convention == "casa" else fft_image assert_array_almost_equal(vis, fft_image, decimal=13) @@ -143,10 +150,10 @@ def test_adjointness(): ncorr = 4 uvw = 100 * np.random.random(size=(nrow, 3)) - ll = 0.01*np.random.randn(nsource) - mm = 0.01*np.random.randn(nsource) + ll = 0.01 * np.random.randn(nsource) + mm = 0.01 * np.random.randn(nsource) lm = np.vstack((ll, mm)).T - frequency = np.arange(1, nchan+1) * lightspeed # avoid overflow + frequency = np.arange(1, nchan + 1) * lightspeed # avoid overflow shape_im = (nsource, nchan, ncorr) size_im = np.prod(shape_im) @@ -156,10 +163,16 @@ def test_adjointness(): gamma_vis = np.random.randn(nrow, nchan, ncorr) flag = np.zeros(shape_vis, dtype=bool) - LHS = (gamma_vis.reshape(size_vis, 1).T.dot( - R(gamma_im, uvw, lm, frequency).reshape(size_vis, 1))).real - RHS = (RH(gamma_vis, uvw, lm, frequency, flag).reshape( - size_im, 1).T.dot(gamma_im.reshape(size_im, 1))).real + LHS = ( + gamma_vis.reshape(size_vis, 1).T.dot( + R(gamma_im, uvw, lm, frequency).reshape(size_vis, 1) + ) + ).real + RHS = ( + RH(gamma_vis, uvw, lm, frequency, flag) + .reshape(size_im, 1) + .T.dot(gamma_im.reshape(size_im, 1)) + ).real assert np.abs(LHS - RHS) < 1e-13 @@ -181,13 +194,14 @@ def test_vis_to_im_flagged(): uvw = 100 * np.random.random(size=(nrow, 3)) uvw[0, :] = 0.0 - ll = 0.01*np.random.randn(nsource) - mm = 0.01*np.random.randn(nsource) + ll = 0.01 * np.random.randn(nsource) + mm = 0.01 * np.random.randn(nsource) lm = np.vstack((ll, mm)).T - frequency = np.arange(1, nchan+1) * lightspeed # avoid overflow + frequency = np.arange(1, nchan + 1) * lightspeed # avoid overflow - vis = np.random.randn(nrow, nchan, ncorr) + \ - 1.0j*np.random.randn(nrow, nchan, ncorr) + vis = np.random.randn(nrow, nchan, ncorr) + 1.0j * np.random.randn( + nrow, nchan, ncorr + ) vis[0, :, :] = 1.0 flags = np.ones((nrow, nchan, ncorr), dtype=bool) @@ -196,13 +210,12 @@ def test_vis_to_im_flagged(): frequency = np.ones(nchan, dtype=np.float64) * lightspeed im_of_vis = vis_to_im(vis, uvw, lm, frequency, flags) - assert_array_almost_equal(im_of_vis, - np.ones((nsource, nchan, ncorr), - dtype=np.float64), - decimal=13) + assert_array_almost_equal( + im_of_vis, np.ones((nsource, nchan, ncorr), dtype=np.float64), decimal=13 + ) -@pytest.mark.parametrize("convention", ['fourier', 'casa']) +@pytest.mark.parametrize("convention", ["fourier", "casa"]) def test_im_to_vis_dask(convention): """ Tests against numpy version @@ -215,8 +228,8 @@ def test_im_to_vis_dask(convention): nrow = 8000 uvw = 100 * np.random.random(size=(nrow, 3)) nsource = 800 # must be odd for this test to work - ll = 0.01*np.random.randn(nsource) - mm = 0.01*np.random.randn(nsource) + ll = 0.01 * np.random.randn(nsource) + mm = 0.01 * np.random.randn(nsource) lm = np.vstack((ll, mm)).T nchan = 11 frequency = np.linspace(1.0, 2.0, nchan) * lightspeed @@ -224,15 +237,15 @@ def test_im_to_vis_dask(convention): image = np.random.randn(nsource, nchan, ncorr) # set up dask arrays - uvw_dask = da.from_array(uvw, chunks=(nrow//8, 3)) + uvw_dask = da.from_array(uvw, chunks=(nrow // 8, 3)) lm_dask = da.from_array(lm, chunks=(nsource, 2)) - frequency_dask = da.from_array(frequency, chunks=nchan//2) - image_dask = da.from_array(image, chunks=(nsource, nchan//2, ncorr)) + frequency_dask = da.from_array(frequency, chunks=nchan // 2) + image_dask = da.from_array(image, chunks=(nsource, nchan // 2, ncorr)) vis = np_im_to_vis(image, uvw, lm, frequency, convention=convention) - vis_dask = dask_im_to_vis(image_dask, uvw_dask, - lm_dask, frequency_dask, - convention=convention).compute() + vis_dask = dask_im_to_vis( + image_dask, uvw_dask, lm_dask, frequency_dask, convention=convention + ).compute() assert_array_almost_equal(vis, vis_dask, decimal=13) @@ -254,27 +267,29 @@ def test_vis_to_im_dask(): vis = np.random.randn(nrow, nchan, ncorr) uvw = np.random.randn(nrow, 3) - ll = 0.01*np.random.randn(nsource) - mm = 0.01*np.random.randn(nsource) + ll = 0.01 * np.random.randn(nsource) + mm = 0.01 * np.random.randn(nsource) lm = np.vstack((ll, mm)).T nchan = 11 frequency = np.linspace(1.0, 2.0, nchan) * lightspeed flagged_frac = 0.45 - flags = np.random.choice(a=[False, True], size=(nrow, nchan, ncorr), - p=[flagged_frac, 1-flagged_frac]) + flags = np.random.choice( + a=[False, True], size=(nrow, nchan, ncorr), p=[flagged_frac, 1 - flagged_frac] + ) image = np_vis_to_im(vis, uvw, lm, frequency, flags) # set up dask arrays - uvw_dask = da.from_array(uvw, chunks=(nrow//8, 3)) + uvw_dask = da.from_array(uvw, chunks=(nrow // 8, 3)) lm_dask = da.from_array(lm, chunks=(nsource, 2)) - frequency_dask = da.from_array(frequency, chunks=nchan//2) - vis_dask = da.from_array(vis, chunks=(nrow//8, nchan//2, ncorr)) - flags_dask = da.from_array(flags, chunks=(nrow//8, nchan//2, ncorr)) + frequency_dask = da.from_array(frequency, chunks=nchan // 2) + vis_dask = da.from_array(vis, chunks=(nrow // 8, nchan // 2, ncorr)) + flags_dask = da.from_array(flags, chunks=(nrow // 8, nchan // 2, ncorr)) image_dask = dask_vis_to_im( - vis_dask, uvw_dask, lm_dask, frequency_dask, flags_dask).compute() + vis_dask, uvw_dask, lm_dask, frequency_dask, flags_dask + ).compute() assert_array_almost_equal(image, image_dask, decimal=13) @@ -285,6 +300,7 @@ def test_symmetric_covariance(): (symmetric since its real). """ from africanus.dft.kernels import vis_to_im, im_to_vis + np.random.seed(123) nsource = 25 @@ -292,8 +308,8 @@ def test_symmetric_covariance(): nchan = 1 lmmax = 0.05 - ll = -lmmax + 2*lmmax*np.random.random(nsource) # noqa - mm = -lmmax + 2*lmmax*np.random.random(nsource) + ll = -lmmax + 2 * lmmax * np.random.random(nsource) # noqa + mm = -lmmax + 2 * lmmax * np.random.random(nsource) lm = np.vstack((ll, mm)).T nrows = 1000 diff --git a/africanus/experimental/rime/fused/arguments.py b/africanus/experimental/rime/fused/arguments.py index 95e38221e..b7cfc52ea 100644 --- a/africanus/experimental/rime/fused/arguments.py +++ b/africanus/experimental/rime/fused/arguments.py @@ -4,12 +4,10 @@ class ArgumentPack(Mapping): def __init__(self, names, types, index): - self.pack = OrderedDict((n, (t, i)) for n, t, i - in zip(names, types, index)) + self.pack = OrderedDict((n, (t, i)) for n, t, i in zip(names, types, index)) def copy(self): - names, types, index = zip(*((k, t, i) for k, (t, i) - in self.pack.items())) + names, types, index = zip(*((k, t, i) for k, (t, i) in self.pack.items())) return ArgumentPack(names, types, index) def pop(self, key): @@ -45,15 +43,22 @@ def __len__(self): class ArgumentDependencies: REQUIRED_ARGS = ("time", "antenna1", "antenna2", "feed1", "feed2") - KEY_ARGS = ("utime", "time_index", - "uantenna", "antenna1_index", "antenna2_index", - "ufeed", "feed1_index", "feed2_index") + KEY_ARGS = ( + "utime", + "time_index", + "uantenna", + "antenna1_index", + "antenna2_index", + "ufeed", + "feed1_index", + "feed2_index", + ) def __init__(self, arg_names, terms, transformers): if not set(self.REQUIRED_ARGS).issubset(arg_names): raise ValueError( - f"{set(self.REQUIRED_ARGS) - set(arg_names)} " - f"missing from arg_names") + f"{set(self.REQUIRED_ARGS) - set(arg_names)} " f"missing from arg_names" + ) self.names = arg_names self.terms = terms @@ -79,9 +84,9 @@ def __init__(self, arg_names, terms, transformers): # Determine a canonical set of valid inputs # We start with the desired and required arguments - self.valid_inputs = (set(desired.keys()) | - set(self.REQUIRED_ARGS) | - set(optional.keys())) + self.valid_inputs = ( + set(desired.keys()) | set(self.REQUIRED_ARGS) | set(optional.keys()) + ) # Then, for each argument than can be created # we add the transformer arguments and remove @@ -111,8 +116,7 @@ def _resolve_arg_dependencies(self): for transformer in self.maybe_create[arg]: # We didn't have the arguments, make a note of this if not set(transformer.ARGS).issubset(available_args): - failed_transforms[arg].append( - (transformer, set(transformer.ARGS))) + failed_transforms[arg].append((transformer, set(transformer.ARGS))) continue # The transformer can create arg @@ -128,10 +132,12 @@ def _resolve_arg_dependencies(self): if arg in failed_transforms: for transformer, needed in failed_transforms[arg]: - err_msgs.append(f"{transformer} can create {arg} " - f"but needs {needed}, of which " - f"{needed - supplied_args} is missing " - f"from the input arguments.") + err_msgs.append( + f"{transformer} can create {arg} " + f"but needs {needed}, of which " + f"{needed - supplied_args} is missing " + f"from the input arguments." + ) raise ValueError("\n".join(err_msgs)) @@ -146,9 +152,11 @@ def _resolve_arg_dependencies(self): defaults = set(defaults) if len(defaults) != 1: - raise ValueError(f"Multiple terms: {self.terms} have " - f"contradicting definitions for " - f"{k}: {defaults}") + raise ValueError( + f"Multiple terms: {self.terms} have " + f"contradicting definitions for " + f"{k}: {defaults}" + ) opt_defaults[k] = defaults.pop() diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index af4f98feb..ed4b9c1ca 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -40,15 +40,17 @@ def rime_impl(*args): @overload(rime_impl, jit_options=JIT_OPTIONS) def nb_rime(*args): if not len(args) > 0: - raise TypeError("rime must be at least be called " - "with the signature argument") + raise TypeError( + "rime must be at least be called " "with the signature argument" + ) if not isinstance(args[0], types.Literal): raise TypeError(f"Signature hash ({args[0]}) must be a literal") if not len(args) % 2 == 1: - raise TypeError(f"Length of named arguments {len(args)} " - f"is not divisible by 2") + raise TypeError( + f"Length of named arguments {len(args)} " f"is not divisible by 2" + ) argstart = 1 + (len(args) - 1) // 2 names = args[1:argstart] @@ -84,7 +86,7 @@ def impl(*args): nsrc, _ = args[lm_i].shape nrow, _ = args[uvw_i].shape - nchan, = args[chan_freq_i].shape + (nchan,) = args[chan_freq_i].shape vis = np.zeros((nrow, nchan, ncorr), np.complex128) # Kahan summation compensation @@ -128,8 +130,7 @@ def __hash__(self): return hash(self.rime_spec) def __eq__(self, rhs): - return (isinstance(rhs, RimeFactory) and - self.rime_spec == rhs.rime_spec) + return isinstance(rhs, RimeFactory) and self.rime_spec == rhs.rime_spec def __init__(self, rime_spec=DEFAULT_SPEC): if isinstance(rime_spec, RimeSpecification): @@ -141,16 +142,14 @@ def __init__(self, rime_spec=DEFAULT_SPEC): self.rime_spec = rime_spec self.impl = rime_impl_factory( - rime_spec.terms, - rime_spec.transformers, - len(rime_spec.corrs)) + rime_spec.terms, rime_spec.transformers, len(rime_spec.corrs) + ) def dask_blockwise_args(self, **kwargs): - """ Get the dask schema """ + """Get the dask schema""" argdeps = ArgumentDependencies( - tuple(kwargs.keys()), - self.rime_spec.terms, - self.rime_spec.transformers) + tuple(kwargs.keys()), self.rime_spec.terms, self.rime_spec.transformers + ) # Holds kwargs + any dummy outputs from transformations dummy_kw = kwargs.copy() @@ -190,23 +189,23 @@ def dask_blockwise_args(self, **kwargs): if len(dims) != 1: raise ValueError( f"Multiple candidates provided conflicting " - f"dimension definitions for {a}: {candidates}.") + f"dimension definitions for {a}: {candidates}." + ) merged_schema[a] = dims.pop() names = list(sorted(argdeps.valid_inputs & set(kwargs.keys()))) - blockwise_args = [e for n in names - for e in (kwargs[n], merged_schema.get(n, None))] + blockwise_args = [ + e for n in names for e in (kwargs[n], merged_schema.get(n, None)) + ] assert 2 * len(names) == len(blockwise_args) return names, blockwise_args def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs): - keys = (self.REQUIRED_ARGS_LITERAL + - tuple(map(types.literal, kwargs.keys()))) + keys = self.REQUIRED_ARGS_LITERAL + tuple(map(types.literal, kwargs.keys())) - args = keys + (time, antenna1, antenna2, feed1, - feed2) + tuple(kwargs.values()) + args = keys + (time, antenna1, antenna2, feed1, feed2) + tuple(kwargs.values()) return self.impl(types.literal(self.rime_spec.spec_hash), *args) diff --git a/africanus/experimental/rime/fused/dask.py b/africanus/experimental/rime/fused/dask.py index 3bc2d2e61..4a63a1708 100644 --- a/africanus/experimental/rime/fused/dask.py +++ b/africanus/experimental/rime/fused/dask.py @@ -1,8 +1,7 @@ import numpy as np from africanus.util.requirements import requires_optional -from africanus.experimental.rime.fused.core import ( - RimeFactory, consolidate_args) +from africanus.experimental.rime.fused.core import RimeFactory, consolidate_args try: import dask.array as da @@ -20,7 +19,7 @@ def rime_dask_wrapper(factory, names, nconcat_dims, *args): # (2) slice the existing dimensions # (3) expand by the contraction dims which will # be removed in the later dask reduction - return out[(None,) + (slice(None),)*out.ndim + (None,)*nconcat_dims] + return out[(None,) + (slice(None),) * out.ndim + (None,) * nconcat_dims] @requires_optional("dask.array", opt_import_err) @@ -44,18 +43,24 @@ def rime(rime_spec, *args, **kw): # This incurs memory allocations within numba, as well as # exceptions, leading to memory leaks as described # in https://github.com/numba/numba/issues/3263 - meta = np.empty((0,)*len(out_dims), dtype=np.complex128) + meta = np.empty((0,) * len(out_dims), dtype=np.complex128) # Construct the wrapper call from given arguments - out = da.blockwise(rime_dask_wrapper, out_dims, - factory, None, - names, None, - len(contract_dims), None, - *args, - concatenate=False, - adjust_chunks=adjust_chunks, - new_axes=new_axes, - meta=meta) + out = da.blockwise( + rime_dask_wrapper, + out_dims, + factory, + None, + names, + None, + len(contract_dims), + None, + *args, + concatenate=False, + adjust_chunks=adjust_chunks, + new_axes=new_axes, + meta=meta, + ) # Contract over source and concatenation dims axes = (0,) + tuple(range(len(dims), len(dims) + len(contract_dims))) diff --git a/africanus/experimental/rime/fused/intrinsics.py b/africanus/experimental/rime/fused/intrinsics.py index c7bc5e26c..8fb568ab3 100644 --- a/africanus/experimental/rime/fused/intrinsics.py +++ b/africanus/experimental/rime/fused/intrinsics.py @@ -24,55 +24,44 @@ def scalar_scalar(lhs, rhs): - return lhs*rhs + return lhs * rhs def scalar_diag(lhs, rhs): - return lhs*rhs[0], lhs*rhs[1] + return lhs * rhs[0], lhs * rhs[1] def scalar_full(lhs, rhs): - return lhs*rhs[0], lhs*rhs[1], lhs*rhs[2], lhs*rhs[3] + return lhs * rhs[0], lhs * rhs[1], lhs * rhs[2], lhs * rhs[3] def diag_scalar(lhs, rhs): - return lhs[0]*rhs, lhs[1]*rhs + return lhs[0] * rhs, lhs[1] * rhs def diag_diag(lhs, rhs): - return lhs[0]*rhs[0], lhs[1]*rhs[1] + return lhs[0] * rhs[0], lhs[1] * rhs[1] def diag_full(lhs, rhs): - return ( - lhs[0]*rhs[0], - lhs[0]*rhs[1], - lhs[1]*rhs[2], - lhs[1]*rhs[3]) + return (lhs[0] * rhs[0], lhs[0] * rhs[1], lhs[1] * rhs[2], lhs[1] * rhs[3]) def full_scalar(lhs, rhs): - return ( - lhs[0]*rhs, - lhs[1]*rhs, - lhs[2]*rhs, - lhs[3]*rhs) + return (lhs[0] * rhs, lhs[1] * rhs, lhs[2] * rhs, lhs[3] * rhs) def full_diag(lhs, rhs): - return ( - lhs[0]*rhs[0], - lhs[1]*rhs[1], - lhs[2]*rhs[0], - lhs[3]*rhs[1]) + return (lhs[0] * rhs[0], lhs[1] * rhs[1], lhs[2] * rhs[0], lhs[3] * rhs[1]) def full_full(lhs, rhs): return ( - lhs[0]*rhs[0] + lhs[1]*rhs[2], - lhs[0]*rhs[1] + lhs[1]*rhs[3], - lhs[2]*rhs[0] + lhs[3]*rhs[2], - lhs[2]*rhs[1] + lhs[3]*rhs[3]) + lhs[0] * rhs[0] + lhs[1] * rhs[2], + lhs[0] * rhs[1] + lhs[1] * rhs[3], + lhs[2] * rhs[0] + lhs[3] * rhs[2], + lhs[2] * rhs[1] + lhs[3] * rhs[3], + ) def hermitian_scalar(jones): @@ -84,10 +73,7 @@ def hermitian_diag(jones): def hermitian_full(jones): - return (np.conj(jones[0]), - np.conj(jones[2]), - np.conj(jones[1]), - np.conj(jones[3])) + return (np.conj(jones[0]), np.conj(jones[2]), np.conj(jones[1]), np.conj(jones[3])) _jones_typ_map = { @@ -99,7 +85,7 @@ def hermitian_full(jones): ("diag", "full"): diag_full, ("full", "scalar"): full_scalar, ("full", "diag"): full_diag, - ("full", "full"): full_full + ("full", "full"): full_full, } @@ -140,14 +126,13 @@ def term_mul(lhs, rhs): try: return _jones_typ_map[(lhs_type, rhs_type)] except KeyError: - raise TypingError(f"No known multiplication " - f"function for {lhs} and {rhs}") + raise TypingError(f"No known multiplication " f"function for {lhs} and {rhs}") _hermitian_map = { "scalar": hermitian_scalar, "diag": hermitian_diag, - "full": hermitian_full + "full": hermitian_full, } @@ -157,8 +142,7 @@ def hermitian(jones): try: return _hermitian_map[jones_type] except KeyError: - raise TypingError(f"No known hermitian function " - f"for {jones}: {jones_type}.") + raise TypingError(f"No known hermitian function " f"for {jones}: {jones_type}.") def unify_jones_terms(typingctx, lhs, rhs): @@ -174,9 +158,9 @@ def unify_jones_terms(typingctx, lhs, rhs): lhs_corrs = corr_map[lhs_type] rhs_corrs = corr_map[rhs_type] except KeyError: - raise TypingError(f"{lhs} or {rhs} has no " - f"entry in the {corr_map} " - f"mapping") + raise TypingError( + f"{lhs} or {rhs} has no " f"entry in the {corr_map} " f"mapping" + ) lhs_types = (lhs,) if lhs_corrs == 1 else tuple(lhs) rhs_types = (rhs,) if rhs_corrs == 1 else tuple(rhs) @@ -184,7 +168,7 @@ def unify_jones_terms(typingctx, lhs, rhs): out_type = typingctx.unify_types(*lhs_types, *rhs_types) out_corrs = max(lhs_corrs, rhs_corrs) - return out_type if out_corrs == 1 else types.Tuple((out_type,)*out_corrs) + return out_type if out_corrs == 1 else types.Tuple((out_type,) * out_corrs) @intrinsic @@ -216,8 +200,7 @@ def _add(x, y): v2 = builder.extract_value(t2, i) vr = typingctx.unify_types(t1e, t2e) - data = context.compile_internal(builder, _add, - vr(t1e, t2e), [v1, v2]) + data = context.compile_internal(builder, _add, vr(t1e, t2e), [v1, v2]) ret_tuple = builder.insert_value(ret_tuple, data, i) @@ -227,9 +210,16 @@ def _add(x, y): class IntrinsicFactory: - KEY_ARGS = ("utime", "time_index", - "uantenna", "antenna1_index", "antenna2_index", - "ufeed", "feed1_index", "feed2_index") + KEY_ARGS = ( + "utime", + "time_index", + "uantenna", + "antenna1_index", + "antenna2_index", + "ufeed", + "feed1_index", + "feed2_index", + ) def __init__(self, arg_dependencies): self.argdeps = arg_dependencies @@ -257,8 +247,7 @@ def _resolve_arg_dependencies(self): for transformer in argdeps.maybe_create[arg]: # We didn't have the arguments, make a note of this if not set(transformer.ARGS).issubset(available_args): - failed_transforms[arg].append( - (transformer, set(transformer.ARGS))) + failed_transforms[arg].append((transformer, set(transformer.ARGS))) continue # The transformer can create arg @@ -274,10 +263,12 @@ def _resolve_arg_dependencies(self): if arg in failed_transforms: for transformer, needed in failed_transforms[arg]: - err_msgs.append(f"{transformer} can create {arg} " - f"but needs {needed}, of which " - f"{needed - set(argdeps.names)} is " - f"missing from the input arguments.") + err_msgs.append( + f"{transformer} can create {arg} " + f"but needs {needed}, of which " + f"{needed - set(argdeps.names)} is " + f"missing from the input arguments." + ) raise ValueError("\n".join(err_msgs)) @@ -292,9 +283,11 @@ def _resolve_arg_dependencies(self): defaults = set(defaults) if len(defaults) != 1: - raise ValueError(f"Multiple terms: {argdeps.terms} have " - f"contradicting definitions for " - f"{k}: {defaults}") + raise ValueError( + f"Multiple terms: {argdeps.terms} have " + f"contradicting definitions for " + f"{k}: {defaults}" + ) opt_defaults[k] = defaults.pop() @@ -305,9 +298,11 @@ def _resolve_arg_dependencies(self): def pack_optionals_and_indices_fn(self): argdeps = self.argdeps - out_names = (argdeps.names + - tuple(argdeps.optional_defaults.keys()) + - tuple(argdeps.KEY_ARGS)) + out_names = ( + argdeps.names + + tuple(argdeps.optional_defaults.keys()) + + tuple(argdeps.KEY_ARGS) + ) @intrinsic def pack_index(typingctx, args): @@ -323,20 +318,19 @@ def pack_index(typingctx, args): "antenna2_index": types.int64[:], "ufeed": arg_info["feed1"][0], "feed1_index": types.int64[:], - "feed2_index": types.int64[:] + "feed2_index": types.int64[:], } if tuple(key_types.keys()) != argdeps.KEY_ARGS: - raise RuntimeError( - f"{tuple(key_types.keys())} != {argdeps.KEY_ARGS}") + raise RuntimeError(f"{tuple(key_types.keys())} != {argdeps.KEY_ARGS}") rvt = typingctx.resolve_value_type_prefer_literal - optionals = [(n, rvt(d), d) for n, d - in argdeps.optional_defaults.items()] + optionals = [(n, rvt(d), d) for n, d in argdeps.optional_defaults.items()] optional_types = tuple(p[1] for p in optionals) - return_type = types.Tuple(args.types + optional_types + - tuple(key_types.values())) + return_type = types.Tuple( + args.types + optional_types + tuple(key_types.values()) + ) sig = return_type(args) def codegen(context, builder, signature, args): @@ -363,10 +357,11 @@ def codegen(context, builder, signature, args): # Compute indexing arguments and insert into # the new tuple - fn_args = [builder.extract_value(args[0], arg_info[a][1]) - for a in argdeps.REQUIRED_ARGS] - fn_arg_types = tuple(arg_info[k][0] for k - in argdeps.REQUIRED_ARGS) + fn_args = [ + builder.extract_value(args[0], arg_info[a][1]) + for a in argdeps.REQUIRED_ARGS + ] + fn_arg_types = tuple(arg_info[k][0] for k in argdeps.REQUIRED_ARGS) fn_sig = types.Tuple(list(key_types.values()))(*fn_arg_types) def _indices(time, antenna1, antenna2, feed1, feed2): @@ -378,12 +373,18 @@ def _indices(time, antenna1, antenna2, feed1, feed2): feed1_index = np.searchsorted(ufeeds, feed1) feed2_index = np.searchsorted(ufeeds, feed2) - return (utime, time_index, - uants, antenna1_index, antenna2_index, - ufeeds, feed1_index, feed2_index) + return ( + utime, + time_index, + uants, + antenna1_index, + antenna2_index, + ufeeds, + feed1_index, + feed2_index, + ) - index = context.compile_internal(builder, _indices, - fn_sig, fn_args) + index = context.compile_internal(builder, _indices, fn_sig, fn_args) n += len(optionals) @@ -403,8 +404,7 @@ def _indices(time, antenna1, antenna2, feed1, feed2): def pack_transformed_fn(self, arg_names): argdeps = self.argdeps transformers = list(set(t for _, t in argdeps.can_create.items())) - out_names = arg_names + tuple(o for t in transformers - for o in t.OUTPUTS) + out_names = arg_names + tuple(o for t in transformers for o in t.OUTPUTS) @intrinsic def pack_transformed(typingctx, args): @@ -437,19 +437,21 @@ def pack_transformed(typingctx, args): raise TypingError( f"{transformer} produces {transformer.OUTPUTS} " f"but {transformer}.init_fields does not return " - f"a tuple of the same length, but {fields}") + f"a tuple of the same length, but {fields}" + ) transform_output_types.extend(t for _, t in fields) # Create a return tuple containing the existing arguments # with the transformed outputs added to the end - return_type = types.Tuple(args.types + - tuple(transform_output_types)) + return_type = types.Tuple(args.types + tuple(transform_output_types)) # Sanity check if len(return_type) != len(out_names): - raise TypingError(f"len(return_type): {len(return_type)} != " - f"len(out_names): {len(out_names)}") + raise TypingError( + f"len(return_type): {len(return_type)} != " + f"len(out_names): {len(out_names)}" + ) sig = return_type(args) @@ -484,8 +486,7 @@ def codegen(context, builder, signature, args): try: typ, j = arg_info[name] except KeyError: - raise TypingError( - f"{name} is not present in arg_types") + raise TypingError(f"{name} is not present in arg_types") value = builder.extract_value(args[0], j) transform_args.append(value) @@ -495,16 +496,16 @@ def codegen(context, builder, signature, args): for name, default in transformer.KWARGS.items(): default_typ = rvt(default) default_value = context.get_constant_generic( - builder, - default_typ, - default) + builder, default_typ, default + ) transform_types.append(default_typ) transform_args.append(default_value) # Get the transformer fields and function transform_fields, transform_fn = transformer.init_fields( - typingctx, *transform_types) + typingctx, *transform_types + ) single_return = len(transform_fields) == 1 @@ -517,23 +518,22 @@ def codegen(context, builder, signature, args): # Call the transform function transform_sig = ret_type(*transform_types) - value = context.compile_internal(builder, # noqa - transform_fn, - transform_sig, - transform_args) + value = context.compile_internal( + builder, # noqa + transform_fn, + transform_sig, + transform_args, + ) # Unpack the returned value and insert into # return_tuple if single_return: - ret_tuple = builder.insert_value(ret_tuple, value, - i + n) + ret_tuple = builder.insert_value(ret_tuple, value, i + n) i += 1 else: for j, o in enumerate(transformer.OUTPUTS): element = builder.extract_value(value, j) - ret_tuple = builder.insert_value(ret_tuple, - element, - i + n) + ret_tuple = builder.insert_value(ret_tuple, element, i + n) i += 1 return ret_tuple @@ -551,8 +551,9 @@ def term_state(typingctx, args): raise TypingError(f"args must be a Tuple but is {args}") if len(arg_names) != len(args): - raise TypingError(f"len(arg_names): {len(arg_names)} != " - f"len(args): {len(args)}") + raise TypingError( + f"len(arg_names): {len(arg_names)} != " f"len(args): {len(args)}" + ) arg_pack = ArgumentPack(arg_names, args, tuple(range(len(args)))) @@ -584,11 +585,10 @@ def codegen(context, builder, signature, args): rvt = typingctx.resolve_value_type_prefer_literal def make_struct(): - """ Allocate the structure """ + """Allocate the structure""" return structref.new(state_type) - state = context.compile_internal(builder, make_struct, - state_type(), []) + state = context.compile_internal(builder, make_struct, state_type(), []) U = structref._Utils(context, builder, state_type) data_struct = U.get_data_struct(state) @@ -600,8 +600,7 @@ def make_struct(): # the args tuple and placing it on the structref context.nrt.incref(builder, value_type, value) field_type = state_type.field_dict[arg_name] - casted = context.cast(builder, value, - value_type, field_type) + casted = context.cast(builder, value, value_type, field_type) context.nrt.incref(builder, value_type, casted) # The old value on the structref is being replaced, @@ -627,7 +626,8 @@ def make_struct(): if isinstance(typ, types.Omitted): const_type = rvt(typ.value) const = context.get_constant_generic( - builder, const_type, typ.value) + builder, const_type, typ.value + ) cargs.append(const) ctypes.append(const_type) else: @@ -653,8 +653,8 @@ def make_struct(): constructor_sig = return_type(*constructor_types[ti]) return_value = context.compile_internal( - builder, constructors[ti], - constructor_sig, constructor_args[ti]) + builder, constructors[ti], constructor_sig, constructor_args[ti] + ) if nfields == 0: pass @@ -706,19 +706,24 @@ def term_sampler(typingctx, state, s, r, t, f1, f2, a1, a2, c): # the ability to figure out the current target context manager # in future releases in order to find a better solution here. from numba.core.registry import cpu_target + if cpu_target.typing_context != typingctx: raise TypingError("typingctx's don't match") - tis = partial(type_inference_stage, - typingctx=typingctx, - targetctx=cpu_target.target_context, - args=ir_args, - return_type=None) + tis = partial( + type_inference_stage, + typingctx=typingctx, + targetctx=cpu_target.target_context, + args=ir_args, + return_type=None, + ) else: - tis = partial(type_inference_stage, - typingctx=typingctx, - args=ir_args, - return_type=None) + tis = partial( + type_inference_stage, + typingctx=typingctx, + args=ir_args, + return_type=None, + ) type_infer = [tis(interp=ir) for ir in sampler_ir] sampler_return_types = [ti.return_type for ti in type_infer] @@ -733,7 +738,8 @@ def term_sampler(typingctx, state, s, r, t, f1, f2, a1, a2, c): f"(1) a single scalar correlation\n" f"(2) a Tuple containing 2 scalar correlations\n" f"(3) a Tuple containing 4 scalar correlations\n" - f"but instead got a {typ}") + f"but instead got a {typ}" + ) if isinstance(typ, types.BaseTuple): if len(typ) not in (2, 4): @@ -749,8 +755,7 @@ def term_sampler(typingctx, state, s, r, t, f1, f2, a1, a2, c): sampler_ret_type = sampler_return_types[0] for typ in sampler_return_types[1:]: - sampler_ret_type = unify_jones_terms(typingctx, - sampler_ret_type, typ) + sampler_ret_type = unify_jones_terms(typingctx, sampler_ret_type, typ) sig = sampler_ret_type(state, s, r, t, f1, f2, a1, a2, c) @@ -771,17 +776,21 @@ def codegen(context, builder, signature, args): sampler_args = [state, s, r, t, f1, f2, a1, a2, c] # Call the sampling function - data = context.compile_internal(builder, # noqa - sampling_fn, - sampler_sig, - sampler_args) + data = context.compile_internal( + builder, # noqa + sampling_fn, + sampler_sig, + sampler_args, + ) # Apply hermitian transform if this is a right term if terms[ti].configuration == "right": - data = context.compile_internal(builder, # noqa - hermitian(ret_type), - ret_type(ret_type), - [data]) + data = context.compile_internal( + builder, # noqa + hermitian(ret_type), + ret_type(ret_type), + [data], + ) jones.append(data) @@ -790,12 +799,13 @@ def codegen(context, builder, signature, args): for jrt, j in zip(sampler_return_types[1:], jones[1:]): jones_mul = term_mul(prev_t, jrt) - jones_mul_typ = unify_jones_terms(context.typing_context, - prev_t, jrt) + jones_mul_typ = unify_jones_terms( + context.typing_context, prev_t, jrt + ) jones_sig = jones_mul_typ(prev_t, jrt) - prev = context.compile_internal(builder, jones_mul, - jones_sig, - [prev, j]) + prev = context.compile_internal( + builder, jones_mul, jones_sig, [prev, j] + ) prev_t = jones_mul_typ diff --git a/africanus/experimental/rime/fused/specification.py b/africanus/experimental/rime/fused/specification.py index 4966a7bf4..c84afea28 100644 --- a/africanus/experimental/rime/fused/specification.py +++ b/africanus/experimental/rime/fused/specification.py @@ -54,12 +54,10 @@ class RimeSpecificationError(ValueError): def parse_stokes(stokes_string): stokes = parse_str_list(stokes_string) - if (not isinstance(stokes, list) or - not all(isinstance(s, str) for s in stokes)): - + if not isinstance(stokes, list) or not all(isinstance(s, str) for s in stokes): raise RimeParseError( - f"Stokes specification must be of the form " - f"[I,Q,U,V]. Got {stokes}.") + f"Stokes specification must be of the form " f"[I,Q,U,V]. Got {stokes}." + ) return [s.upper() for s in stokes] @@ -67,12 +65,11 @@ def parse_stokes(stokes_string): def parse_corrs(corrs_string): corrs = parse_str_list(corrs_string) - if (not isinstance(corrs, list) or - not all(isinstance(c, str) for c in corrs)): - + if not isinstance(corrs, list) or not all(isinstance(c, str) for c in corrs): raise RimeParseError( f"Correlation specification must be of the form " - f"[XX,XY,YX,YY]. Got {corrs}.") + f"[XX,XY,YX,YY]. Got {corrs}." + ) return [c.upper() for c in corrs] @@ -86,7 +83,8 @@ def parse_rime(rime: str): raise RimeParseError( f"RIME must be of the form " f"[Gp, (Kpq, Bpq), Gq]: [I,Q,U,V] -> [XX,XY,YX,YY]. " - f"Got {rime}.") + f"Got {rime}." + ) bits = [s.strip() for s in polarisation_bits.split("->")] @@ -95,7 +93,8 @@ def parse_rime(rime: str): except (ValueError, TypeError): raise RimeParseError( f"Polarisation specification must be of the form " - f"[I,Q,U,V] -> [XX,XY,YX,YY]. Got {polarisation_bits}.") + f"[I,Q,U,V] -> [XX,XY,YX,YY]. Got {polarisation_bits}." + ) stokes_bits, corr_bits = bits @@ -103,12 +102,14 @@ def parse_rime(rime: str): corrs = parse_corrs(corr_bits) equation = parse_str_list(rime_bits) - if (not isinstance(equation, (tuple, list)) or - any(isinstance(e, (tuple, list)) for e in equation) or - not all(isinstance(t, str) for t in equation)): + if ( + not isinstance(equation, (tuple, list)) + or any(isinstance(e, (tuple, list)) for e in equation) + or not all(isinstance(t, str) for t in equation) + ): raise RimeParseError( - f"RIME must be a tuple/list of Terms " - f"(Kpq, Bpq). Got {equation}.") + f"RIME must be a tuple/list of Terms " f"(Kpq, Bpq). Got {equation}." + ) return equation, stokes, corrs @@ -148,8 +149,12 @@ def search_types(module, typ, exclude=("__init__.py", "core.py")): mod = import_module(f"{module.__package__}.{py_file.stem}") for k, v in vars(mod).items(): - if (k.startswith("_") or not isinstance(v, type) or - not issubclass(v, typ) or v in typ): + if ( + k.startswith("_") + or not isinstance(v, type) + or not issubclass(v, typ) + or v in typ + ): continue typs[k] = v @@ -162,7 +167,8 @@ def _decompose_term_str(term_str): if not match: raise RimeSpecificationError( - f"{term_str} does not match {TERM_STRING_REGEX.pattern}") + f"{term_str} does not match {TERM_STRING_REGEX.pattern}" + ) return tuple(match.groups()) @@ -255,7 +261,8 @@ class CustomJones(Term): "K": "Phase", "B": "Brightness", "L": "FeedRotation", - "E": "BeamCubeDDE"} + "E": "BeamCubeDDE", + } def __reduce__(self): return (RimeSpecification, self._saved_args) @@ -264,8 +271,9 @@ def __hash__(self): return hash(self._saved_args) def __eq__(self, rhs): - return (isinstance(rhs, RimeSpecification) and - self._saved_args == rhs._saved_args) + return ( + isinstance(rhs, RimeSpecification) and self._saved_args == rhs._saved_args + ) def __init__(self, specification, terms=None, transformers=None): # Argument Handling @@ -282,7 +290,8 @@ def __init__(self, specification, terms=None, transformers=None): else: raise TypeError( f"terms: {terms} must be a dictionary or " - f"an iterable of (key, value) pairs") + f"an iterable of (key, value) pairs" + ) if not transformers: saved_transforms = transformers @@ -290,8 +299,8 @@ def __init__(self, specification, terms=None, transformers=None): saved_transforms = frozenset(transformers) else: raise TypeError( - f"transformers: {transformers} must be " - f"an iterable of Transformers") + f"transformers: {transformers} must be " f"an iterable of Transformers" + ) # Parse the specification equation, stokes, corrs = parse_rime(specification) @@ -299,7 +308,8 @@ def __init__(self, specification, terms=None, transformers=None): if not set(stokes).issubset(self.VALID_STOKES): raise RimeSpecificationError( f"{stokes} contains invalid stokes parameters. " - f"Only {self.VALID_STOKES} are accepted") + f"Only {self.VALID_STOKES} are accepted" + ) self._saved_args = (specification, saved_terms, saved_transforms) self.equation = equation @@ -319,8 +329,10 @@ def __init__(self, specification, terms=None, transformers=None): raise RimeSpecificationError(f"Unknown term {str(e)}") try: - term_types = tuple(t if isinstance(t, type) and issubclass(t, Term) - else term_types[t] for t in terms_wanted) + term_types = tuple( + t if isinstance(t, type) and issubclass(t, Term) else term_types[t] + for t in terms_wanted + ) except KeyError as e: raise RimeSpecificationError(f"Can't find a type for {str(e)}") @@ -333,11 +345,10 @@ def __init__(self, specification, terms=None, transformers=None): "corrs": corrs, "stokes": stokes, "feed_type": feed_type, - "process_pool": pool + "process_pool": pool, } - hash_elements = list(v for k, v in global_kw.items() - if k != "process_pool") + hash_elements = list(v for k, v in global_kw.items() if k != "process_pool") for cls, cfg in zip(term_types, term_cfgs): if cfg == "pq": @@ -357,14 +368,15 @@ def __init__(self, specification, terms=None, transformers=None): raise RimeSpecificationError( f"{cls}.__init__{init_sig} must take a " f"'configuration' argument and call " - f"super().__init__(configuration)") + f"super().__init__(configuration)" + ) for a, p in list(init_sig.parameters.items())[1:]: - if p.kind not in {p.POSITIONAL_ONLY, - p.POSITIONAL_OR_KEYWORD}: + if p.kind not in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: raise RimeSpecificationError( f"{cls}.__init__{init_sig} may not contain " - f"*args or **kwargs") + f"*args or **kwargs" + ) try: cls_kw[a] = available_kw[a] @@ -372,7 +384,8 @@ def __init__(self, specification, terms=None, transformers=None): raise RimeSpecificationError( f"{cls}.__init__{init_sig} wants argument {a} " f"but it is not available. " - f"Available args: {available_kw}") + f"Available args: {available_kw}" + ) term = cls(**cls_kw) hash_elements.append(".".join((cls.__module__, cls.__name__))) @@ -382,12 +395,10 @@ def __init__(self, specification, terms=None, transformers=None): term_type_set = set(term_types) if Phase not in term_type_set: - raise RimeSpecificationError( - "RIME must at least contain a Phase term") + raise RimeSpecificationError("RIME must at least contain a Phase term") if Brightness not in term_type_set: - raise RimeSpecificationError( - "RIME must at least contain a Brightness term") + raise RimeSpecificationError("RIME must at least contain a Brightness term") transformers = [] @@ -396,11 +407,11 @@ def __init__(self, specification, terms=None, transformers=None): cls_kw = {} for a, p in list(init_sig.parameters.items())[1:]: - if p.kind not in {p.POSITIONAL_ONLY, - p.POSITIONAL_OR_KEYWORD}: + if p.kind not in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD}: raise RimeSpecification( f"{cls}.__init__{init_sig} may not contain " - f"*args or **kwargs") + f"*args or **kwargs" + ) try: cls_kw[a] = available_kw[a] @@ -408,7 +419,8 @@ def __init__(self, specification, terms=None, transformers=None): raise RimeSpecificationError( f"{cls}.__init__{init_sig} wants argument {a} " f"but it is not available. " - f"Available args: {available_kw}") + f"Available args: {available_kw}" + ) transformer = cls(**cls_kw) hash_elements.append(".".join((cls.__module__, cls.__name__))) @@ -436,8 +448,8 @@ def _feed_type(corrs): return "circular" raise RimeSpecificationError( - f"Correlations must be purely linear or circular. " - f"Got {corrs}") + f"Correlations must be purely linear or circular. " f"Got {corrs}" + ) @staticmethod def flatten_eqn(equation): @@ -447,22 +459,24 @@ def flatten_eqn(equation): elif isinstance(equation, str): return equation else: - raise TypeError(f"equation: {equation} must " - f"be a string or sequence") + raise TypeError(f"equation: {equation} must " f"be a string or sequence") def equation_bits(self): return self.flatten_eqn(self.equation) def __repr__(self): - return "".join((self.__class__.__name__, "(\"", str(self), "\")")) + return "".join((self.__class__.__name__, '("', str(self), '")')) def __str__(self): - return "".join(( - self.equation_bits(), - ": ", - "".join(("[", ",".join(self.stokes), "]")), - " -> ", - "".join(("[", ",".join(self.corrs), "]")))) + return "".join( + ( + self.equation_bits(), + ": ", + "".join(("[", ",".join(self.stokes), "]")), + " -> ", + "".join(("[", ",".join(self.corrs), "]")), + ) + ) def parse_str_list(str_list): diff --git a/africanus/experimental/rime/fused/terms/brightness.py b/africanus/experimental/rime/fused/terms/brightness.py index aad514afe..f48ad490f 100644 --- a/africanus/experimental/rime/fused/terms/brightness.py +++ b/africanus/experimental/rime/fused/terms/brightness.py @@ -8,33 +8,34 @@ STOKES_CONVERSION = { - 'RR': {('I', 'V'): lambda i, v: i + v}, - 'RL': {('Q', 'U'): lambda q, u: q + u*1j}, - 'LR': {('Q', 'U'): lambda q, u: q - u*1j}, - 'LL': {('I', 'V'): lambda i, v: i - v}, - - 'XX': {('I', 'Q'): lambda i, q: i + q}, - 'XY': {('U', 'V'): lambda u, v: u + v*1j}, - 'YX': {('U', 'V'): lambda u, v: u - v*1j}, - 'YY': {('I', 'Q'): lambda i, q: i - q}, + "RR": {("I", "V"): lambda i, v: i + v}, + "RL": {("Q", "U"): lambda q, u: q + u * 1j}, + "LR": {("Q", "U"): lambda q, u: q - u * 1j}, + "LL": {("I", "V"): lambda i, v: i - v}, + "XX": {("I", "Q"): lambda i, q: i + q}, + "XY": {("U", "V"): lambda u, v: u + v * 1j}, + "YX": {("U", "V"): lambda u, v: u - v * 1j}, + "YY": {("I", "Q"): lambda i, q: i - q}, } def conversion_factory(stokes_schema, corr_schema): @intrinsic def corr_convert(typingctx, spectral_model, source_index, chan_index): - if (not isinstance(spectral_model, types.Array) or - spectral_model.ndim != 3): - raise errors.TypingError(f"'spectral_model' should be 3D array. " - f"Got {spectral_model}") + if not isinstance(spectral_model, types.Array) or spectral_model.ndim != 3: + raise errors.TypingError( + f"'spectral_model' should be 3D array. " f"Got {spectral_model}" + ) if not isinstance(source_index, types.Integer): - raise errors.TypingError(f"'source_index' should be an integer. " - f"Got {source_index}") + raise errors.TypingError( + f"'source_index' should be an integer. " f"Got {source_index}" + ) if not isinstance(chan_index, types.Integer): - raise errors.TypingError(f"'chan_index' should be an integer. " - f"Got {chan_index}") + raise errors.TypingError( + f"'chan_index' should be an integer. " f"Got {chan_index}" + ) spectral_model_map = {s: i for i, s in enumerate(stokes_schema)} conv_map = {} @@ -43,8 +44,9 @@ def corr_convert(typingctx, spectral_model, source_index, chan_index): try: conv_schema = STOKES_CONVERSION[corr] except KeyError: - raise ValueError(f"No conversion schema " - f"registered for correlation {corr}") + raise ValueError( + f"No conversion schema " f"registered for correlation {corr}" + ) i1 = -1 i2 = -1 @@ -57,16 +59,17 @@ def corr_convert(typingctx, spectral_model, source_index, chan_index): continue if i1 == -1 or i2 == -1: - raise ValueError(f"No conversion found for correlation {corr}." - f" {stokes_schema} are available, but one " - f"of the following combinations " - f"{set(conv_schema.values())} is needed " - f"for conversion to {corr}") + raise ValueError( + f"No conversion found for correlation {corr}." + f" {stokes_schema} are available, but one " + f"of the following combinations " + f"{set(conv_schema.values())} is needed " + f"for conversion to {corr}" + ) conv_map[corr] = (fn, i1, i2) - cplx_type = typingctx.unify_types( - spectral_model.dtype, types.complex64) + cplx_type = typingctx.unify_types(spectral_model.dtype, types.complex64) ret_type = types.Tuple([cplx_type] * len(corr_schema)) sig = ret_type(spectral_model, source_index, chan_index) @@ -75,6 +78,7 @@ def indexer_factory(stokes_index): Extracts a stokes parameter from a 2D stokes array at a variable source_index and constant stokes_index """ + def indexer(stokes_array, source_index, chan_index): return stokes_array[source_index, chan_index, stokes_index] @@ -88,22 +92,19 @@ def codegen(context, builder, signature, args): for c, (conv_fn, i1, i2) in enumerate(conv_map.values()): # Extract the first stokes parameter from the stokes array - sig = array_type.dtype( - array_type, source_index_type, chan_index_type) + sig = array_type.dtype(array_type, source_index_type, chan_index_type) s1 = context.compile_internal( - builder, indexer_factory(i1), - sig, [array, source_index, chan_index]) + builder, indexer_factory(i1), sig, [array, source_index, chan_index] + ) # Extract the second stokes parameter from the stokes array s2 = context.compile_internal( - builder, indexer_factory(i2), - sig, [array, source_index, chan_index]) + builder, indexer_factory(i2), sig, [array, source_index, chan_index] + ) # Compute correlation from stokes parameters - sig = signature.return_type[c]( - array_type.dtype, array_type.dtype) - corr = context.compile_internal( - builder, conv_fn, sig, [s1, s2]) + sig = signature.return_type[c](array_type.dtype, array_type.dtype) + corr = context.compile_internal(builder, conv_fn, sig, [s1, s2]) # Insert result of tuple_getter into the tuple corrs = builder.insert_value(corrs, corr, c) @@ -123,8 +124,7 @@ def __init__(self, configuration, stokes, corrs): self.stokes = stokes self.corrs = corrs - def dask_schema(self, stokes, spi, ref_freq, - chan_freq, spi_base="standard"): + def dask_schema(self, stokes, spi, ref_freq, chan_freq, spi_base="standard"): assert stokes.ndim == 2 assert spi.ndim == 3 assert ref_freq.ndim == 1 @@ -136,37 +136,39 @@ def dask_schema(self, stokes, spi, ref_freq, "spi": ("source", "spi", "stokes"), "ref_freq": ("source",), "chan_freq": ("chan",), - "spi_base": None + "spi_base": None, } STANDARD = 0 LOG = 1 LOG10 = 2 - def init_fields(self, typingctx, stokes, spi, ref_freq, - chan_freq, spi_base="standard"): + def init_fields( + self, typingctx, stokes, spi, ref_freq, chan_freq, spi_base="standard" + ): expected_nstokes = len(self.stokes) fields = [("spectral_model", stokes.dtype[:, :, :])] - def brightness(stokes, spi, ref_freq, - chan_freq, spi_base="standard"): + def brightness(stokes, spi, ref_freq, chan_freq, spi_base="standard"): nsrc, nstokes = stokes.shape - nchan, = chan_freq.shape + (nchan,) = chan_freq.shape nspi = spi.shape[1] if nstokes != expected_nstokes: - raise ValueError("corr_schema stokes don't match " - "provided number of stokes") - - if ((spi_base.startswith("[") and spi_base.endswith("]")) or - (spi_base.startswith("(") and spi_base.endswith(")"))): + raise ValueError( + "corr_schema stokes don't match " "provided number of stokes" + ) - list_spi_base = [s.strip().lower() - for s in spi_base.split(",")] + if (spi_base.startswith("[") and spi_base.endswith("]")) or ( + spi_base.startswith("(") and spi_base.endswith(")") + ): + list_spi_base = [s.strip().lower() for s in spi_base.split(",")] if len(list_spi_base) != nstokes: - raise ValueError("List of spectral bases must equal " - "number of stokes parameters") + raise ValueError( + "List of spectral bases must equal " + "number of stokes parameters" + ) else: list_spi_base = [spi_base.lower()] * nstokes @@ -195,12 +197,10 @@ def brightness(stokes, spi, ref_freq, spec_model = 0 for si in range(0, nspi): - term = spi[s, si, p] * freq_ratio**(si + 1) + term = spi[s, si, p] * freq_ratio ** (si + 1) spec_model += term - spectral_model[s, f, p] = ( - stokes[s, p] * np.exp(spec_model) - ) + spectral_model[s, f, p] = stokes[s, p] * np.exp(spec_model) elif b == "log10": for s in range(nsrc): rf = ref_freq[s] @@ -210,15 +210,12 @@ def brightness(stokes, spi, ref_freq, spec_model = 0 for si in range(0, nspi): - term = spi[s, si, p] * freq_ratio**(si + 1) + term = spi[s, si, p] * freq_ratio ** (si + 1) spec_model += term - spectral_model[s, f, p] = ( - stokes[s, p] * 10**spec_model - ) + spectral_model[s, f, p] = stokes[s, p] * 10**spec_model else: - raise ValueError( - "spi_base not in (\"standard\", \"log\", \"log10\")") + raise ValueError('spi_base not in ("standard", "log", "log10")') return spectral_model diff --git a/africanus/experimental/rime/fused/terms/core.py b/africanus/experimental/rime/fused/terms/core.py index 845e4adcf..a4c6bae17 100644 --- a/africanus/experimental/rime/fused/terms/core.py +++ b/africanus/experimental/rime/fused/terms/core.py @@ -10,7 +10,7 @@ @structref.register class StateStructRef(types.StructRef): def preprocess_fields(self, fields): - """ Disallow literal types in field definitions """ + """Disallow literal types in field definitions""" return tuple((n, types.unliteral(t)) for n, t in fields) @@ -18,8 +18,9 @@ def sigcheck_factory(expected_sig): def check_constructor_signature(self, fn): sig = inspect.signature(fn) if sig != expected_sig: - raise ValueError(f"{fn.__name__}{sig} should be " - f"{fn.__name__}{expected_sig}") + raise ValueError( + f"{fn.__name__}{sig} should be " f"{fn.__name__}{expected_sig}" + ) return check_constructor_signature @@ -65,42 +66,50 @@ def _expand_namespace(cls, name, namespace): field_params = list(init_fields_sig.parameters.values()) if len(init_fields_sig.parameters) < 2: - raise InvalidSignature(f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)") + raise InvalidSignature( + f"{name}.init_fields{init_fields_sig} " + f"should be " + f"{name}.init_fields(self, typingctx, ...)" + ) it = iter(init_fields_sig.parameters.items()) first, second = next(it), next(it) if first[0] != "self" or second[0] != "typingctx": - raise InvalidSignature(f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)") + raise InvalidSignature( + f"{name}.init_fields{init_fields_sig} " + f"should be " + f"{name}.init_fields(self, typingctx, ...)" + ) for n, p in it: if p.kind == p.VAR_POSITIONAL: - raise InvalidSignature(f"*{n} in " - f"{name}.init_fields{init_fields_sig} " - f"is not supported") + raise InvalidSignature( + f"*{n} in " + f"{name}.init_fields{init_fields_sig} " + f"is not supported" + ) if p.kind == p.VAR_KEYWORD: - raise InvalidSignature(f"**{n} in " - f"{name}.init_fields{init_fields_sig} " - f"is not supported") + raise InvalidSignature( + f"**{n} in " + f"{name}.init_fields{init_fields_sig} " + f"is not supported" + ) dask_schema_sig = inspect.signature(methods["dask_schema"]) expected_dask_params = field_params[0:1] + field_params[2:] - expected_dask_sig = init_fields_sig.replace( - parameters=expected_dask_params) + expected_dask_sig = init_fields_sig.replace(parameters=expected_dask_params) if dask_schema_sig != expected_dask_sig: - raise InvalidSignature(f"{name}.dask_schema{dask_schema_sig} " - f"should be " - f"{name}.dask_schema{expected_dask_sig}") + raise InvalidSignature( + f"{name}.dask_schema{dask_schema_sig} " + f"should be " + f"{name}.dask_schema{expected_dask_sig}" + ) Parameter = inspect.Parameter - expected_init_sig = init_fields_sig.replace( - parameters=field_params[2:]) + expected_init_sig = init_fields_sig.replace(parameters=field_params[2:]) validator = sigcheck_factory(expected_init_sig) sampler_sig = inspect.signature(methods["sampler"]) @@ -108,19 +117,27 @@ def _expand_namespace(cls, name, namespace): expected_sampler_sig = inspect.Signature(parameters=params) if sampler_sig != expected_sampler_sig: - raise InvalidSignature(f"{name}.sampler{sampler_sig} " - f"should be " - f"{name}.sampler{expected_sampler_sig}") - - args = tuple(n for n, p in init_fields_sig.parameters.items() - if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD} - and n not in {"self", "typingctx"} - and p.default is p.empty) - - kw = [(n, p.default) for n, p in init_fields_sig.parameters.items() - if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} - and n not in {"self", "typingctx"} - and p.default is not p.empty] + raise InvalidSignature( + f"{name}.sampler{sampler_sig} " + f"should be " + f"{name}.sampler{expected_sampler_sig}" + ) + + args = tuple( + n + for n, p in init_fields_sig.parameters.items() + if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD} + and n not in {"self", "typingctx"} + and p.default is p.empty + ) + + kw = [ + (n, p.default) + for n, p in init_fields_sig.parameters.items() + if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + and n not in {"self", "typingctx"} + and p.default is not p.empty + ] namespace = namespace.copy() namespace["ARGS"] = args @@ -132,7 +149,7 @@ def _expand_namespace(cls, name, namespace): @classmethod def term_in_bases(cls, bases): - """ Is `Term` in bases? """ + """Is `Term` in bases?""" for base in bases: if base is Term or cls.term_in_bases(base.__bases__): return True @@ -157,8 +174,7 @@ def configuration(self): return self._configuration def __eq__(self, rhs): - return (isinstance(rhs, Term) and - self._configuration == rhs._configuration) + return isinstance(rhs, Term) and self._configuration == rhs._configuration def __repr__(self): return self.__class__.__name__ @@ -168,7 +184,7 @@ def __str__(self): @classmethod def validate_sampler(cls, sampler): - """ Validate the sampler implementation """ + """Validate the sampler implementation""" sampler_sig = inspect.signature(sampler) Parameter = inspect.Parameter P = partial(Parameter, kind=Parameter.POSITIONAL_OR_KEYWORD) @@ -176,6 +192,8 @@ def validate_sampler(cls, sampler): expected_sig = inspect.Signature(params) if sampler_sig != expected_sig: - raise InvalidSignature(f"{sampler.__name__}{sampler_sig}" - f"should be " - f"{sampler.__name__}{expected_sig}") + raise InvalidSignature( + f"{sampler.__name__}{sampler_sig}" + f"should be " + f"{sampler.__name__}{expected_sig}" + ) diff --git a/africanus/experimental/rime/fused/terms/cube_dde.py b/africanus/experimental/rime/fused/terms/cube_dde.py index ecb4ee3b9..870270b1f 100644 --- a/africanus/experimental/rime/fused/terms/cube_dde.py +++ b/africanus/experimental/rime/fused/terms/cube_dde.py @@ -1,4 +1,3 @@ - from collections import namedtuple from numba.core import cgutils, types @@ -12,7 +11,7 @@ def zero_vis_factory(ncorr): @intrinsic def zero_vis(typingctx, value): - sig = types.Tuple([value]*ncorr)(value) + sig = types.Tuple([value] * ncorr)(value) def codegen(context, builder, signature, args): llvm_ret_type = context.get_value_type(signature.return_type) @@ -28,9 +27,9 @@ def codegen(context, builder, signature, args): return zero_vis -BeamInfo = namedtuple("BeamInfo", [ - "lscale", "mscale", - "lmaxi", "mmaxi", "lmaxf", "mmaxf"]) +BeamInfo = namedtuple( + "BeamInfo", ["lscale", "mscale", "lmaxi", "mmaxi", "lmaxf", "mmaxf"] +) class BeamCubeDDE(Term): @@ -38,17 +37,26 @@ class BeamCubeDDE(Term): def __init__(self, configuration, corrs): if configuration not in {"left", "right"}: - raise ValueError(f"BeamCubeDDE configuration must be" - f"either 'left' or 'right'. " - f"Got {configuration}") + raise ValueError( + f"BeamCubeDDE configuration must be" + f"either 'left' or 'right'. " + f"Got {configuration}" + ) super().__init__(configuration) self.corrs = corrs - def dask_schema(self, beam, beam_lm_extents, beam_freq_map, - lm, beam_parangle, chan_freq, - beam_point_errors=None, - beam_antenna_scaling=None): + def dask_schema( + self, + beam, + beam_lm_extents, + beam_freq_map, + lm, + beam_parangle, + chan_freq, + beam_point_errors=None, + beam_antenna_scaling=None, + ): return { "beam": ("beam_lw", "beam_mh", "beam_nud", "corr"), "beam_lm_extents": ("lm_ext", "lm_ext_comp"), @@ -57,28 +65,40 @@ def dask_schema(self, beam, beam_lm_extents, beam_freq_map, "chan_freq": ("chan",), } - def init_fields(self, typingctx, - beam, beam_lm_extents, beam_freq_map, - lm, beam_parangle, chan_freq, - beam_point_errors=None, - beam_antenna_scaling=None): - + def init_fields( + self, + typingctx, + beam, + beam_lm_extents, + beam_freq_map, + lm, + beam_parangle, + chan_freq, + beam_point_errors=None, + beam_antenna_scaling=None, + ): ncorr = len(self.corrs) ex_dtype = beam_lm_extents.dtype - beam_info_types = [ex_dtype]*2 + [types.int64]*2 + [types.float64]*2 + beam_info_types = [ex_dtype] * 2 + [types.int64] * 2 + [types.float64] * 2 beam_info_type = types.NamedTuple(beam_info_types, BeamInfo) - fields = [("beam_freq_data", chan_freq.copy(ndim=2)), - ("beam_info", beam_info_type)] - - def beam(beam, beam_lm_extents, beam_freq_map, - lm, beam_parangle, chan_freq, - beam_point_errors=None, - beam_antenna_scaling=None): - + fields = [ + ("beam_freq_data", chan_freq.copy(ndim=2)), + ("beam_info", beam_info_type), + ] + + def beam( + beam, + beam_lm_extents, + beam_freq_map, + lm, + beam_parangle, + chan_freq, + beam_point_errors=None, + beam_antenna_scaling=None, + ): if beam.shape[3] != ncorr: - raise ValueError( - "Beam correlations don't match specification corrs") + raise ValueError("Beam correlations don't match specification corrs") freq_data = np.empty((chan_freq.shape[0], 3), chan_freq.dtype) beam_nud = beam_freq_map.shape[0] @@ -177,8 +197,8 @@ def cube_dde(state, s, r, t, f1, f2, a1, a2, c): tm = sm # Rotate lm coordinate angle - vl = tl*cos_pa - tm*sin_pa - vm = tl*sin_pa + tm*cos_pa + vl = tl * cos_pa - tm * sin_pa + vm = tl * sin_pa + tm * cos_pa # Scale by antenna scaling # vl *= antenna_scaling[a, f, 0] @@ -189,8 +209,8 @@ def cube_dde(state, s, r, t, f1, f2, a1, a2, c): lower_m, upper_m = state.beam_lm_extents[1] # Shift into the cube coordinate system - vl = state.beam_info.lscale*(vl - lower_l) - vm = state.beam_info.mscale*(vm - lower_m) + vl = state.beam_info.lscale * (vl - lower_l) + vm = state.beam_info.mscale * (vm - lower_m) # Clamp the coordinates to the edges of the cube vl = max(0.0, min(vl, state.beam_info.lmaxf)) @@ -212,82 +232,82 @@ def cube_dde(state, s, r, t, f1, f2, a1, a2, c): absc_sum = zero_vis(state.beam.real.dtype.type(0)) # Lower cube - weight = (1.0 - ld)*(1.0 - md)*nud + weight = (1.0 - ld) * (1.0 - md) * nud for co in range(ncorr): value = state.beam[gl0, gm0, gc0, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - weight = ld*(1.0 - md)*nud + weight = ld * (1.0 - md) * nud for co in range(ncorr): value = state.beam[gl1, gm0, gc0, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - weight = (1.0 - ld)*md*nud + weight = (1.0 - ld) * md * nud for co in range(ncorr): value = state.beam[gl0, gm1, gc0, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - weight = ld*md*nud + weight = ld * md * nud for co in range(ncorr): value = state.beam[gl1, gm1, gc0, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) # Upper cube - weight = (1.0 - ld)*(1.0 - md)*inv_nud + weight = (1.0 - ld) * (1.0 - md) * inv_nud for co in range(ncorr): value = state.beam[gl0, gm0, gc1, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - weight = ld*(1.0 - md)*inv_nud + weight = ld * (1.0 - md) * inv_nud for co in range(ncorr): value = state.beam[gl1, gm0, gc1, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - weight = (1.0 - ld)*md*inv_nud + weight = (1.0 - ld) * md * inv_nud for co in range(ncorr): value = state.beam[gl0, gm1, gc1, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - weight = ld*md*inv_nud + weight = ld * md * inv_nud for co in range(ncorr): value = state.beam[gl1, gm1, gc1, co] - absc_sum = tuple_setitem(absc_sum, co, - weight*np.abs(value) + absc_sum[co]) - corr_sum = tuple_setitem(corr_sum, co, - weight*value + corr_sum[co]) + absc_sum = tuple_setitem( + absc_sum, co, weight * np.abs(value) + absc_sum[co] + ) + corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) for co in range(ncorr): div = np.abs(corr_sum[co]) - value = corr_sum[co]*absc_sum[co] + value = corr_sum[co] * absc_sum[co] if div != 0.0: value /= div diff --git a/africanus/experimental/rime/fused/terms/feed_rotation.py b/africanus/experimental/rime/fused/terms/feed_rotation.py index f5557ece3..9b18ca0ae 100644 --- a/africanus/experimental/rime/fused/terms/feed_rotation.py +++ b/africanus/experimental/rime/fused/terms/feed_rotation.py @@ -6,19 +6,25 @@ class FeedRotation(Term): def __init__(self, configuration, feed_type, corrs): if configuration not in {"left", "right"}: - raise ValueError(f"FeedRotation configuration must " - f"be either 'left' or 'right'. " - f"Got {configuration}") + raise ValueError( + f"FeedRotation configuration must " + f"be either 'left' or 'right'. " + f"Got {configuration}" + ) if feed_type not in {"linear", "circular"}: - raise ValueError(f"FeedRotation feed_type must be " - f"either 'linear' or 'circular'. " - f"Got {feed_type}") + raise ValueError( + f"FeedRotation feed_type must be " + f"either 'linear' or 'circular'. " + f"Got {feed_type}" + ) if len(corrs) != 4: - raise ValueError(f"Four correlations required for " - f"feed rotation but {corrs} were " - f"specified") + raise ValueError( + f"Four correlations required for " + f"feed rotation but {corrs} were " + f"specified" + ) super().__init__(configuration) self.feed_type = feed_type @@ -50,9 +56,10 @@ def feed_rotation(state, s, r, t, f1, f2, a1, a2, c): else: # e^{ix} = cos(x) + i.sin(x) return ( - 0.5*((cos_a + cos_b) - (sin_a + sin_b)*1j), - 0.5*((cos_a - cos_b) + (sin_a - sin_b)*1j), - 0.5*((cos_a - cos_b) - (sin_a - sin_b)*1j), - 0.5*((cos_a + cos_b) + (sin_a + sin_b)*1j)) + 0.5 * ((cos_a + cos_b) - (sin_a + sin_b) * 1j), + 0.5 * ((cos_a - cos_b) + (sin_a - sin_b) * 1j), + 0.5 * ((cos_a - cos_b) - (sin_a - sin_b) * 1j), + 0.5 * ((cos_a + cos_b) + (sin_a + sin_b) * 1j), + ) return feed_rotation diff --git a/africanus/experimental/rime/fused/terms/gaussian.py b/africanus/experimental/rime/fused/terms/gaussian.py index 58e467bc2..8222a731a 100644 --- a/africanus/experimental/rime/fused/terms/gaussian.py +++ b/africanus/experimental/rime/fused/terms/gaussian.py @@ -12,17 +12,15 @@ def dask_schema(self, uvw, chan_freq, gauss_shape): assert chan_freq.ndim == 1 assert gauss_shape.ndim == 2 - return {"uvw": ("row", "uvw"), - "chan_freq": ("chan",), - "gauss_shape": ("source", "gauss_shape_params")} + return { + "uvw": ("row", "uvw"), + "chan_freq": ("chan",), + "gauss_shape": ("source", "gauss_shape_params"), + } def init_fields(self, typingctx, uvw, chan_freq, gauss_shape): - guv_dtype = typingctx.unify_types( - uvw.dtype, - chan_freq.dtype, - gauss_shape.dtype) - fields = [("gauss_uv", guv_dtype[:, :, :]), - ("scaled_freq", chan_freq)] + guv_dtype = typingctx.unify_types(uvw.dtype, chan_freq.dtype, gauss_shape.dtype) + fields = [("gauss_uv", guv_dtype[:, :, :]), ("scaled_freq", chan_freq)] fwhm = 2.0 * np.sqrt(2.0 * np.log(2.0)) fwhminv = 1.0 / fwhm @@ -33,7 +31,7 @@ def gaussian_init(uvw, chan_freq, gauss_shape): nrow, _ = uvw.shape gauss_uv = np.empty((nsrc, nrow, 2), dtype=guv_dtype) - scaled_freq = chan_freq*gauss_scale + scaled_freq = chan_freq * gauss_scale for s in range(nsrc): emaj, emin, angle = gauss_shape[s] @@ -47,8 +45,8 @@ def gaussian_init(uvw, chan_freq, gauss_shape): u = uvw[r, 0] v = uvw[r, 1] - gauss_uv[s, r, 0] = (u*em - v*el)*er - gauss_uv[s, r, 1] = u*el + v*em + gauss_uv[s, r, 0] = (u * em - v * el) * er + gauss_uv[s, r, 1] = u * el + v * em return gauss_uv, scaled_freq @@ -58,6 +56,6 @@ def sampler(self): def gaussian_sample(state, s, r, t, f1, f2, a1, a2, c): fu1 = state.gauss_uv[s, r, 0] * state.scaled_freq[c] fv1 = state.gauss_uv[s, r, 1] * state.scaled_freq[c] - return np.exp(-(fu1*fu1 + fv1*fv1)) + return np.exp(-(fu1 * fu1 + fv1 * fv1)) return gaussian_sample diff --git a/africanus/experimental/rime/fused/terms/phase.py b/africanus/experimental/rime/fused/terms/phase.py index e6198d95e..a0757aff1 100644 --- a/africanus/experimental/rime/fused/terms/phase.py +++ b/africanus/experimental/rime/fused/terms/phase.py @@ -13,10 +13,12 @@ def dask_schema(self, lm, uvw, chan_freq, convention="fourier"): assert chan_freq.ndim == 1 assert isinstance(convention, str) - return {"lm": ("source", "lm"), - "uvw": ("row", "uvw"), - "chan_freq": ("chan",), - "convention": None} + return { + "lm": ("source", "lm"), + "uvw": ("row", "uvw"), + "chan_freq": ("chan",), + "convention": None, + } def init_fields(self, typingctx, lm, uvw, chan_freq, convention="fourier"): phase_dt = typingctx.unify_types(lm.dtype, uvw.dtype, chan_freq.dtype) @@ -25,7 +27,7 @@ def init_fields(self, typingctx, lm, uvw, chan_freq, convention="fourier"): def phase(lm, uvw, chan_freq, convention="fourier"): nsrc, _ = lm.shape nrow, _ = uvw.shape - nchan, = chan_freq.shape + (nchan,) = chan_freq.shape phase_dot = np.empty((nsrc, nrow), dtype=phase_dt) @@ -33,11 +35,11 @@ def phase(lm, uvw, chan_freq, convention="fourier"): one = lm.dtype.type(1.0) if convention == "fourier": - C = phase_dt(-2.0*np.pi/lightspeed) + C = phase_dt(-2.0 * np.pi / lightspeed) elif convention == "casa": - C = phase_dt(2.0*np.pi/lightspeed) + C = phase_dt(2.0 * np.pi / lightspeed) else: - raise ValueError("convention not in (\"fourier\", \"casa\")") + raise ValueError('convention not in ("fourier", "casa")') for s in range(nsrc): l = lm[s, 0] # noqa @@ -50,7 +52,7 @@ def phase(lm, uvw, chan_freq, convention="fourier"): v = uvw[r, 1] w = uvw[r, 2] - phase_dot[s, r] = C*(l*u + m*v + n*w) + phase_dot[s, r] = C * (l * u + m * v + n * w) return phase_dot @@ -59,6 +61,6 @@ def phase(lm, uvw, chan_freq, convention="fourier"): def sampler(self): def phase_sample(state, s, r, t, f1, f2, a1, a2, c): p = state.phase_dot[s, r] * state.chan_freq[c] - return np.cos(p) + np.sin(p)*1j + return np.cos(p) + np.sin(p) * 1j return phase_sample diff --git a/africanus/experimental/rime/fused/tests/test_rime.py b/africanus/experimental/rime/fused/tests/test_rime.py index 4087e1f51..b3f1358d1 100644 --- a/africanus/experimental/rime/fused/tests/test_rime.py +++ b/africanus/experimental/rime/fused/tests/test_rime.py @@ -11,18 +11,23 @@ from africanus.model.shape import gaussian from africanus.experimental.rime.fused.specification import ( - RimeSpecification, parse_rime) + RimeSpecification, + parse_rime, +) from africanus.experimental.rime.fused.core import rime from africanus.experimental.rime.fused.dask import rime as dask_rime @pytest.mark.skip -@pytest.mark.parametrize("rime_spec", [ - # "G_{p}[E_{stpf}L_{tpf}K_{stpqf}B_{spq}L_{tqf}E_{q}]G_{q}", - # "Gp[EpLpKpqBpqLqEq]sGq", - "[Gp, (Ep, Lp, Kpq, Bpq, Lq, Eq), Gq]: [I, Q, U, V] -> [XX, XY, YX, YY]", - # "[Gp x (Ep x Lp x Kpq x Bpq x Lq x Eq) x Gq] -> [XX, XY, YX, YY]", -]) +@pytest.mark.parametrize( + "rime_spec", + [ + # "G_{p}[E_{stpf}L_{tpf}K_{stpqf}B_{spq}L_{tqf}E_{q}]G_{q}", + # "Gp[EpLpKpqBpqLqEq]sGq", + "[Gp, (Ep, Lp, Kpq, Bpq, Lq, Eq), Gq]: [I, Q, U, V] -> [XX, XY, YX, YY]", + # "[Gp x (Ep x Lp x Kpq x Bpq x Lq x Eq) x Gq] -> [XX, XY, YX, YY]", + ], +) def test_rime_specification(rime_spec): # custom_mapping = {"Kpq": MyCustomPhase} print(parse_rime(rime_spec)) @@ -61,9 +66,7 @@ def unity_vis_dataset(request, stokes_schema, corr_schema): time = np.linspace(start, end, 2) antenna1 = np.array([0, 0]) antenna2 = np.array([1, 2]) - antenna_position = np.array([[1, 1, 1], - [1, 1, 1], - [1, 1, 1]]) + antenna_position = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) antenna_position = np.random.random((3, 3)) feed1 = feed2 = np.array([0, 0]) radec = np.zeros((1, 2)) @@ -92,23 +95,20 @@ def unity_vis_dataset(request, stokes_schema, corr_schema): @pytest.mark.parametrize("stokes_schema", [["I", "Q", "U", "V"]], ids=str) -@pytest.mark.parametrize("corr_schema", [ - ["XX", "XY", "YX", "YY"], - ["RR", "RL", "LR", "LL"] -], ids=str) -def test_fused_rime_feed_rotation(unity_vis_dataset, - stokes_schema, - corr_schema): - stokes_to_corr = "".join(("[", ",".join(stokes_schema), - "] -> [", - ",".join(corr_schema), "]")) +@pytest.mark.parametrize( + "corr_schema", [["XX", "XY", "YX", "YY"], ["RR", "RL", "LR", "LL"]], ids=str +) +def test_fused_rime_feed_rotation(unity_vis_dataset, stokes_schema, corr_schema): + stokes_to_corr = "".join( + ("[", ",".join(stokes_schema), "] -> [", ",".join(corr_schema), "]") + ) ds = unity_vis_dataset.copy() utime, time_inv = np.unique(ds["time"], return_inverse=True) from africanus.rime.parangles_casa import casa_parallactic_angles from africanus.rime.feeds import feed_rotation - pa = casa_parallactic_angles( - utime, ds["antenna_position"], ds["phase_dir"]) + + pa = casa_parallactic_angles(utime, ds["antenna_position"], ds["phase_dir"]) def pa_feed_rotation(left): row_pa = pa[time_inv, ds["antenna1"] if left else ds["antenna2"]] @@ -125,68 +125,81 @@ def pa_feed_rotation(left): FL, FR = (pa_feed_rotation(v) for v in (True, False)) lm = radec_to_lm(ds["radec"], ds["phase_dir"]) P = phase_delay(lm, ds["uvw"], ds["chan_freq"], convention="casa") - SM = spectral_model(ds["stokes"], ds["spi"], - ds["ref_freq"], ds["chan_freq"], base="std") + SM = spectral_model( + ds["stokes"], ds["spi"], ds["ref_freq"], ds["chan_freq"], base="std" + ) B = convert(SM, stokes_schema, corr_schema) B = B.reshape(B.shape[:2] + (2, 2)) result = np.einsum("rij,srf,sfjk,rkl->srfil", FL, P, B, FR).sum(axis=0) expected = result.reshape(result.shape[:2] + (4,)) - out = rime(f"(Lp, Kpq, Bpq, Lq): {stokes_to_corr}", - ds, convention="casa", spi_base="standard") + out = rime( + f"(Lp, Kpq, Bpq, Lq): {stokes_to_corr}", + ds, + convention="casa", + spi_base="standard", + ) assert_array_almost_equal(expected, out) @pytest.mark.parametrize("stokes_schema", [["I", "Q", "U", "V"]], ids=str) -@pytest.mark.parametrize("corr_schema", [ - ["XX", "XY", "YX", "YY"], - ["RR", "RL", "LR", "LL"] -], ids=str) +@pytest.mark.parametrize( + "corr_schema", [["XX", "XY", "YX", "YY"], ["RR", "RL", "LR", "LL"]], ids=str +) def test_fused_rime_cube_dde(unity_vis_dataset, stokes_schema, corr_schema): - stokes_to_corr = "".join(("[", ",".join(stokes_schema), - "] -> [", - ",".join(corr_schema), "]")) + stokes_to_corr = "".join( + ("[", ",".join(stokes_schema), "] -> [", ",".join(corr_schema), "]") + ) lw = mh = nud = 10 - chan_freq = np.array([.856e9, 2*.856e9]) + chan_freq = np.array([0.856e9, 2 * 0.856e9]) beam = np.random.random(size=(lw, mh, nud, len(corr_schema))) beam_lm_extents = np.array([[-1.0, 1.0], [-1.0, 1.0]]) - beam_freq_map = np.random.uniform( - low=chan_freq[0], high=chan_freq[-1], size=nud) + beam_freq_map = np.random.uniform(low=chan_freq[0], high=chan_freq[-1], size=nud) beam_freq_map.sort() - ds = {**unity_vis_dataset, - "chan_freq": chan_freq, - "beam": beam, - "beam_lm_extents": beam_lm_extents, - "beam_freq_map": beam_freq_map} + ds = { + **unity_vis_dataset, + "chan_freq": chan_freq, + "beam": beam, + "beam_lm_extents": beam_lm_extents, + "beam_freq_map": beam_freq_map, + } utime, time_inv = np.unique(ds["time"], return_inverse=True) - nchan, = chan_freq.shape + (nchan,) = chan_freq.shape lm = radec_to_lm(ds["radec"], ds["phase_dir"]) P = phase_delay(lm, ds["uvw"], ds["chan_freq"], convention="casa") - SM = spectral_model(ds["stokes"], ds["spi"], - ds["ref_freq"], ds["chan_freq"], base="std") + SM = spectral_model( + ds["stokes"], ds["spi"], ds["ref_freq"], ds["chan_freq"], base="std" + ) B = convert(SM, stokes_schema, corr_schema) B = B.reshape(B.shape[:2] + (2, 2)) from africanus.rime.parangles_casa import casa_parallactic_angles from africanus.rime.fast_beam_cubes import beam_cube_dde - beam_pa = casa_parallactic_angles( - utime, ds["antenna_position"], ds["phase_dir"]) + beam_pa = casa_parallactic_angles(utime, ds["antenna_position"], ds["phase_dir"]) def dde(left): ntime, na = beam_pa.shape point_errors = np.zeros((ntime, na, nchan, 2)) ant_scale = np.zeros((na, nchan, 2)) - ddes = beam_cube_dde(beam, beam_lm_extents, beam_freq_map, lm, beam_pa, - point_errors, ant_scale, chan_freq) + ddes = beam_cube_dde( + beam, + beam_lm_extents, + beam_freq_map, + lm, + beam_pa, + point_errors, + ant_scale, + chan_freq, + ) ant_inv = ds["antenna1"] if left else ds["antenna2"] row_ddes = ddes[:, time_inv, ant_inv, :, :] row_ddes = row_ddes.reshape(row_ddes.shape[:3] + (2, 2)) @@ -196,8 +209,12 @@ def dde(left): result = np.einsum("srfij,srf,sfjk,srfkl->srfil", EL, P, B, ER).sum(axis=0) expected = result.reshape(result.shape[:2] + (4,)) - out = rime(f"(Ep, Kpq, Bpq, Eq): {stokes_to_corr}", - ds, convention="casa", spi_base="standard") + out = rime( + f"(Ep, Kpq, Bpq, Eq): {stokes_to_corr}", + ds, + convention="casa", + spi_base="standard", + ) assert_array_almost_equal(expected, out) @@ -206,32 +223,30 @@ def dde(left): @pytest.mark.parametrize("stokes_schema", [["I", "Q", "U", "V"]], ids=str) @pytest.mark.parametrize("corr_schema", [["XX", "XY", "YX", "YY"]], ids=str) def test_fused_rime(chunks, stokes_schema, corr_schema): - chunks = {**chunks, - "stokes": (len(stokes_schema),), - "corr": (len(corr_schema),)} + chunks = {**chunks, "stokes": (len(stokes_schema),), "corr": (len(corr_schema),)} nsrc = sum(chunks["source"]) nrow = sum(chunks["row"]) nspi = sum(chunks["spi"]) nchan = sum(chunks["chan"]) nstokes = sum(chunks["stokes"]) - stokes_to_corr = "".join(("[", ",".join(stokes_schema), - "] -> [", - ",".join(corr_schema), "]")) + stokes_to_corr = "".join( + ("[", ",".join(stokes_schema), "] -> [", ",".join(corr_schema), "]") + ) time = np.linspace(0.1, 1.0, nrow) antenna1 = np.zeros(nrow, dtype=np.int32) antenna2 = np.arange(nrow, dtype=np.int32) feed1 = feed2 = antenna1 - radec = np.random.random(size=(nsrc, 2))*1e-5 - phase_dir = np.random.random(2)*1e-5 + radec = np.random.random(size=(nsrc, 2)) * 1e-5 + phase_dir = np.random.random(2) * 1e-5 uvw = np.random.random(size=(nrow, 3)) - chan_freq = np.linspace(.856e9, 2*.856e9, nchan) + chan_freq = np.linspace(0.856e9, 2 * 0.856e9, nchan) # Make perfect stokes paramters i.e. I**2 = Q**2 + U**2 + V**2. stokes = np.random.normal(size=(nsrc, nstokes)) stokes[:, 0] = np.sqrt((stokes[:, 1:] ** 2).sum(axis=-1)) spi = np.random.random(size=(nsrc, nspi, nstokes)) - ref_freq = np.random.uniform(low=.5*.856e9, high=4*.856e9, size=nsrc) + ref_freq = np.random.uniform(low=0.5 * 0.856e9, high=4 * 0.856e9, size=nsrc) lm = radec_to_lm(radec, phase_dir) dataset = { @@ -249,62 +264,62 @@ def test_fused_rime(chunks, stokes_schema, corr_schema): "ref_freq": ref_freq, } - out = rime(f"(Kpq, Bpq): {stokes_to_corr}", - dataset, convention="casa", spi_base="standard") + out = rime( + f"(Kpq, Bpq): {stokes_to_corr}", dataset, convention="casa", spi_base="standard" + ) P = phase_delay(lm, uvw, chan_freq, convention="casa") SM = spectral_model(stokes, spi, ref_freq, chan_freq, base="std") B = convert(SM, stokes_schema, corr_schema) - expected = (P[:, :, :, None]*B[:, None, :, :]).sum(axis=0) + expected = (P[:, :, :, None] * B[:, None, :, :]).sum(axis=0) assert_array_almost_equal(expected, out) - assert np.count_nonzero(out) > .8 * out.size + assert np.count_nonzero(out) > 0.8 * out.size - out = rime(f"(Kpq, Bpq): {stokes_to_corr}", - dataset, convention="fourier", spi_base="log") + out = rime( + f"(Kpq, Bpq): {stokes_to_corr}", dataset, convention="fourier", spi_base="log" + ) P = phase_delay(lm, uvw, chan_freq, convention="fourier") SM = spectral_model(stokes, spi, ref_freq, chan_freq, base="log") B = convert(SM, stokes_schema, corr_schema) - expected = (P[:, :, :, None]*B[:, None, :, :]).sum(axis=0) + expected = (P[:, :, :, None] * B[:, None, :, :]).sum(axis=0) assert_array_almost_equal(expected, out) - assert np.count_nonzero(out) > .8 * out.size + assert np.count_nonzero(out) > 0.8 * out.size out = rime(f"(Kpq, Bpq): {stokes_to_corr}", dataset, spi_base="log10") P = phase_delay(lm, uvw, chan_freq, convention="fourier") SM = spectral_model(stokes, spi, ref_freq, chan_freq, base="log10") B = convert(SM, stokes_schema, corr_schema) - expected = (P[:, :, :, None]*B[:, None, :, :]).sum(axis=0) + expected = (P[:, :, :, None] * B[:, None, :, :]).sum(axis=0) assert_array_almost_equal(expected, out) - assert np.count_nonzero(out) > .8 * out.size + assert np.count_nonzero(out) > 0.8 * out.size gauss_shape = np.random.random((nsrc, 3)) gauss_shape[:, :2] *= 1e-5 - rime_spec = RimeSpecification(f"(Cpq, Kpq, Bpq): {stokes_to_corr}", - terms={"C": "Gaussian"}) - out = rime(rime_spec, - {**dataset, "gauss_shape": gauss_shape}) + rime_spec = RimeSpecification( + f"(Cpq, Kpq, Bpq): {stokes_to_corr}", terms={"C": "Gaussian"} + ) + out = rime(rime_spec, {**dataset, "gauss_shape": gauss_shape}) P = phase_delay(lm, uvw, chan_freq, convention="fourier") SM = spectral_model(stokes, spi, ref_freq, chan_freq, base="std") B = convert(SM, stokes_schema, corr_schema) G = gaussian(uvw, chan_freq, gauss_shape) - expected = (G[:, :, :, None]*P[:, :, :, None]*B[:, None, :, :]).sum(axis=0) + expected = (G[:, :, :, None] * P[:, :, :, None] * B[:, None, :, :]).sum(axis=0) assert_array_almost_equal(expected, out) - assert np.count_nonzero(out) > .8 * out.size + assert np.count_nonzero(out) > 0.8 * out.size @pytest.mark.parametrize("chunks", chunks) @pytest.mark.parametrize("stokes_schema", [["I", "Q", "U", "V"]], ids=str) -@pytest.mark.parametrize("corr_schema", [ - ["XX", "XY", "YX", "YY"], - ["RR", "RL", "LR", "LL"], - ["RR", "LL"] -], ids=str) +@pytest.mark.parametrize( + "corr_schema", + [["XX", "XY", "YX", "YY"], ["RR", "RL", "LR", "LL"], ["RR", "LL"]], + ids=str, +) def test_fused_dask_rime(chunks, stokes_schema, corr_schema): da = pytest.importorskip("dask.array") - chunks = {**chunks, - "stokes": (len(stokes_schema),), - "corr": (len(corr_schema),)} + chunks = {**chunks, "stokes": (len(stokes_schema),), "corr": (len(corr_schema),)} nsrc = sum(chunks["source"]) nrow = sum(chunks["row"]) nspi = sum(chunks["spi"]) @@ -315,26 +330,25 @@ def test_fused_dask_rime(chunks, stokes_schema, corr_schema): mh = sum(chunks["mh"]) nud = sum(chunks["nud"]) - stokes_to_corr = "".join(("[", ",".join(stokes_schema), - "] -> [", - ",".join(corr_schema), "]")) + stokes_to_corr = "".join( + ("[", ",".join(stokes_schema), "] -> [", ",".join(corr_schema), "]") + ) start, end = _observation_endpoints(2021, 10, 9, 8) time = np.linspace(start, end, nrow) antenna1 = np.zeros(nrow, dtype=np.int32) antenna2 = np.arange(nrow, dtype=np.int32) feed1 = feed2 = antenna1 - radec = np.random.random(size=(nsrc, 2))*1e-5 - phase_dir = np.random.random(size=(2,))*1e-5 - uvw = np.random.random(size=(nrow, 3))*1e5 - chan_freq = np.linspace(.856e9, 2*.859e9, nchan) + radec = np.random.random(size=(nsrc, 2)) * 1e-5 + phase_dir = np.random.random(size=(2,)) * 1e-5 + uvw = np.random.random(size=(nrow, 3)) * 1e5 + chan_freq = np.linspace(0.856e9, 2 * 0.859e9, nchan) stokes = np.random.random(size=(nsrc, nstokes)) spi = np.random.random(size=(nsrc, nspi, nstokes)) - ref_freq = np.random.random(size=nsrc)*.856e9 + .856e9 + ref_freq = np.random.random(size=nsrc) * 0.856e9 + 0.856e9 beam = np.random.random(size=(lw, mh, nud, ncorr)) beam_lm_extents = np.array([[-1.0, 1.0], [-1.0, 1.0]]) - beam_freq_map = np.random.uniform( - low=chan_freq[0], high=chan_freq[-1], size=nud) + beam_freq_map = np.random.uniform(low=chan_freq[0], high=chan_freq[-1], size=nud) beam_freq_map.sort() gauss_shape = np.random.random((nsrc, 3)) gauss_shape[:, :2] *= 1e-5 @@ -350,7 +364,7 @@ def darray(array, dims): dask_dataset = { "time": darray(time, ("row",)), "antenna1": darray(antenna1, ("row",)), - "antenna2": darray(antenna2, ("row",)), + "antenna2": darray(antenna2, ("row",)), "feed1": darray(feed1, ("row",)), "feed2": darray(feed2, ("row",)), "radec": darray(radec, ("source", "radec")), @@ -369,16 +383,14 @@ def darray(array, dims): } # Feed rotations only make sense if we have four correlations - equation_str = ("(Lp, Ep, Kpq, Bpq, Eq, Lq)" - if ncorr == 4 - else "(Ep, Kpq, Bpq, Eq)") + equation_str = "(Lp, Ep, Kpq, Bpq, Eq, Lq)" if ncorr == 4 else "(Ep, Kpq, Bpq, Eq)" rime_spec = RimeSpecification(f"{equation_str}: {stokes_to_corr}") dask_out = dask_rime(rime_spec, dask_dataset, convention="casa") dataset = { "time": time, "antenna1": antenna1, - "antenna2": antenna2, + "antenna2": antenna2, "feed1": feed1, "feed2": feed2, "radec": radec, @@ -399,4 +411,4 @@ def darray(array, dims): out = rime(rime_spec, dataset, convention="casa") dout = dask_out.compute() assert_array_almost_equal(dout, out) - assert np.count_nonzero(out) > .8 * out.size + assert np.count_nonzero(out) > 0.8 * out.size diff --git a/africanus/experimental/rime/fused/tests/test_structref_setter.py b/africanus/experimental/rime/fused/tests/test_structref_setter.py index 23a84e8dc..66ba88e96 100644 --- a/africanus/experimental/rime/fused/tests/test_structref_setter.py +++ b/africanus/experimental/rime/fused/tests/test_structref_setter.py @@ -9,7 +9,7 @@ @structref.register class StateStructRef(types.StructRef): def preprocess_fields(self, fields): - """ Disallow literal types in field definitions """ + """Disallow literal types in field definitions""" return tuple((n, types.unliteral(t)) for n, t in fields) @@ -28,11 +28,10 @@ def constructor(typingctx, arg_tuple): def codegen(context, builder, signature, args): def make_struct(): - """ Allocate the structure """ + """Allocate the structure""" return structref.new(state_type) - state = context.compile_internal(builder, make_struct, - state_type(), []) + state = context.compile_internal(builder, make_struct, state_type(), []) # Now assign each argument U = structref._Utils(context, builder, state_type) @@ -42,8 +41,7 @@ def make_struct(): value = builder.extract_value(args[0], i) value_type = signature.args[0][i] field_type = state_type.field_dict[name] - casted = context.cast(builder, value, - value_type, field_type) + casted = context.cast(builder, value, value_type, field_type) old_value = getattr(data_struct, name) context.nrt.incref(builder, value_type, casted) context.nrt.decref(builder, value_type, old_value) diff --git a/africanus/experimental/rime/fused/transformers/core.py b/africanus/experimental/rime/fused/transformers/core.py index f8c6b0f26..ad975f7d8 100644 --- a/africanus/experimental/rime/fused/transformers/core.py +++ b/africanus/experimental/rime/fused/transformers/core.py @@ -8,8 +8,9 @@ def sigcheck_factory(expected_sig): def check_transformer_sig(self, fn): sig = inspect.signature(fn) if sig != expected_sig: - raise ValueError(f"{fn.__name__}{sig} should be " - f"{fn.__name__}{expected_sig}") + raise ValueError( + f"{fn.__name__}{sig} should be " f"{fn.__name__}{expected_sig}" + ) return check_transformer_sig @@ -24,6 +25,7 @@ class TransformerMetaClass(type): class members on the subclass based on the above signatures """ + REQUIRED = ("dask_schema", "init_fields") @classmethod @@ -43,60 +45,78 @@ def _expand_namespace(cls, name, namespace): field_params = list(init_fields_sig.parameters.values()) if len(init_fields_sig.parameters) < 2: - raise InvalidSignature(f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)") + raise InvalidSignature( + f"{name}.init_fields{init_fields_sig} " + f"should be " + f"{name}.init_fields(self, typingctx, ...)" + ) it = iter(init_fields_sig.parameters.items()) first, second = next(it), next(it) if first[0] != "self" or second[0] != "typingctx": - raise InvalidSignature(f"{name}.init_fields{init_fields_sig} " - f"should be " - f"{name}.init_fields(self, typingctx, ...)") + raise InvalidSignature( + f"{name}.init_fields{init_fields_sig} " + f"should be " + f"{name}.init_fields(self, typingctx, ...)" + ) for n, p in it: if p.kind == p.VAR_POSITIONAL: - raise InvalidSignature(f"*{n} in " - f"{name}.init_fields{init_fields_sig} " - f"is not supported") + raise InvalidSignature( + f"*{n} in " + f"{name}.init_fields{init_fields_sig} " + f"is not supported" + ) if p.kind == p.VAR_KEYWORD: - raise InvalidSignature(f"**{n} in " - f"{name}.init_fields{init_fields_sig} " - f"is not supported") + raise InvalidSignature( + f"**{n} in " + f"{name}.init_fields{init_fields_sig} " + f"is not supported" + ) dask_schema_sig = inspect.signature(methods["dask_schema"]) expected_dask_params = field_params[0:1] + field_params[2:] - expected_dask_sig = init_fields_sig.replace( - parameters=expected_dask_params) + expected_dask_sig = init_fields_sig.replace(parameters=expected_dask_params) if dask_schema_sig != expected_dask_sig: - raise InvalidSignature(f"{name}.dask_schema{dask_schema_sig} " - f"should be " - f"{name}.dask_schema{expected_dask_sig}") - - if not ("OUTPUTS" in namespace and - isinstance(namespace["OUTPUTS"], (tuple, list)) and - all(isinstance(o, str) for o in namespace["OUTPUTS"])): - - raise InvalidSignature(f"{name}.OUTPUTS should be a tuple " - f"of the names of the outputs produced " - f"by this transformer") + raise InvalidSignature( + f"{name}.dask_schema{dask_schema_sig} " + f"should be " + f"{name}.dask_schema{expected_dask_sig}" + ) + + if not ( + "OUTPUTS" in namespace + and isinstance(namespace["OUTPUTS"], (tuple, list)) + and all(isinstance(o, str) for o in namespace["OUTPUTS"]) + ): + raise InvalidSignature( + f"{name}.OUTPUTS should be a tuple " + f"of the names of the outputs produced " + f"by this transformer" + ) transform_sig = init_fields_sig.replace(parameters=field_params[2:]) namespace["OUTPUTS"] = tuple(namespace["OUTPUTS"]) - args = tuple(n for n, p in init_fields_sig.parameters.items() - if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD} - and n not in {"self", "typingctx"} - and p.default is p.empty) - - kw = ((n, p.default) for n, p in init_fields_sig.parameters.items() - if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} - and n not in {"self", "typingctx"} - and p.default is not p.empty) + args = tuple( + n + for n, p in init_fields_sig.parameters.items() + if p.kind in {p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD} + and n not in {"self", "typingctx"} + and p.default is p.empty + ) + + kw = ( + (n, p.default) + for n, p in init_fields_sig.parameters.items() + if p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + and n not in {"self", "typingctx"} + and p.default is not p.empty + ) namespace = namespace.copy() namespace["ARGS"] = args @@ -108,7 +128,7 @@ def _expand_namespace(cls, name, namespace): @classmethod def transformer_in_bases(cls, bases): - """ Is `Transformer` in bases? """ + """Is `Transformer` in bases?""" for base in bases: if base is Transformer or cls.transformer_in_bases(base.__bases__): return True @@ -121,8 +141,7 @@ def __new__(mcls, name, bases, namespace): if mcls.transformer_in_bases(bases): namespace = mcls._expand_namespace(name, namespace) - return super(TransformerMetaClass, mcls).__new__( - mcls, name, bases, namespace) + return super(TransformerMetaClass, mcls).__new__(mcls, name, bases, namespace) class Transformer(metaclass=TransformerMetaClass): diff --git a/africanus/experimental/rime/fused/transformers/lm.py b/africanus/experimental/rime/fused/transformers/lm.py index 3fa0764d8..2b67c9a62 100644 --- a/africanus/experimental/rime/fused/transformers/lm.py +++ b/africanus/experimental/rime/fused/transformers/lm.py @@ -26,8 +26,8 @@ def lm(radec, phase_dir): sin_dec = np.sin(radec[s, 1]) cos_dec = np.cos(radec[s, 1]) - lm[s, 0] = cos_dec*sin_ra_delta - lm[s, 1] = sin_dec*cos_pc_dec - cos_dec*sin_pc_dec*cos_ra_delta + lm[s, 0] = cos_dec * sin_ra_delta + lm[s, 1] = sin_dec * cos_pc_dec - cos_dec * sin_pc_dec * cos_ra_delta return lm @@ -37,10 +37,7 @@ def dask_schema(self, radec, phase_dir): assert radec.ndim == 2 assert phase_dir.ndim == 1 - inputs = { - "radec": ("source", "radec"), - "phase_dir": ("radec",) - } + inputs = {"radec": ("source", "radec"), "phase_dir": ("radec",)} outputs = {"lm": np.empty((0, 0), dtype=radec.dtype)} diff --git a/africanus/experimental/rime/fused/transformers/parangle.py b/africanus/experimental/rime/fused/transformers/parangle.py index 44ced1918..64cb14489 100644 --- a/africanus/experimental/rime/fused/transformers/parangle.py +++ b/africanus/experimental/rime/fused/transformers/parangle.py @@ -12,38 +12,46 @@ class ParallacticTransformer(Transformer): def __init__(self, process_pool): self.pool = process_pool - def init_fields(self, typingctx, - utime, ufeed, uantenna, - antenna_position, phase_dir, - receptor_angle=None): - dt = typingctx.unify_types(utime.dtype, ufeed.dtype, - antenna_position.dtype, - phase_dir.dtype) + def init_fields( + self, + typingctx, + utime, + ufeed, + uantenna, + antenna_position, + phase_dir, + receptor_angle=None, + ): + dt = typingctx.unify_types( + utime.dtype, ufeed.dtype, antenna_position.dtype, phase_dir.dtype + ) fields = [ ("feed_parangle", dt[:, :, :, :, :]), - ("beam_parangle", dt[:, :, :, :])] + ("beam_parangle", dt[:, :, :, :]), + ] parangle_dt = types.Array(types.float64, 2, "C") have_ra = not cgutils.is_nonelike(receptor_angle) - if have_ra and (not isinstance(receptor_angle, types.Array) or - receptor_angle.ndim != 2): + if have_ra and ( + not isinstance(receptor_angle, types.Array) or receptor_angle.ndim != 2 + ): raise errors.TypingError("receptor_angle must be a 2D array") @njit(inline="never") def parangle_stub(time, antenna, phase_dir): with objmode(out=parangle_dt): - out = self.pool.apply(casa_parallactic_angles, - (time, antenna, phase_dir)) + out = self.pool.apply( + casa_parallactic_angles, (time, antenna, phase_dir) + ) return out - def parangles(utime, ufeed, uantenna, - antenna_position, phase_dir, - receptor_angle=None): - - ntime, = utime.shape - nant, = uantenna.shape - nfeed, = ufeed.shape + def parangles( + utime, ufeed, uantenna, antenna_position, phase_dir, receptor_angle=None + ): + (ntime,) = utime.shape + (nant,) = uantenna.shape + (nfeed,) = ufeed.shape # Select out the antennae we're interested in antenna_position = antenna_position[uantenna] @@ -57,8 +65,7 @@ def parangles(utime, ufeed, uantenna, raise ValueError("receptor_angle.ndim != 2") if receptor_angle.shape[1] != 2: - raise ValueError("Only 2 receptor angles " - "currently supported") + raise ValueError("Only 2 receptor angles " "currently supported") # Select out the feeds we're interested in receptor_angle = receptor_angle[ufeed, :] @@ -87,13 +94,11 @@ def parangles(utime, ufeed, uantenna, return fields, parangles - def dask_schema(self, utime, ufeed, uantenna, - antenna_position, phase_dir, - receptor_angle=None): - dt = np.result_type(utime, ufeed, antenna_position, - phase_dir, receptor_angle) - inputs = {"antenna_position": ("antenna", "ant-comp"), - "phase_dir": ("radec",)} + def dask_schema( + self, utime, ufeed, uantenna, antenna_position, phase_dir, receptor_angle=None + ): + dt = np.result_type(utime, ufeed, antenna_position, phase_dir, receptor_angle) + inputs = {"antenna_position": ("antenna", "ant-comp"), "phase_dir": ("radec",)} if receptor_angle is not None: inputs["receptor_angle"] = ("feed", "receptor_angle") @@ -101,8 +106,8 @@ def dask_schema(self, utime, ufeed, uantenna, inputs["receptor_angle"] = None outputs = { - "feed_parangle": np.empty((0,)*5, dt), - "beam_parangle": np.empty((0,)*4, dt) + "feed_parangle": np.empty((0,) * 5, dt), + "beam_parangle": np.empty((0,) * 4, dt), } return inputs, outputs diff --git a/africanus/gps/examples/generate_phase_only_gains.py b/africanus/gps/examples/generate_phase_only_gains.py index b2f2020b4..7180685c9 100755 --- a/africanus/gps/examples/generate_phase_only_gains.py +++ b/africanus/gps/examples/generate_phase_only_gains.py @@ -28,9 +28,9 @@ def create_parser(): # get times and normalise ms = table(args.ms) -time = ms.getcol('TIME') -ant1 = ms.getcol('ANTENNA1') -ant2 = ms.getcol('ANTENNA2') +time = ms.getcol("TIME") +ant1 = ms.getcol("ANTENNA1") +ant2 = ms.getcol("ANTENNA2") n_ant = int(np.maximum(ant1.max(), ant2.max()) + 1) ms.close() time = np.unique(time) @@ -42,8 +42,8 @@ def create_parser(): # get freqs and normalise -spw = table(args.ms + '::SPECTRAL_WINDOW') -freq = spw.getcol('CHAN_FREQ').squeeze() +spw = table(args.ms + "::SPECTRAL_WINDOW") +freq = spw.getcol("CHAN_FREQ").squeeze() spw.close() freq -= freq.min() freq /= freq.max() @@ -73,7 +73,7 @@ def create_parser(): gains[:, p, :, :, 1] = samp # convert to gains -gains = np.exp(1.0j*gains) +gains = np.exp(1.0j * gains) # save result np.save(args.gain_file, gains) diff --git a/africanus/gps/kernels.py b/africanus/gps/kernels.py index 11191b83b..6e3ec1291 100644 --- a/africanus/gps/kernels.py +++ b/africanus/gps/kernels.py @@ -41,7 +41,7 @@ def exponential_squared(x, xp, sigmaf, l, pspec=False): # noqa: E741 if (x[1::] - x[0:-1] != delx).any(): raise ValueError("pspec only defined on regular grid") s = np.fft.fftshift(np.fft.fftfreq(N, d=delx)) - return np.sqrt(2*np.pi*l)*sigmaf**2.0*np.exp(-l**2*s**2/2.0) + return np.sqrt(2 * np.pi * l) * sigmaf**2.0 * np.exp(-(l**2) * s**2 / 2.0) else: xxp = abs_diff(x, xp) - return sigmaf**2*np.exp(-xxp**2/(2*l**2)) + return sigmaf**2 * np.exp(-(xxp**2) / (2 * l**2)) diff --git a/africanus/gridding/nifty/dask.py b/africanus/gridding/nifty/dask.py index a3e0ba8f9..8bec69b04 100644 --- a/africanus/gridding/nifty/dask.py +++ b/africanus/gridding/nifty/dask.py @@ -21,9 +21,11 @@ try: import nifty_gridder as ng except ImportError: - nifty_import_err = ImportError("Please manually install nifty_gridder " - "from https://gitlab.mpcdf.mpg.de/ift/" - "nifty_gridder.git") + nifty_import_err = ImportError( + "Please manually install nifty_gridder " + "from https://gitlab.mpcdf.mpg.de/ift/" + "nifty_gridder.git" + ) else: nifty_import_err = None @@ -35,27 +37,24 @@ class GridderConfigWrapper(object): Wraps a nifty GridderConfiguration for pickling purposes. """ - def __init__(self, nx=1024, ny=1024, eps=2e-13, - cell_size_x=2.0, cell_size_y=2.0): + def __init__(self, nx=1024, ny=1024, eps=2e-13, cell_size_x=2.0, cell_size_y=2.0): self.nx = nx self.ny = ny self.csx = cell_size_x self.csy = cell_size_y self.eps = eps - self.grid_config = ng.GridderConfig(nx, ny, eps, - cell_size_x, - cell_size_y) + self.grid_config = ng.GridderConfig(nx, ny, eps, cell_size_x, cell_size_y) @property def object(self): return self.grid_config def __reduce__(self): - return (GridderConfigWrapper, - (self.nx, self.ny, self.eps, self.csx, self.csy)) + return (GridderConfigWrapper, (self.nx, self.ny, self.eps, self.csx, self.csy)) if import_error is None: + @normalize_token.register(GridderConfigWrapper) def normalize_gridder_config_wrapper(gc): return normalize_token((gc.nx, gc.ny, gc.csx, gc.csy, gc.eps)) @@ -91,30 +90,29 @@ def grid_config(nx=1024, ny=1024, eps=2e-13, cell_size_x=2.0, cell_size_y=2.0): def _nifty_baselines(uvw, chan_freq): - """ Wrapper function for creating baseline mappings per row chunk """ + """Wrapper function for creating baseline mappings per row chunk""" assert len(chan_freq) == 1, "Handle multiple channel chunks" return ng.Baselines(uvw[0], chan_freq[0]) -def _nifty_indices(baselines, grid_config, flag, - chan_begin, chan_end, wmin, wmax): - """ Wrapper function for creating indices per row chunk """ - return ng.getIndices(baselines, grid_config, flag[0], - chan_begin, chan_end, wmin, wmax) +def _nifty_indices(baselines, grid_config, flag, chan_begin, chan_end, wmin, wmax): + """Wrapper function for creating indices per row chunk""" + return ng.getIndices( + baselines, grid_config, flag[0], chan_begin, chan_end, wmin, wmax + ) def _nifty_grid(baselines, grid_config, indices, vis, weights): - """ Wrapper function for creating a grid of visibilities per row chunk """ + """Wrapper function for creating a grid of visibilities per row chunk""" assert len(vis) == 1 and type(vis) is list - return ng.ms2grid_c(baselines, grid_config, indices, - vis[0], None, weights[0])[None, :, :] + return ng.ms2grid_c(baselines, grid_config, indices, vis[0], None, weights[0])[ + None, :, : + ] -def _nifty_grid_streams(baselines, grid_config, indices, - vis, weights, grid_in=None): - """ Wrapper function for creating a grid of visibilities per row chunk """ - return ng.ms2grid_c(baselines, grid_config, indices, - vis, grid_in, weights) +def _nifty_grid_streams(baselines, grid_config, indices, vis, weights, grid_in=None): + """Wrapper function for creating a grid of visibilities per row chunk""" + return ng.ms2grid_c(baselines, grid_config, indices, vis, grid_in, weights) class GridStreamReduction(Mapping): @@ -129,12 +127,10 @@ class GridStreamReduction(Mapping): ``stream`` parallel streams. """ - def __init__(self, baselines, indices, gc, - corr_vis, corr_weights, - corr, streams): - token = dask.base.tokenize(baselines, indices, gc, - corr_vis, corr_weights, - corr, streams) + def __init__(self, baselines, indices, gc, corr_vis, corr_weights, corr, streams): + token = dask.base.tokenize( + baselines, indices, gc, corr_vis, corr_weights, corr, streams + ) self.name = "-".join(("nifty-grid-stream", str(corr), token)) self.bl_name = baselines.name self.idx_name = indices.name @@ -189,14 +185,16 @@ def _create_dict(self): last_key = None for rb in range(rb_start, rb_end): - fn = (_nifty_grid_streams, - (baselines_name, rb), - gc, - (indices_name, rb), - (corr_vis_name, rb, cb), - (corr_wgt_name, rb, cb), - # Re-use grid from last operation if present - last_key) + fn = ( + _nifty_grid_streams, + (baselines_name, rb), + gc, + (indices_name, rb), + (corr_vis_name, rb, cb), + (corr_wgt_name, rb, cb), + # Re-use grid from last operation if present + last_key, + ) key = (name, rb, cb) layers[key] = fn @@ -269,8 +267,17 @@ def _create_dict(self): @requires_optional("dask.array", import_error) @requires_optional("nifty_gridder", nifty_import_err) -def grid(vis, uvw, flags, weights, frequencies, grid_config, - wmin=-1e30, wmax=1e30, streams=None): +def grid( + vis, + uvw, + flags, + weights, + frequencies, + grid_config, + wmin=-1e30, + wmax=1e30, + streams=None, +): """ Grids the supplied visibilities in parallel. Note that a grid is create for each visibility chunk. @@ -307,10 +314,15 @@ def grid(vis, uvw, flags, weights, frequencies, grid_config, raise ValueError("Chunking in channel currently unsupported") # Create a baseline object per row chunk - baselines = da.blockwise(_nifty_baselines, ("row",), - uvw, ("row", "uvw"), - frequencies, ("chan",), - dtype=object) + baselines = da.blockwise( + _nifty_baselines, + ("row",), + uvw, + ("row", "uvw"), + frequencies, + ("chan",), + dtype=object, + ) if len(frequencies.chunks[0]) != 1: raise ValueError("Chunking in channel unsupported") @@ -323,36 +335,54 @@ def grid(vis, uvw, flags, weights, frequencies, grid_config, corr_vis = vis[:, :, corr] corr_weights = weights[:, :, corr] - indices = da.blockwise(_nifty_indices, ("row",), - baselines, ("row",), - gc, None, - corr_flags, ("row", "chan"), - -1, None, # channel begin - -1, None, # channel end - wmin, None, - wmax, None, - dtype=np.int32) + indices = da.blockwise( + _nifty_indices, + ("row",), + baselines, + ("row",), + gc, + None, + corr_flags, + ("row", "chan"), + -1, + None, # channel begin + -1, + None, # channel end + wmin, + None, + wmax, + None, + dtype=np.int32, + ) if streams is None: # Standard parallel reduction, possibly memory hungry # if many threads (and thus grids) are gridding # parallel - grid = da.blockwise(_nifty_grid, ("row", "nu", "nv"), - baselines, ("row",), - gc, None, - indices, ("row",), - corr_vis, ("row", "chan"), - corr_weights, ("row", "chan"), - new_axes={"nu": gc.Nu(), "nv": gc.Nv()}, - adjust_chunks={"row": 1}, - dtype=np.complex128) + grid = da.blockwise( + _nifty_grid, + ("row", "nu", "nv"), + baselines, + ("row",), + gc, + None, + indices, + ("row",), + corr_vis, + ("row", "chan"), + corr_weights, + ("row", "chan"), + new_axes={"nu": gc.Nu(), "nv": gc.Nv()}, + adjust_chunks={"row": 1}, + dtype=np.complex128, + ) grids.append(grid.sum(axis=0)) else: # Stream reduction - layers = GridStreamReduction(baselines, indices, gc, - corr_vis, corr_weights, - corr, streams) + layers = GridStreamReduction( + baselines, indices, gc, corr_vis, corr_weights, corr, streams + ) deps = [baselines, indices, corr_vis, corr_weights] graph = HighLevelGraph.from_collections(layers.name, layers, deps) chunks = corr_vis.chunks @@ -370,9 +400,8 @@ def grid(vis, uvw, flags, weights, frequencies, grid_config, def _nifty_dirty(grid, grid_config): - """ Wrapper function for creating a dirty image """ - grids = [grid_config.grid2dirty_c(grid[:, :, c]).real - for c in range(grid.shape[2])] + """Wrapper function for creating a dirty image""" + grids = [grid_config.grid2dirty_c(grid[:, :, c]).real for c in range(grid.shape[2])] return np.stack(grids, axis=2) @@ -401,17 +430,21 @@ def dirty(grid, grid_config): nx = gc.Nxdirty() ny = gc.Nydirty() - return da.blockwise(_nifty_dirty, ("nx", "ny", "corr"), - grid, ("nx", "ny", "corr"), - gc, None, - adjust_chunks={"nx": nx, "ny": ny}, - dtype=grid.real.dtype) + return da.blockwise( + _nifty_dirty, + ("nx", "ny", "corr"), + grid, + ("nx", "ny", "corr"), + gc, + None, + adjust_chunks={"nx": nx, "ny": ny}, + dtype=grid.real.dtype, + ) def _nifty_model(image, grid_config): - """ Wrapper function for creating a dirty image """ - images = [grid_config.dirty2grid_c(image[:, :, c]) - for c in range(image.shape[2])] + """Wrapper function for creating a dirty image""" + images = [grid_config.dirty2grid_c(image[:, :, c]) for c in range(image.shape[2])] return np.stack(images, axis=2) @@ -440,11 +473,16 @@ def model(image, grid_config): nu = gc.Nu() nv = gc.Nv() - return da.blockwise(_nifty_model, ("nu", "nv", "corr"), - image, ("nu", "nv", "corr"), - gc, None, - adjust_chunks={"nu": nu, "nv": nv}, - dtype=image.dtype) + return da.blockwise( + _nifty_model, + ("nu", "nv", "corr"), + image, + ("nu", "nv", "corr"), + gc, + None, + adjust_chunks={"nu": nu, "nv": nv}, + dtype=image.dtype, + ) def _nifty_degrid(grid, baselines, indices, grid_config): @@ -454,8 +492,7 @@ def _nifty_degrid(grid, baselines, indices, grid_config): @requires_optional("dask.array", import_error) @requires_optional("nifty_gridder", nifty_import_err) -def degrid(grid, uvw, flags, weights, frequencies, - grid_config, wmin=-1e30, wmax=1e30): +def degrid(grid, uvw, flags, weights, frequencies, grid_config, wmin=-1e30, wmax=1e30): """ Degrids the visibilities from the supplied grid in parallel. @@ -489,10 +526,15 @@ def degrid(grid, uvw, flags, weights, frequencies, raise ValueError("Chunking in channel currently unsupported") # Create a baseline object per row chunk - baselines = da.blockwise(_nifty_baselines, ("row",), - uvw, ("row", "uvw"), - frequencies, ("chan",), - dtype=object) + baselines = da.blockwise( + _nifty_baselines, + ("row",), + uvw, + ("row", "uvw"), + frequencies, + ("chan",), + dtype=object, + ) gc = grid_config.object vis_chunks = [] @@ -501,23 +543,40 @@ def degrid(grid, uvw, flags, weights, frequencies, corr_flags = flags[:, :, corr].map_blocks(np.require, requirements="C") corr_grid = grid[:, :, corr].map_blocks(np.require, requirements="C") - indices = da.blockwise(_nifty_indices, ("row",), - baselines, ("row",), - gc, None, - corr_flags, ("row", "chan"), - -1, None, # channel begin - -1, None, # channel end - wmin, None, - wmax, None, - dtype=np.int32) - - vis = da.blockwise(_nifty_degrid, ("row", "chan"), - corr_grid, ("ny", "nx"), - baselines, ("row",), - indices, ("row",), - grid_config, None, - new_axes={"chan": frequencies.shape[0]}, - dtype=grid.dtype) + indices = da.blockwise( + _nifty_indices, + ("row",), + baselines, + ("row",), + gc, + None, + corr_flags, + ("row", "chan"), + -1, + None, # channel begin + -1, + None, # channel end + wmin, + None, + wmax, + None, + dtype=np.int32, + ) + + vis = da.blockwise( + _nifty_degrid, + ("row", "chan"), + corr_grid, + ("ny", "nx"), + baselines, + ("row",), + indices, + ("row",), + grid_config, + None, + new_axes={"chan": frequencies.shape[0]}, + dtype=grid.dtype, + ) vis_chunks.append(vis) diff --git a/africanus/gridding/nifty/tests/test_nifty_gridder.py b/africanus/gridding/nifty/tests/test_nifty_gridder.py index 097696136..997d69156 100644 --- a/africanus/gridding/nifty/tests/test_nifty_gridder.py +++ b/africanus/gridding/nifty/tests/test_nifty_gridder.py @@ -6,8 +6,7 @@ import pickle import pytest -from africanus.gridding.nifty.dask import (grid, degrid, dirty, model, - grid_config) +from africanus.gridding.nifty.dask import grid, degrid, dirty, model, grid_config def rf(*a, **kw): @@ -15,16 +14,16 @@ def rf(*a, **kw): def rc(*a, **kw): - return rf(*a, **kw) + 1j*rf(*a, **kw) + return rf(*a, **kw) + 1j * rf(*a, **kw) def test_dask_nifty_gridder(): - """ Only tests that we can call it and create a dirty image """ - dask = pytest.importorskip('dask') - da = pytest.importorskip('dask.array') - _ = pytest.importorskip('nifty_gridder') + """Only tests that we can call it and create a dirty image""" + dask = pytest.importorskip("dask") + da = pytest.importorskip("dask.array") + _ = pytest.importorskip("nifty_gridder") - row = (16,)*8 + row = (16,) * 8 chan = (32,) corr = (4,) nx = 1026 @@ -35,9 +34,9 @@ def test_dask_nifty_gridder(): ncorr = sum(corr) # Random UV data - uvw = rf(size=(nrow, 3)).astype(np.float64)*128 + uvw = rf(size=(nrow, 3)).astype(np.float64) * 128 vis = rf(size=(nrow, nchan, ncorr)).astype(np.complex128) - freq = np.linspace(.856e9, 2*.856e9, nchan) + freq = np.linspace(0.856e9, 2 * 0.856e9, nchan) flag = np.zeros(vis.shape, dtype=np.uint8) flag = np.random.randint(0, 2, vis.shape, dtype=np.uint8).astype(np.bool_) weight = rf(vis.shape).astype(np.float64) @@ -81,9 +80,9 @@ def test_dask_nifty_gridder(): def test_dask_nifty_degridder(): - """ Only tests that we can call it and create some visibilities """ - da = pytest.importorskip('dask.array') - _ = pytest.importorskip('nifty_gridder') + """Only tests that we can call it and create some visibilities""" + da = pytest.importorskip("dask.array") + _ = pytest.importorskip("nifty_gridder") row = (16, 16, 16, 16) chan = (32,) @@ -98,8 +97,8 @@ def test_dask_nifty_degridder(): gc = grid_config(nx, ny, 2e-13, 2.0, 2.0) # Random UV data - uvw = rf(size=(nrow, 3)).astype(np.float64)*128 - freq = np.linspace(.856e9, 2*.856e9, nchan) + uvw = rf(size=(nrow, 3)).astype(np.float64) * 128 + freq = np.linspace(0.856e9, 2 * 0.856e9, nchan) flag = np.zeros((nrow, nchan, ncorr), dtype=np.bool_) weight = np.ones((nrow, nchan, ncorr), dtype=np.float64) image = rc(size=(nx, ny, ncorr)).astype(np.complex128) diff --git a/africanus/gridding/perleypolyhedron/dask.py b/africanus/gridding/perleypolyhedron/dask.py index 52a3116db..2721788ab 100644 --- a/africanus/gridding/perleypolyhedron/dask.py +++ b/africanus/gridding/perleypolyhedron/dask.py @@ -7,33 +7,33 @@ else: opt_import_err = None -from africanus.gridding.perleypolyhedron.gridder import ( - gridder as np_gridder) +from africanus.gridding.perleypolyhedron.gridder import gridder as np_gridder +from africanus.gridding.perleypolyhedron.degridder import degridder as np_degridder from africanus.gridding.perleypolyhedron.degridder import ( - degridder as np_degridder) -from africanus.gridding.perleypolyhedron.degridder import ( - degridder_serial as np_degridder_serial) -from africanus.gridding.perleypolyhedron.policies import ( - stokes_conversion_policies) + degridder_serial as np_degridder_serial, +) +from africanus.gridding.perleypolyhedron.policies import stokes_conversion_policies from africanus.util.requirements import requires_optional -def __degrid(uvw, - gridstack, - lambdas, - chanmap, - image_centres, - phase_centre, - cell=None, - convolution_kernel=None, - convolution_kernel_width=None, - convolution_kernel_oversampling=None, - baseline_transform_policy=None, - phase_transform_policy=None, - stokes_conversion_policy=None, - convolution_policy=None, - vis_dtype=np.complex128, - rowparallel=False): +def __degrid( + uvw, + gridstack, + lambdas, + chanmap, + image_centres, + phase_centre, + cell=None, + convolution_kernel=None, + convolution_kernel_width=None, + convolution_kernel_oversampling=None, + baseline_transform_policy=None, + phase_transform_policy=None, + stokes_conversion_policy=None, + convolution_policy=None, + vis_dtype=np.complex128, + rowparallel=False, +): image_centres = image_centres[0][0] if image_centres.ndim != 2: raise ValueError( @@ -51,52 +51,55 @@ def __degrid(uvw, lambdas = lambdas chanmap = chanmap if chanmap.size != lambdas.size: - raise ValueError( - "Chanmap and corresponding lambdas must match in shape") + raise ValueError("Chanmap and corresponding lambdas must match in shape") nchan = lambdas.size nrow = uvw.shape[0] ncorr = stokes_conversion_policies.ncorr_outpy( - policy_type=stokes_conversion_policy)() + policy_type=stokes_conversion_policy + )() vis = np.zeros((nrow, nchan, ncorr), dtype=vis_dtype) degridcall = np_degridder_serial if not rowparallel else np_degridder for fi, f in enumerate(image_centres): # add contributions from all facets - vis[:, :, :] += \ - degridcall(uvw, - gridstack[fi, :, :, :], - lambdas, - chanmap, - cell, - image_centres, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - baseline_transform_policy, - phase_transform_policy, - stokes_conversion_policy, - convolution_policy, - vis_dtype=vis_dtype) + vis[:, :, :] += degridcall( + uvw, + gridstack[fi, :, :, :], + lambdas, + chanmap, + cell, + image_centres, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + baseline_transform_policy, + phase_transform_policy, + stokes_conversion_policy, + convolution_policy, + vis_dtype=vis_dtype, + ) return vis @requires_optional("dask", opt_import_err) -def degridder(uvw, - gridstack, - lambdas, - chanmap, - cell, - image_centres, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - baseline_transform_policy, - phase_transform_policy, - stokes_conversion_policy, - convolution_policy, - vis_dtype=np.complex128, - rowparallel=False): +def degridder( + uvw, + gridstack, + lambdas, + chanmap, + cell, + image_centres, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + baseline_transform_policy, + phase_transform_policy, + stokes_conversion_policy, + convolution_policy, + vis_dtype=np.complex128, + rowparallel=False, +): """ 2D Convolutional degridder, discrete to contiguous @uvw: value coordinates, (nrow, 3) @@ -147,12 +150,18 @@ def degridder(uvw, "of image centres" ) vis = da.blockwise( - __degrid, ("row", "chan", "corr"), - uvw, ("row", "uvw"), - gridstack, ("nfacet", "nband", "y", "x"), - lambdas, ("chan", ), - chanmap, ("chan", ), - image_centres, ("nfacet", "coord"), + __degrid, + ("row", "chan", "corr"), + uvw, + ("row", "uvw"), + gridstack, + ("nfacet", "nband", "y", "x"), + lambdas, + ("chan",), + chanmap, + ("chan",), + image_centres, + ("nfacet", "coord"), convolution_kernel=convolution_kernel, convolution_kernel_width=convolution_kernel_width, convolution_kernel_oversampling=convolution_kernel_oversampling, @@ -164,35 +173,37 @@ def degridder(uvw, phase_centre=phase_centre, vis_dtype=vis_dtype, new_axes={ - "corr": - stokes_conversion_policies.ncorr_outpy( - policy_type=stokes_conversion_policy)() + "corr": stokes_conversion_policies.ncorr_outpy( + policy_type=stokes_conversion_policy + )() }, dtype=vis_dtype, meta=np.empty( - (0, 0, 0), - dtype=vis_dtype) # row, chan, correlation product as per MSv2 spec + (0, 0, 0), dtype=vis_dtype + ), # row, chan, correlation product as per MSv2 spec ) return vis -def __grid(uvw, - vis, - image_centres, - lambdas=None, - chanmap=None, - convolution_kernel=None, - convolution_kernel_width=None, - convolution_kernel_oversampling=None, - baseline_transform_policy=None, - phase_transform_policy=None, - stokes_conversion_policy=None, - convolution_policy=None, - npix=None, - cell=None, - phase_centre=None, - grid_dtype=np.complex128, - do_normalize=False): +def __grid( + uvw, + vis, + image_centres, + lambdas=None, + chanmap=None, + convolution_kernel=None, + convolution_kernel_width=None, + convolution_kernel_oversampling=None, + baseline_transform_policy=None, + phase_transform_policy=None, + stokes_conversion_policy=None, + convolution_policy=None, + npix=None, + cell=None, + phase_centre=None, + grid_dtype=np.complex128, + do_normalize=False, +): image_centres = image_centres[0] if image_centres.ndim != 2: raise ValueError( @@ -208,26 +219,17 @@ def __grid(uvw, chanmap = chanmap[0] grid_stack = np.zeros( (1, image_centres.shape[0], 1, np.max(chanmap) + 1, npix, npix), - dtype=grid_dtype) + dtype=grid_dtype, + ) for fi, f in enumerate(image_centres): - grid_stack[0, fi, 0, :, :, :] = \ - np_gridder(uvw, vis, lambdas, chanmap, npix, cell, f, phase_centre, - convolution_kernel, convolution_kernel_width, - convolution_kernel_oversampling, - baseline_transform_policy, phase_transform_policy, - stokes_conversion_policy, - convolution_policy, grid_dtype, do_normalize) - return grid_stack - - -@requires_optional("dask", opt_import_err) -def gridder(uvw, + grid_stack[0, fi, 0, :, :, :] = np_gridder( + uvw, vis, lambdas, chanmap, npix, cell, - image_centres, + f, phase_centre, convolution_kernel, convolution_kernel_width, @@ -236,8 +238,32 @@ def gridder(uvw, phase_transform_policy, stokes_conversion_policy, convolution_policy, - grid_dtype=np.complex128, - do_normalize=False): + grid_dtype, + do_normalize, + ) + return grid_stack + + +@requires_optional("dask", opt_import_err) +def gridder( + uvw, + vis, + lambdas, + chanmap, + npix, + cell, + image_centres, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + baseline_transform_policy, + phase_transform_policy, + stokes_conversion_policy, + convolution_policy, + grid_dtype=np.complex128, + do_normalize=False, +): """ 2D Convolutional gridder, contiguous to discrete @uvw: value coordinates, (nrow, 3) @@ -269,25 +295,27 @@ def gridder(uvw, """ if len(vis.chunks) != 3 or lambdas.chunks[0] != vis.chunks[1]: raise ValueError( - "Visibility frequency chunking does not match " - "lambda frequency chunking" + "Visibility frequency chunking does not match " "lambda frequency chunking" ) if len(vis.chunks) != 3 or chanmap.chunks[0] != vis.chunks[1]: raise ValueError( - "Visibility frequency chunking does not match chanmap " - "frequency chunking" + "Visibility frequency chunking does not match chanmap " "frequency chunking" ) - if len(vis.chunks) != 3 or len( - uvw.chunks) != 2 or vis.chunks[0] != uvw.chunks[0]: - raise ValueError( - "Visibility row chunking does not match uvw row chunking") + if len(vis.chunks) != 3 or len(uvw.chunks) != 2 or vis.chunks[0] != uvw.chunks[0]: + raise ValueError("Visibility row chunking does not match uvw row chunking") grids = da.blockwise( - __grid, ("row", "nfacet", "nstokes", "nband", "y", "x"), - uvw, ("row", "uvw"), - vis, ("row", "chan", "corr"), - image_centres, ("nfacet", "coord"), - lambdas, ("chan", ), - chanmap, ("chan", ), + __grid, + ("row", "nfacet", "nstokes", "nband", "y", "x"), + uvw, + ("row", "uvw"), + vis, + ("row", "chan", "corr"), + image_centres, + ("nfacet", "coord"), + lambdas, + ("chan",), + chanmap, + ("chan",), convolution_kernel=convolution_kernel, convolution_kernel_width=convolution_kernel_width, convolution_kernel_oversampling=convolution_kernel_oversampling, @@ -308,10 +336,11 @@ def gridder(uvw, # multi-stokes cubes are supported "nstokes": 1, "y": npix, - "x": npix + "x": npix, }, dtype=grid_dtype, - meta=np.empty((0, 0, 0, 0, 0, 0), dtype=grid_dtype)) + meta=np.empty((0, 0, 0, 0, 0, 0), dtype=grid_dtype), + ) # Parallel reduction over row dimension return grids.mean(axis=0, split_every=2) diff --git a/africanus/gridding/perleypolyhedron/degridder.py b/africanus/gridding/perleypolyhedron/degridder.py index d8d35cbf1..4c1703271 100644 --- a/africanus/gridding/perleypolyhedron/degridder.py +++ b/africanus/gridding/perleypolyhedron/degridder.py @@ -3,88 +3,95 @@ from africanus.util.numba import jit from africanus.gridding.perleypolyhedron.policies import ( - baseline_transform_policies as btp) + baseline_transform_policies as btp, +) +from africanus.gridding.perleypolyhedron.policies import phase_transform_policies as ptp +from africanus.gridding.perleypolyhedron.policies import convolution_policies as cp from africanus.gridding.perleypolyhedron.policies import ( - phase_transform_policies as ptp) -from africanus.gridding.perleypolyhedron.policies import ( - convolution_policies as cp) -from africanus.gridding.perleypolyhedron.policies import ( - stokes_conversion_policies as scp) + stokes_conversion_policies as scp, +) @jit(nopython=True, nogil=True, fastmath=True, inline="always") -def degridder_row_kernel(uvw, - gridstack, - wavelengths, - chanmap, - cell, - image_centre, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - baseline_transform_policy, - phase_transform_policy, - stokes_conversion_policy, - convolution_policy, - vis_dtype=np.complex128, - nband=0, - nrow=0, - npix=0, - nvischan=0, - ncorr=0, - vis=None, - scale_factor=0, - r=0): +def degridder_row_kernel( + uvw, + gridstack, + wavelengths, + chanmap, + cell, + image_centre, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + baseline_transform_policy, + phase_transform_policy, + stokes_conversion_policy, + convolution_policy, + vis_dtype=np.complex128, + nband=0, + nrow=0, + npix=0, + nvischan=0, + ncorr=0, + vis=None, + scale_factor=0, + r=0, +): ra0, dec0 = phase_centre ra, dec = image_centre - btp.policy(uvw[r, :], ra, dec, ra0, dec0, - baseline_transform_policy) + btp.policy(uvw[r, :], ra, dec, ra0, dec0, baseline_transform_policy) for c in range(nvischan): scaled_u = uvw[r, 0] * scale_factor / wavelengths[c] scaled_v = uvw[r, 1] * scale_factor / wavelengths[c] scaled_w = uvw[r, 2] * scale_factor / wavelengths[c] grid = gridstack[chanmap[c], :, :] - cp.policy(scaled_u, - scaled_v, - scaled_w, - npix, - grid, - vis, - r, - c, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - stokes_conversion_policy, - policy_type=convolution_policy) - ptp.policy(vis[r, :, :], - uvw[r, :], - wavelengths, - ra0, - dec0, - ra, - dec, - policy_type=phase_transform_policy, - phasesign=-1.0) + cp.policy( + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type=convolution_policy, + ) + ptp.policy( + vis[r, :, :], + uvw[r, :], + wavelengths, + ra0, + dec0, + ra, + dec, + policy_type=phase_transform_policy, + phasesign=-1.0, + ) @jit(nopython=True, nogil=True, fastmath=True, parallel=True) -def degridder(uvw, - gridstack, - wavelengths, - chanmap, - cell, - image_centre, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - baseline_transform_policy, - phase_transform_policy, - stokes_conversion_policy, - convolution_policy, - vis_dtype=np.complex128): +def degridder( + uvw, + gridstack, + wavelengths, + chanmap, + cell, + image_centre, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + baseline_transform_policy, + phase_transform_policy, + stokes_conversion_policy, + convolution_policy, + vis_dtype=np.complex128, +): """ 2D Convolutional degridder, discrete to contiguous @uvw: value coordinates, (nrow, 3) @@ -114,8 +121,7 @@ def degridder(uvw, """ if chanmap.size != wavelengths.size: - raise ValueError( - "Chanmap and corresponding wavelengths must match in shape") + raise ValueError("Chanmap and corresponding wavelengths must match in shape") chanmap = chanmap.ravel() wavelengths = wavelengths.ravel() nband = np.max(chanmap) + 1 @@ -127,12 +133,12 @@ def degridder(uvw, ncorr = scp.ncorr_out(policy_type=literally(stokes_conversion_policy)) if gridstack.shape[0] < nband: raise ValueError( - "Not enough channel bands in grid stack to match mfs band mapping") + "Not enough channel bands in grid stack to match mfs band mapping" + ) if uvw.shape[1] != 3: raise ValueError("UVW array must be array of tripples") if uvw.shape[0] != nrow: - raise ValueError( - "UVW array must have same number of rows as vis array") + raise ValueError("UVW array must have same number of rows as vis array") if nvischan != wavelengths.size: raise ValueError("Chanmap must correspond to visibility channels") @@ -141,48 +147,52 @@ def degridder(uvw, # scale the FOV using the simularity theorem scale_factor = npix * cell / 3600.0 * np.pi / 180.0 for r in prange(nrow): - degridder_row_kernel(uvw, - gridstack, - wavelengths, - chanmap, - cell, - image_centre, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - literally(baseline_transform_policy), - literally(phase_transform_policy), - literally(stokes_conversion_policy), - literally(convolution_policy), - vis_dtype=vis_dtype, - nband=nband, - nrow=nrow, - npix=npix, - nvischan=nvischan, - ncorr=ncorr, - vis=vis, - scale_factor=scale_factor, - r=r) + degridder_row_kernel( + uvw, + gridstack, + wavelengths, + chanmap, + cell, + image_centre, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + literally(baseline_transform_policy), + literally(phase_transform_policy), + literally(stokes_conversion_policy), + literally(convolution_policy), + vis_dtype=vis_dtype, + nband=nband, + nrow=nrow, + npix=npix, + nvischan=nvischan, + ncorr=ncorr, + vis=vis, + scale_factor=scale_factor, + r=r, + ) return vis @jit(nopython=True, nogil=True, fastmath=True, parallel=False) -def degridder_serial(uvw, - gridstack, - wavelengths, - chanmap, - cell, - image_centre, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - baseline_transform_policy, - phase_transform_policy, - stokes_conversion_policy, - convolution_policy, - vis_dtype=np.complex128): +def degridder_serial( + uvw, + gridstack, + wavelengths, + chanmap, + cell, + image_centre, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + baseline_transform_policy, + phase_transform_policy, + stokes_conversion_policy, + convolution_policy, + vis_dtype=np.complex128, +): """ 2D Convolutional degridder, discrete to contiguous @uvw: value coordinates, (nrow, 3) @@ -212,8 +222,7 @@ def degridder_serial(uvw, """ if chanmap.size != wavelengths.size: - raise ValueError( - "Chanmap and corresponding wavelengths must match in shape") + raise ValueError("Chanmap and corresponding wavelengths must match in shape") chanmap = chanmap.ravel() wavelengths = wavelengths.ravel() nband = np.max(chanmap) + 1 @@ -225,12 +234,12 @@ def degridder_serial(uvw, ncorr = scp.ncorr_out(policy_type=literally(stokes_conversion_policy)) if gridstack.shape[0] < nband: raise ValueError( - "Not enough channel bands in grid stack to match mfs band mapping") + "Not enough channel bands in grid stack to match mfs band mapping" + ) if uvw.shape[1] != 3: raise ValueError("UVW array must be array of tripples") if uvw.shape[0] != nrow: - raise ValueError( - "UVW array must have same number of rows as vis array") + raise ValueError("UVW array must have same number of rows as vis array") if nvischan != wavelengths.size: raise ValueError("Chanmap must correspond to visibility channels") @@ -239,27 +248,29 @@ def degridder_serial(uvw, # scale the FOV using the simularity theorem scale_factor = npix * cell / 3600.0 * np.pi / 180.0 for r in range(nrow): - degridder_row_kernel(uvw, - gridstack, - wavelengths, - chanmap, - cell, - image_centre, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - literally(baseline_transform_policy), - literally(phase_transform_policy), - literally(stokes_conversion_policy), - literally(convolution_policy), - vis_dtype=vis_dtype, - nband=nband, - nrow=nrow, - npix=npix, - nvischan=nvischan, - ncorr=ncorr, - vis=vis, - scale_factor=scale_factor, - r=r) + degridder_row_kernel( + uvw, + gridstack, + wavelengths, + chanmap, + cell, + image_centre, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + literally(baseline_transform_policy), + literally(phase_transform_policy), + literally(stokes_conversion_policy), + literally(convolution_policy), + vis_dtype=vis_dtype, + nband=nband, + nrow=nrow, + npix=npix, + nvischan=nvischan, + ncorr=ncorr, + vis=vis, + scale_factor=scale_factor, + r=r, + ) return vis diff --git a/africanus/gridding/perleypolyhedron/gridder.py b/africanus/gridding/perleypolyhedron/gridder.py index 0acbeef6b..e25f42fd9 100644 --- a/africanus/gridding/perleypolyhedron/gridder.py +++ b/africanus/gridding/perleypolyhedron/gridder.py @@ -3,31 +3,32 @@ from africanus.util.numba import jit from africanus.gridding.perleypolyhedron.policies import ( - baseline_transform_policies as btp) -from africanus.gridding.perleypolyhedron.policies import ( - phase_transform_policies as ptp) -from africanus.gridding.perleypolyhedron.policies import ( - convolution_policies as cp) + baseline_transform_policies as btp, +) +from africanus.gridding.perleypolyhedron.policies import phase_transform_policies as ptp +from africanus.gridding.perleypolyhedron.policies import convolution_policies as cp @jit(nopython=True, nogil=True, fastmath=True, parallel=False) -def gridder(uvw, - vis, - wavelengths, - chanmap, - npix, - cell, - image_centre, - phase_centre, - convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - baseline_transform_policy, - phase_transform_policy, - stokes_conversion_policy, - convolution_policy, - grid_dtype=np.complex128, - do_normalize=False): +def gridder( + uvw, + vis, + wavelengths, + chanmap, + npix, + cell, + image_centre, + phase_centre, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + baseline_transform_policy, + phase_transform_policy, + stokes_conversion_policy, + convolution_policy, + grid_dtype=np.complex128, + do_normalize=False, +): """ 2D Convolutional gridder, contiguous to discrete @uvw: value coordinates, (nrow, 3) @@ -58,8 +59,7 @@ def gridder(uvw, @do_normalize: normalize grid by convolution weights """ if chanmap.size != wavelengths.size: - raise ValueError( - "Chanmap and corresponding wavelengths must match in shape") + raise ValueError("Chanmap and corresponding wavelengths must match in shape") chanmap = chanmap.ravel() wavelengths = wavelengths.ravel() nband = np.max(chanmap) + 1 @@ -67,8 +67,7 @@ def gridder(uvw, if uvw.shape[1] != 3: raise ValueError("UVW array must be array of tripples") if uvw.shape[0] != nrow: - raise ValueError( - "UVW array must have same number of rows as vis array") + raise ValueError("UVW array must have same number of rows as vis array") if nvischan != wavelengths.size: raise ValueError("Chanmap must correspond to visibility channels") @@ -80,17 +79,18 @@ def gridder(uvw, for r in range(nrow): ra0, dec0 = phase_centre ra, dec = image_centre - ptp.policy(vis[r, :, :], - uvw[r, :], - wavelengths, - ra0, - dec0, - ra, - dec, - policy_type=literally(phase_transform_policy), - phasesign=1.0) - btp.policy(uvw[r, :], ra0, dec0, ra, dec, - literally(baseline_transform_policy)) + ptp.policy( + vis[r, :, :], + uvw[r, :], + wavelengths, + ra0, + dec0, + ra, + dec, + policy_type=literally(phase_transform_policy), + phasesign=1.0, + ) + btp.policy(uvw[r, :], ra0, dec0, ra, dec, literally(baseline_transform_policy)) for c in range(nvischan): scaled_u = uvw[r, 0] * scale_factor / wavelengths[c] scaled_v = uvw[r, 1] * scale_factor / wavelengths[c] @@ -109,7 +109,8 @@ def gridder(uvw, convolution_kernel_width, convolution_kernel_oversampling, literally(stokes_conversion_policy), - policy_type=literally(convolution_policy)) + policy_type=literally(convolution_policy), + ) if do_normalize: for c in range(nband): gridstack[c, :, :] /= wt_ch[c] + 1.0e-8 diff --git a/africanus/gridding/perleypolyhedron/kernels.py b/africanus/gridding/perleypolyhedron/kernels.py index 7fec5d18b..e503ca45f 100644 --- a/africanus/gridding/perleypolyhedron/kernels.py +++ b/africanus/gridding/perleypolyhedron/kernels.py @@ -28,9 +28,7 @@ def uspace(W, oversample): """ # must be odd so that the taps can be centred at the origin assert W % 2 == 1 - taps = np.arange( - oversample * - (W + 2)) / float(oversample) - (W + 2) // 2 + taps = np.arange(oversample * (W + 2)) / float(oversample) - (W + 2) // 2 # (|+.) * W centred at 0 return taps @@ -45,11 +43,13 @@ def sinc(W, oversample=5, a=1.0): _KBSINC_AUTOCOEFFS = np.polyfit( - [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0], - [1.9980, 2.3934, 3.3800, 4.2054, 4.9107, 5.7567, 6.6291, 7.4302], 1) + [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0], + [1.9980, 2.3934, 3.3800, 4.2054, 4.9107, 5.7567, 6.6291, 7.4302], + 1, +) -@requires_optional('scipy', scipy_import_error) +@requires_optional("scipy", scipy_import_error) def kbsinc(W, b=None, oversample=5, order=15): """ Modified keiser bessel windowed sinc (Jackson et al., @@ -61,15 +61,14 @@ def kbsinc(W, b=None, oversample=5, order=15): b = np.poly1d(_KBSINC_AUTOCOEFFS)((W + 2)) u = uspace(W, oversample) - wnd = jn(order, b * np.sqrt(1 - (2 * u / - ((W + 2) + 1))**2)) * 1 / ((W + 2) + 1) + wnd = jn(order, b * np.sqrt(1 - (2 * u / ((W + 2) + 1)) ** 2)) * 1 / ((W + 2) + 1) res = sinc(W, oversample=oversample) * wnd * np.sum(wnd) return res / np.sum(res) _HANNING_AUTOCOEFFS = np.polyfit( - [1.5, 2.0, 2.5, 3.0, 3.5], - [0.7600, 0.7146, 0.6185, 0.5534, 0.5185], 3) + [1.5, 2.0, 2.5, 3.0, 3.5], [0.7600, 0.7146, 0.6185, 0.5534, 0.5185], 3 +) def hanningsinc(W, a=None, oversample=5): @@ -96,7 +95,7 @@ def pack_kernel(K, W, oversample=5): """ pkern = np.empty(oversample * (W + 2), dtype=K.dtype) for t in range(oversample): - pkern[t * (W + 2):(t + 1) * (W + 2)] = K[t::oversample] + pkern[t * (W + 2) : (t + 1) * (W + 2)] = K[t::oversample] return pkern @@ -112,7 +111,7 @@ def unpack_kernel(K, W, oversample=5): """ upkern = np.empty(oversample * (W + 2), dtype=K.dtype) for t in range(oversample): - upkern[t::oversample] = K[t * (W + 2):(t + 1) * (W + 2)] + upkern[t::oversample] = K[t * (W + 2) : (t + 1) * (W + 2)] return upkern @@ -123,14 +122,19 @@ def compute_detaper(npix, K, W, oversample=5): Assumes a 2D square kernel to be passed as argument K """ pk = np.zeros((npix * oversample, npix * oversample)) - pk[npix * oversample // 2 - K.shape[0] // 2:npix * oversample // 2 - - K.shape[0] // 2 + K.shape[0], - npix * oversample // 2 - K.shape[1] // 2:npix * oversample // 2 - - K.shape[1] // 2 + K.shape[1]] = K + pk[ + npix * oversample // 2 - K.shape[0] // 2 : npix * oversample // 2 + - K.shape[0] // 2 + + K.shape[0], + npix * oversample // 2 - K.shape[1] // 2 : npix * oversample // 2 + - K.shape[1] // 2 + + K.shape[1], + ] = K fpk = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(pk))) - fk = fpk[npix * oversample // 2 - npix // 2:npix * oversample // 2 - - npix // 2 + npix, npix * oversample // 2 - - npix // 2:npix * oversample // 2 - npix // 2 + npix] + fk = fpk[ + npix * oversample // 2 - npix // 2 : npix * oversample // 2 - npix // 2 + npix, + npix * oversample // 2 - npix // 2 : npix * oversample // 2 - npix // 2 + npix, + ] return np.abs(fk) @@ -154,8 +158,7 @@ def compute_detaper_dft(npix, K, W, oversample=5): for x in range(K.size): xx = ksample[x % K.shape[1]] yy = ksample[x // K.shape[1]] - pk[mm, ll] += rK[x] * np.exp(-2.0j * np.pi * - (llN * xx + mmN * yy)) + pk[mm, ll] += rK[x] * np.exp(-2.0j * np.pi * (llN * xx + mmN * yy)) return np.abs(pk) diff --git a/africanus/gridding/perleypolyhedron/policies/baseline_transform_policies.py b/africanus/gridding/perleypolyhedron/policies/baseline_transform_policies.py index e09b9cc29..0ea9d1642 100644 --- a/africanus/gridding/perleypolyhedron/policies/baseline_transform_policies.py +++ b/africanus/gridding/perleypolyhedron/policies/baseline_transform_policies.py @@ -7,7 +7,7 @@ def uvw_norotate(uvw, ra0, dec0, ra, dec, policy_type): def uvw_rotate(uvw, ra0, dec0, ra, dec, policy_type): - ''' + """ Compute the following 3x3 coordinate transformation matrix: Z_rot(facet_new_rotation) * \\ T(new_phase_centre_ra,new_phase_centre_dec) * \\ @@ -27,7 +27,7 @@ def uvw_rotate(uvw, ra0, dec0, ra, dec, policy_type): centre here, so the last rotation matrix is ignored! This transformation will let the image be tangent to the celestial sphere at the new delay centre - ''' + """ d_ra = ra - ra0 c_d_ra = cos(d_ra) s_d_ra = sin(d_ra) @@ -51,18 +51,18 @@ def uvw_rotate(uvw, ra0, dec0, ra, dec, policy_type): @jit(nopython=True, nogil=True, fastmath=True, parallel=False) def uvw_planarwapprox(uvw, ra0, dec0, ra, dec, policy_type): - ''' - Implements the coordinate uv transform associated with taking a planar - approximation to w(n-1) as described in Kogan & Greisen's AIPS Memo 113 - This is essentially equivalent to rotating the facet to be tangent to - the celestial sphere as Perley suggested to limit error, but it instead - takes w into account in a linear approximation to the phase error near - the facet centre. This keeps the facets parallel to the original facet - plane. Of course this 2D taylor expansion of the first order is only - valid over a small field of view, but that true of normal tilted - faceting as well. Only a convolution can get rid of the (n-1) - factor in the ME. - ''' + """ + Implements the coordinate uv transform associated with taking a planar + approximation to w(n-1) as described in Kogan & Greisen's AIPS Memo 113 + This is essentially equivalent to rotating the facet to be tangent to + the celestial sphere as Perley suggested to limit error, but it instead + takes w into account in a linear approximation to the phase error near + the facet centre. This keeps the facets parallel to the original facet + plane. Of course this 2D taylor expansion of the first order is only + valid over a small field of view, but that true of normal tilted + faceting as well. Only a convolution can get rid of the (n-1) + factor in the ME. + """ d_ra = ra - ra0 n_dec = dec o_dec = dec0 diff --git a/africanus/gridding/perleypolyhedron/policies/convolution_policies.py b/africanus/gridding/perleypolyhedron/policies/convolution_policies.py index b0308a95e..2093be11c 100644 --- a/africanus/gridding/perleypolyhedron/policies/convolution_policies.py +++ b/africanus/gridding/perleypolyhedron/policies/convolution_policies.py @@ -4,11 +4,21 @@ def convolve_1d_axisymmetric_unpacked_scatter( - scaled_u, scaled_v, scaled_w, npix, grid, vis, r, c, - convolution_kernel, convolution_kernel_width, - convolution_kernel_oversampling, stokes_conversion_policy, - policy_type): - ''' + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type, +): + """ Convolution policy for a 1D axisymmetric unpacked AA kernel (gridding kernel) @scaled_u: simularity theorem and lambda scaled u @@ -26,7 +36,7 @@ def convolve_1d_axisymmetric_unpacked_scatter( @stokes_conversion_policy: any accepted correlation to stokes conversion policy in .policies.stokes_conversion_policies - ''' + """ offset_u = scaled_u + npix // 2 offset_v = scaled_v + npix // 2 disc_u = int(np.round(offset_u)) @@ -35,30 +45,44 @@ def convolve_1d_axisymmetric_unpacked_scatter( frac_v = int((-offset_v + disc_v) * convolution_kernel_oversampling) cw = 0.0 for tv in range(convolution_kernel_width): - conv_v = convolution_kernel[(tv + 1) * convolution_kernel_oversampling - + frac_v] + conv_v = convolution_kernel[(tv + 1) * convolution_kernel_oversampling + frac_v] grid_v_lookup = disc_v + tv - convolution_kernel_width // 2 for tu in range(convolution_kernel_width): - conv_u = convolution_kernel[(tu + 1) * - convolution_kernel_oversampling + - frac_u] + conv_u = convolution_kernel[ + (tu + 1) * convolution_kernel_oversampling + frac_u + ] grid_u_lookup = disc_u + tu - convolution_kernel_width // 2 - if (grid_v_lookup >= 0 and grid_v_lookup < npix - and grid_u_lookup >= 0 and grid_u_lookup < npix): - grid[grid_v_lookup, grid_u_lookup] += \ - conv_v * conv_u * \ - scp.corr2stokes(vis[r, c, :], - stokes_conversion_policy) + if ( + grid_v_lookup >= 0 + and grid_v_lookup < npix + and grid_u_lookup >= 0 + and grid_u_lookup < npix + ): + grid[grid_v_lookup, grid_u_lookup] += ( + conv_v + * conv_u + * scp.corr2stokes(vis[r, c, :], stokes_conversion_policy) + ) cw += conv_v * conv_u return cw def convolve_1d_axisymmetric_packed_scatter( - scaled_u, scaled_v, scaled_w, npix, grid, vis, r, c, - convolution_kernel, convolution_kernel_width, - convolution_kernel_oversampling, stokes_conversion_policy, - policy_type): - ''' + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type, +): + """ Convolution policy for a 1D axisymmetric packed AA kernel (gridding kernel) @scaled_u: simularity theorem and lambda scaled u @scaled_v: simularity theorem and lambda scaled v @@ -73,7 +97,7 @@ def convolve_1d_axisymmetric_packed_scatter( @stokes_conversion_policy: any accepted correlation to stokes conversion policy in .policies.stokes_conversion_policies - ''' + """ offset_u = scaled_u + npix // 2 offset_v = scaled_v + npix // 2 disc_u = int(np.round(offset_u)) @@ -95,28 +119,46 @@ def convolve_1d_axisymmetric_packed_scatter( # where frac wraps around to negative indexing cw = 0.0 for tv in range(convolution_kernel_width): - conv_v = convolution_kernel[tv + frac_offset_v + frac_v * - (convolution_kernel_width + 2)] + conv_v = convolution_kernel[ + tv + frac_offset_v + frac_v * (convolution_kernel_width + 2) + ] grid_v_lookup = disc_v + tv - convolution_kernel_width // 2 for tu in range(convolution_kernel_width): - conv_u = convolution_kernel[tu + frac_offset_u + frac_u * - (convolution_kernel_width + 2)] + conv_u = convolution_kernel[ + tu + frac_offset_u + frac_u * (convolution_kernel_width + 2) + ] grid_u_lookup = disc_u + tu - convolution_kernel_width // 2 - if (grid_v_lookup >= 0 and grid_v_lookup < npix - and grid_u_lookup >= 0 and grid_u_lookup < npix): - grid[grid_v_lookup, grid_u_lookup] += \ - conv_v * conv_u * \ - scp.corr2stokes(vis[r, c, :], - stokes_conversion_policy) + if ( + grid_v_lookup >= 0 + and grid_v_lookup < npix + and grid_u_lookup >= 0 + and grid_u_lookup < npix + ): + grid[grid_v_lookup, grid_u_lookup] += ( + conv_v + * conv_u + * scp.corr2stokes(vis[r, c, :], stokes_conversion_policy) + ) cw += conv_v * conv_u return cw -def convolve_nn_scatter(scaled_u, scaled_v, scaled_w, npix, grid, vis, r, c, - convolution_kernel, convolution_kernel_width, - convolution_kernel_oversampling, - stokes_conversion_policy, policy_type): - ''' +def convolve_nn_scatter( + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type, +): + """ Convolution policy for a nn scatter kernel (gridding kernel) @scaled_u: simularity theorem and lambda scaled u @scaled_v: simularity theorem and lambda scaled v @@ -133,25 +175,32 @@ def convolve_nn_scatter(scaled_u, scaled_v, scaled_w, npix, grid, vis, r, c, @stokes_conversion_policy: any accepted correlation to stokes conversion policy in .policies.stokes_conversion_policies - ''' + """ offset_u = scaled_u + npix // 2 offset_v = scaled_v + npix // 2 disc_u = int(np.round(offset_u)) disc_v = int(np.round(offset_v)) cw = 1.0 - grid[disc_v, disc_u] += \ - scp.corr2stokes(vis[r, c, :], - stokes_conversion_policy) + grid[disc_v, disc_u] += scp.corr2stokes(vis[r, c, :], stokes_conversion_policy) return cw -def convolve_1d_axisymmetric_packed_gather(scaled_u, scaled_v, scaled_w, npix, - grid, vis, r, c, convolution_kernel, - convolution_kernel_width, - convolution_kernel_oversampling, - stokes_conversion_policy, - policy_type): - ''' +def convolve_1d_axisymmetric_packed_gather( + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type, +): + """ Convolution policy for a 1D axisymmetric packed AA kernel (degridding kernel) @scaled_u: simularity theorem and lambda scaled u @@ -168,7 +217,7 @@ def convolve_1d_axisymmetric_packed_gather(scaled_u, scaled_v, scaled_w, npix, @stokes_conversion_policy: any accepted correlation to stokes conversion policy in .policies.stokes_conversion_policies - ''' + """ offset_u = scaled_u + npix // 2 offset_v = scaled_v + npix // 2 disc_u = int(np.round(offset_u)) @@ -190,30 +239,51 @@ def convolve_1d_axisymmetric_packed_gather(scaled_u, scaled_v, scaled_w, npix, # where frac wraps around to negative indexing cw = 0 for tv in range(convolution_kernel_width): - conv_v = convolution_kernel[tv + frac_offset_v + frac_v * - (convolution_kernel_width + 2)] + conv_v = convolution_kernel[ + tv + frac_offset_v + frac_v * (convolution_kernel_width + 2) + ] grid_v_lookup = disc_v + tv - convolution_kernel_width // 2 for tu in range(convolution_kernel_width): - conv_u = convolution_kernel[tu + frac_offset_u + frac_u * - (convolution_kernel_width + 2)] + conv_u = convolution_kernel[ + tu + frac_offset_u + frac_u * (convolution_kernel_width + 2) + ] grid_u_lookup = disc_u + tu - convolution_kernel_width // 2 - if (grid_v_lookup >= 0 and grid_v_lookup < npix - and grid_u_lookup >= 0 and grid_u_lookup < npix): + if ( + grid_v_lookup >= 0 + and grid_v_lookup < npix + and grid_u_lookup >= 0 + and grid_u_lookup < npix + ): scp.stokes2corr( - grid[disc_v + tv - convolution_kernel_width // 2, disc_u + - tu - convolution_kernel_width // 2] * conv_v * conv_u, + grid[ + disc_v + tv - convolution_kernel_width // 2, + disc_u + tu - convolution_kernel_width // 2, + ] + * conv_v + * conv_u, vis[r, c, :], - policy_type=stokes_conversion_policy) + policy_type=stokes_conversion_policy, + ) cw += conv_v * conv_u vis[r, c, :] /= cw + 1.0e-8 def convolve_1d_axisymmetric_unpacked_gather( - scaled_u, scaled_v, scaled_w, npix, grid, vis, r, c, - convolution_kernel, convolution_kernel_width, - convolution_kernel_oversampling, stokes_conversion_policy, - policy_type): - ''' + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type, +): + """ Convolution policy for a 1D axisymmetric unpacked AA kernel (degridding kernel) @scaled_u: simularity theorem and lambda scaled u @@ -230,7 +300,7 @@ def convolve_1d_axisymmetric_unpacked_gather( @stokes_conversion_policy: any accepted correlation to stokes conversion policy in .policies.stokes_conversion_policies - ''' + """ offset_u = scaled_u + npix // 2 offset_v = scaled_v + npix // 2 disc_u = int(np.round(offset_u)) @@ -239,36 +309,62 @@ def convolve_1d_axisymmetric_unpacked_gather( frac_v = int((-offset_v + disc_v) * convolution_kernel_oversampling) cw = 0 for tv in range(convolution_kernel_width): - conv_v = convolution_kernel[(tv + 1) * convolution_kernel_oversampling - + frac_v] + conv_v = convolution_kernel[(tv + 1) * convolution_kernel_oversampling + frac_v] grid_v_lookup = disc_v + tv - convolution_kernel_width // 2 for tu in range(convolution_kernel_width): - conv_u = convolution_kernel[(tu + 1) * - convolution_kernel_oversampling + - frac_u] + conv_u = convolution_kernel[ + (tu + 1) * convolution_kernel_oversampling + frac_u + ] grid_u_lookup = disc_u + tu - convolution_kernel_width // 2 - if (grid_v_lookup >= 0 and grid_v_lookup < npix - and grid_u_lookup >= 0 and grid_u_lookup < npix): + if ( + grid_v_lookup >= 0 + and grid_v_lookup < npix + and grid_u_lookup >= 0 + and grid_u_lookup < npix + ): scp.stokes2corr( grid[grid_v_lookup, grid_u_lookup] * conv_v * conv_u, vis[r, c, :], - policy_type=stokes_conversion_policy) + policy_type=stokes_conversion_policy, + ) cw += conv_v * conv_u vis[r, c, :] /= cw + 1.0e-8 -def policy(scaled_u, scaled_v, scaled_w, npix, grid, vis, r, c, - convolution_kernel, convolution_kernel_width, - convolution_kernel_oversampling, stokes_conversion_policy, - policy_type): +def policy( + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type, +): pass @overload(policy, inline="always") -def policy_impl(scaled_u, scaled_v, scaled_w, npix, grid, vis, r, c, - convolution_kernel, convolution_kernel_width, - convolution_kernel_oversampling, stokes_conversion_policy, - policy_type): +def policy_impl( + scaled_u, + scaled_v, + scaled_w, + npix, + grid, + vis, + r, + c, + convolution_kernel, + convolution_kernel_width, + convolution_kernel_oversampling, + stokes_conversion_policy, + policy_type, +): if policy_type.literal_value == "conv_1d_axisymmetric_packed_scatter": return convolve_1d_axisymmetric_packed_scatter elif policy_type.literal_value == "conv_nn_scatter": diff --git a/africanus/gridding/perleypolyhedron/policies/phase_transform_policies.py b/africanus/gridding/perleypolyhedron/policies/phase_transform_policies.py index 8afc3aec3..b42e490d1 100644 --- a/africanus/gridding/perleypolyhedron/policies/phase_transform_policies.py +++ b/africanus/gridding/perleypolyhedron/policies/phase_transform_policies.py @@ -2,37 +2,21 @@ from numpy import pi, cos, sin, sqrt -def phase_norotate(vis, - uvw, - lambdas, - ra0, - dec0, - ra, - dec, - policy_type, - phasesign=1.0): +def phase_norotate(vis, uvw, lambdas, ra0, dec0, ra, dec, policy_type, phasesign=1.0): pass -def phase_rotate(vis, - uvw, - lambdas, - ra0, - dec0, - ra, - dec, - policy_type, - phasesign=1.0): - ''' - Convert ra,dec to l,m,n based on Synthesis Imaging II, Pg. 388 - The phase term (as documented in Perley & Cornwell (1992)) - calculation requires the delta l,m,n coordinates. - Through simplification l0,m0,n0 = (0,0,1) (assume dec == dec0 and - ra == ra0, and the simplification follows) - l,m,n is then calculated using the new and original phase centres - as per the relation on Pg. 388 - lambdas has the same shape as vis - ''' +def phase_rotate(vis, uvw, lambdas, ra0, dec0, ra, dec, policy_type, phasesign=1.0): + """ + Convert ra,dec to l,m,n based on Synthesis Imaging II, Pg. 388 + The phase term (as documented in Perley & Cornwell (1992)) + calculation requires the delta l,m,n coordinates. + Through simplification l0,m0,n0 = (0,0,1) (assume dec == dec0 and + ra == ra0, and the simplification follows) + l,m,n is then calculated using the new and original phase centres + as per the relation on Pg. 388 + lambdas has the same shape as vis + """ d_ra = ra - ra0 d_dec = dec d_decp = dec0 @@ -43,11 +27,10 @@ def phase_rotate(vis, c_d_decp = cos(d_decp) s_d_decp = sin(d_decp) ll = c_d_dec * s_d_ra - mm = (s_d_dec * c_d_decp - c_d_dec * s_d_decp * c_d_ra) + mm = s_d_dec * c_d_decp - c_d_dec * s_d_decp * c_d_ra nn = -(1 - sqrt(1 - ll * ll - mm * mm)) for c in range(lambdas.size): - x = phasesign * 2 * pi * (uvw[0] * ll + uvw[1] * mm + - uvw[2] * nn) / lambdas[c] + x = phasesign * 2 * pi * (uvw[0] * ll + uvw[1] * mm + uvw[2] * nn) / lambdas[c] vis[c, :] *= cos(x) + 1.0j * sin(x) @@ -56,17 +39,8 @@ def policy(vis, uvw, lambdas, ra0, dec0, ra, dec, policy_type, phasesign=1.0): @overload(policy, inline="always") -def policy_impl(vis, - uvw, - lambdas, - ra0, - dec0, - ra, - dec, - policy_type, - phasesign=1.0): - if policy_type.literal_value == "None" or \ - policy_type.literal_value is None: +def policy_impl(vis, uvw, lambdas, ra0, dec0, ra, dec, policy_type, phasesign=1.0): + if policy_type.literal_value == "None" or policy_type.literal_value is None: return phase_norotate elif policy_type.literal_value == "phase_rotate": return phase_rotate diff --git a/africanus/gridding/perleypolyhedron/policies/stokes_conversion_policies.py b/africanus/gridding/perleypolyhedron/policies/stokes_conversion_policies.py index 7cc5811ad..873fd744d 100644 --- a/africanus/gridding/perleypolyhedron/policies/stokes_conversion_policies.py +++ b/africanus/gridding/perleypolyhedron/policies/stokes_conversion_policies.py @@ -164,21 +164,17 @@ def corr2stokesimpl(vis_in, policy_type): elif policy_type.literal_value == "U_FROM_XXXYYXYY": return lambda vis_in, policy_type: (vis_in[1] + vis_in[2]) * 0.5 elif policy_type.literal_value == "U_FROM_RLLR": - return lambda vis_in, policy_type: -1.0j * (vis_in[0] - vis_in[1] - ) * 0.5 + return lambda vis_in, policy_type: -1.0j * (vis_in[0] - vis_in[1]) * 0.5 elif policy_type.literal_value == "U_FROM_RRRLLRLL": - return lambda vis_in, policy_type: -1.0j * (vis_in[1] - vis_in[2] - ) * 0.5 + return lambda vis_in, policy_type: -1.0j * (vis_in[1] - vis_in[2]) * 0.5 elif policy_type.literal_value == "V_FROM_RRLL": return lambda vis_in, policy_type: (vis_in[0] - vis_in[1]) * 0.5 elif policy_type.literal_value == "V_FROM_RRRLLRLL": return lambda vis_in, policy_type: (vis_in[0] - vis_in[3]) * 0.5 elif policy_type.literal_value == "V_FROM_XYYX": - return lambda vis_in, policy_type: -1.0j * (vis_in[0] - vis_in[1] - ) * 0.5 + return lambda vis_in, policy_type: -1.0j * (vis_in[0] - vis_in[1]) * 0.5 elif policy_type.literal_value == "V_FROM_XXXYYXYY": - return lambda vis_in, policy_type: -1.0j * (vis_in[1] - vis_in[2] - ) * 0.5 + return lambda vis_in, policy_type: -1.0j * (vis_in[1] - vis_in[2]) * 0.5 else: raise ValueError("Invalid stokes conversion") diff --git a/africanus/gridding/perleypolyhedron/tests/test_daskintrf.py b/africanus/gridding/perleypolyhedron/tests/test_daskintrf.py index 7781b5c07..c82dc44bd 100644 --- a/africanus/gridding/perleypolyhedron/tests/test_daskintrf.py +++ b/africanus/gridding/perleypolyhedron/tests/test_daskintrf.py @@ -4,9 +4,7 @@ import numpy as np import pytest -from africanus.gridding.perleypolyhedron import (kernels, - gridder, - degridder) +from africanus.gridding.perleypolyhedron import kernels, gridder, degridder from africanus.gridding.perleypolyhedron import dask as dwrap from africanus.dft.kernels import im_to_vis from africanus.constants import c as lightspeed @@ -33,8 +31,8 @@ def elapsed(self): def __str__(self): res = "{0:s}: Walltime {1:.0f}m{2:.2f}s elapsed".format( - self._id, self.elapsed // 60, - self.elapsed - (self.elapsed // 60) * 60) + self._id, self.elapsed // 60, self.elapsed - (self.elapsed // 60) * 60 + ) return res __repr__ = __str__ @@ -58,38 +56,41 @@ def test_gridder_dask(): d0 = np.pi / 4.0 for n in range(25): for ih0, h0 in enumerate( - np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)): + np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime) + ): s = np.sin c = np.cos - R = np.array([[s(h0), c(h0), 0], - [-s(d0) * c(h0), - s(d0) * s(h0), - c(d0)], - [c(d0) * c(h0), -c(d0) * s(h0), - s(d0)]]) + R = np.array( + [ + [s(h0), c(h0), 0], + [-s(d0) * c(h0), s(d0) * s(h0), c(d0)], + [c(d0) * c(h0), -c(d0) * s(h0), s(d0)], + ] + ) uvw[n * ntime + ih0, :] = np.dot(R, blpos[n, :].T) uvw = da.from_array(uvw, chunks=(row_chunks, 3)) pxacrossbeam = 5 nchan = 128 - frequency = da.from_array(np.linspace(1.0e9, 1.4e9, nchan), - chunks=(nchan, )) + frequency = da.from_array(np.linspace(1.0e9, 1.4e9, nchan), chunks=(nchan,)) wavelength = lightspeed / frequency cell = da.rad2deg( - wavelength[0] / - (max(da.max(da.absolute(uvw[:, 0])), - da.max(da.absolute(uvw[:, 1]))) * pxacrossbeam)) + wavelength[0] + / ( + max(da.max(da.absolute(uvw[:, 0])), da.max(da.absolute(uvw[:, 1]))) + * pxacrossbeam + ) + ) npixfacet = 100 fftpad = 1.1 image_centres = da.from_array(np.array([[0, d0]]), chunks=(1, 2)) - chanmap = da.from_array(np.zeros(nchan, dtype=np.int64), - chunks=(nchan, )) + chanmap = da.from_array(np.zeros(nchan, dtype=np.int64), chunks=(nchan,)) detaper_facet = kernels.compute_detaper_dft_seperable( - int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, - OS) - vis_dft = da.ones(shape=(nrow, nchan, 2), - chunks=(row_chunks, nchan, 2), - dtype=np.complex64) + int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS + ) + vis_dft = da.ones( + shape=(nrow, nchan, 2), chunks=(row_chunks, nchan, 2), dtype=np.complex64 + ) vis_grid_facet = dwrap.gridder( uvw, vis_dft, @@ -97,7 +98,8 @@ def test_gridder_dask(): chanmap, int(npixfacet * fftpad), cell * 3600.0, - image_centres, (0, d0), + image_centres, + (0, d0), kern, W, OS, @@ -105,25 +107,31 @@ def test_gridder_dask(): "None", "I_FROM_XXYY", "conv_1d_axisymmetric_packed_scatter", - do_normalize=True) + do_normalize=True, + ) vis_grid_facet = vis_grid_facet.compute() - ftvisfacet = (np.fft.fftshift( - np.fft.ifft2(np.fft.ifftshift( - vis_grid_facet[0, :, :]))).reshape( - (1, int(npixfacet * fftpad), int( - npixfacet * fftpad)))).real / detaper_facet * int( - npixfacet * fftpad)**2 - ftvisfacet = ftvisfacet[:, - int(npixfacet * fftpad) // 2 - npixfacet // - 2:int(npixfacet * fftpad) // 2 - - npixfacet // 2 + npixfacet, - int(npixfacet * fftpad) // 2 - npixfacet // - 2:int(npixfacet * fftpad) // 2 - - npixfacet // 2 + npixfacet] + ftvisfacet = ( + ( + np.fft.fftshift( + np.fft.ifft2(np.fft.ifftshift(vis_grid_facet[0, :, :])) + ).reshape((1, int(npixfacet * fftpad), int(npixfacet * fftpad))) + ).real + / detaper_facet + * int(npixfacet * fftpad) ** 2 + ) + ftvisfacet = ftvisfacet[ + :, + int(npixfacet * fftpad) // 2 - npixfacet // 2 : int(npixfacet * fftpad) // 2 + - npixfacet // 2 + + npixfacet, + int(npixfacet * fftpad) // 2 - npixfacet // 2 : int(npixfacet * fftpad) // 2 + - npixfacet // 2 + + npixfacet, + ] print(tictoc) - assert (np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6) + assert np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6 def test_gridder_nondask(): @@ -141,32 +149,37 @@ def test_gridder_nondask(): d0 = np.pi / 4.0 for n in range(25): for ih0, h0 in enumerate( - np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)): + np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime) + ): s = np.sin c = np.cos - R = np.array([[s(h0), c(h0), 0], - [-s(d0) * c(h0), - s(d0) * s(h0), - c(d0)], - [c(d0) * c(h0), -c(d0) * s(h0), - s(d0)]]) + R = np.array( + [ + [s(h0), c(h0), 0], + [-s(d0) * c(h0), s(d0) * s(h0), c(d0)], + [c(d0) * c(h0), -c(d0) * s(h0), s(d0)], + ] + ) uvw[n * ntime + ih0, :] = np.dot(R, blpos[n, :].T) pxacrossbeam = 5 nchan = 128 frequency = np.linspace(1.0e9, 1.4e9, nchan) wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (max(np.max(np.absolute(uvw[:, 0])), - np.max(np.absolute(uvw[:, 1]))) * pxacrossbeam)) + wavelength[0] + / ( + max(np.max(np.absolute(uvw[:, 0])), np.max(np.absolute(uvw[:, 1]))) + * pxacrossbeam + ) + ) npixfacet = 100 fftpad = 1.1 image_centres = np.array([[0, d0]]) chanmap = np.zeros(nchan, dtype=np.int64) detaper_facet = kernels.compute_detaper_dft_seperable( - int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, - OS) + int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS + ) vis_dft = np.ones((nrow, nchan, 2), dtype=np.complex64) vis_grid_facet = gridder.gridder( uvw, @@ -175,7 +188,8 @@ def test_gridder_nondask(): chanmap, int(npixfacet * fftpad), cell * 3600.0, - image_centres[0, :], (0, d0), + image_centres[0, :], + (0, d0), kern, W, OS, @@ -183,35 +197,43 @@ def test_gridder_nondask(): "None", "I_FROM_XXYY", "conv_1d_axisymmetric_packed_scatter", - do_normalize=True) - ftvisfacet = (np.fft.fftshift( - np.fft.ifft2(np.fft.ifftshift( - vis_grid_facet[0, :, :]))).reshape( - (1, int(npixfacet * fftpad), int( - npixfacet * fftpad)))).real / detaper_facet * int( - npixfacet * fftpad)**2 - ftvisfacet = ftvisfacet[:, - int(npixfacet * fftpad) // 2 - npixfacet // - 2:int(npixfacet * fftpad) // 2 - - npixfacet // 2 + npixfacet, - int(npixfacet * fftpad) // 2 - npixfacet // - 2:int(npixfacet * fftpad) // 2 - - npixfacet // 2 + npixfacet] + do_normalize=True, + ) + ftvisfacet = ( + ( + np.fft.fftshift( + np.fft.ifft2(np.fft.ifftshift(vis_grid_facet[0, :, :])) + ).reshape((1, int(npixfacet * fftpad), int(npixfacet * fftpad))) + ).real + / detaper_facet + * int(npixfacet * fftpad) ** 2 + ) + ftvisfacet = ftvisfacet[ + :, + int(npixfacet * fftpad) // 2 - npixfacet // 2 : int(npixfacet * fftpad) // 2 + - npixfacet // 2 + + npixfacet, + int(npixfacet * fftpad) // 2 - npixfacet // 2 : int(npixfacet * fftpad) // 2 + - npixfacet // 2 + + npixfacet, + ] print(tictoc) - assert (np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6) + assert np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6 def test_degrid_dft_packed_nondask(): # construct kernel W = 5 OS = 3 - kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), - W, - oversample=OS) + kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, oversample=OS) nrow = int(5e4) uvw = np.column_stack( - (5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)), - 5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), np.zeros(nrow))) + ( + 5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)), + 5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), + np.zeros(nrow), + ) + ) pxacrossbeam = 10 nchan = 1024 @@ -219,15 +241,16 @@ def test_degrid_dft_packed_nondask(): wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 512 mod = np.zeros((1, npix, npix), dtype=np.complex64) mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0 - ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift( - mod[0, :, :]))).reshape((1, npix, npix)) + ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(mod[0, :, :]))).reshape( + (1, npix, npix) + ) chanmap = np.zeros(nchan, dtype=np.int64) with clock("Non-DASK degridding") as tictoc: @@ -245,7 +268,8 @@ def test_degrid_dft_packed_nondask(): "None", # no faceting "None", # no faceting "XXYY_FROM_I", - "conv_1d_axisymmetric_packed_gather") + "conv_1d_axisymmetric_packed_gather", + ) print(tictoc) @@ -256,14 +280,16 @@ def test_degrid_dft_packed_dask(): # construct kernel W = 5 OS = 3 - kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), - W, - oversample=OS) + kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, oversample=OS) nrow = int(5e4) nrow_chunk = nrow // 32 uvw = np.column_stack( - (5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)), - 5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), np.zeros(nrow))) + ( + 5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)), + 5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), + np.zeros(nrow), + ) + ) pxacrossbeam = 10 nchan = 1024 @@ -271,23 +297,24 @@ def test_degrid_dft_packed_dask(): wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 512 mod = np.zeros((1, npix, npix), dtype=np.complex64) mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0 - ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift( - mod[0, :, :]))).reshape((1, 1, npix, npix)) + ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(mod[0, :, :]))).reshape( + (1, 1, npix, npix) + ) chanmap = np.zeros(nchan, dtype=np.int64) with clock("DASK degridding") as tictoc: vis_degrid = dwrap.degridder( da.from_array(uvw, chunks=(nrow_chunk, 3)), da.from_array(ftmod, chunks=(1, 1, npix, npix)), - da.from_array(wavelength, chunks=(nchan, )), - da.from_array(chanmap, chunks=(nchan, )), + da.from_array(wavelength, chunks=(nchan,)), + da.from_array(chanmap, chunks=(nchan,)), cell * 3600.0, da.from_array(np.array([[0, np.pi / 4.0]]), chunks=(1, 2)), (0, np.pi / 4.0), @@ -297,7 +324,8 @@ def test_degrid_dft_packed_dask(): "None", # no faceting "None", # no faceting "XXYY_FROM_I", - "conv_1d_axisymmetric_packed_gather") + "conv_1d_axisymmetric_packed_gather", + ) vis_degrid = vis_degrid.compute() @@ -310,14 +338,16 @@ def test_degrid_dft_packed_dask_dft_check(): # construct kernel W = 5 OS = 3 - kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), - W, - oversample=OS) + kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, oversample=OS) nrow = 100 nrow_chunk = nrow // 8 uvw = np.column_stack( - (5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)), - 5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), np.zeros(nrow))) + ( + 5000.0 * np.cos(np.linspace(0, 2 * np.pi, nrow)), + 5000.0 * np.sin(np.linspace(0, 2 * np.pi, nrow)), + np.zeros(nrow), + ) + ) pxacrossbeam = 10 nchan = 16 @@ -325,28 +355,31 @@ def test_degrid_dft_packed_dask_dft_check(): wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 512 mod = np.zeros((1, npix, npix), dtype=np.complex64) mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0 - ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift( - mod[0, :, :]))).reshape((1, 1, npix, npix)) + ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(mod[0, :, :]))).reshape( + (1, 1, npix, npix) + ) chanmap = np.zeros(nchan, dtype=np.int64) dec, ra = np.meshgrid( np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), - np.arange(-npix // 2, npix // 2) * np.deg2rad(cell)) + np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), + ) radec = np.column_stack((ra.flatten(), dec.flatten())) - vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), - uvw, radec, frequency) + vis_dft = im_to_vis( + mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw, radec, frequency + ) vis_degrid = dwrap.degridder( da.from_array(uvw, chunks=(nrow_chunk, 3)), da.from_array(ftmod, chunks=(1, 1, npix, npix)), - da.from_array(wavelength, chunks=(nchan, )), - da.from_array(chanmap, chunks=(nchan, )), + da.from_array(wavelength, chunks=(nchan,)), + da.from_array(chanmap, chunks=(nchan,)), cell * 3600.0, da.from_array(np.array([[0, np.pi / 4.0]]), chunks=(1, 2)), (0, np.pi / 4.0), @@ -356,13 +389,16 @@ def test_degrid_dft_packed_dask_dft_check(): "None", # no faceting "None", # no faceting "XXYY_FROM_I", - "conv_1d_axisymmetric_packed_gather") + "conv_1d_axisymmetric_packed_gather", + ) vis_degrid = vis_degrid.compute() - assert np.percentile( - np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), - 99.0) < 0.05 - assert np.percentile( - np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), - 99.0) < 0.05 + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), 99.0) + < 0.05 + ) + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), 99.0) + < 0.05 + ) diff --git a/africanus/gridding/perleypolyhedron/tests/test_ppgridder.py b/africanus/gridding/perleypolyhedron/tests/test_ppgridder.py index 13c794fe3..54a744be1 100644 --- a/africanus/gridding/perleypolyhedron/tests/test_ppgridder.py +++ b/africanus/gridding/perleypolyhedron/tests/test_ppgridder.py @@ -5,9 +5,7 @@ import numpy as np import pytest -from africanus.gridding.perleypolyhedron import (kernels, - gridder, - degridder) +from africanus.gridding.perleypolyhedron import kernels, gridder, degridder from africanus.dft.kernels import im_to_vis, vis_to_im from africanus.coordinates import radec_to_lmn from africanus.constants import c as lightspeed @@ -27,38 +25,54 @@ def test_construct_kernels(tmp_path_factory): plt.axvline(-0.5, 0, 1, ls="--", c="k") plt.plot( ll[sel] * OVERSAMP / 2 / np.pi, - 10 * np.log10( + 10 + * np.log10( np.abs( np.fft.fftshift( - np.fft.fft( - kernels.kbsinc(WIDTH, oversample=OVERSAMP, - order=0)[sel])))), - label="kbsinc order 0") + np.fft.fft(kernels.kbsinc(WIDTH, oversample=OVERSAMP, order=0)[sel]) + ) + ) + ), + label="kbsinc order 0", + ) plt.plot( ll[sel] * OVERSAMP / 2 / np.pi, - 10 * np.log10( + 10 + * np.log10( np.abs( np.fft.fftshift( np.fft.fft( - kernels.kbsinc( - WIDTH, oversample=OVERSAMP, order=15)[sel])))), - label="kbsinc order 15") - plt.plot(ll[sel] * OVERSAMP / 2 / np.pi, - 10 * np.log10( - np.abs( - np.fft.fftshift( - np.fft.fft( - kernels.hanningsinc( - WIDTH, oversample=OVERSAMP)[sel])))), - label="hanning sinc") + kernels.kbsinc(WIDTH, oversample=OVERSAMP, order=15)[sel] + ) + ) + ) + ), + label="kbsinc order 15", + ) plt.plot( ll[sel] * OVERSAMP / 2 / np.pi, - 10 * np.log10( + 10 + * np.log10( np.abs( np.fft.fftshift( - np.fft.fft( - kernels.sinc(WIDTH, oversample=OVERSAMP)[sel])))), - label="sinc") + np.fft.fft(kernels.hanningsinc(WIDTH, oversample=OVERSAMP)[sel]) + ) + ) + ), + label="hanning sinc", + ) + plt.plot( + ll[sel] * OVERSAMP / 2 / np.pi, + 10 + * np.log10( + np.abs( + np.fft.fftshift( + np.fft.fft(kernels.sinc(WIDTH, oversample=OVERSAMP)[sel]) + ) + ) + ), + label="sinc", + ) plt.xlim(-10, 10) plt.legend() plt.ylabel("Response [dB]") @@ -83,31 +97,84 @@ def test_packunpack(): Kp = kernels.pack_kernel(K, W, oversample=oversample) Kup = kernels.unpack_kernel(Kp, W, oversample=oversample) assert np.all(K == Kup) - assert np.allclose(K, [ - -2.0, -1.75, -1.5, -1.25, -1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, - 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75 - ]) - assert np.allclose(Kp, [ - -2.0, -1.0, 0, 1.0, 2.0, -1.75, -0.75, 0.25, 1.25, 2.25, -1.5, - -0.5, 0.5, 1.5, 2.5, -1.25, -0.25, 0.75, 1.75, 2.75 - ]) + assert np.allclose( + K, + [ + -2.0, + -1.75, + -1.5, + -1.25, + -1.0, + -0.75, + -0.5, + -0.25, + 0, + 0.25, + 0.5, + 0.75, + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 2.25, + 2.5, + 2.75, + ], + ) + assert np.allclose( + Kp, + [ + -2.0, + -1.0, + 0, + 1.0, + 2.0, + -1.75, + -0.75, + 0.25, + 1.25, + 2.25, + -1.5, + -0.5, + 0.5, + 1.5, + 2.5, + -1.25, + -0.25, + 0.75, + 1.75, + 2.75, + ], + ) def test_facetcodepath(): # construct kernel W = 5 OS = 3 - kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), - W, - oversample=OS) + kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, oversample=OS) # offset 0 uvw = np.array([[0, 0, 0]]) vis = np.array([[[1.0 + 0j, 1.0 + 0j]]]) - gridder.gridder(uvw, vis, np.array([1.0]), np.array([0]), 64, - 30, (0, 0), (0, 0), kern, W, OS, "rotate", - "phase_rotate", "I_FROM_XXYY", - "conv_1d_axisymmetric_packed_scatter") + gridder.gridder( + uvw, + vis, + np.array([1.0]), + np.array([0]), + 64, + 30, + (0, 0), + (0, 0), + kern, + W, + OS, + "rotate", + "phase_rotate", + "I_FROM_XXYY", + "conv_1d_axisymmetric_packed_scatter", + ) def test_degrid_dft(tmp_path_factory): @@ -116,23 +183,28 @@ def test_degrid_dft(tmp_path_factory): OS = 3 kern = kernels.kbsinc(W, oversample=OS) uvw = np.column_stack( - (5000.0 * np.cos(np.linspace(0, 2 * np.pi, 1000)), - 5000.0 * np.sin(np.linspace(0, 2 * np.pi, 1000)), np.zeros(1000))) + ( + 5000.0 * np.cos(np.linspace(0, 2 * np.pi, 1000)), + 5000.0 * np.sin(np.linspace(0, 2 * np.pi, 1000)), + np.zeros(1000), + ) + ) pxacrossbeam = 10 frequency = np.array([1.4e9]) wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 512 mod = np.zeros((1, npix, npix), dtype=np.complex64) mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0 - ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift( - mod[0, :, :]))).reshape((1, npix, npix)) + ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(mod[0, :, :]))).reshape( + (1, npix, npix) + ) chanmap = np.array([0]) vis_degrid = degridder.degridder( uvw, @@ -148,15 +220,18 @@ def test_degrid_dft(tmp_path_factory): "None", # no faceting "None", # no faceting "XXYY_FROM_I", - "conv_1d_axisymmetric_unpacked_gather") + "conv_1d_axisymmetric_unpacked_gather", + ) dec, ra = np.meshgrid( np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), - np.arange(-npix // 2, npix // 2) * np.deg2rad(cell)) + np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), + ) radec = np.column_stack((ra.flatten(), dec.flatten())) - vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), - uvw, radec, frequency) + vis_dft = im_to_vis( + mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw, radec, frequency + ) try: import matplotlib @@ -165,61 +240,68 @@ def test_degrid_dft(tmp_path_factory): else: matplotlib.use("agg") from matplotlib import pyplot as plt + plt.figure() plt.plot(vis_degrid[:, 0, 0].real, label=r"$\Re(\mathtt{degrid})$") plt.plot(vis_dft[:, 0, 0].real, label=r"$\Re(\mathtt{dft})$") - plt.plot(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), - label="Error") + plt.plot( + np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), label="Error" + ) plt.legend() plt.xlabel("sample") plt.ylabel("Real of predicted") plt.savefig( - os.path.join(os.environ.get("TMPDIR", "/tmp"), - "degrid_vs_dft_re.png")) + os.path.join(os.environ.get("TMPDIR", "/tmp"), "degrid_vs_dft_re.png") + ) plt.figure() plt.plot(vis_degrid[:, 0, 0].imag, label=r"$\Im(\mathtt{degrid})$") plt.plot(vis_dft[:, 0, 0].imag, label=r"$\Im(\mathtt{dft})$") - plt.plot(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), - label="Error") + plt.plot( + np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), label="Error" + ) plt.legend() plt.xlabel("sample") plt.ylabel("Imag of predicted") - plt.savefig(tmp_path_factory.mktemp("degrid_dft") / - "degrid_vs_dft_im.png") + plt.savefig(tmp_path_factory.mktemp("degrid_dft") / "degrid_vs_dft_im.png") - assert np.percentile( - np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), - 99.0) < 0.05 - assert np.percentile( - np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), - 99.0) < 0.05 + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), 99.0) + < 0.05 + ) + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), 99.0) + < 0.05 + ) def test_degrid_dft_packed(tmp_path_factory): # construct kernel W = 5 OS = 3 - kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), - W, - oversample=OS) + kern = kernels.pack_kernel(kernels.kbsinc(W, oversample=OS), W, oversample=OS) uvw = np.column_stack( - (5000.0 * np.cos(np.linspace(0, 2 * np.pi, 1000)), - 5000.0 * np.sin(np.linspace(0, 2 * np.pi, 1000)), np.zeros(1000))) + ( + 5000.0 * np.cos(np.linspace(0, 2 * np.pi, 1000)), + 5000.0 * np.sin(np.linspace(0, 2 * np.pi, 1000)), + np.zeros(1000), + ) + ) pxacrossbeam = 10 frequency = np.array([1.4e9]) wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 512 mod = np.zeros((1, npix, npix), dtype=np.complex64) mod[0, npix // 2 - 5, npix // 2 - 5] = 1.0 - ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift( - mod[0, :, :]))).reshape((1, npix, npix)) + ftmod = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(mod[0, :, :]))).reshape( + (1, npix, npix) + ) chanmap = np.array([0]) vis_degrid = degridder.degridder( uvw, @@ -235,15 +317,18 @@ def test_degrid_dft_packed(tmp_path_factory): "None", # no faceting "None", # no faceting "XXYY_FROM_I", - "conv_1d_axisymmetric_packed_gather") + "conv_1d_axisymmetric_packed_gather", + ) dec, ra = np.meshgrid( np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), - np.arange(-npix // 2, npix // 2) * np.deg2rad(cell)) + np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), + ) radec = np.column_stack((ra.flatten(), dec.flatten())) - vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), - uvw, radec, frequency) + vis_dft = im_to_vis( + mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw, radec, frequency + ) try: import matplotlib @@ -252,34 +337,42 @@ def test_degrid_dft_packed(tmp_path_factory): else: matplotlib.use("agg") from matplotlib import pyplot as plt + plt.figure() plt.plot(vis_degrid[:, 0, 0].real, label=r"$\Re(\mathtt{degrid})$") plt.plot(vis_dft[:, 0, 0].real, label=r"$\Re(\mathtt{dft})$") - plt.plot(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), - label="Error") + plt.plot( + np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), label="Error" + ) plt.legend() plt.xlabel("sample") plt.ylabel("Real of predicted") plt.savefig( - os.path.join(os.environ.get("TMPDIR", "/tmp"), - "degrid_vs_dft_re_packed.png")) + os.path.join( + os.environ.get("TMPDIR", "/tmp"), "degrid_vs_dft_re_packed.png" + ) + ) plt.figure() plt.plot(vis_degrid[:, 0, 0].imag, label=r"$\Im(\mathtt{degrid})$") plt.plot(vis_dft[:, 0, 0].imag, label=r"$\Im(\mathtt{dft})$") - plt.plot(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), - label="Error") + plt.plot( + np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), label="Error" + ) plt.legend() plt.xlabel("sample") plt.ylabel("Imag of predicted") - plt.savefig(tmp_path_factory.mktemp("degrid_dft_packed") / - "degrid_vs_dft_im_packed.png") + plt.savefig( + tmp_path_factory.mktemp("degrid_dft_packed") / "degrid_vs_dft_im_packed.png" + ) - assert np.percentile( - np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), - 99.0) < 0.05 - assert np.percentile( - np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), - 99.0) < 0.05 + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), 99.0) + < 0.05 + ) + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), 99.0) + < 0.05 + ) def test_detaper(tmp_path_factory): @@ -298,6 +391,7 @@ def test_detaper(tmp_path_factory): else: matplotlib.use("agg") from matplotlib import pyplot as plt + plt.figure() plt.subplot(131) plt.title("FFT detaper") @@ -313,8 +407,8 @@ def test_detaper(tmp_path_factory): plt.colorbar() plt.savefig(tmp_path_factory.mktemp("detaper") / "detaper.png") - assert (np.percentile(np.abs(detaper - detaperdft), 99.0) < 1.0e-14) - assert (np.max(np.abs(detaperdft - detaperdftsep)) < 1.0e-14) + assert np.percentile(np.abs(detaper - detaperdft), 99.0) < 1.0e-14 + assert np.max(np.abs(detaperdft - detaperdftsep)) < 1.0e-14 def test_grid_dft(tmp_path_factory): @@ -332,9 +426,9 @@ def test_grid_dft(tmp_path_factory): wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 256 fftpad = 1.25 mod = np.zeros((1, npix, npix), dtype=np.complex64) @@ -350,16 +444,20 @@ def test_grid_dft(tmp_path_factory): dec, ra = np.meshgrid( np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), - np.arange(-npix // 2, npix // 2) * np.deg2rad(cell)) + np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), + ) radec = np.column_stack((ra.flatten(), dec.flatten())) - vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, - npix * npix).T.copy(), uvw, - radec, frequency).repeat(2).reshape(nrow, 1, 2) + vis_dft = ( + im_to_vis( + mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw, radec, frequency + ) + .repeat(2) + .reshape(nrow, 1, 2) + ) chanmap = np.array([0]) - detaper = kernels.compute_detaper(int(npix * fftpad), - np.outer(kern, kern), W, OS) + detaper = kernels.compute_detaper(int(npix * fftpad), np.outer(kern, kern), W, OS) vis_grid = gridder.gridder( uvw, vis_dft, @@ -376,21 +474,35 @@ def test_grid_dft(tmp_path_factory): "None", # no faceting "I_FROM_XXYY", "conv_1d_axisymmetric_unpacked_scatter", - do_normalize=True) - - ftvis = (np.fft.fftshift( - np.fft.ifft2(np.fft.ifftshift(vis_grid[0, :, :]))).reshape( - (1, int(npix * fftpad), int( - npix * fftpad)))).real / detaper * int(npix * fftpad)**2 - ftvis = ftvis[:, - int(npix * fftpad) // 2 - - npix // 2:int(npix * fftpad) // 2 - npix // 2 + npix, - int(npix * fftpad) // 2 - - npix // 2:int(npix * fftpad) // 2 - npix // 2 + npix] - dftvis = vis_to_im(vis_dft, uvw, radec, frequency, - np.zeros(vis_dft.shape, - dtype=np.bool_)).T.copy().reshape( - 2, 1, npix, npix) / nrow + do_normalize=True, + ) + + ftvis = ( + ( + np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(vis_grid[0, :, :]))).reshape( + (1, int(npix * fftpad), int(npix * fftpad)) + ) + ).real + / detaper + * int(npix * fftpad) ** 2 + ) + ftvis = ftvis[ + :, + int(npix * fftpad) // 2 - npix // 2 : int(npix * fftpad) // 2 + - npix // 2 + + npix, + int(npix * fftpad) // 2 - npix // 2 : int(npix * fftpad) // 2 + - npix // 2 + + npix, + ] + dftvis = ( + vis_to_im( + vis_dft, uvw, radec, frequency, np.zeros(vis_dft.shape, dtype=np.bool_) + ) + .T.copy() + .reshape(2, 1, npix, npix) + / nrow + ) try: import matplotlib @@ -399,6 +511,7 @@ def test_grid_dft(tmp_path_factory): else: matplotlib.use("agg") from matplotlib import pyplot as plt + plt.figure() plt.subplot(131) plt.title("FFT") @@ -412,11 +525,9 @@ def test_grid_dft(tmp_path_factory): plt.title("ABS diff") plt.imshow(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :])) plt.colorbar() - plt.savefig(tmp_path_factory.mktemp("grid_dft") / - "grid_diff_dft.png") + plt.savefig(tmp_path_factory.mktemp("grid_dft") / "grid_diff_dft.png") - assert (np.percentile(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :]), - 95.0) < 0.15) + assert np.percentile(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :]), 95.0) < 0.15 def test_grid_dft_packed(tmp_path_factory): @@ -434,9 +545,9 @@ def test_grid_dft_packed(tmp_path_factory): wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (2 * max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 256 fftpad = 1.25 mod = np.zeros((1, npix, npix), dtype=np.complex64) @@ -452,15 +563,21 @@ def test_grid_dft_packed(tmp_path_factory): dec, ra = np.meshgrid( np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), - np.arange(-npix // 2, npix // 2) * np.deg2rad(cell)) + np.arange(-npix // 2, npix // 2) * np.deg2rad(cell), + ) radec = np.column_stack((ra.flatten(), dec.flatten())) - vis_dft = im_to_vis(mod[0, :, :].reshape(1, 1, - npix * npix).T.copy(), uvw, - radec, frequency).repeat(2).reshape(nrow, 1, 2) + vis_dft = ( + im_to_vis( + mod[0, :, :].reshape(1, 1, npix * npix).T.copy(), uvw, radec, frequency + ) + .repeat(2) + .reshape(nrow, 1, 2) + ) chanmap = np.array([0]) detaper = kernels.compute_detaper_dft_seperable( - int(npix * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS) + int(npix * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS + ) vis_grid = gridder.gridder( uvw, vis_dft, @@ -477,21 +594,35 @@ def test_grid_dft_packed(tmp_path_factory): "None", # no faceting "I_FROM_XXYY", "conv_1d_axisymmetric_packed_scatter", - do_normalize=True) - - ftvis = (np.fft.fftshift( - np.fft.ifft2(np.fft.ifftshift(vis_grid[0, :, :]))).reshape( - (1, int(npix * fftpad), int( - npix * fftpad)))).real / detaper * int(npix * fftpad)**2 - ftvis = ftvis[:, - int(npix * fftpad) // 2 - - npix // 2:int(npix * fftpad) // 2 - npix // 2 + npix, - int(npix * fftpad) // 2 - - npix // 2:int(npix * fftpad) // 2 - npix // 2 + npix] - dftvis = vis_to_im(vis_dft, uvw, radec, frequency, - np.zeros(vis_dft.shape, - dtype=np.bool_)).T.copy().reshape( - 2, 1, npix, npix) / nrow + do_normalize=True, + ) + + ftvis = ( + ( + np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(vis_grid[0, :, :]))).reshape( + (1, int(npix * fftpad), int(npix * fftpad)) + ) + ).real + / detaper + * int(npix * fftpad) ** 2 + ) + ftvis = ftvis[ + :, + int(npix * fftpad) // 2 - npix // 2 : int(npix * fftpad) // 2 + - npix // 2 + + npix, + int(npix * fftpad) // 2 - npix // 2 : int(npix * fftpad) // 2 + - npix // 2 + + npix, + ] + dftvis = ( + vis_to_im( + vis_dft, uvw, radec, frequency, np.zeros(vis_dft.shape, dtype=np.bool_) + ) + .T.copy() + .reshape(2, 1, npix, npix) + / nrow + ) try: import matplotlib @@ -500,6 +631,7 @@ def test_grid_dft_packed(tmp_path_factory): else: matplotlib.use("agg") from matplotlib import pyplot as plt + plt.figure() plt.subplot(131) plt.title("FFT") @@ -513,11 +645,11 @@ def test_grid_dft_packed(tmp_path_factory): plt.title("ABS diff") plt.imshow(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :])) plt.colorbar() - plt.savefig(tmp_path_factory.mktemp("grid_dft_packed") / - "grid_diff_dft_packed.png") + plt.savefig( + tmp_path_factory.mktemp("grid_dft_packed") / "grid_diff_dft_packed.png" + ) - assert (np.percentile(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :]), - 95.0) < 0.15) + assert np.percentile(np.abs(ftvis[0, :, :] - dftvis[0, 0, :, :]), 95.0) < 0.15 def test_wcorrection_faceting_backward(tmp_path_factory): @@ -533,15 +665,16 @@ def test_wcorrection_faceting_backward(tmp_path_factory): ntime = int(nrow / 25.0) d0 = np.pi / 4.0 for n in range(25): - for ih0, h0 in enumerate( - np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)): + for ih0, h0 in enumerate(np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)): s = np.sin c = np.cos - R = np.array([[s(h0), c(h0), 0], - [-s(d0) * c(h0), - s(d0) * s(h0), - c(d0)], [c(d0) * c(h0), -c(d0) * s(h0), - s(d0)]]) + R = np.array( + [ + [s(h0), c(h0), 0], + [-s(d0) * c(h0), s(d0) * s(h0), c(d0)], + [c(d0) * c(h0), -c(d0) * s(h0), s(d0)], + ] + ) uvw[n * ntime + ih0, :] = np.dot(R, blpos[n, :].T) pxacrossbeam = 5 @@ -549,24 +682,22 @@ def test_wcorrection_faceting_backward(tmp_path_factory): wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npix = 2048 npixfacet = 100 fftpad = 1.1 mod = np.ones((1, 1, 1), dtype=np.complex64) - deltaradec = np.array( - [[600 * np.deg2rad(cell), 600 * np.deg2rad(cell)]]) - lm = radec_to_lmn(deltaradec + np.array([[0, d0]]), - phase_centre=np.array([0, d0])) + deltaradec = np.array([[600 * np.deg2rad(cell), 600 * np.deg2rad(cell)]]) + lm = radec_to_lmn(deltaradec + np.array([[0, d0]]), phase_centre=np.array([0, d0])) - vis_dft = im_to_vis(mod, uvw, lm[:, 0:2], - frequency).repeat(2).reshape(nrow, 1, 2) + vis_dft = im_to_vis(mod, uvw, lm[:, 0:2], frequency).repeat(2).reshape(nrow, 1, 2) chanmap = np.array([0]) detaper = kernels.compute_detaper_dft_seperable( - int(npix * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS) + int(npix * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS + ) vis_grid_nofacet = gridder.gridder( uvw, vis_dft, @@ -583,26 +714,39 @@ def test_wcorrection_faceting_backward(tmp_path_factory): "None", # no faceting "I_FROM_XXYY", "conv_1d_axisymmetric_packed_scatter", - do_normalize=True) - ftvis = (np.fft.fftshift( - np.fft.ifft2(np.fft.ifftshift(vis_grid_nofacet[0, :, :]))).reshape( - (1, int(npix * fftpad), int( - npix * fftpad)))).real / detaper * int(npix * fftpad)**2 - ftvis = ftvis[:, - int(npix * fftpad) // 2 - - npix // 2:int(npix * fftpad) // 2 - npix // 2 + npix, - int(npix * fftpad) // 2 - - npix // 2:int(npix * fftpad) // 2 - npix // 2 + npix] + do_normalize=True, + ) + ftvis = ( + ( + np.fft.fftshift( + np.fft.ifft2(np.fft.ifftshift(vis_grid_nofacet[0, :, :])) + ).reshape((1, int(npix * fftpad), int(npix * fftpad))) + ).real + / detaper + * int(npix * fftpad) ** 2 + ) + ftvis = ftvis[ + :, + int(npix * fftpad) // 2 - npix // 2 : int(npix * fftpad) // 2 + - npix // 2 + + npix, + int(npix * fftpad) // 2 - npix // 2 : int(npix * fftpad) // 2 + - npix // 2 + + npix, + ] detaper_facet = kernels.compute_detaper_dft_seperable( - int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS) + int(npixfacet * fftpad), kernels.unpack_kernel(kern, W, OS), W, OS + ) vis_grid_facet = gridder.gridder( uvw, vis_dft, wavelength, chanmap, int(npixfacet * fftpad), - cell * 3600.0, (deltaradec + np.array([[0, d0]]))[0, :], (0, d0), + cell * 3600.0, + (deltaradec + np.array([[0, d0]]))[0, :], + (0, d0), kern, W, OS, @@ -610,19 +754,26 @@ def test_wcorrection_faceting_backward(tmp_path_factory): "phase_rotate", "I_FROM_XXYY", "conv_1d_axisymmetric_packed_scatter", - do_normalize=True) - ftvisfacet = (np.fft.fftshift( - np.fft.ifft2(np.fft.ifftshift(vis_grid_facet[0, :, :]))).reshape( - (1, int(npixfacet * fftpad), int( - npixfacet * fftpad)))).real / detaper_facet * int( - npixfacet * fftpad)**2 - ftvisfacet = ftvisfacet[:, - int(npixfacet * fftpad) // 2 - - npixfacet // 2:int(npixfacet * fftpad) // 2 - - npixfacet // 2 + npixfacet, - int(npixfacet * fftpad) // 2 - - npixfacet // 2:int(npixfacet * fftpad) // 2 - - npixfacet // 2 + npixfacet] + do_normalize=True, + ) + ftvisfacet = ( + ( + np.fft.fftshift( + np.fft.ifft2(np.fft.ifftshift(vis_grid_facet[0, :, :])) + ).reshape((1, int(npixfacet * fftpad), int(npixfacet * fftpad))) + ).real + / detaper_facet + * int(npixfacet * fftpad) ** 2 + ) + ftvisfacet = ftvisfacet[ + :, + int(npixfacet * fftpad) // 2 - npixfacet // 2 : int(npixfacet * fftpad) // 2 + - npixfacet // 2 + + npixfacet, + int(npixfacet * fftpad) // 2 - npixfacet // 2 : int(npixfacet * fftpad) // 2 + - npixfacet // 2 + + npixfacet, + ] try: import matplotlib @@ -631,11 +782,12 @@ def test_wcorrection_faceting_backward(tmp_path_factory): else: matplotlib.use("agg") from matplotlib import pyplot as plt + plot_dir = tmp_path_factory.mktemp("wcorrection_backward") plt.figure() plt.subplot(121) - plt.imshow(ftvis[0, 1624 - 50:1624 + 50, 1447 - 50:1447 + 50]) + plt.imshow(ftvis[0, 1624 - 50 : 1624 + 50, 1447 - 50 : 1447 + 50]) plt.colorbar() plt.title("Offset FFT (peak={0:.1f})".format(np.max(ftvis))) plt.subplot(122) @@ -644,7 +796,7 @@ def test_wcorrection_faceting_backward(tmp_path_factory): plt.title("Faceted FFT (peak={0:.1f})".format(np.max(ftvisfacet))) plt.savefig(plot_dir / "facet_imaging.png") - assert (np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6) + assert np.abs(np.max(ftvisfacet[0, :, :]) - 1.0) < 1.0e-6 def test_wcorrection_faceting_forward(tmp_path_factory): @@ -660,15 +812,16 @@ def test_wcorrection_faceting_forward(tmp_path_factory): ntime = int(nrow / 25.0) d0 = np.pi / 4.0 for n in range(25): - for ih0, h0 in enumerate( - np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)): + for ih0, h0 in enumerate(np.linspace(np.deg2rad(-20), np.deg2rad(20), ntime)): s = np.sin c = np.cos - R = np.array([[s(h0), c(h0), 0], - [-s(d0) * c(h0), - s(d0) * s(h0), - c(d0)], [c(d0) * c(h0), -c(d0) * s(h0), - s(d0)]]) + R = np.array( + [ + [s(h0), c(h0), 0], + [-s(d0) * c(h0), s(d0) * s(h0), c(d0)], + [c(d0) * c(h0), -c(d0) * s(h0), s(d0)], + ] + ) uvw[n * ntime + ih0, :] = np.dot(R, blpos[n, :].T) pxacrossbeam = 5 @@ -676,20 +829,19 @@ def test_wcorrection_faceting_forward(tmp_path_factory): wavelength = lightspeed / frequency cell = np.rad2deg( - wavelength[0] / - (max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * - pxacrossbeam)) + wavelength[0] + / (max(np.max(np.abs(uvw[:, 0])), np.max(np.abs(uvw[:, 1]))) * pxacrossbeam) + ) npixfacet = 100 mod = np.ones((1, 1, 1), dtype=np.complex64) deltaradec = np.array([[20 * np.deg2rad(cell), 20 * np.deg2rad(cell)]]) - lm = radec_to_lmn(deltaradec + np.array([[0, d0]]), - phase_centre=np.array([0, d0])) + lm = radec_to_lmn(deltaradec + np.array([[0, d0]]), phase_centre=np.array([0, d0])) - vis_dft = im_to_vis(mod, uvw, lm[:, 0:2], - frequency).repeat(2).reshape(nrow, 1, 2) + vis_dft = im_to_vis(mod, uvw, lm[:, 0:2], frequency).repeat(2).reshape(nrow, 1, 2) chanmap = np.array([0]) - ftmod = np.ones((1, npixfacet, npixfacet), - dtype=np.complex64) # point source at centre of facet + ftmod = np.ones( + (1, npixfacet, npixfacet), dtype=np.complex64 + ) # point source at centre of facet vis_degrid = degridder.degridder( uvw, ftmod, @@ -704,7 +856,8 @@ def test_wcorrection_faceting_forward(tmp_path_factory): "rotate", # no faceting "phase_rotate", # no faceting "XXYY_FROM_I", - "conv_1d_axisymmetric_packed_gather") + "conv_1d_axisymmetric_packed_gather", + ) try: import matplotlib @@ -713,32 +866,35 @@ def test_wcorrection_faceting_forward(tmp_path_factory): else: matplotlib.use("agg") from matplotlib import pyplot as plt + plot_dir = tmp_path_factory.mktemp("wcorrection_forward") plt.figure() - plt.plot(vis_degrid[:, 0, 0].real, - label=r"$\Re(\mathtt{degrid facet})$") + plt.plot(vis_degrid[:, 0, 0].real, label=r"$\Re(\mathtt{degrid facet})$") plt.plot(vis_dft[:, 0, 0].real, label=r"$\Re(\mathtt{dft})$") - plt.plot(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), - label="Error") + plt.plot( + np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), label="Error" + ) plt.legend() plt.xlabel("sample") plt.ylabel("Real of predicted") plt.savefig(plot_dir / "facet_degrid_vs_dft_re_packed.png") plt.figure() - plt.plot(vis_degrid[:, 0, 0].imag, - label=r"$\Im(\mathtt{degrid facet})$") + plt.plot(vis_degrid[:, 0, 0].imag, label=r"$\Im(\mathtt{degrid facet})$") plt.plot(vis_dft[:, 0, 0].imag, label=r"$\Im(\mathtt{dft})$") - plt.plot(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), - label="Error") + plt.plot( + np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), label="Error" + ) plt.legend() plt.xlabel("sample") plt.ylabel("Imag of predicted") plt.savefig(plot_dir / "facet_degrid_vs_dft_im_packed.png") - assert np.percentile( - np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), - 99.0) < 0.05 - assert np.percentile( - np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), - 99.0) < 0.05 + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].real - vis_degrid[:, 0, 0].real), 99.0) + < 0.05 + ) + assert ( + np.percentile(np.abs(vis_dft[:, 0, 0].imag - vis_degrid[:, 0, 0].imag), 99.0) + < 0.05 + ) diff --git a/africanus/gridding/util.py b/africanus/gridding/util.py index 9a71d43af..40ca9d7e9 100644 --- a/africanus/gridding/util.py +++ b/africanus/gridding/util.py @@ -1,4 +1,3 @@ - import numpy as np @@ -88,13 +87,15 @@ def estimate_cell_size(u, v, wavelength, factor=3.0, ny=None, nx=None): u_cell_size = 1.0 / (2.0 * factor * umax) v_cell_size = 1.0 / (2.0 * factor * vmax) - if ny is not None and u_cell_size*ny < (1.0 / umin): - raise ValueError("v_cell_size*ny [%f] < (1.0 / umin) [%f]" % - (u_cell_size*ny, 1.0 / umin)) + if ny is not None and u_cell_size * ny < (1.0 / umin): + raise ValueError( + "v_cell_size*ny [%f] < (1.0 / umin) [%f]" % (u_cell_size * ny, 1.0 / umin) + ) - if nx is not None and v_cell_size*nx < (1.0 / vmin): - raise ValueError("v_cell_size*nx [%f] < (1.0 / vmin) [%f]" % - (v_cell_size*nx, 1.0 / vmin)) + if nx is not None and v_cell_size * nx < (1.0 / vmin): + raise ValueError( + "v_cell_size*nx [%f] < (1.0 / vmin) [%f]" % (v_cell_size * nx, 1.0 / vmin) + ) # Convert radians to arcseconds - return np.rad2deg([u_cell_size, v_cell_size])*(60*60) + return np.rad2deg([u_cell_size, v_cell_size]) * (60 * 60) diff --git a/africanus/gridding/wgridder/dask.py b/africanus/gridding/wgridder/dask.py index c10ffa891..29fac3dd8 100644 --- a/africanus/gridding/wgridder/dask.py +++ b/africanus/gridding/wgridder/dask.py @@ -14,25 +14,56 @@ from africanus.gridding.wgridder.hessian import HESSIAN_DOCS from africanus.gridding.wgridder.im2vis import _model_internal as model_np from africanus.gridding.wgridder.vis2im import _dirty_internal as dirty_np -from africanus.gridding.wgridder.im2residim import (_residual_internal - as residual_np) -from africanus.gridding.wgridder.hessian import (_hessian_internal - as hessian_np) +from africanus.gridding.wgridder.im2residim import _residual_internal as residual_np +from africanus.gridding.wgridder.hessian import _hessian_internal as hessian_np from africanus.util.requirements import requires_optional -def _model_wrapper(uvw, freq, model, freq_bin_idx, freq_bin_counts, cell, - weights, flag, celly, epsilon, nthreads, do_wstacking): - - return model_np(uvw[0], freq, model[0][0], freq_bin_idx, freq_bin_counts, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking) - - -@requires_optional('dask.array', dask_import_error) -def model(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, - weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, - do_wstacking=True): +def _model_wrapper( + uvw, + freq, + model, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, +): + return model_np( + uvw[0], + freq, + model[0][0], + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + ) + + +@requires_optional("dask.array", dask_import_error) +def model( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, +): # determine output type complex_type = da.result_type(image, np.complex64) @@ -41,51 +72,107 @@ def model(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: - weight_out = ('row', 'chan') + weight_out = ("row", "chan") if flag is None: flag_out = None else: - flag_out = ('row', 'chan') - - vis = da.blockwise(_model_wrapper, ('row', 'chan'), - uvw, ('row', 'three'), - freq, ('chan',), - image, ('chan', 'nx', 'ny'), - freq_bin_idx, ('chan',), - freq_bin_counts, ('chan',), - cell, None, - weights, weight_out, - flag, flag_out, - celly, None, - epsilon, None, - nthreads, None, - do_wstacking, None, - adjust_chunks={'chan': freq.chunks[0]}, - dtype=complex_type, - align_arrays=False) + flag_out = ("row", "chan") + + vis = da.blockwise( + _model_wrapper, + ("row", "chan"), + uvw, + ("row", "three"), + freq, + ("chan",), + image, + ("chan", "nx", "ny"), + freq_bin_idx, + ("chan",), + freq_bin_counts, + ("chan",), + cell, + None, + weights, + weight_out, + flag, + flag_out, + celly, + None, + epsilon, + None, + nthreads, + None, + do_wstacking, + None, + adjust_chunks={"chan": freq.chunks[0]}, + dtype=complex_type, + align_arrays=False, + ) return vis -def _dirty_wrapper(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking, double_accum): - - return dirty_np(uvw[0], freq, vis, freq_bin_idx, freq_bin_counts, - nx, ny, cell, weights, flag, celly, epsilon, - nthreads, do_wstacking, double_accum) - - -@requires_optional('dask.array', dask_import_error) -def dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, - weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, - do_wstacking=True, double_accum=False): - +def _dirty_wrapper( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, +): + return dirty_np( + uvw[0], + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, + ) + + +@requires_optional("dask.array", dask_import_error) +def dirty( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, + double_accum=False, +): # get real data type (not available from inputs) if vis.dtype == np.complex128: real_type = np.float64 @@ -97,153 +184,280 @@ def dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: - weight_out = ('row', 'chan') + weight_out = ("row", "chan") if flag is None: flag_out = None else: - flag_out = ('row', 'chan') - - img = da.blockwise(_dirty_wrapper, ('row', 'chan', 'nx', 'ny'), - uvw, ('row', 'three'), - freq, ('chan',), - vis, ('row', 'chan'), - freq_bin_idx, ('chan',), - freq_bin_counts, ('chan',), - nx, None, - ny, None, - cell, None, - weights, weight_out, - flag, flag_out, - celly, None, - epsilon, None, - nthreads, None, - do_wstacking, None, - double_accum, None, - adjust_chunks={'chan': freq_bin_idx.chunks[0], - 'row': (1,)*len(vis.chunks[0])}, - new_axes={"nx": nx, "ny": ny}, - dtype=real_type, - align_arrays=False) + flag_out = ("row", "chan") + + img = da.blockwise( + _dirty_wrapper, + ("row", "chan", "nx", "ny"), + uvw, + ("row", "three"), + freq, + ("chan",), + vis, + ("row", "chan"), + freq_bin_idx, + ("chan",), + freq_bin_counts, + ("chan",), + nx, + None, + ny, + None, + cell, + None, + weights, + weight_out, + flag, + flag_out, + celly, + None, + epsilon, + None, + nthreads, + None, + do_wstacking, + None, + double_accum, + None, + adjust_chunks={ + "chan": freq_bin_idx.chunks[0], + "row": (1,) * len(vis.chunks[0]), + }, + new_axes={"nx": nx, "ny": ny}, + dtype=real_type, + align_arrays=False, + ) return img.sum(axis=0) -def _residual_wrapper(uvw, freq, model, vis, freq_bin_idx, freq_bin_counts, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking, double_accum): - - return residual_np(uvw[0], freq, model, vis, freq_bin_idx, - freq_bin_counts, cell, weights, flag, celly, epsilon, - nthreads, do_wstacking, double_accum) - - -@requires_optional('dask.array', dask_import_error) -def residual(uvw, freq, image, vis, freq_bin_idx, freq_bin_counts, cell, - weights=None, flag=None, celly=None, epsilon=1e-5, - nthreads=1, do_wstacking=True, double_accum=False): - +def _residual_wrapper( + uvw, + freq, + model, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, +): + return residual_np( + uvw[0], + freq, + model, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, + ) + + +@requires_optional("dask.array", dask_import_error) +def residual( + uvw, + freq, + image, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, + double_accum=False, +): if celly is None: celly = cell if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: - weight_out = ('row', 'chan') + weight_out = ("row", "chan") if flag is None: flag_out = None else: - flag_out = ('row', 'chan') - - img = da.blockwise(_residual_wrapper, ('row', 'chan', 'nx', 'ny'), - uvw, ('row', 'three'), - freq, ('chan',), - image, ('chan', 'nx', 'ny'), - vis, ('row', 'chan'), - freq_bin_idx, ('chan',), - freq_bin_counts, ('chan',), - cell, None, - weights, weight_out, - flag, flag_out, - celly, None, - epsilon, None, - nthreads, None, - do_wstacking, None, - double_accum, None, - adjust_chunks={'chan': freq_bin_idx.chunks[0], - 'row': (1,)*len(vis.chunks[0])}, - dtype=image.dtype, - align_arrays=False) + flag_out = ("row", "chan") + + img = da.blockwise( + _residual_wrapper, + ("row", "chan", "nx", "ny"), + uvw, + ("row", "three"), + freq, + ("chan",), + image, + ("chan", "nx", "ny"), + vis, + ("row", "chan"), + freq_bin_idx, + ("chan",), + freq_bin_counts, + ("chan",), + cell, + None, + weights, + weight_out, + flag, + flag_out, + celly, + None, + epsilon, + None, + nthreads, + None, + do_wstacking, + None, + double_accum, + None, + adjust_chunks={ + "chan": freq_bin_idx.chunks[0], + "row": (1,) * len(vis.chunks[0]), + }, + dtype=image.dtype, + align_arrays=False, + ) return img.sum(axis=0) -def _hessian_wrapper(uvw, freq, model, freq_bin_idx, freq_bin_counts, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking, double_accum): - - return hessian_np(uvw[0], freq, model, freq_bin_idx, - freq_bin_counts, cell, weights, flag, celly, epsilon, - nthreads, do_wstacking, double_accum) - - -@requires_optional('dask.array', dask_import_error) -def hessian(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, - weights=None, flag=None, celly=None, epsilon=1e-5, - nthreads=1, do_wstacking=True, double_accum=False): - +def _hessian_wrapper( + uvw, + freq, + model, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, +): + return hessian_np( + uvw[0], + freq, + model, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, + ) + + +@requires_optional("dask.array", dask_import_error) +def hessian( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, + double_accum=False, +): if celly is None: celly = cell if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() if weights is None: weight_out = None else: - weight_out = ('row', 'chan') + weight_out = ("row", "chan") if flag is None: flag_out = None else: - flag_out = ('row', 'chan') - - img = da.blockwise(_hessian_wrapper, ('row', 'chan', 'nx', 'ny'), - uvw, ('row', 'three'), - freq, ('chan',), - image, ('chan', 'nx', 'ny'), - freq_bin_idx, ('chan',), - freq_bin_counts, ('chan',), - cell, None, - weights, weight_out, - flag, flag_out, - celly, None, - epsilon, None, - nthreads, None, - do_wstacking, None, - double_accum, None, - adjust_chunks={'chan': freq_bin_idx.chunks[0], - 'row': (1,)*len(uvw.chunks[0])}, - dtype=image.dtype, - align_arrays=False) + flag_out = ("row", "chan") + + img = da.blockwise( + _hessian_wrapper, + ("row", "chan", "nx", "ny"), + uvw, + ("row", "three"), + freq, + ("chan",), + image, + ("chan", "nx", "ny"), + freq_bin_idx, + ("chan",), + freq_bin_counts, + ("chan",), + cell, + None, + weights, + weight_out, + flag, + flag_out, + celly, + None, + epsilon, + None, + nthreads, + None, + do_wstacking, + None, + double_accum, + None, + adjust_chunks={ + "chan": freq_bin_idx.chunks[0], + "row": (1,) * len(uvw.chunks[0]), + }, + dtype=image.dtype, + align_arrays=False, + ) return img.sum(axis=0) -model.__doc__ = MODEL_DOCS.substitute( - array_type=":class:`dask.array.Array`") -dirty.__doc__ = DIRTY_DOCS.substitute( - array_type=":class:`dask.array.Array`") -residual.__doc__ = RESIDUAL_DOCS.substitute( - array_type=":class:`dask.array.Array`") -hessian.__doc__ = HESSIAN_DOCS.substitute( - array_type=":class:`dask.array.Array`") +model.__doc__ = MODEL_DOCS.substitute(array_type=":class:`dask.array.Array`") +dirty.__doc__ = DIRTY_DOCS.substitute(array_type=":class:`dask.array.Array`") +residual.__doc__ = RESIDUAL_DOCS.substitute(array_type=":class:`dask.array.Array`") +hessian.__doc__ = HESSIAN_DOCS.substitute(array_type=":class:`dask.array.Array`") diff --git a/africanus/gridding/wgridder/hessian.py b/africanus/gridding/wgridder/hessian.py index fd367813f..37aebaa8e 100644 --- a/africanus/gridding/wgridder/hessian.py +++ b/africanus/gridding/wgridder/hessian.py @@ -12,11 +12,22 @@ from africanus.util.requirements import requires_optional -@requires_optional('ducc0.wgridder', ducc_import_error) -def _hessian_internal(uvw, freq, image, freq_bin_idx, freq_bin_counts, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking, double_accum): - +@requires_optional("ducc0.wgridder", ducc_import_error) +def _hessian_internal( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, +): # adjust for chunking # need a copy here if using multiple row chunks freq_bin_idx2 = freq_bin_idx - freq_bin_idx.min() @@ -34,41 +45,81 @@ def _hessian_internal(uvw, freq, image, freq_bin_idx, freq_bin_counts, mask = flag[:, ind] else: mask = None - modelvis = dirty2ms(uvw=uvw, freq=freq[ind], - dirty=image[i], wgt=None, - pixsize_x=cell, pixsize_y=celly, - nu=0, nv=0, epsilon=epsilon, - nthreads=nthreads, mask=mask, - do_wstacking=do_wstacking) + modelvis = dirty2ms( + uvw=uvw, + freq=freq[ind], + dirty=image[i], + wgt=None, + pixsize_x=cell, + pixsize_y=celly, + nu=0, + nv=0, + epsilon=epsilon, + nthreads=nthreads, + mask=mask, + do_wstacking=do_wstacking, + ) convolvedim[0, i] = ms2dirty( - uvw=uvw, freq=freq[ind], ms=modelvis, - wgt=wgt, npix_x=nx, npix_y=ny, - pixsize_x=cell, pixsize_y=celly, - nu=0, nv=0, epsilon=epsilon, - nthreads=nthreads, mask=mask, - do_wstacking=do_wstacking, - double_precision_accumulation=double_accum) + uvw=uvw, + freq=freq[ind], + ms=modelvis, + wgt=wgt, + npix_x=nx, + npix_y=ny, + pixsize_x=cell, + pixsize_y=celly, + nu=0, + nv=0, + epsilon=epsilon, + nthreads=nthreads, + mask=mask, + do_wstacking=do_wstacking, + double_precision_accumulation=double_accum, + ) return convolvedim # This additional wrapper is required to allow the dask wrappers # to chunk over row -@requires_optional('ducc0.wgridder', ducc_import_error) -def hessian(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, - weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, - do_wstacking=True, double_accum=False): - +@requires_optional("ducc0.wgridder", ducc_import_error) +def hessian( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, + double_accum=False, +): if celly is None: celly = cell if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() - residim = _hessian_internal(uvw, freq, image, freq_bin_idx, - freq_bin_counts, cell, weights, flag, - celly, epsilon, nthreads, do_wstacking, - double_accum) + residim = _hessian_internal( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, + ) return residim[0] @@ -139,10 +190,10 @@ def hessian(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, residual : $(array_type) Residual image corresponding to :code:`model` of shape :code:`(band, nx, ny)`. - """) + """ +) try: - hessian.__doc__ = HESSIAN_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + hessian.__doc__ = HESSIAN_DOCS.substitute(array_type=":class:`numpy.ndarray`") except AttributeError: pass diff --git a/africanus/gridding/wgridder/im2residim.py b/africanus/gridding/wgridder/im2residim.py index bda894572..9038eab9b 100644 --- a/africanus/gridding/wgridder/im2residim.py +++ b/africanus/gridding/wgridder/im2residim.py @@ -12,11 +12,23 @@ from africanus.util.requirements import requires_optional -@requires_optional('ducc0.wgridder', ducc_import_error) -def _residual_internal(uvw, freq, image, vis, freq_bin_idx, freq_bin_counts, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking, double_accum): - +@requires_optional("ducc0.wgridder", ducc_import_error) +def _residual_internal( + uvw, + freq, + image, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, +): # adjust for chunking # need a copy here if using multiple row chunks freq_bin_idx2 = freq_bin_idx - freq_bin_idx.min() @@ -36,40 +48,82 @@ def _residual_internal(uvw, freq, image, vis, freq_bin_idx, freq_bin_counts, mask = None tvis = vis[:, ind] residvis = tvis - dirty2ms( - uvw=uvw, freq=freq[ind], - dirty=image[i], wgt=None, - pixsize_x=cell, pixsize_y=celly, - nu=0, nv=0, epsilon=epsilon, - nthreads=nthreads, mask=mask, - do_wstacking=do_wstacking) - residim[0, i] = ms2dirty(uvw=uvw, freq=freq[ind], ms=residvis, - wgt=wgt, npix_x=nx, npix_y=ny, - pixsize_x=cell, pixsize_y=celly, - nu=0, nv=0, epsilon=epsilon, - nthreads=nthreads, mask=mask, - do_wstacking=do_wstacking, - double_precision_accumulation=double_accum) + uvw=uvw, + freq=freq[ind], + dirty=image[i], + wgt=None, + pixsize_x=cell, + pixsize_y=celly, + nu=0, + nv=0, + epsilon=epsilon, + nthreads=nthreads, + mask=mask, + do_wstacking=do_wstacking, + ) + residim[0, i] = ms2dirty( + uvw=uvw, + freq=freq[ind], + ms=residvis, + wgt=wgt, + npix_x=nx, + npix_y=ny, + pixsize_x=cell, + pixsize_y=celly, + nu=0, + nv=0, + epsilon=epsilon, + nthreads=nthreads, + mask=mask, + do_wstacking=do_wstacking, + double_precision_accumulation=double_accum, + ) return residim # This additional wrapper is required to allow the dask wrappers # to chunk over row -@requires_optional('ducc0.wgridder', ducc_import_error) -def residual(uvw, freq, image, vis, freq_bin_idx, freq_bin_counts, cell, - weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, - do_wstacking=True, double_accum=False): - +@requires_optional("ducc0.wgridder", ducc_import_error) +def residual( + uvw, + freq, + image, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, + double_accum=False, +): if celly is None: celly = cell if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() - residim = _residual_internal(uvw, freq, image, vis, freq_bin_idx, - freq_bin_counts, cell, weights, flag, - celly, epsilon, nthreads, do_wstacking, - double_accum) + residim = _residual_internal( + uvw, + freq, + image, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, + ) return residim[0] @@ -157,10 +211,10 @@ def residual(uvw, freq, image, vis, freq_bin_idx, freq_bin_counts, cell, residual : $(array_type) Residual image corresponding to :code:`model` of shape :code:`(band, nx, ny)`. - """) + """ +) try: - residual.__doc__ = RESIDUAL_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + residual.__doc__ = RESIDUAL_DOCS.substitute(array_type=":class:`numpy.ndarray`") except AttributeError: pass diff --git a/africanus/gridding/wgridder/im2vis.py b/africanus/gridding/wgridder/im2vis.py index 977893a5d..c82259063 100644 --- a/africanus/gridding/wgridder/im2vis.py +++ b/africanus/gridding/wgridder/im2vis.py @@ -12,9 +12,21 @@ from africanus.util.requirements import requires_optional -@requires_optional('ducc0.wgridder', ducc_import_error) -def _model_internal(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, - weights, flag, celly, epsilon, nthreads, do_wstacking): +@requires_optional("ducc0.wgridder", ducc_import_error) +def _model_internal( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, +): # adjust for chunking # need a copy here if using multiple row chunks freq_bin_idx2 = freq_bin_idx - freq_bin_idx.min() @@ -32,27 +44,60 @@ def _model_internal(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, mask = flag[:, ind] else: mask = None - vis[:, ind] = dirty2ms(uvw=uvw, freq=freq[ind], dirty=image[i], - wgt=wgt, pixsize_x=cell, pixsize_y=celly, - nu=0, nv=0, epsilon=epsilon, mask=mask, - nthreads=nthreads, do_wstacking=do_wstacking) + vis[:, ind] = dirty2ms( + uvw=uvw, + freq=freq[ind], + dirty=image[i], + wgt=wgt, + pixsize_x=cell, + pixsize_y=celly, + nu=0, + nv=0, + epsilon=epsilon, + mask=mask, + nthreads=nthreads, + do_wstacking=do_wstacking, + ) return vis -@requires_optional('ducc0.wgridder', ducc_import_error) -def model(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, weights=None, - flag=None, celly=None, epsilon=1e-5, nthreads=1, do_wstacking=True): - +@requires_optional("ducc0.wgridder", ducc_import_error) +def model( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, +): if celly is None: celly = cell if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() - return _model_internal(uvw, freq, image, freq_bin_idx, freq_bin_counts, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking) + return _model_internal( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + ) MODEL_DOCS = DocstringTemplate( @@ -129,10 +174,10 @@ def model(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, weights=None, vis : $(array_type) Visibilities corresponding to :code:`model` of shape :code:`(row,chan)`. - """) + """ +) try: - model.__doc__ = MODEL_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + model.__doc__ = MODEL_DOCS.substitute(array_type=":class:`numpy.ndarray`") except AttributeError: pass diff --git a/africanus/gridding/wgridder/tests/test_wgridder.py b/africanus/gridding/wgridder/tests/test_wgridder.py index de204dc5e..3dc5bcbb5 100644 --- a/africanus/gridding/wgridder/tests/test_wgridder.py +++ b/africanus/gridding/wgridder/tests/test_wgridder.py @@ -9,34 +9,40 @@ def _l2error(a, b): - return np.sqrt(np.sum(np.abs(a-b)**2)/np.maximum(np.sum(np.abs(a)**2), - np.sum(np.abs(b)**2))) + return np.sqrt( + np.sum(np.abs(a - b) ** 2) + / np.maximum(np.sum(np.abs(a) ** 2), np.sum(np.abs(b) ** 2)) + ) -def explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, xpixsize, ypixsize, - apply_w): - x, y = np.meshgrid(*[-ss/2 + np.arange(ss) for ss in [nxdirty, nydirty]], - indexing='ij') +def explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, xpixsize, ypixsize, apply_w): + x, y = np.meshgrid( + *[-ss / 2 + np.arange(ss) for ss in [nxdirty, nydirty]], indexing="ij" + ) x *= xpixsize y *= ypixsize res = np.zeros((nxdirty, nydirty)) - eps = x**2+y**2 + eps = x**2 + y**2 if apply_w: - nm1 = -eps/(np.sqrt(1.-eps)+1.) - n = nm1+1 + nm1 = -eps / (np.sqrt(1.0 - eps) + 1.0) + n = nm1 + 1 else: - nm1 = 0. - n = 1. + nm1 = 0.0 + n = 1.0 for row in range(ms.shape[0]): for chan in range(ms.shape[1]): - phase = (freq[chan]/lightspeed * - (x*uvw[row, 0] + y*uvw[row, 1] - uvw[row, 2]*nm1)) + phase = ( + freq[chan] + / lightspeed + * (x * uvw[row, 0] + y * uvw[row, 1] - uvw[row, 2] * nm1) + ) if wgt is None: - res += (ms[row, chan]*np.exp(2j*np.pi*phase)).real + res += (ms[row, chan] * np.exp(2j * np.pi * phase)).real else: - res += (ms[row, chan]*wgt[row, chan] - * np.exp(2j*np.pi*phase)).real - return res/n + res += ( + ms[row, chan] * wgt[row, chan] * np.exp(2j * np.pi * phase) + ).real + return res / n @pmp("nx", (16,)) @@ -45,16 +51,16 @@ def explicit_gridder(uvw, freq, ms, wgt, nxdirty, nydirty, xpixsize, ypixsize, @pmp("nrow", (1000,)) @pmp("nchan", (1, 7)) @pmp("nband", (1, 3)) -@pmp("precision", ('single', 'double')) +@pmp("precision", ("single", "double")) @pmp("epsilon", (1e-3, 1e-4)) @pmp("nthreads", (1, 6)) -def test_gridder(nx, ny, fov, nrow, nchan, nband, - precision, epsilon, nthreads): +def test_gridder(nx, ny, fov, nrow, nchan, nband, precision, epsilon, nthreads): # run comparison against dft with a frequency mapping imposed if nband > nchan: return from africanus.gridding.wgridder import dirty - if precision == 'single': + + if precision == "single": real_type = "f4" complex_type = "c8" else: @@ -62,15 +68,15 @@ def test_gridder(nx, ny, fov, nrow, nchan, nband, complex_type = "c16" np.random.seed(420) - cell = fov*np.pi/180/nx + cell = fov * np.pi / 180 / nx f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) - uvw = ((np.random.rand(nrow, 3)-0.5) / - (cell*freq[-1]/lightspeed)) - vis = (np.random.rand(nrow, nchan)-0.5 + 1j * - (np.random.rand(nrow, nchan)-0.5)).astype(complex_type) + freq = f0 + np.arange(nchan) * (f0 / nchan) + uvw = (np.random.rand(nrow, 3) - 0.5) / (cell * freq[-1] / lightspeed) + vis = ( + np.random.rand(nrow, nchan) - 0.5 + 1j * (np.random.rand(nrow, nchan) - 0.5) + ).astype(complex_type) wgt = np.random.rand(nrow, nchan).astype(real_type) - step = nchan//nband + step = nchan // nband if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -78,14 +84,25 @@ def test_gridder(nx, ny, fov, nrow, nchan, nband, else: freq_bin_idx = np.array([0], dtype=np.int16) freq_bin_counts = np.array([1], dtype=np.int16) - image = dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, - weights=wgt, nthreads=nthreads) + image = dirty( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights=wgt, + nthreads=nthreads, + ) nband = freq_bin_idx.size ref = np.zeros((nband, nx, ny), dtype=np.float64) for i in range(nband): ind = slice(freq_bin_idx[i], freq_bin_idx[i] + freq_bin_counts[i]) - ref[i] = explicit_gridder(uvw, freq[ind], vis[:, ind], wgt[:, ind], - nx, ny, cell, cell, True) + ref[i] = explicit_gridder( + uvw, freq[ind], vis[:, ind], wgt[:, ind], nx, ny, cell, cell, True + ) # l2 error should be within epsilon of zero assert_allclose(_l2error(image, ref), 0, atol=epsilon) @@ -94,13 +111,18 @@ def test_gridder(nx, ny, fov, nrow, nchan, nband, @pmp("nx", (30,)) @pmp("ny", (50, 128)) @pmp("fov", (0.5, 2.5)) -@pmp("nrow", (333, 5000,)) +@pmp( + "nrow", + ( + 333, + 5000, + ), +) @pmp("nchan", (1, 4)) @pmp("nband", (1, 2)) -@pmp("precision", ('single', 'double')) +@pmp("precision", ("single", "double")) @pmp("nthreads", (6,)) -def test_adjointness(nx, ny, fov, nrow, nchan, nband, - precision, nthreads): +def test_adjointness(nx, ny, fov, nrow, nchan, nband, precision, nthreads): # instead of explicitly testing the degridder we can just check that # it is consistent with the gridder i.e. # @@ -111,7 +133,8 @@ def test_adjointness(nx, ny, fov, nrow, nchan, nband, if nband > nchan: return from africanus.gridding.wgridder import dirty, model - if precision == 'single': + + if precision == "single": real_type = np.float32 complex_type = np.complex64 tol = 1e-4 @@ -120,15 +143,15 @@ def test_adjointness(nx, ny, fov, nrow, nchan, nband, complex_type = np.complex128 tol = 1e-12 np.random.seed(420) - cell = fov*np.pi/180/nx + cell = fov * np.pi / 180 / nx f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) - uvw = ((np.random.rand(nrow, 3)-0.5) / - (cell*freq[-1]/lightspeed)) - vis = (np.random.rand(nrow, nchan)-0.5 + 1j * - (np.random.rand(nrow, nchan)-0.5)).astype(complex_type) + freq = f0 + np.arange(nchan) * (f0 / nchan) + uvw = (np.random.rand(nrow, 3) - 0.5) / (cell * freq[-1] / lightspeed) + vis = ( + np.random.rand(nrow, nchan) - 0.5 + 1j * (np.random.rand(nrow, nchan) - 0.5) + ).astype(complex_type) wgt = np.random.rand(nrow, nchan).astype(real_type) - step = nchan//nband + step = nchan // nband if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -137,33 +160,56 @@ def test_adjointness(nx, ny, fov, nrow, nchan, nband, freq_bin_idx = np.array([0], dtype=np.int8) freq_bin_counts = np.array([1], dtype=np.int8) nband = freq_bin_idx.size - image = dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, - weights=wgt, nthreads=nthreads) + image = dirty( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights=wgt, + nthreads=nthreads, + ) model_im = np.random.randn(nband, nx, ny).astype(real_type) - modelvis = model(uvw, freq, model_im, freq_bin_idx, freq_bin_counts, - cell, weights=wgt, nthreads=nthreads) + modelvis = model( + uvw, + freq, + model_im, + freq_bin_idx, + freq_bin_counts, + cell, + weights=wgt, + nthreads=nthreads, + ) # should have relative tolerance close to machine precision - assert_allclose(np.vdot(vis, modelvis).real, np.vdot(image, model_im), - rtol=tol) + assert_allclose(np.vdot(vis, modelvis).real, np.vdot(image, model_im), rtol=tol) -@pmp("nx", (20, )) +@pmp("nx", (20,)) @pmp("ny", (32, 70)) @pmp("fov", (1.5, 3.5)) -@pmp("nrow", (222, 777,)) +@pmp( + "nrow", + ( + 222, + 777, + ), +) @pmp("nchan", (1, 5)) @pmp("nband", (1, 3)) -@pmp("precision", ('single', 'double')) +@pmp("precision", ("single", "double")) @pmp("nthreads", (3,)) -def test_residual(nx, ny, fov, nrow, nchan, nband, - precision, nthreads): +def test_residual(nx, ny, fov, nrow, nchan, nband, precision, nthreads): # Compare the result of im2residim to # VR = V - Rx - computed with im2vis # IR = R.H VR - computed with vis2im from africanus.gridding.wgridder import dirty, model, residual + np.random.seed(420) - if precision == 'single': + if precision == "single": real_type = np.float32 complex_type = np.complex64 decimal = 4 @@ -171,15 +217,15 @@ def test_residual(nx, ny, fov, nrow, nchan, nband, real_type = np.float64 complex_type = np.complex128 decimal = 12 - cell = fov*np.pi/180/nx + cell = fov * np.pi / 180 / nx f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) - uvw = ((np.random.rand(nrow, 3)-0.5) / - (cell*freq[-1]/lightspeed)) - vis = (np.random.rand(nrow, nchan)-0.5 + 1j * - (np.random.rand(nrow, nchan)-0.5)).astype(complex_type) + freq = f0 + np.arange(nchan) * (f0 / nchan) + uvw = (np.random.rand(nrow, 3) - 0.5) / (cell * freq[-1] / lightspeed) + vis = ( + np.random.rand(nrow, nchan) - 0.5 + 1j * (np.random.rand(nrow, nchan) - 0.5) + ).astype(complex_type) wgt = np.random.rand(nrow, nchan).astype(real_type) - step = nchan//nband + step = nchan // nband if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -189,39 +235,57 @@ def test_residual(nx, ny, fov, nrow, nchan, nband, freq_bin_counts = np.array([1], dtype=np.int8) nband = freq_bin_idx.size model_im = np.random.randn(nband, nx, ny).astype(real_type) - modelvis = model(uvw, freq, model_im, freq_bin_idx, freq_bin_counts, cell, - nthreads=nthreads) + modelvis = model( + uvw, freq, model_im, freq_bin_idx, freq_bin_counts, cell, nthreads=nthreads + ) residualvis = vis - modelvis - residim1 = dirty(uvw, freq, residualvis, freq_bin_idx, freq_bin_counts, - nx, ny, cell, weights=wgt, nthreads=nthreads) - - residim2 = residual(uvw, freq, model_im, vis, freq_bin_idx, - freq_bin_counts, cell, weights=wgt, - nthreads=nthreads) + residim1 = dirty( + uvw, + freq, + residualvis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights=wgt, + nthreads=nthreads, + ) + + residim2 = residual( + uvw, + freq, + model_im, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights=wgt, + nthreads=nthreads, + ) # These are essentially computing the same thing just in a different # order so should be close to machine precision rmax = np.maximum(np.abs(residim1).max(), np.abs(residim2).max()) - assert_array_almost_equal( - residim1/rmax, residim2/rmax, decimal=decimal) + assert_array_almost_equal(residim1 / rmax, residim2 / rmax, decimal=decimal) -@pmp("nx", (128, )) +@pmp("nx", (128,)) @pmp("ny", (256,)) @pmp("fov", (0.5,)) @pmp("nrow", (10000000,)) @pmp("nchan", (2,)) @pmp("nband", (2,)) -@pmp("precision", ('single',)) +@pmp("precision", ("single",)) @pmp("nthreads", (4,)) -def test_hessian(nx, ny, fov, nrow, nchan, nband, - precision, nthreads): +def test_hessian(nx, ny, fov, nrow, nchan, nband, precision, nthreads): # Compare the result of dirty computed with Hessian # ID = hessian(x) # to that computed using dirty. from africanus.gridding.wgridder import dirty, hessian + np.random.seed(420) - if precision == 'single': + if precision == "single": real_type = np.float32 complex_type = np.complex64 atol = 1e-5 @@ -230,21 +294,21 @@ def test_hessian(nx, ny, fov, nrow, nchan, nband, complex_type = np.complex128 atol = 1e-5 - uvw = 1000*np.random.randn(nrow, 3) + uvw = 1000 * np.random.randn(nrow, 3) uvw[:, 2] = 0 u_max = np.abs(uvw[:, 0]).max() v_max = np.abs(uvw[:, 1]).max() uv_max = np.maximum(u_max, v_max) f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) + freq = f0 + np.arange(nchan) * (f0 / nchan) - cell_N = 0.1/(2*uv_max*freq.max()/lightspeed) - cell = cell_N/2.0 # super_resolution_factor of 2 + cell_N = 0.1 / (2 * uv_max * freq.max() / lightspeed) + cell = cell_N / 2.0 # super_resolution_factor of 2 vis = np.ones((nrow, nchan), dtype=complex_type) - step = nchan//nband + step = nchan // nband if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -254,22 +318,40 @@ def test_hessian(nx, ny, fov, nrow, nchan, nband, freq_bin_counts = np.array([1], dtype=np.int8) nband = freq_bin_idx.size model_im = np.zeros((nband, nx, ny), dtype=real_type) - model_im[:, nx//2, ny//2] = 1.0 - - dirty_im1 = dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, - nx, ny, cell, nthreads=nthreads, do_wstacking=False, - double_accum=True) + model_im[:, nx // 2, ny // 2] = 1.0 + + dirty_im1 = dirty( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + nthreads=nthreads, + do_wstacking=False, + double_accum=True, + ) # test accumulation - assert_allclose(dirty_im1.max()/nrow, 1.0, rtol=atol) - - dirty_im2 = hessian(uvw, freq, model_im, freq_bin_idx, - freq_bin_counts, cell, nthreads=nthreads, - do_wstacking=False, double_accum=True) + assert_allclose(dirty_im1.max() / nrow, 1.0, rtol=atol) + + dirty_im2 = hessian( + uvw, + freq, + model_im, + freq_bin_idx, + freq_bin_counts, + cell, + nthreads=nthreads, + do_wstacking=False, + double_accum=True, + ) # rtol not reliable since there will be values close to zero in the # dirty images - assert_allclose(dirty_im1/nrow, dirty_im2/nrow, atol=atol, rtol=1e-2) + assert_allclose(dirty_im1 / nrow, dirty_im2 / nrow, atol=atol, rtol=1e-2) @pmp("nx", (30, 250)) @@ -278,16 +360,16 @@ def test_hessian(nx, ny, fov, nrow, nchan, nband, @pmp("nrow", (3333, 10000)) @pmp("nchan", (1, 8)) @pmp("nband", (1, 2)) -@pmp("precision", ('single', 'double')) +@pmp("precision", ("single", "double")) @pmp("nthreads", (1, 4)) @pmp("nchunks", (1, 3)) -def test_dask_dirty(nx, ny, fov, nrow, nchan, nband, - precision, nthreads, nchunks): +def test_dask_dirty(nx, ny, fov, nrow, nchan, nband, precision, nthreads, nchunks): da = pytest.importorskip("dask.array") from africanus.gridding.wgridder import dirty as dirty_np from africanus.gridding.wgridder.dask import dirty + np.random.seed(420) - if precision == 'single': + if precision == "single": real_type = np.float32 complex_type = np.complex64 decimal = 4 # sometimes fails at 5 @@ -295,15 +377,15 @@ def test_dask_dirty(nx, ny, fov, nrow, nchan, nband, real_type = np.float64 complex_type = np.complex128 decimal = 5 - cell = fov*np.pi/180/nx + cell = fov * np.pi / 180 / nx f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) - uvw = ((np.random.rand(nrow, 3)-0.5) / - (cell*freq[-1]/lightspeed)) - vis = (np.random.rand(nrow, nchan)-0.5 + 1j * - (np.random.rand(nrow, nchan)-0.5)).astype(complex_type) + freq = f0 + np.arange(nchan) * (f0 / nchan) + uvw = (np.random.rand(nrow, 3) - 0.5) / (cell * freq[-1] / lightspeed) + vis = ( + np.random.rand(nrow, nchan) - 0.5 + 1j * (np.random.rand(nrow, nchan) - 0.5) + ).astype(complex_type) wgt = np.random.rand(nrow, nchan).astype(real_type) - step = np.maximum(1, nchan//nband) + step = np.maximum(1, nchan // nband) if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -312,12 +394,22 @@ def test_dask_dirty(nx, ny, fov, nrow, nchan, nband, freq_bin_idx = np.array([0], dtype=np.int8) freq_bin_counts = np.array([1], dtype=np.int8) nband = freq_bin_idx.size - image = dirty_np(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, - cell, weights=wgt, nthreads=nthreads) + image = dirty_np( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights=wgt, + nthreads=nthreads, + ) # now get result using dask - rows_per_task = int(np.ceil(nrow/nchunks)) - row_chunks = (nchunks-1) * (rows_per_task,) + rows_per_task = int(np.ceil(nrow / nchunks)) + row_chunks = (nchunks - 1) * (rows_per_task,) row_chunks += (nrow - np.sum(row_chunks),) freq_da = da.from_array(freq, chunks=step) uvw_da = da.from_array(uvw, chunks=(row_chunks, -1)) @@ -326,14 +418,22 @@ def test_dask_dirty(nx, ny, fov, nrow, nchan, nband, freq_bin_idx_da = da.from_array(freq_bin_idx, chunks=1) freq_bin_counts_da = da.from_array(freq_bin_counts, chunks=1) - image_da = dirty(uvw_da, freq_da, vis_da, freq_bin_idx_da, - freq_bin_counts_da, nx, ny, cell, weights=wgt_da, - nthreads=nthreads).compute() + image_da = dirty( + uvw_da, + freq_da, + vis_da, + freq_bin_idx_da, + freq_bin_counts_da, + nx, + ny, + cell, + weights=wgt_da, + nthreads=nthreads, + ).compute() # relative error should agree to within epsilon dmax = np.maximum(np.abs(image).max(), np.abs(image_da).max()) - assert_array_almost_equal(image/dmax, image_da/dmax, - decimal=decimal) + assert_array_almost_equal(image / dmax, image_da / dmax, decimal=decimal) @pmp("nx", (30, 250)) @@ -342,16 +442,16 @@ def test_dask_dirty(nx, ny, fov, nrow, nchan, nband, @pmp("nrow", (3333, 10000)) @pmp("nchan", (1, 8)) @pmp("nband", (1, 2)) -@pmp("precision", ('single', 'double')) +@pmp("precision", ("single", "double")) @pmp("nthreads", (1, 4)) @pmp("nchunks", (1, 3)) -def test_dask_model(nx, ny, fov, nrow, nchan, nband, - precision, nthreads, nchunks): +def test_dask_model(nx, ny, fov, nrow, nchan, nband, precision, nthreads, nchunks): da = pytest.importorskip("dask.array") from africanus.gridding.wgridder import model as model_np from africanus.gridding.wgridder.dask import model + np.random.seed(420) - if precision == 'single': + if precision == "single": real_type = np.float32 complex_type = np.complex64 decimal = 4 # sometimes fails at 5 @@ -359,16 +459,16 @@ def test_dask_model(nx, ny, fov, nrow, nchan, nband, real_type = np.float64 complex_type = np.complex128 decimal = 5 - cell = fov*np.pi/180/nx + cell = fov * np.pi / 180 / nx f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) - uvw = ((np.random.rand(nrow, 3)-0.5) / - (cell*freq[-1]/lightspeed)) - vis = (np.random.rand(nrow, nchan)-0.5 + 1j * - (np.random.rand(nrow, nchan)-0.5)).astype(complex_type) + freq = f0 + np.arange(nchan) * (f0 / nchan) + uvw = (np.random.rand(nrow, 3) - 0.5) / (cell * freq[-1] / lightspeed) + vis = ( + np.random.rand(nrow, nchan) - 0.5 + 1j * (np.random.rand(nrow, nchan) - 0.5) + ).astype(complex_type) wgt = np.random.rand(nrow, nchan).astype(real_type) - step = np.maximum(1, nchan//nband) + step = np.maximum(1, nchan // nband) if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -379,12 +479,20 @@ def test_dask_model(nx, ny, fov, nrow, nchan, nband, nband = freq_bin_idx.size image = np.random.randn(nband, nx, ny).astype(real_type) - vis = model_np(uvw, freq, image, freq_bin_idx, freq_bin_counts, cell, - weights=wgt, nthreads=nthreads) + vis = model_np( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights=wgt, + nthreads=nthreads, + ) # now get result using dask - rows_per_task = int(np.ceil(nrow/nchunks)) - row_chunks = (nchunks-1) * (rows_per_task,) + rows_per_task = int(np.ceil(nrow / nchunks)) + row_chunks = (nchunks - 1) * (rows_per_task,) row_chunks += (nrow - np.sum(row_chunks),) freq_da = da.from_array(freq, chunks=step) uvw_da = da.from_array(uvw, chunks=(row_chunks, -1)) @@ -393,14 +501,20 @@ def test_dask_model(nx, ny, fov, nrow, nchan, nband, freq_bin_idx_da = da.from_array(freq_bin_idx, chunks=1) freq_bin_counts_da = da.from_array(freq_bin_counts, chunks=1) - vis_da = model(uvw_da, freq_da, image_da, freq_bin_idx_da, - freq_bin_counts_da, cell, weights=wgt_da, - nthreads=nthreads).compute() + vis_da = model( + uvw_da, + freq_da, + image_da, + freq_bin_idx_da, + freq_bin_counts_da, + cell, + weights=wgt_da, + nthreads=nthreads, + ).compute() # relative error should agree to within epsilon vmax = np.maximum(np.abs(vis).max(), np.abs(vis_da).max()) - assert_array_almost_equal(vis/vmax, vis_da/vmax, - decimal=decimal) + assert_array_almost_equal(vis / vmax, vis_da / vmax, decimal=decimal) @pmp("nx", (30, 250)) @@ -409,16 +523,16 @@ def test_dask_model(nx, ny, fov, nrow, nchan, nband, @pmp("nrow", (3333, 10000)) @pmp("nchan", (1, 8)) @pmp("nband", (1, 2)) -@pmp("precision", ('single', 'double')) +@pmp("precision", ("single", "double")) @pmp("nthreads", (1, 4)) @pmp("nchunks", (1, 3)) -def test_dask_residual(nx, ny, fov, nrow, nchan, nband, - precision, nthreads, nchunks): +def test_dask_residual(nx, ny, fov, nrow, nchan, nband, precision, nthreads, nchunks): da = pytest.importorskip("dask.array") from africanus.gridding.wgridder import residual as residual_np from africanus.gridding.wgridder.dask import residual + np.random.seed(420) - if precision == 'single': + if precision == "single": real_type = np.float32 complex_type = np.complex64 decimal = 4 # sometimes fails at 5 @@ -426,15 +540,15 @@ def test_dask_residual(nx, ny, fov, nrow, nchan, nband, real_type = np.float64 complex_type = np.complex128 decimal = 5 - cell = fov*np.pi/180/nx + cell = fov * np.pi / 180 / nx f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) - uvw = ((np.random.rand(nrow, 3)-0.5) / - (cell*freq[-1]/lightspeed)) - vis = (np.random.rand(nrow, nchan)-0.5 + 1j * - (np.random.rand(nrow, nchan)-0.5)).astype(complex_type) + freq = f0 + np.arange(nchan) * (f0 / nchan) + uvw = (np.random.rand(nrow, 3) - 0.5) / (cell * freq[-1] / lightspeed) + vis = ( + np.random.rand(nrow, nchan) - 0.5 + 1j * (np.random.rand(nrow, nchan) - 0.5) + ).astype(complex_type) wgt = np.random.rand(nrow, nchan).astype(real_type) - step = np.maximum(1, nchan//nband) + step = np.maximum(1, nchan // nband) if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -444,12 +558,20 @@ def test_dask_residual(nx, ny, fov, nrow, nchan, nband, freq_bin_counts = np.array([1], dtype=np.int8) nband = freq_bin_idx.size image = np.random.randn(nband, nx, ny).astype(real_type) - residim_np = residual_np(uvw, freq, image, vis, freq_bin_idx, - freq_bin_counts, cell, weights=wgt, - nthreads=nthreads) - - rows_per_task = int(np.ceil(nrow/nchunks)) - row_chunks = (nchunks-1) * (rows_per_task,) + residim_np = residual_np( + uvw, + freq, + image, + vis, + freq_bin_idx, + freq_bin_counts, + cell, + weights=wgt, + nthreads=nthreads, + ) + + rows_per_task = int(np.ceil(nrow / nchunks)) + row_chunks = (nchunks - 1) * (rows_per_task,) row_chunks += (nrow - np.sum(row_chunks),) freq_da = da.from_array(freq, chunks=step) uvw_da = da.from_array(uvw, chunks=(row_chunks, -1)) @@ -459,14 +581,21 @@ def test_dask_residual(nx, ny, fov, nrow, nchan, nband, freq_bin_idx_da = da.from_array(freq_bin_idx, chunks=1) freq_bin_counts_da = da.from_array(freq_bin_counts, chunks=1) - residim_da = residual(uvw_da, freq_da, image_da, vis_da, - freq_bin_idx_da, freq_bin_counts_da, - cell, weights=wgt_da, nthreads=nthreads).compute() + residim_da = residual( + uvw_da, + freq_da, + image_da, + vis_da, + freq_bin_idx_da, + freq_bin_counts_da, + cell, + weights=wgt_da, + nthreads=nthreads, + ).compute() # should agree to within epsilon rmax = np.maximum(np.abs(residim_np).max(), np.abs(residim_da).max()) - assert_array_almost_equal( - residim_np/rmax, residim_da/rmax, decimal=decimal) + assert_array_almost_equal(residim_np / rmax, residim_da / rmax, decimal=decimal) @pmp("nx", (64,)) @@ -475,28 +604,27 @@ def test_dask_residual(nx, ny, fov, nrow, nchan, nband, @pmp("nrow", (3333, 10000)) @pmp("nchan", (4,)) @pmp("nband", (2,)) -@pmp("precision", ('single', 'double')) +@pmp("precision", ("single", "double")) @pmp("nthreads", (1,)) @pmp("nchunks", (1, 3)) -def test_dask_hessian(nx, ny, fov, nrow, nchan, nband, - precision, nthreads, nchunks): +def test_dask_hessian(nx, ny, fov, nrow, nchan, nband, precision, nthreads, nchunks): da = pytest.importorskip("dask.array") from africanus.gridding.wgridder import hessian as hessian_np from africanus.gridding.wgridder.dask import hessian + np.random.seed(420) - if precision == 'single': + if precision == "single": real_type = np.float32 decimal = 4 # sometimes fails at 5 else: real_type = np.float64 decimal = 5 - cell = fov*np.pi/180/nx + cell = fov * np.pi / 180 / nx f0 = 1e9 - freq = (f0 + np.arange(nchan)*(f0/nchan)) - uvw = ((np.random.rand(nrow, 3)-0.5) / - (cell*freq[-1]/lightspeed)) + freq = f0 + np.arange(nchan) * (f0 / nchan) + uvw = (np.random.rand(nrow, 3) - 0.5) / (cell * freq[-1] / lightspeed) wgt = np.random.rand(nrow, nchan).astype(real_type) - step = np.maximum(1, nchan//nband) + step = np.maximum(1, nchan // nband) if step: freq_bin_idx = np.arange(0, nchan, step) freq_mapping = np.append(freq_bin_idx, nchan) @@ -506,12 +634,19 @@ def test_dask_hessian(nx, ny, fov, nrow, nchan, nband, freq_bin_counts = np.array([1], dtype=np.int8) nband = freq_bin_idx.size image = np.random.randn(nband, nx, ny).astype(real_type) - convim_np = hessian_np(uvw, freq, image, freq_bin_idx, - freq_bin_counts, cell, weights=wgt, - nthreads=nthreads) - - rows_per_task = int(np.ceil(nrow/nchunks)) - row_chunks = (nchunks-1) * (rows_per_task,) + convim_np = hessian_np( + uvw, + freq, + image, + freq_bin_idx, + freq_bin_counts, + cell, + weights=wgt, + nthreads=nthreads, + ) + + rows_per_task = int(np.ceil(nrow / nchunks)) + row_chunks = (nchunks - 1) * (rows_per_task,) row_chunks += (nrow - np.sum(row_chunks),) freq_da = da.from_array(freq, chunks=step) uvw_da = da.from_array(uvw, chunks=(row_chunks, -1)) @@ -520,11 +655,17 @@ def test_dask_hessian(nx, ny, fov, nrow, nchan, nband, freq_bin_idx_da = da.from_array(freq_bin_idx, chunks=1) freq_bin_counts_da = da.from_array(freq_bin_counts, chunks=1) - convim_da = hessian(uvw_da, freq_da, image_da, - freq_bin_idx_da, freq_bin_counts_da, - cell, weights=wgt_da, nthreads=nthreads).compute() + convim_da = hessian( + uvw_da, + freq_da, + image_da, + freq_bin_idx_da, + freq_bin_counts_da, + cell, + weights=wgt_da, + nthreads=nthreads, + ).compute() # should agree to within epsilon rmax = np.maximum(np.abs(convim_np).max(), np.abs(convim_da).max()) - assert_array_almost_equal( - convim_np/rmax, convim_da/rmax, decimal=decimal) + assert_array_almost_equal(convim_np / rmax, convim_da / rmax, decimal=decimal) diff --git a/africanus/gridding/wgridder/vis2im.py b/africanus/gridding/wgridder/vis2im.py index 752d376b0..7d3be9132 100644 --- a/africanus/gridding/wgridder/vis2im.py +++ b/africanus/gridding/wgridder/vis2im.py @@ -12,10 +12,24 @@ from africanus.util.requirements import requires_optional -@requires_optional('ducc0.wgridder', ducc_import_error) -def _dirty_internal(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, - cell, weights, flag, celly, epsilon, nthreads, - do_wstacking, double_accum): +@requires_optional("ducc0.wgridder", ducc_import_error) +def _dirty_internal( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, +): # adjust for chunking # need a copy here if using multiple row chunks freq_bin_idx2 = freq_bin_idx - freq_bin_idx.min() @@ -38,33 +52,71 @@ def _dirty_internal(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, mask = flag[:, ind] else: mask = None - dirty[0, i] = ms2dirty(uvw=uvw, freq=freq[ind], ms=vis[:, ind], - wgt=wgt, npix_x=nx, npix_y=ny, - pixsize_x=cell, pixsize_y=celly, - nu=0, nv=0, epsilon=epsilon, - nthreads=nthreads, mask=mask, - do_wstacking=do_wstacking, - double_precision_accumulation=double_accum) + dirty[0, i] = ms2dirty( + uvw=uvw, + freq=freq[ind], + ms=vis[:, ind], + wgt=wgt, + npix_x=nx, + npix_y=ny, + pixsize_x=cell, + pixsize_y=celly, + nu=0, + nv=0, + epsilon=epsilon, + nthreads=nthreads, + mask=mask, + do_wstacking=do_wstacking, + double_precision_accumulation=double_accum, + ) return dirty # This additional wrapper is required to allow the dask wrappers # to chunk over row -@requires_optional('ducc0.wgridder', ducc_import_error) -def dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, - weights=None, flag=None, celly=None, epsilon=1e-5, nthreads=1, - do_wstacking=True, double_accum=False): - +@requires_optional("ducc0.wgridder", ducc_import_error) +def dirty( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights=None, + flag=None, + celly=None, + epsilon=1e-5, + nthreads=1, + do_wstacking=True, + double_accum=False, +): if celly is None: celly = cell if not nthreads: import multiprocessing + nthreads = multiprocessing.cpu_count() - dirty = _dirty_internal(uvw, freq, vis, freq_bin_idx, freq_bin_counts, - nx, ny, cell, weights, flag, celly, - epsilon, nthreads, do_wstacking, double_accum) + dirty = _dirty_internal( + uvw, + freq, + vis, + freq_bin_idx, + freq_bin_counts, + nx, + ny, + cell, + weights, + flag, + celly, + epsilon, + nthreads, + do_wstacking, + double_accum, + ) return dirty[0] @@ -136,10 +188,10 @@ def dirty(uvw, freq, vis, freq_bin_idx, freq_bin_counts, nx, ny, cell, model : $(array_type) Dirty image corresponding to visibilities of shape :code:`(nband, nx, ny)`. - """) + """ +) try: - dirty.__doc__ = DIRTY_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + dirty.__doc__ = DIRTY_DOCS.substitute(array_type=":class:`numpy.ndarray`") except AttributeError: pass diff --git a/africanus/linalg/geometry.py b/africanus/linalg/geometry.py index c4523160a..f3e7235b2 100644 --- a/africanus/linalg/geometry.py +++ b/africanus/linalg/geometry.py @@ -32,8 +32,9 @@ def timed(*args, **kw): class BoundingConvexHull(object): @requires_optional("scipy.stats", opt_import_err) - def __init__(self, list_hulls, name="unnamed", - mask=None, check_mask_outofbounds=True): + def __init__( + self, list_hulls, name="unnamed", mask=None, check_mask_outofbounds=True + ): """ Initializes a bounding convex hull around a list of bounding convex hulls or series of points. @@ -52,10 +53,7 @@ def __init__(self, list_hulls, name="unnamed", self._check_mask_outofbounds = check_mask_outofbounds self._cached_filled_mask = None self._vertices = points = np.vstack( - [ - b.corners if hasattr(b, "corners") else [b[0], b[1]] - for b in list_hulls - ] + [b.corners if hasattr(b, "corners") else [b[0], b[1]] for b in list_hulls] ) self._hull = spat.ConvexHull(points) if mask is None: @@ -64,14 +62,12 @@ def __init__(self, list_hulls, name="unnamed", self.sparse_mask = mask def invalidate_cached_masks(self): - """ Invalidates the cached masks (sparse or regular) """ + """Invalidates the cached masks (sparse or regular)""" self._cached_filled_mask = None self._mask, self._mask_weights = self.init_mask() def __str__(self): - return ",".join( - ["({0:d},{1:d})".format(x, y) for (x, y) in self.corners] - ) + return ",".join(["({0:d},{1:d})".format(x, y) for (x, y) in self.corners]) def init_mask(self): """ @@ -98,7 +94,7 @@ def init_mask(self): @property def sprase_mask_weights(self): - """ returns sparse mask weights """ + """returns sparse mask weights""" return self._mask_weights @property @@ -118,23 +114,18 @@ def sparse_mask(self, mask): raise TypeError("Mask must be list") if not ( hasattr(mask, "__len__") - and ( - len(mask) == 0 - or (hasattr(mask[0], "__len__") and len(mask[0]) == 2) - ) + and (len(mask) == 0 or (hasattr(mask[0], "__len__") and len(mask[0]) == 2)) ): raise TypeError("Mask must be a sparse mask of 2 element values") if self._check_mask_outofbounds: - self._mask = copy.deepcopy( - [c for c in mask if (c[1], c[0]) in self] - ) + self._mask = copy.deepcopy([c for c in mask if (c[1], c[0]) in self]) else: self._mask = copy.deepcopy(mask) self._mask_weights = np.ones(len(self._mask)) @property def mask(self, dtype=np.float64): - """ Creates a filled rectangular mask grid of size y, x """ + """Creates a filled rectangular mask grid of size y, x""" if self._cached_filled_mask is not None: return self._cached_filled_mask @@ -151,12 +142,8 @@ def mask(self, dtype=np.float64): else: sparse_mask = np.array(self.sparse_mask) sel = np.logical_and( - np.logical_and( - sparse_mask[:, 1] >= minx, sparse_mask[:, 1] <= maxx - ), - np.logical_and( - sparse_mask[:, 0] >= miny, sparse_mask[:, 0] <= maxy - ), + np.logical_and(sparse_mask[:, 1] >= minx, sparse_mask[:, 1] <= maxx), + np.logical_and(sparse_mask[:, 0] >= miny, sparse_mask[:, 0] <= maxy), ) flat_index = (sparse_mask[sel][:, 0] - miny) * nx + ( sparse_mask[sel][:, 1] - minx @@ -167,25 +154,21 @@ def mask(self, dtype=np.float64): @classmethod def regional_data(cls, sel_region, data_cube, axes=(2, 3), oob_value=0): - """ 2D array containing all values within convex hull - sliced out along axes provided as argument. Portions of sel_region - that are outside of the data_cube is set to oob_value + """2D array containing all values within convex hull + sliced out along axes provided as argument. Portions of sel_region + that are outside of the data_cube is set to oob_value - assumes the last value of axes is the fastest varying axis + assumes the last value of axes is the fastest varying axis """ if not isinstance(sel_region, BoundingConvexHull): - raise TypeError( - "Object passed in is not of type BoundingConvexHull" - ) + raise TypeError("Object passed in is not of type BoundingConvexHull") if not (hasattr(axes, "__len__") and len(axes) == 2): raise ValueError( "Expected a tupple of axes along which to slice out a region" ) axes = sorted(axes) - lines = np.hstack( - [sel_region.corners, np.roll(sel_region.corners, -1, axis=0)] - ) + lines = np.hstack([sel_region.corners, np.roll(sel_region.corners, -1, axis=0)]) minx = np.min(lines[:, 0:4:2]) maxx = np.max(lines[:, 0:4:2]) miny = np.min(lines[:, 1:4:2]) @@ -203,8 +186,10 @@ def regional_data(cls, sel_region, data_cube, axes=(2, 3), oob_value=0): or maxy < 0 or maxx < 0 ): - raise ValueError("Expected a bounding hull that is " - "at least partially within the image") + raise ValueError( + "Expected a bounding hull that is " + "at least partially within the image" + ) # extract data, pad if necessary slc_data = [slice(None)] * len(data_cube.shape) @@ -233,8 +218,7 @@ def regional_data(cls, sel_region, data_cube, axes=(2, 3), oob_value=0): if any(np.array([pad_left, pad_bottom, pad_right, pad_top]) > 0): padded_data = ( - np.zeros(tuple(new_shape), dtype=selected_data.dtype) - * oob_value + np.zeros(tuple(new_shape), dtype=selected_data.dtype) * oob_value ) padded_data[tuple(slc_padded)] = selected_data.copy() else: @@ -312,9 +296,7 @@ def normalize_masks(cls, regions, only_overlapped_regions=True): paint_count = paint_count[sel] # with the reduced number of overlap pixels unflatten - unique_pxls = np.vstack( - [unique_pxls_flatten // nx, unique_pxls_flatten % nx] - ).T + unique_pxls = np.vstack([unique_pxls_flatten // nx, unique_pxls_flatten % nx]).T unique_pxls = list(map(tuple, unique_pxls)) paint_count[...] = 1.0 / paint_count @@ -322,9 +304,7 @@ def normalize_masks(cls, regions, only_overlapped_regions=True): for reg in regions: reg._cached_filled_mask = None # invalidate overlap = [ - x - for x in zip(paint_count, unique_pxls) - if x[1] in reg.sparse_mask + x for x in zip(paint_count, unique_pxls) if x[1] in reg.sparse_mask ] for px_pc, px in overlap: sel = ( @@ -336,21 +316,16 @@ def normalize_masks(cls, regions, only_overlapped_regions=True): @property def circumference(self): - """ area contained in hull """ + """area contained in hull""" lines = self.edges - return np.sum( - np.linalg.norm(lines[:, 1, :] - lines[:, 0, :], axis=1) + 1 - ) + return np.sum(np.linalg.norm(lines[:, 1, :] - lines[:, 0, :], axis=1) + 1) @property def area(self): - """ area contained in hull """ + """area contained in hull""" lines = np.hstack([self.corners, np.roll(self.corners, -1, axis=0)]) return ( - 0.5 - * np.abs( - np.sum([x1 * (y2) - (x2) * y1 for x1, y1, x2, y2 in lines]) - ) + 0.5 * np.abs(np.sum([x1 * (y2) - (x2) * y1 for x1, y1, x2, y2 in lines])) + 0.5 * self.circumference - 1 ) @@ -365,11 +340,11 @@ def name(self, v): @property def corners(self): - """ Returns vertices and guarentees clockwise winding """ + """Returns vertices and guarentees clockwise winding""" return self._vertices[self._hull.vertices][::-1] def normals(self, left=True): - """ return a list of left normals to the hull """ + """return a list of left normals to the hull""" normals = [] for i in range(self.corners.shape[0]): # assuming clockwise winding @@ -383,7 +358,7 @@ def normals(self, left=True): @property def edges(self): - """ return edge segments of the hull (clockwise wound) """ + """return edge segments of the hull (clockwise wound)""" edges = [] for i in range(self.corners.shape[0]): # assuming clockwise winding @@ -394,18 +369,18 @@ def edges(self): @property def edge_midpoints(self): - """ return edge midpoints of the hull (clockwise wound) """ + """return edge midpoints of the hull (clockwise wound)""" edges = self.edges return np.mean(edges, axis=1) @property def lnormals(self): - """ left normals to the edges of the hull """ + """left normals to the edges of the hull""" return self.normals(left=True) @property def rnormals(self): - """ right normals to the edges of the hull """ + """right normals to the edges of the hull""" return self.normals(left=False) def overlaps_with(self, other, min_sep_dist=0.5): @@ -443,7 +418,7 @@ def overlaps_with(self, other, min_sep_dist=0.5): @property def centre(self, integral=True): - """ Barycentre of hull """ + """Barycentre of hull""" if integral: def rnd(x): @@ -454,7 +429,7 @@ def rnd(x): return np.mean(self._vertices, axis=0) def __contains__(self, s, tolerance=0.5): # less than half a pixel away - """ tests whether a point s(x,y) is in the convex hull """ + """tests whether a point s(x,y) is in the convex hull""" # there are three cases to consider # CASE 1: # scalar projection between all @@ -487,11 +462,7 @@ def __init__(self, xl, xu, yl, yu, name="unnamed", mask=None, **kwargs): self.__xnpx = abs(xu - xl + 1) # inclusive of the upper pixel self.__ynpx = abs(yu - yl + 1) BoundingConvexHull.__init__( - self, - [[xl, yl], [xl, yu], [xu, yu], [xu, yl]], - name, - mask=mask, - **kwargs + self, [[xl, yl], [xl, yu], [xu, yu], [xu, yl]], name, mask=mask, **kwargs ) def init_mask(self): @@ -516,7 +487,7 @@ def init_mask(self): return sparse_mask, mask_weights def __contains__(self, s): - """ tests whether a point s(x,y) is in the box""" + """tests whether a point s(x,y) is in the box""" lines = np.hstack([self.corners, np.roll(self.corners, -1, axis=0)]) minx = np.min(lines[:, 0:4:2]) maxx = np.max(lines[:, 0:4:2]) @@ -545,63 +516,46 @@ def sparse_mask(self, mask): raise TypeError("Mask must be list") if not ( hasattr(mask, "__len__") - and ( - len(mask) == 0 - or (hasattr(mask[0], "__len__") and len(mask[0]) == 2) - ) + and (len(mask) == 0 or (hasattr(mask[0], "__len__") and len(mask[0]) == 2)) ): raise TypeError("Mask must be a sparse mask of 2 element values") if len(mask) == 0: self._mask = [] else: - lines = np.hstack( - [self.corners, np.roll(self.corners, -1, axis=0)]) + lines = np.hstack([self.corners, np.roll(self.corners, -1, axis=0)]) minx = np.min(lines[:, 0:4:2]) maxx = np.max(lines[:, 0:4:2]) miny = np.min(lines[:, 1:4:2]) maxy = np.max(lines[:, 1:4:2]) sparse_mask = np.asarray(mask) sel = np.logical_and( - np.logical_and( - sparse_mask[:, 1] >= minx, sparse_mask[:, 1] <= maxx - ), - np.logical_and( - sparse_mask[:, 0] >= miny, sparse_mask[:, 0] <= maxy - ), + np.logical_and(sparse_mask[:, 1] >= minx, sparse_mask[:, 1] <= maxx), + np.logical_and(sparse_mask[:, 0] >= miny, sparse_mask[:, 0] <= maxy), ) self._mask = sparse_mask[sel] self._mask_weights = np.ones(len(self._mask)) @classmethod def project_regions( - cls, - regional_data_list, - regions_list, - axes=(2, 3), - dtype=np.float64, - **kwargs + cls, regional_data_list, regions_list, axes=(2, 3), dtype=np.float64, **kwargs ): - """ Projects individial regions back onto a single contiguous cube """ + """Projects individial regions back onto a single contiguous cube""" if not ( hasattr(regional_data_list, "__len__") and hasattr(regions_list, "__len__") and len(regions_list) == len(regional_data_list) ): - raise TypeError("Region data list and regions lists " - "must be lists of equal length") + raise TypeError( + "Region data list and regions lists " "must be lists of equal length" + ) if not all([isinstance(x, np.ndarray) for x in regional_data_list]): raise TypeError("Region data list must be a list of ndarrays") if not all([isinstance(x, BoundingBox) for x in regions_list]): - raise TypeError( - "Region list must be a list of Axis Aligned Bounding Boxes" - ) + raise TypeError("Region list must be a list of Axis Aligned Bounding Boxes") if regions_list == []: return np.empty((0)) if not all( - [ - reg.ndim == regional_data_list[0].ndim - for reg in regional_data_list - ] + [reg.ndim == regional_data_list[0].ndim for reg in regional_data_list] ): raise ValueError("All data cubes must be of equal dimension") axes = tuple(sorted(axes)) @@ -630,9 +584,11 @@ def project_regions( fnx = xu - xl + 1 # inclusive fny = yu - yl + 1 # inclusive if f.shape[axes[0]] != fny - 1 or f.shape[axes[1]] != fnx - 1: - raise ValueError("One or more bounding box descriptors " - "does not match shape of corresponding " - "data cubes") + raise ValueError( + "One or more bounding box descriptors " + "does not match shape of corresponding " + "data cubes" + ) slc_data = [slice(None)] * len(stitched_img.shape) for (start, end), axis in zip([(yl, yu), (xl, xu)], axes): slc_data[axis] = slice(start, end) @@ -656,10 +612,12 @@ class BoundingBoxFactory(object): def AxisAlignedBoundingBox( cls, convex_hull_object, square=False, enforce_odd=True, **kwargs ): - """ Constructs an axis aligned bounding box around convex hull """ + """Constructs an axis aligned bounding box around convex hull""" if not isinstance(convex_hull_object, BoundingConvexHull): - raise TypeError("Convex hull object passed in constructor " - "is not of type BoundingConvexHull") + raise TypeError( + "Convex hull object passed in constructor " + "is not of type BoundingConvexHull" + ) if square: nx = ( np.max(convex_hull_object.corners[:, 0]) @@ -694,17 +652,16 @@ def AxisAlignedBoundingBox( yu, convex_hull_object.name, mask=convex_hull_object.sparse_mask, - **kwargs + **kwargs, ) @classmethod def SplitBox(cls, bounding_box_object, nsubboxes=1, **kwargs): - """ Split a axis-aligned bounding box into smaller boxes """ + """Split a axis-aligned bounding box into smaller boxes""" if not isinstance(bounding_box_object, BoundingBox): raise TypeError("Expected bounding box object") if not (isinstance(nsubboxes, int) and nsubboxes >= 1): - raise ValueError( - "nsubboxes must be integral type and be 1 or more") + raise ValueError("nsubboxes must be integral type and be 1 or more") xl = np.min(bounding_box_object.corners[:, 0]) xu = np.max(bounding_box_object.corners[:, 0]) yl = np.min(bounding_box_object.corners[:, 1]) @@ -755,7 +712,7 @@ def SplitBox(cls, bounding_box_object, nsubboxes=1, **kwargs): ul[1], bounding_box_object.name, mask=bounding_box_object.sparse_mask, - **kwargs + **kwargs, ) for bl, br, ur, ul in contained_boxes ] @@ -789,5 +746,5 @@ def PadBox(cls, bounding_box_object, desired_nx, desired_ny, **kwargs): yu, bounding_box_object.name, mask=bounding_box_object.sparse_mask, - **kwargs + **kwargs, ) # mask unchanged in the new shape, border frame discarded diff --git a/africanus/linalg/kronecker_tools.py b/africanus/linalg/kronecker_tools.py index 811d75aff..6722b9144 100644 --- a/africanus/linalg/kronecker_tools.py +++ b/africanus/linalg/kronecker_tools.py @@ -52,7 +52,7 @@ def kron_matvec(A, b): x = b for d in range(D): Gd = A[d].shape[0] - X = np.reshape(x, (Gd, N//Gd)) + X = np.reshape(x, (Gd, N // Gd)) Z = np.einsum("ab,bc->ac", A[d], X) Z = np.einsum("ab -> ba", Z) x = Z.flatten() @@ -181,5 +181,5 @@ def kron_cholesky(A): try: L[i] = np.linalg.cholesky(A[i]) except Exception: # add jitter - L[i] = np.linalg.cholesky(A[i] + 1e-13*np.eye(A[i].shape[0])) + L[i] = np.linalg.cholesky(A[i] + 1e-13 * np.eye(A[i].shape[0])) return L diff --git a/africanus/linalg/test/test_geometry.py b/africanus/linalg/test/test_geometry.py index a50cd2831..fb4fcf7d4 100644 --- a/africanus/linalg/test/test_geometry.py +++ b/africanus/linalg/test/test_geometry.py @@ -20,9 +20,7 @@ def test_hull_construction(debug): ) # integral mask area needs to be close to true area assert np.abs(mask.sum() - bh.area) / bh.area < 0.05 - normalized_normals = ( - bh.rnormals / np.linalg.norm(bh.rnormals, axis=1)[:, None] - ) + normalized_normals = bh.rnormals / np.linalg.norm(bh.rnormals, axis=1)[:, None] # test case 2 for e, n in zip(bh.edges, normalized_normals): edge_vec = e[1] - e[0] @@ -34,8 +32,7 @@ def test_hull_construction(debug): sinc_npx = 255 sinc = np.sinc(np.linspace(-7, 7, sinc_npx)) sinc2d = np.outer(sinc, sinc).reshape((1, 1, sinc_npx, sinc_npx)) - (extracted_data, - extracted_window_extents) = BoundingConvexHull.regional_data( + (extracted_data, extracted_window_extents) = BoundingConvexHull.regional_data( bh_extract, sinc2d, oob_value=np.nan ) assert extracted_window_extents == [-10, 293, -30, 268] @@ -94,9 +91,7 @@ def test_hull_construction(debug): assert (-15, 35) not in bb2 assert (0, 35) in bb2 - bb3 = BoundingBoxFactory.AxisAlignedBoundingBox( - bb, square=True - ) # enforce odd + bb3 = BoundingBoxFactory.AxisAlignedBoundingBox(bb, square=True) # enforce odd assert bb3.box_npx[0] == bb3.box_npx[1] assert bb3.box_npx[0] % 2 == 1 # enforce odd assert bb3.area == bb3.box_npx[0] ** 2 @@ -107,12 +102,8 @@ def test_hull_construction(debug): # test case 7 bb4s = BoundingBoxFactory.SplitBox(bb, nsubboxes=3) assert len(bb4s) == 9 - xlims = [(np.min(c.corners[:, 0]), np.max(c.corners[:, 0])) for c in bb4s][ - 0:3 - ] - ylims = [(np.min(c.corners[:, 1]), np.max(c.corners[:, 1])) for c in bb4s][ - 0::3 - ] + xlims = [(np.min(c.corners[:, 0]), np.max(c.corners[:, 0])) for c in bb4s][0:3] + ylims = [(np.min(c.corners[:, 1]), np.max(c.corners[:, 1])) for c in bb4s][0::3] assert np.all(xlims == np.array([(-14, -3), (-2, 9), (10, 20)])) assert np.all(ylims == np.array([(30, 36), (37, 43), (44, 49)])) assert np.sum([b.area for b in bb4s]) == bb.area @@ -145,19 +136,14 @@ def test_hull_construction(debug): ) facets = list( map( - lambda pf: BoundingConvexHull.regional_data( - pf, sinc2d, oob_value=np.nan - ), + lambda pf: BoundingConvexHull.regional_data(pf, sinc2d, oob_value=np.nan), facet_regions, ) ) stitched_image, stitched_region = BoundingBox.project_regions( [f[0] for f in facets], facet_regions ) - assert ( - np.abs(sinc_integral - np.nansum([np.nansum(f[0]) for f in facets])) - < 1.0e-8 - ) + assert np.abs(sinc_integral - np.nansum([np.nansum(f[0]) for f in facets])) < 1.0e-8 assert np.abs(sinc_integral - np.sum(stitched_image)) < 1.0e-8 v = np.argmax(stitched_image) vx = v % stitched_image.shape[3] diff --git a/africanus/model/coherency/conversion.py b/africanus/model/coherency/conversion.py index c9ee34b60..ce73dc554 100644 --- a/africanus/model/coherency/conversion.py +++ b/africanus/model/coherency/conversion.py @@ -7,32 +7,34 @@ import numpy as np -from africanus.util.casa_types import (STOKES_TYPES, - STOKES_ID_MAP) +from africanus.util.casa_types import STOKES_TYPES, STOKES_ID_MAP from africanus.util.docs import DocstringTemplate stokes_conv = { - 'RR': {('I', 'V'): lambda i, v: i + v + 0j}, - 'RL': {('Q', 'U'): lambda q, u: q + u*1j}, - 'LR': {('Q', 'U'): lambda q, u: q - u*1j}, - 'LL': {('I', 'V'): lambda i, v: i - v + 0j}, - - 'XX': {('I', 'Q'): lambda i, q: i + q + 0j}, - 'XY': {('U', 'V'): lambda u, v: u + v*1j}, - 'YX': {('U', 'V'): lambda u, v: u - v*1j}, - 'YY': {('I', 'Q'): lambda i, q: i - q + 0j}, - - 'I': {('XX', 'YY'): lambda xx, yy: (xx + yy).real / 2, - ('RR', 'LL'): lambda rr, ll: (rr + ll).real / 2}, - - 'Q': {('XX', 'YY'): lambda xx, yy: (xx - yy).real / 2, - ('RL', 'LR'): lambda rl, lr: (rl + lr).real / 2}, - - 'U': {('XY', 'YX'): lambda xy, yx: (xy + yx).real / 2, - ('RL', 'LR'): lambda rl, lr: (rl - lr).imag / 2}, - - 'V': {('XY', 'YX'): lambda xy, yx: (xy - yx).imag / 2, - ('RR', 'LL'): lambda rr, ll: (rr - ll).real / 2}, + "RR": {("I", "V"): lambda i, v: i + v + 0j}, + "RL": {("Q", "U"): lambda q, u: q + u * 1j}, + "LR": {("Q", "U"): lambda q, u: q - u * 1j}, + "LL": {("I", "V"): lambda i, v: i - v + 0j}, + "XX": {("I", "Q"): lambda i, q: i + q + 0j}, + "XY": {("U", "V"): lambda u, v: u + v * 1j}, + "YX": {("U", "V"): lambda u, v: u - v * 1j}, + "YY": {("I", "Q"): lambda i, q: i - q + 0j}, + "I": { + ("XX", "YY"): lambda xx, yy: (xx + yy).real / 2, + ("RR", "LL"): lambda rr, ll: (rr + ll).real / 2, + }, + "Q": { + ("XX", "YY"): lambda xx, yy: (xx - yy).real / 2, + ("RL", "LR"): lambda rl, lr: (rl + lr).real / 2, + }, + "U": { + ("XY", "YX"): lambda xy, yx: (xy + yx).real / 2, + ("RL", "LR"): lambda rl, lr: (rl - lr).imag / 2, + }, + "V": { + ("XY", "YX"): lambda xy, yx: (xy - yx).imag / 2, + ("RR", "LL"): lambda rr, ll: (rr - ll).real / 2, + }, } @@ -62,36 +64,38 @@ def _element_indices_and_shape(data): if len(shape) <= depth: shape.append(len(current)) elif shape[depth] != len(current): - raise DimensionMismatch("Dimension mismatch %d != %d at depth %d" - % (shape[depth], len(current), depth)) + raise DimensionMismatch( + "Dimension mismatch %d != %d at depth %d" + % (shape[depth], len(current), depth) + ) # Handle each sequence element for i, e in enumerate(current): # Found a list, recurse if isinstance(e, (tuple, list)): - queue.append((e, current_idx + (i, ), depth + 1)) + queue.append((e, current_idx + (i,), depth + 1)) # String elif isinstance(e, str): if e in result: raise ValueError("'%s' defined multiple times" % e) - result[e] = current_idx + (i, ) + result[e] = current_idx + (i,) # We have a CASA integer Stokes ID, convert to string elif np.issubdtype(type(e), np.integer): try: e = STOKES_ID_MAP[e] except KeyError: - raise ValueError("Invalid id '%d'. " - "Valid id's '%s'" - % (e, pformat(STOKES_ID_MAP))) + raise ValueError( + "Invalid id '%d'. " + "Valid id's '%s'" % (e, pformat(STOKES_ID_MAP)) + ) if e in result: raise ValueError("'%s' defined multiple times" % e) - result[e] = current_idx + (i, ) + result[e] = current_idx + (i,) else: - raise TypeError("Invalid type '%s' for element '%s'" - % (type(e), e)) + raise TypeError("Invalid type '%s' for element '%s'" % (type(e), e)) return result, tuple(shape) @@ -100,7 +104,7 @@ def convert_setup(input, input_schema, output_schema): input_indices, input_shape = _element_indices_and_shape(input_schema) output_indices, output_shape = _element_indices_and_shape(output_schema) - if input.shape[-len(input_shape):] != input_shape: + if input.shape[-len(input_shape) :] != input_shape: raise ValueError("Last dimension of input doesn't match input schema") mapping = [] @@ -111,8 +115,9 @@ def convert_setup(input, input_schema, output_schema): try: deps = stokes_conv[okey] except KeyError: - raise ValueError("Unknown output '%s'. Known types '%s'" - % (deps, STOKES_TYPES)) + raise ValueError( + "Unknown output '%s'. Known types '%s'" % (deps, STOKES_TYPES) + ) found_conv = False @@ -138,12 +143,12 @@ def convert_setup(input, input_schema, output_schema): # We must find a conversion if not found_conv: - raise MissingConversionInputs("None of the supplied inputs '%s' " - "can produce output '%s'. It can be " - "produced by the following " - "combinations '%s'." % ( - input_schema, - okey, deps.keys())) + raise MissingConversionInputs( + "None of the supplied inputs '%s' " + "can produce output '%s'. It can be " + "produced by the following " + "combinations '%s'." % (input_schema, okey, deps.keys()) + ) out_dtype = np.result_type(*[dt for _, _, _, _, dt in mapping]) @@ -152,7 +157,7 @@ def convert_setup(input, input_schema, output_schema): def convert_impl(input, mapping, in_shape, out_shape, dtype): # Make the output array - out_shape = input.shape[:-len(in_shape)] + out_shape + out_shape = input.shape[: -len(in_shape)] + out_shape output = np.empty(out_shape, dtype=dtype) for c1_idx, c2_idx, out_idx, fn, _ in mapping: @@ -162,12 +167,12 @@ def convert_impl(input, mapping, in_shape, out_shape, dtype): def convert(input, input_schema, output_schema): - """ See STOKES_DOCS below """ + """See STOKES_DOCS below""" # Do the conversion - mapping, in_shape, out_shape, dtype = convert_setup(input, - input_schema, - output_schema) + mapping, in_shape, out_shape, dtype = convert_setup( + input, input_schema, output_schema + ) return convert_impl(input, mapping, in_shape, out_shape, dtype) @@ -246,12 +251,11 @@ def convert(input, input_schema, output_schema): _map_str = ", ".join(["%s: %d" % (t, i) for i, t in enumerate(STOKES_TYPES)]) _map_str = "{{ " + _map_str + " }}" # Indent must match docstrings -_map_str = fill(_map_str, initial_indent='', subsequent_indent=' '*8) +_map_str = fill(_map_str, initial_indent="", subsequent_indent=" " * 8) CONVERT_DOCS = DocstringTemplate(CONVERT_DOCS.format(stokes_type_map=_map_str)) del _map_str try: - convert.__doc__ = CONVERT_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + convert.__doc__ = CONVERT_DOCS.substitute(array_type=":class:`numpy.ndarray`") except AttributeError: pass diff --git a/africanus/model/coherency/cuda/conversion.py b/africanus/model/coherency/cuda/conversion.py index f888700ae..ae06203aa 100644 --- a/africanus/model/coherency/cuda/conversion.py +++ b/africanus/model/coherency/cuda/conversion.py @@ -8,9 +8,11 @@ import numpy as np -from africanus.model.coherency.conversion import (_element_indices_and_shape, - CONVERT_DOCS, - MissingConversionInputs) +from africanus.model.coherency.conversion import ( + _element_indices_and_shape, + CONVERT_DOCS, + MissingConversionInputs, +) from africanus.util.code import memoize_on_key, format_code from africanus.util.cuda import cuda_type, grids from africanus.util.jinja2 import jinja_env @@ -27,27 +29,30 @@ log = logging.getLogger(__name__) stokes_conv = { - 'RR': {('I', 'V'): ("complex", "make_{{out_type}}({{I}} + {{V}}, 0)")}, - 'RL': {('Q', 'U'): ("complex", "make_{{out_type}}({{Q}}, {{U}})")}, - 'LR': {('Q', 'U'): ("complex", "make_{{out_type}}({{Q}}, -{{U}})")}, - 'LL': {('I', 'V'): ("complex", "make_{{out_type}}({{I}} - {{V}}, 0)")}, - - 'XX': {('I', 'Q'): ("complex", "make_{{out_type}}({{I}} + {{Q}}, 0)")}, - 'XY': {('U', 'V'): ("complex", "make_{{out_type}}({{U}}, {{V}})")}, - 'YX': {('U', 'V'): ("complex", "make_{{out_type}}({{U}}, -{{V}})")}, - 'YY': {('I', 'Q'): ("complex", "make_{{out_type}}({{I}} - {{Q}}, 0)")}, - - 'I': {('XX', 'YY'): ("real", "(({{XX}}.x + {{YY}}.x) / 2)"), - ('RR', 'LL'): ("real", "(({{RR}}.x + {{LL}}.x) / 2)")}, - - 'Q': {('XX', 'YY'): ("real", "(({{XX}}.x - {{YY}}.x) / 2)"), - ('RL', 'LR'): ("real", "(({{RL}}.x + {{LR}}.x) / 2)")}, - - 'U': {('XY', 'YX'): ("real", "(({{XY}}.x + {{YX}}.x) / 2)"), - ('RL', 'LR'): ("real", "(({{RL}}.y - {{LR}}.y) / 2)")}, - - 'V': {('XY', 'YX'): ("real", "(({{XY}}.y - {{YX}}.y) / 2)"), - ('RR', 'LL'): ("real", "(({{RR}}.x - {{LL}}.x) / 2)")}, + "RR": {("I", "V"): ("complex", "make_{{out_type}}({{I}} + {{V}}, 0)")}, + "RL": {("Q", "U"): ("complex", "make_{{out_type}}({{Q}}, {{U}})")}, + "LR": {("Q", "U"): ("complex", "make_{{out_type}}({{Q}}, -{{U}})")}, + "LL": {("I", "V"): ("complex", "make_{{out_type}}({{I}} - {{V}}, 0)")}, + "XX": {("I", "Q"): ("complex", "make_{{out_type}}({{I}} + {{Q}}, 0)")}, + "XY": {("U", "V"): ("complex", "make_{{out_type}}({{U}}, {{V}})")}, + "YX": {("U", "V"): ("complex", "make_{{out_type}}({{U}}, -{{V}})")}, + "YY": {("I", "Q"): ("complex", "make_{{out_type}}({{I}} - {{Q}}, 0)")}, + "I": { + ("XX", "YY"): ("real", "(({{XX}}.x + {{YY}}.x) / 2)"), + ("RR", "LL"): ("real", "(({{RR}}.x + {{LL}}.x) / 2)"), + }, + "Q": { + ("XX", "YY"): ("real", "(({{XX}}.x - {{YY}}.x) / 2)"), + ("RL", "LR"): ("real", "(({{RL}}.x + {{LR}}.x) / 2)"), + }, + "U": { + ("XY", "YX"): ("real", "(({{XY}}.x + {{YX}}.x) / 2)"), + ("RL", "LR"): ("real", "(({{RL}}.y - {{LR}}.y) / 2)"), + }, + "V": { + ("XY", "YX"): ("real", "(({{XY}}.y - {{YX}}.y) / 2)"), + ("RR", "LL"): ("real", "(({{RR}}.x - {{LL}}.x) / 2)"), + }, } @@ -55,7 +60,7 @@ def stokes_convert_setup(input, input_schema, output_schema): input_indices, input_shape = _element_indices_and_shape(input_schema) output_indices, output_shape = _element_indices_and_shape(output_schema) - if input.shape[-len(input_shape):] != input_shape: + if input.shape[-len(input_shape) :] != input_shape: raise ValueError("Last dimension of input doesn't match input schema") mapping = [] @@ -66,8 +71,9 @@ def stokes_convert_setup(input, input_schema, output_schema): try: deps = stokes_conv[okey] except KeyError: - raise ValueError("Unknown output '%s'. Known types '%s'" - % (okey, stokes_conv.keys())) + raise ValueError( + "Unknown output '%s'. Known types '%s'" % (okey, stokes_conv.keys()) + ) found_conv = False @@ -92,12 +98,12 @@ def stokes_convert_setup(input, input_schema, output_schema): # We must find a conversion if not found_conv: - raise MissingConversionInputs("None of the supplied inputs '%s' " - "can produce output '%s'. It can be " - "produced by the following " - "combinations '%s'." % ( - input_schema, - okey, deps.keys())) + raise MissingConversionInputs( + "None of the supplied inputs '%s' " + "can produce output '%s'. It can be " + "produced by the following " + "combinations '%s'." % (input_schema, okey, deps.keys()) + ) # Output types must be all "real" or all "complex" if not all(dtypes[0] == dt for dt in dtypes[1:]): @@ -114,9 +120,7 @@ def schema_to_tuple(schema): def _key_fn(inputs, input_schema, output_schema): - return (inputs.dtype, - schema_to_tuple(input_schema), - schema_to_tuple(output_schema)) + return (inputs.dtype, schema_to_tuple(input_schema), schema_to_tuple(output_schema)) _TEMPLATE_PATH = pjoin("model", "coherency", "cuda", "conversion.cu.j2") @@ -125,9 +129,8 @@ def _key_fn(inputs, input_schema, output_schema): @memoize_on_key(_key_fn) def _generate_kernel(inputs, input_schema, output_schema): mapping, in_shape, out_shape, out_dtype = stokes_convert_setup( - inputs, - input_schema, - output_schema) + inputs, input_schema, output_schema + ) # Flatten input and output shapes # Check that number elements are the same @@ -135,10 +138,11 @@ def _generate_kernel(inputs, input_schema, output_schema): out_elems = reduce(mul, out_shape, 1) if in_elems != out_elems: - raise ValueError("Number of input_schema elements %s " - "and output schema elements %s " - "must match for CUDA kernel." % - (in_shape, out_shape)) + raise ValueError( + "Number of input_schema elements %s " + "and output schema elements %s " + "must match for CUDA kernel." % (in_shape, out_shape) + ) # Infer the output data type if out_dtype == "real": @@ -162,9 +166,11 @@ def _generate_kernel(inputs, input_schema, output_schema): # Flattened indices flat_outi = np.ravel_multi_index(outi, out_shape) render = jinja_env.from_string(template_fn).render - kwargs = {c1: "in[%d]" % np.ravel_multi_index(c1i, in_shape), - c2: "in[%d]" % np.ravel_multi_index(c2i, in_shape), - "out_type": cuda_out_dtype} + kwargs = { + c1: "in[%d]" % np.ravel_multi_index(c1i, in_shape), + c2: "in[%d]" % np.ravel_multi_index(c2i, in_shape), + "out_type": cuda_out_dtype, + } expr_str = render(**kwargs) assign_exprs.append("out[%d] = %s;" % (flat_outi, expr_str)) @@ -172,11 +178,13 @@ def _generate_kernel(inputs, input_schema, output_schema): # Now render the main template render = jinja_env.get_template(_TEMPLATE_PATH).render name = "stokes_convert" - code = render(kernel_name=name, - input_type=cuda_type(inputs.dtype), - output_type=cuda_type(out_dtype), - assign_exprs=assign_exprs, - elements=in_elems) + code = render( + kernel_name=name, + input_type=cuda_type(inputs.dtype), + output_type=cuda_type(out_dtype), + assign_exprs=assign_exprs, + elements=in_elems, + ) # cuda block, flatten non-schema dims into a single source dim blockdimx = 512 @@ -185,17 +193,16 @@ def _generate_kernel(inputs, input_schema, output_schema): return (cp.RawKernel(code, name), block, in_shape, out_shape, out_dtype) -@requires_optional('cupy', opt_import_error) +@requires_optional("cupy", opt_import_error) def convert(inputs, input_schema, output_schema): - (kernel, block, - in_shape, out_shape, dtype) = _generate_kernel(inputs, - input_schema, - output_schema) + (kernel, block, in_shape, out_shape, dtype) = _generate_kernel( + inputs, input_schema, output_schema + ) # Flatten non-schema input dimensions, # from inspection of the cupy reshape code, # this incurs a copy when inputs is non-contiguous - nsrc = reduce(mul, inputs.shape[:-len(in_shape)], 1) + nsrc = reduce(mul, inputs.shape[: -len(in_shape)], 1) nelems = reduce(mul, in_shape, 1) rinputs = inputs.reshape(nsrc, nelems) @@ -210,14 +217,13 @@ def convert(inputs, input_schema, output_schema): log.exception(format_code(kernel.code)) raise - shape = inputs.shape[:-len(in_shape)] + out_shape + shape = inputs.shape[: -len(in_shape)] + out_shape outputs = outputs.reshape(shape) assert outputs.flags.c_contiguous return outputs try: - convert.__doc__ = CONVERT_DOCS.substitute( - array_type=":class:`cupy.ndarray`") + convert.__doc__ = CONVERT_DOCS.substitute(array_type=":class:`cupy.ndarray`") except AttributeError: pass diff --git a/africanus/model/coherency/cuda/tests/test_convert.py b/africanus/model/coherency/cuda/tests/test_convert.py index acf4f0e91..0d3da3de5 100644 --- a/africanus/model/coherency/cuda/tests/test_convert.py +++ b/africanus/model/coherency/cuda/tests/test_convert.py @@ -7,16 +7,15 @@ from africanus.model.coherency import convert as np_convert from africanus.model.coherency.cuda import convert from africanus.model.coherency.tests.test_convert import ( - stokes_corr_cases, - stokes_corr_int_cases, - visibility_factory, - vis_shape) + stokes_corr_cases, + stokes_corr_int_cases, + visibility_factory, + vis_shape, +) @pytest.mark.skip -def test_stokes_schemas(in_type, input_schema, - out_type, output_schema, - vis_shape): +def test_stokes_schemas(in_type, input_schema, out_type, output_schema, vis_shape): input_shape = np.asarray(input_schema).shape output_shape = np.asarray(output_schema).shape @@ -25,13 +24,13 @@ def test_stokes_schemas(in_type, input_schema, assert xformed_vis.shape == vis_shape + output_shape -@pytest.mark.parametrize("in_type, input_schema, out_type, output_schema", - stokes_corr_cases + stokes_corr_int_cases) +@pytest.mark.parametrize( + "in_type, input_schema, out_type, output_schema", + stokes_corr_cases + stokes_corr_int_cases, +) @pytest.mark.parametrize("vis_shape", vis_shape) -def test_cuda_convert(in_type, input_schema, - out_type, output_schema, - vis_shape): - cp = pytest.importorskip('cupy') +def test_cuda_convert(in_type, input_schema, out_type, output_schema, vis_shape): + cp = pytest.importorskip("cupy") input_shape = np.asarray(input_schema).shape vis = visibility_factory(vis_shape, input_shape, in_type) diff --git a/africanus/model/coherency/dask.py b/africanus/model/coherency/dask.py index dcbe4675d..fd028fd69 100644 --- a/africanus/model/coherency/dask.py +++ b/africanus/model/coherency/dask.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- -from africanus.model.coherency.conversion import (convert_setup, - convert_impl, - CONVERT_DOCS) +from africanus.model.coherency.conversion import ( + convert_setup, + convert_impl, + CONVERT_DOCS, +) from africanus.util.requirements import requires_optional @@ -15,10 +17,8 @@ da_import_error = None -def convert_wrapper(np_input, mapping=None, in_shape=None, - out_shape=None, dtype_=None): - result = convert_impl(np_input, mapping, in_shape, - out_shape, dtype_) +def convert_wrapper(np_input, mapping=None, in_shape=None, out_shape=None, dtype_=None): + result = convert_impl(np_input, mapping, in_shape, out_shape, dtype_) # Introduce extra singleton dimension at the end of our shape return result.reshape(result.shape + (1,) * len(in_shape)) @@ -26,9 +26,9 @@ def convert_wrapper(np_input, mapping=None, in_shape=None, @requires_optional("dask.array", da_import_error) def convert(input, input_schema, output_schema): - mapping, in_shape, out_shape, dtype = convert_setup(input, - input_schema, - output_schema) + mapping, in_shape, out_shape, dtype = convert_setup( + input, input_schema, output_schema + ) n_free_dims = len(input.shape) - len(in_shape) free_dims = tuple("dim-%d" % i for i in range(n_free_dims)) @@ -41,15 +41,18 @@ def convert(input, input_schema, output_schema): # Note the dummy in_corr_dims introduced at the end of our output, # We do this to prevent a contraction over the input dimensions # (which can be arbitrary) within the wrapper class - res = da.core.blockwise(convert_wrapper, - free_dims + out_corr_dims + in_corr_dims, - input, free_dims + in_corr_dims, - mapping=mapping, - in_shape=in_shape, - out_shape=out_shape, - new_axes=new_axes, - dtype_=dtype, - dtype=dtype) + res = da.core.blockwise( + convert_wrapper, + free_dims + out_corr_dims + in_corr_dims, + input, + free_dims + in_corr_dims, + mapping=mapping, + in_shape=in_shape, + out_shape=out_shape, + new_axes=new_axes, + dtype_=dtype, + dtype=dtype, + ) # Now contract over the dummy dimensions start = len(free_dims) + len(out_corr_dims) @@ -58,7 +61,6 @@ def convert(input, input_schema, output_schema): try: - convert.__doc__ = CONVERT_DOCS.substitute( - array_type=":class:`dask.array.Array`") + convert.__doc__ = CONVERT_DOCS.substitute(array_type=":class:`dask.array.Array`") except AttributeError: pass diff --git a/africanus/model/coherency/tests/test_convert.py b/africanus/model/coherency/tests/test_convert.py index 1230a1c7c..490d45231 100644 --- a/africanus/model/coherency/tests/test_convert.py +++ b/africanus/model/coherency/tests/test_convert.py @@ -6,36 +6,24 @@ import numpy as np import pytest -from africanus.model.coherency.conversion import ( - convert as np_convert) +from africanus.model.coherency.conversion import convert as np_convert from africanus.util.casa_types import STOKES_TYPE_MAP as smap stokes_corr_cases = [ - ("complex", [['XX'], ['YY']], - "real", ['I', 'Q']), - ("complex", ['XX', 'YY'], - "real", ['I', 'Q']), - ("complex", ['XX', 'XY', 'YX', 'YY'], - "real", ['I', 'Q', 'U', 'V']), - ("complex", [['XX', 'XY'], ['YX', 'YY']], - "real", [['I', 'Q'], ['U', 'V']]), - ("real", ['I', 'Q', 'U', 'V'], - "complex", ['XX', 'XY', 'YX', 'YY']), - ("real", [['I', 'Q'], ['U', 'V']], - "complex", [['XX', 'XY'], ['YX', 'YY']]), - ("real", [['I', 'Q'], ['U', 'V']], - "complex", [['XX', 'XY', 'YX', 'YY']]), - ("real", [['I', 'Q'], ['U', 'V']], - "complex", [['RR', 'RL', 'LR', 'LL']]), - ("real", ['I', 'V'], - "complex", ['RR', 'LL']), - ("real", ['I', 'Q'], - "complex", ['XX', 'YY']), + ("complex", [["XX"], ["YY"]], "real", ["I", "Q"]), + ("complex", ["XX", "YY"], "real", ["I", "Q"]), + ("complex", ["XX", "XY", "YX", "YY"], "real", ["I", "Q", "U", "V"]), + ("complex", [["XX", "XY"], ["YX", "YY"]], "real", [["I", "Q"], ["U", "V"]]), + ("real", ["I", "Q", "U", "V"], "complex", ["XX", "XY", "YX", "YY"]), + ("real", [["I", "Q"], ["U", "V"]], "complex", [["XX", "XY"], ["YX", "YY"]]), + ("real", [["I", "Q"], ["U", "V"]], "complex", [["XX", "XY", "YX", "YY"]]), + ("real", [["I", "Q"], ["U", "V"]], "complex", [["RR", "RL", "LR", "LL"]]), + ("real", ["I", "V"], "complex", ["RR", "LL"]), + ("real", ["I", "Q"], "complex", ["XX", "YY"]), ] stokes_corr_int_cases = [ - ("complex", [smap['XX'], smap['YY']], - "real", [smap['I'], smap['Q']]) + ("complex", [smap["XX"], smap["YY"]], "real", [smap["I"], smap["Q"]]) ] @@ -46,41 +34,39 @@ ] -vis_shape = [tuple(sum(dim_chunks) for dim_chunks in case) - for case in vis_chunks] +vis_shape = [tuple(sum(dim_chunks) for dim_chunks in case) for case in vis_chunks] -def visibility_factory(vis_shape, input_shape, in_type, - backend="numpy", **kwargs): +def visibility_factory(vis_shape, input_shape, in_type, backend="numpy", **kwargs): shape = vis_shape + input_shape if backend == "numpy": vis = np.arange(1.0, np.product(shape) + 1.0) vis = vis.reshape(shape) elif backend == "dask": - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") vis = da.arange(1.0, np.product(shape) + 1.0, chunks=np.product(shape)) vis = vis.reshape(shape) - vis = vis.rechunk(kwargs['vis_chunks'] + input_shape) + vis = vis.rechunk(kwargs["vis_chunks"] + input_shape) else: raise ValueError("Invalid backend %s" % backend) if in_type == "real": pass elif in_type == "complex": - vis = vis + 1j*vis + vis = vis + 1j * vis else: raise ValueError("Invalid in_type %s" % in_type) return vis -@pytest.mark.parametrize("in_type, input_schema, out_type, output_schema", - stokes_corr_cases + stokes_corr_int_cases) +@pytest.mark.parametrize( + "in_type, input_schema, out_type, output_schema", + stokes_corr_cases + stokes_corr_int_cases, +) @pytest.mark.parametrize("vis_shape", vis_shape) -def test_conversion_schemas(in_type, input_schema, - out_type, output_schema, - vis_shape): +def test_conversion_schemas(in_type, input_schema, out_type, output_schema, vis_shape): input_shape = np.asarray(input_schema).shape output_shape = np.asarray(output_schema).shape vis = visibility_factory(vis_shape, input_shape, in_type) @@ -92,76 +78,85 @@ def test_conversion(): I, Q, U, V = [1.0, 2.0, 3.0, 4.0] # Check conversion to linear (string) - vis = np_convert(np.asarray([[I, Q, U, V]]), - ['I', 'Q', 'U', 'V'], - ['XX', 'XY', 'YX', 'YY']) + vis = np_convert( + np.asarray([[I, Q, U, V]]), ["I", "Q", "U", "V"], ["XX", "XY", "YX", "YY"] + ) XX, XY, YX, YY = vis[0] - assert np.all(vis == [[I + Q, U + V*1j, U - V*1j, I - Q]]) + assert np.all(vis == [[I + Q, U + V * 1j, U - V * 1j, I - Q]]) # Check conversion to linear (integer) - vis = np_convert(np.asarray([[I, Q, U, V]]), - [smap[x] for x in ('I', 'Q', 'U', 'V')], - [smap[x] for x in ('XX', 'XY', 'YX', 'YY')]) + vis = np_convert( + np.asarray([[I, Q, U, V]]), + [smap[x] for x in ("I", "Q", "U", "V")], + [smap[x] for x in ("XX", "XY", "YX", "YY")], + ) - assert np.all(vis == [[I + Q, U + V*1j, U - V*1j, I - Q]]) + assert np.all(vis == [[I + Q, U + V * 1j, U - V * 1j, I - Q]]) # Check conversion to circular (string) - vis = np_convert(np.asarray([[I, Q, U, V]]), - ['I', 'Q', 'U', 'V'], - ['RR', 'RL', 'LR', 'LL']) + vis = np_convert( + np.asarray([[I, Q, U, V]]), ["I", "Q", "U", "V"], ["RR", "RL", "LR", "LL"] + ) RR, RL, LR, LL = vis[0] - assert np.all(vis == [[I + V, Q + U*1j, Q - U*1j, I - V]]) + assert np.all(vis == [[I + V, Q + U * 1j, Q - U * 1j, I - V]]) # Check conversion to circular (integer) - vis = np_convert(np.asarray([[I, Q, U, V]]), - [smap[x] for x in ('I', 'Q', 'U', 'V')], - [smap[x] for x in ('RR', 'RL', 'LR', 'LL')]) + vis = np_convert( + np.asarray([[I, Q, U, V]]), + [smap[x] for x in ("I", "Q", "U", "V")], + [smap[x] for x in ("RR", "RL", "LR", "LL")], + ) - assert np.all(vis == [[I + V, Q + U*1j, Q - U*1j, I - V]]) + assert np.all(vis == [[I + V, Q + U * 1j, Q - U * 1j, I - V]]) # linear to stokes (string) - stokes = np_convert(np.asarray([[XX, XY, YX, YY]]), - ['XX', 'XY', 'YX', 'YY'], - ['I', 'Q', 'U', 'V']) + stokes = np_convert( + np.asarray([[XX, XY, YX, YY]]), ["XX", "XY", "YX", "YY"], ["I", "Q", "U", "V"] + ) assert np.all(stokes == [[I, Q, U, V]]) # linear to stokes (integer) - stokes = np_convert(np.asarray([[XX, XY, YX, YY]]), - [smap[x] for x in ('XX', 'XY', 'YX', 'YY')], - [smap[x] for x in ('I', 'Q', 'U', 'V')]) + stokes = np_convert( + np.asarray([[XX, XY, YX, YY]]), + [smap[x] for x in ("XX", "XY", "YX", "YY")], + [smap[x] for x in ("I", "Q", "U", "V")], + ) assert np.all(stokes == [[I, Q, U, V]]) # circular to stokes (string) - stokes = np_convert(np.asarray([[RR, RL, LR, LL]]), - ['RR', 'RL', 'LR', 'LL'], - ['I', 'Q', 'U', 'V']) + stokes = np_convert( + np.asarray([[RR, RL, LR, LL]]), ["RR", "RL", "LR", "LL"], ["I", "Q", "U", "V"] + ) assert np.all(stokes == [[I, Q, U, V]]) # circular to stokes (intger) - stokes = np_convert(np.asarray([[RR, RL, LR, LL]]), - [smap[x] for x in ('RR', 'RL', 'LR', 'LL')], - [smap[x] for x in ('I', 'Q', 'U', 'V')]) + stokes = np_convert( + np.asarray([[RR, RL, LR, LL]]), + [smap[x] for x in ("RR", "RL", "LR", "LL")], + [smap[x] for x in ("I", "Q", "U", "V")], + ) assert np.all(stokes == [[I, Q, U, V]]) -@pytest.mark.parametrize("in_type, input_schema, out_type, output_schema", - stokes_corr_cases + stokes_corr_int_cases) +@pytest.mark.parametrize( + "in_type, input_schema, out_type, output_schema", + stokes_corr_cases + stokes_corr_int_cases, +) @pytest.mark.parametrize("vis_chunks", vis_chunks) -def test_dask_conversion(in_type, input_schema, - out_type, output_schema, - vis_chunks): +def test_dask_conversion(in_type, input_schema, out_type, output_schema, vis_chunks): from africanus.model.coherency.dask import convert as da_convert vis_shape = tuple(sum(dim_chunks) for dim_chunks in vis_chunks) input_shape = np.asarray(input_schema).shape - vis = visibility_factory(vis_shape, input_shape, in_type, - backend="dask", vis_chunks=vis_chunks) + vis = visibility_factory( + vis_shape, input_shape, in_type, backend="dask", vis_chunks=vis_chunks + ) da_vis = da_convert(vis, input_schema, output_schema) np_vis = np_convert(vis.compute(), input_schema, output_schema) diff --git a/africanus/model/shape/dask.py b/africanus/model/shape/dask.py index 48acdaad5..f9f80d008 100644 --- a/africanus/model/shape/dask.py +++ b/africanus/model/shape/dask.py @@ -44,9 +44,7 @@ def gaussian(uvw, frequency, shape_params): def _shapelet_wrapper(coords, frequency, coeffs, beta, delta_lm): - return nb_shapelet( - coords[0], frequency, coeffs[0][0], beta[0], delta_lm[0] - ) + return nb_shapelet(coords[0], frequency, coeffs[0][0], beta[0], delta_lm[0]) @requires_optional("dask.array", opt_import_error) @@ -69,9 +67,7 @@ def shapelet(coords, frequency, coeffs, beta, delta_lm): ) -def _shapelet_with_w_term_wrapper( - coords, frequency, coeffs, beta, delta_lm, lm -): +def _shapelet_with_w_term_wrapper(coords, frequency, coeffs, beta, delta_lm, lm): return nb_shapelet_with_w_term( coords[0], frequency, coeffs[0][0], beta[0], delta_lm[0], lm[0] ) @@ -100,8 +96,6 @@ def shapelet_with_w_term(coords, frequency, coeffs, beta, delta_lm, lm): try: - gaussian.__doc__ = GAUSSIAN_DOCS.substitute( - array_type=":class:`dask.array.Array`" - ) + gaussian.__doc__ = GAUSSIAN_DOCS.substitute(array_type=":class:`dask.array.Array`") except AttributeError: pass diff --git a/africanus/model/shape/gaussian_shape.py b/africanus/model/shape/gaussian_shape.py index de341536c..bb9114fe4 100644 --- a/africanus/model/shape/gaussian_shape.py +++ b/africanus/model/shape/gaussian_shape.py @@ -24,8 +24,9 @@ def nb_gaussian(uvw, frequency, shape_params): fwhminv = 1.0 / fwhm gauss_scale = fwhminv * np.sqrt(2.0) * np.pi / lightspeed - dtype = np.result_type(*(np.dtype(a.dtype.name) for - a in (uvw, frequency, shape_params))) + dtype = np.result_type( + *(np.dtype(a.dtype.name) for a in (uvw, frequency, shape_params)) + ) def impl(uvw, frequency, shape_params): nsrc = shape_params.shape[0] @@ -50,21 +51,22 @@ def impl(uvw, frequency, shape_params): for r in range(uvw.shape[0]): u, v, w = uvw[r] - u1 = (u*em - v*el)*er - v1 = u*el + v*em + u1 = (u * em - v * el) * er + v1 = u * el + v * em for f in range(scaled_freq.shape[0]): - fu1 = u1*scaled_freq[f] - fv1 = v1*scaled_freq[f] + fu1 = u1 * scaled_freq[f] + fv1 = v1 * scaled_freq[f] - shape[s, r, f] = np.exp(-(fu1*fu1 + fv1*fv1)) + shape[s, r, f] = np.exp(-(fu1 * fu1 + fv1 * fv1)) return shape return impl -GAUSSIAN_DOCS = DocstringTemplate(r""" +GAUSSIAN_DOCS = DocstringTemplate( + r""" Computes the Gaussian Shape Function. .. math:: @@ -100,10 +102,10 @@ def impl(uvw, frequency, shape_params): ------- gauss_shape : $(array_type) Shape parameters of shape :code:`(source, row, chan)` -""") +""" +) try: - gaussian.__doc__ = GAUSSIAN_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + gaussian.__doc__ = GAUSSIAN_DOCS.substitute(array_type=":class:`numpy.ndarray`") except KeyError: pass diff --git a/africanus/model/shape/shapelets.py b/africanus/model/shape/shapelets.py index 8add546b7..3a7881f2e 100644 --- a/africanus/model/shape/shapelets.py +++ b/africanus/model/shape/shapelets.py @@ -34,15 +34,11 @@ def basis_function(n, xx, beta, fourier=False, delta_x=-1): else: x = xx scale = beta - basis_component = 1.0 / np.sqrt( - 2.0 ** n * np.sqrt(np.pi) * factorial(n) * scale - ) - exponential_component = hermite(n, x / scale) * np.exp( - -(x ** 2) / (2.0 * scale ** 2) - ) + basis_component = 1.0 / np.sqrt(2.0**n * np.sqrt(np.pi) * factorial(n) * scale) + exponential_component = hermite(n, x / scale) * np.exp(-(x**2) / (2.0 * scale**2)) if fourier: return ( - 1.0j ** n + 1.0j**n * basis_component * exponential_component * np.sqrt(2 * np.pi) @@ -55,11 +51,9 @@ def basis_function(n, xx, beta, fourier=False, delta_x=-1): @numba.jit(nogil=True, nopython=True, cache=True) def phase_steer_and_w_correct(uvw, lm_source_center, frequency): l0, m0 = lm_source_center - n0 = np.sqrt(1.0 - l0 ** 2 - m0 ** 2) + n0 = np.sqrt(1.0 - l0**2 - m0**2) u, v, w = uvw - real_phase = ( - minus_two_pi_over_c * frequency * (u * l0 + v * m0 + w * (n0 - 1)) - ) + real_phase = minus_two_pi_over_c * frequency * (u * l0 + v * m0 + w * (n0 - 1)) return np.exp(1.0j * real_phase) @@ -102,12 +96,8 @@ def shapelet(coords, frequency, coeffs, beta, delta_lm, dtype=np.complex128): 0 if coeffs[src][n1, n2] == 0 else coeffs[src][n1, n2] - * basis_function( - n1, fu, beta_u, True, delta_x=delta_l - ) - * basis_function( - n2, fv, beta_v, True, delta_x=delta_m - ) + * basis_function(n1, fu, beta_u, True, delta_x=delta_l) + * basis_function(n2, fv, beta_v, True, delta_x=delta_m) ) out_shapelets[row, chan, src] = tmp_shapelet return out_shapelets @@ -155,16 +145,10 @@ def shapelet_with_w_term( 0 if coeffs[src][n1, n2] == 0 else coeffs[src][n1, n2] - * basis_function( - n1, fu, beta_u, True, delta_x=delta_l - ) - * basis_function( - n2, fv, beta_v, True, delta_x=delta_m - ) + * basis_function(n1, fu, beta_u, True, delta_x=delta_l) + * basis_function(n2, fv, beta_v, True, delta_x=delta_m) ) - w_term = phase_steer_and_w_correct( - (u, v, w), (l, m), frequency[chan] - ) + w_term = phase_steer_and_w_correct((u, v, w), (l, m), frequency[chan]) out_shapelets[row, chan, src] = tmp_shapelet * w_term return out_shapelets @@ -197,9 +181,7 @@ def shapelet_1d(u, coeffs, fourier, delta_x=1, beta=1.0): nrow = u.size if fourier: if delta_x is None: - raise ValueError( - "You have to pass in a value for delta_x in Fourier mode" - ) + raise ValueError("You have to pass in a value for delta_x in Fourier mode") out = np.zeros(nrow, dtype=np.complex128) else: out = np.zeros(nrow, dtype=np.float64) @@ -233,11 +215,7 @@ def shapelet_2d(u, v, coeffs_l, fourier, delta_x=None, delta_y=None, beta=1.0): c = coeffs_l[n1, n2] out[i, j] += ( c - * basis_function( - n1, ui, beta, fourier=fourier, delta_x=delta_x - ) - * basis_function( - n2, vj, beta, fourier=fourier, delta_x=delta_y - ) + * basis_function(n1, ui, beta, fourier=fourier, delta_x=delta_x) + * basis_function(n2, vj, beta, fourier=fourier, delta_x=delta_y) ) return out diff --git a/africanus/model/shape/tests/test_gaussian_shape.py b/africanus/model/shape/tests/test_gaussian_shape.py index a3e7d0271..1fdfe216a 100644 --- a/africanus/model/shape/tests/test_gaussian_shape.py +++ b/africanus/model/shape/tests/test_gaussian_shape.py @@ -12,10 +12,9 @@ def test_gauss_shape(): row = 10 chan = 16 - shape_params = np.array([[.4, .3, .2], - [.4, .3, .2]]) + shape_params = np.array([[0.4, 0.3, 0.2], [0.4, 0.3, 0.2]]) uvw = np.random.random((row, 3)) - freq = np.linspace(.856e9, 2*.856e9, chan) + freq = np.linspace(0.856e9, 2 * 0.856e9, chan) gauss_shape = np_gaussian(uvw, freq, shape_params) @@ -23,7 +22,7 @@ def test_gauss_shape(): def test_dask_gauss_shape(): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") from africanus.model.shape.dask import gaussian as da_gaussian row_chunks = (5, 5) @@ -32,13 +31,10 @@ def test_dask_gauss_shape(): row = sum(row_chunks) chan = sum(chan_chunks) - shape_params = da.asarray([[.4, .3, .2], - [.4, .3, .2]]) + shape_params = da.asarray([[0.4, 0.3, 0.2], [0.4, 0.3, 0.2]]) uvw = da.random.random((row, 3), chunks=(row_chunks, 3)) - freq = da.linspace(.856e9, 2*.856e9, chan, chunks=chan_chunks) + freq = da.linspace(0.856e9, 2 * 0.856e9, chan, chunks=chan_chunks) da_gauss_shape = da_gaussian(uvw, freq, shape_params).compute() - np_gauss_shape = np_gaussian(uvw.compute(), - freq.compute(), - shape_params.compute()) + np_gauss_shape = np_gaussian(uvw.compute(), freq.compute(), shape_params.compute()) assert_array_almost_equal(da_gauss_shape, np_gauss_shape) diff --git a/africanus/model/shape/tests/test_shapelets.py b/africanus/model/shape/tests/test_shapelets.py index 7ec47069b..cb55e509a 100644 --- a/africanus/model/shape/tests/test_shapelets.py +++ b/africanus/model/shape/tests/test_shapelets.py @@ -20,11 +20,11 @@ def test_1d_shapelet(): coeffs = np.ones(1, dtype=np.float64) l_min = -15.0 * beta l_max = 15.0 * beta - delta_l = (l_max - l_min)/(npix-1) + delta_l = (l_max - l_min) / (npix - 1) if npix % 2: l_coords = l_min + np.arange(npix) * delta_l else: - l_coords = l_min + np.arange(-0.5, npix-0.5) * delta_l + l_coords = l_min + np.arange(-0.5, npix - 0.5) * delta_l img_shape = shapelet_1d(l_coords, coeffs, False, beta=beta) # get Fourier space coords and take fft @@ -45,15 +45,13 @@ def test_2d_shapelet(gf_shapelets): nsrc = 1 # Define the range of uv values - u_range = [-2 * np.sqrt(2) * (beta[0] ** (-1)), - 2 * np.sqrt(2) * (beta[0] ** (-1))] - v_range = [-2 * np.sqrt(2) * (beta[1] ** (-1)), - 2 * np.sqrt(2) * (beta[1] ** (-1))] + u_range = [-2 * np.sqrt(2) * (beta[0] ** (-1)), 2 * np.sqrt(2) * (beta[0] ** (-1))] + v_range = [-2 * np.sqrt(2) * (beta[1] ** (-1)), 2 * np.sqrt(2) * (beta[1] ** (-1))] # Create an lm grid from the regular uv grid max_u = u_range[1] max_v = v_range[1] - delta_x = 1/(2 * max_u) if max_u > max_v else 1/(2 * max_v) + delta_x = 1 / (2 * max_u) if max_u > max_v else 1 / (2 * max_v) x_range = [-2 * np.sqrt(2) * beta[0], 2 * np.sqrt(2) * beta[0]] y_range = [-2 * np.sqrt(2) * beta[1], 2 * np.sqrt(2) * beta[1]] npix_x = int((x_range[1] - x_range[0]) / delta_x) @@ -76,10 +74,12 @@ def test_2d_shapelet(gf_shapelets): frequency[:] = 1 img_coeffs[:, :, :] = 1 - l_shapelets = shapelet_1d(img_coords[:, 0], img_coeffs[0, 0, :], False, - beta=img_beta[0, 0]) - m_shapelets = shapelet_1d(img_coords[:, 1], img_coeffs[0, 0, :], False, - beta=img_beta[0, 1]) + l_shapelets = shapelet_1d( + img_coords[:, 0], img_coeffs[0, 0, :], False, beta=img_beta[0, 0] + ) + m_shapelets = shapelet_1d( + img_coords[:, 1], img_coeffs[0, 0, :], False, beta=img_beta[0, 1] + ) ca_shapelets = l_shapelets * m_shapelets # Compare griffinfoster (gf) shapelets to codex-africanus (ca) shapelets @@ -110,8 +110,8 @@ def test_fourier_space_shapelets(): npix = 257 # create image space coordinate grid - delta_l = (l_max - l_min)/(npix-1) - delta_m = (m_max - m_min)/(npix-1) + delta_l = (l_max - l_min) / (npix - 1) + delta_m = (m_max - m_min) / (npix - 1) lvals = l_min + np.arange(npix) * delta_l mvals = m_min + np.arange(npix) * delta_m assert lvals[-1] == l_max @@ -148,8 +148,9 @@ def test_fourier_space_shapelets(): # Call the shapelet implementation coeffs_l = coeffs_l.reshape(coeffs_l.shape + (1,)) - uv_space_shapelet = shapelet(uvw, frequency, coeffs_l, beta, - (delta_l, delta_l)).reshape(npix, npix) + uv_space_shapelet = shapelet( + uvw, frequency, coeffs_l, beta, (delta_l, delta_l) + ).reshape(npix, npix) uv_space_shapelet_max = uv_space_shapelet.real.max() uv_space_shapelet /= uv_space_shapelet_max @@ -157,7 +158,7 @@ def test_fourier_space_shapelets(): def test_dask_shapelets(): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") from africanus.model.shape.dask import shapelet as da_shapelet from africanus.model.shape import shapelet as nb_shapelet @@ -167,7 +168,7 @@ def test_dask_shapelets(): row = sum(row_chunks) source = sum(source_chunks) nmax = [5, 5] - beta_vals = [1., 1.] + beta_vals = [1.0, 1.0] nchan = 1 np_coords = np.random.randn(row, 3) @@ -175,537 +176,1050 @@ def test_dask_shapelets(): np_frequency = np.random.randn(nchan) np_beta = np.empty((source, 2)) np_beta[:, 0], np_beta[:, 1] = beta_vals[0], beta_vals[1] - np_delta_lm = np.array([1/(10 * np.max(np_coords[:, 0])), - 1/(10 * np.max(np_coords[:, 1]))]) + np_delta_lm = np.array( + [1 / (10 * np.max(np_coords[:, 0])), 1 / (10 * np.max(np_coords[:, 1]))] + ) da_coords = da.from_array(np_coords, chunks=(row_chunks, 3)) - da_coeffs = da.from_array(np_coeffs, chunks=(source_chunks, - nmax[0], nmax[1])) + da_coeffs = da.from_array(np_coeffs, chunks=(source_chunks, nmax[0], nmax[1])) da_frequency = da.from_array(np_frequency, chunks=(nchan,)) da_beta = da.from_array(np_beta, chunks=(source_chunks, 2)) delta_lm = da.from_array(np_delta_lm, chunks=(2)) - np_shapelets = nb_shapelet(np_coords, - np_frequency, - np_coeffs, - np_beta, - np_delta_lm) - da_shapelets = da_shapelet(da_coords, da_frequency, da_coeffs, - da_beta, delta_lm).compute() + np_shapelets = nb_shapelet(np_coords, np_frequency, np_coeffs, np_beta, np_delta_lm) + da_shapelets = da_shapelet( + da_coords, da_frequency, da_coeffs, da_beta, delta_lm + ).compute() assert_array_almost_equal(da_shapelets, np_shapelets) @pytest.fixture def gf_shapelets(): - return np.array([0.018926452033215378, 0.03118821286615593, - 0.04971075420435816, 0.07663881765058349, - 0.1142841022428579, 0.16483955267999187, - 0.22997234561316462, 0.3103333028594875, - 0.40506035000019336, 0.5113869779741611, - 0.624479487346518, 0.7376074103159139, - 0.8426960266006467, 0.9312262345416489, - 0.9953550860502569, 1.0290570650923165, - 1.0290570650923165, 0.9953550860502569, - 0.9312262345416489, 0.8426960266006465, - 0.7376074103159138, 0.624479487346518, - 0.5113869779741611, 0.40506035000019325, - 0.31033330285948746, 0.22997234561316438, - 0.16483955267999187, 0.1142841022428579, - 0.07663881765058349, 0.04971075420435811, - 0.03118821286615593, 0.018926452033215378, - 0.03118821286615593, 0.05139392317575348, - 0.08191654627828761, 0.12629032396046155, - 0.1883246211023866, 0.2716331116219307, - 0.3789630753680171, 0.5113869779741613, - 0.6674842383176157, 0.8426960266006469, - 1.029057065092317, 1.2154764603642687, - 1.3886481741511985, 1.5345338882566935, - 1.6402094933940174, 1.6957457605469852, - 1.6957457605469852, 1.6402094933940174, - 1.5345338882566935, 1.388648174151198, - 1.2154764603642685, 1.029057065092317, - 0.8426960266006469, 0.6674842383176155, - 0.5113869779741612, 0.37896307536801666, - 0.2716331116219307, 0.1883246211023866, - 0.12629032396046155, 0.08191654627828754, - 0.05139392317575348, 0.03118821286615593, - 0.04971075420435816, 0.08191654627828761, - 0.130566419909516, 0.201293587411668, - 0.3001697785771043, 0.43295481224112503, - 0.6040275655739769, 0.815097436793798, - 1.0639001679476456, 1.3431694604339557, - 1.6402094933940163, 1.937342540967424, - 2.2133601677597268, 2.4458867606410766, - 2.6143226391225554, 2.702841649099706, - 2.702841649099706, 2.6143226391225554, - 2.4458867606410766, 2.213360167759726, - 1.9373425409674239, 1.6402094933940163, - 1.3431694604339557, 1.0639001679476452, - 0.8150974367937979, 0.6040275655739762, - 0.43295481224112503, 0.3001697785771043, - 0.201293587411668, 0.13056641990951587, - 0.08191654627828761, 0.04971075420435816, - 0.07663881765058349, 0.12629032396046155, - 0.201293587411668, 0.31033330285948735, - 0.462770225332246, 0.6674842383176153, - 0.9312262345416487, 1.2566315845680993, - 1.6402094933940163, 2.0707575453161375, - 2.528702657702978, 2.9867911702474954, - 3.412326145659869, 3.770811214654918, - 4.030487954293403, 4.166957263054249, - 4.166957263054249, 4.030487954293403, - 3.770811214654918, 3.412326145659868, - 2.986791170247495, 2.528702657702978, - 2.0707575453161375, 1.640209493394016, - 1.256631584568099, 0.9312262345416477, - 0.6674842383176153, 0.462770225332246, - 0.31033330285948735, 0.2012935874116678, - 0.12629032396046155, 0.07663881765058349, - 0.1142841022428579, 0.1883246211023866, - 0.3001697785771043, 0.462770225332246, - 0.6900847555862331, 0.9953550860502568, - 1.3886481741511976, 1.8738938946990888, - 2.445886760641078, 3.0879216862145458, - 3.7708112146549193, 4.453914581966821, - 5.088473988397974, 5.6230483142227445, - 6.010279275929925, 6.2137828386615945, - 6.2137828386615945, 6.010279275929925, - 5.6230483142227445, 5.088473988397973, - 4.45391458196682, 3.7708112146549193, - 3.0879216862145458, 2.445886760641077, - 1.8738938946990886, 1.3886481741511962, - 0.9953550860502568, 0.6900847555862331, - 0.462770225332246, 0.30016977857710403, - 0.1883246211023866, 0.1142841022428579, - 0.16483955267999187, 0.2716331116219307, - 0.43295481224112503, 0.6674842383176153, - 0.9953550860502568, 1.4356667631130016, - 2.002939510961386, 2.702841649099707, - 3.527864957745564, 4.4539145819668216, - 5.438891513917972, 6.424176879879071, - 7.3394440662346385, 8.110496128715798, - 8.669025068952923, 8.962551776440094, - 8.962551776440094, 8.669025068952923, - 8.110496128715798, 7.339444066234637, - 6.42417687987907, 5.438891513917972, - 4.4539145819668216, 3.527864957745563, - 2.7028416490997067, 2.002939510961384, - 1.4356667631130016, 0.9953550860502568, - 0.6674842383176153, 0.4329548122411246, - 0.2716331116219307, 0.16483955267999187, - 0.22997234561316462, 0.3789630753680171, - 0.6040275655739769, 0.9312262345416487, - 1.3886481741511976, 2.002939510961386, - 2.794357846573947, 3.7708112146549193, - 4.921824684359944, 6.213782838661596, - 7.587952155023488, 8.962551776440092, - 10.239467045177655, 11.315183695191251, - 12.094403296257042, 12.503910749556066, - 12.503910749556066, 12.094403296257042, - 11.315183695191251, 10.239467045177651, - 8.96255177644009, 7.587952155023488, - 6.213782838661596, 4.921824684359942, - 3.7708112146549184, 2.7943578465739445, - 2.002939510961386, 1.3886481741511976, - 0.9312262345416487, 0.6040275655739763, - 0.3789630753680171, 0.22997234561316462, - 0.3103333028594875, 0.5113869779741613, - 0.815097436793798, 1.2566315845680993, - 1.8738938946990888, 2.702841649099707, - 3.7708112146549193, 5.088473988397975, - 6.641694706032257, 8.385111463867542, - 10.239467045177655, 12.094403296257038, - 13.817520620483494, 15.269132987394766, - 16.320641123326727, 16.873245829727512, - 16.873245829727512, 16.320641123326727, - 15.269132987394766, 13.81752062048349, - 12.094403296257036, 10.239467045177655, - 8.385111463867542, 6.641694706032255, - 5.088473988397973, 3.7708112146549153, - 2.702841649099707, 1.8738938946990888, - 1.2566315845680993, 0.8150974367937973, - 0.5113869779741613, 0.3103333028594875, - 0.40506035000019336, 0.6674842383176157, - 1.0639001679476456, 1.6402094933940163, - 2.445886760641078, 3.527864957745564, - 4.921824684359944, 6.641694706032257, - 8.669025068952925, 10.944607468965954, - 13.36499198416051, 15.786134414467007, - 18.035221122246448, 19.929927903593466, - 21.302401465547916, 22.023684852551362, - 22.023684852551362, 21.302401465547916, - 19.929927903593466, 18.03522112224644, - 15.786134414467005, 13.36499198416051, - 10.944607468965954, 8.669025068952921, - 6.641694706032256, 4.9218246843599385, - 3.527864957745564, 2.445886760641078, - 1.6402094933940163, 1.0639001679476445, - 0.6674842383176157, 0.40506035000019336, - 0.5113869779741611, 0.8426960266006469, - 1.3431694604339557, 2.0707575453161375, - 3.0879216862145458, 4.4539145819668216, - 6.213782838661596, 8.385111463867542, - 10.944607468965954, 13.81752062048349, - 16.873245829727516, 19.92992790359346, - 22.7693903557753, 25.16144965029706, - 26.894191715021453, 27.804808939201713, - 27.804808939201713, 26.894191715021453, - 25.16144965029706, 22.769390355775293, - 19.929927903593455, 16.873245829727516, - 13.81752062048349, 10.94460746896595, - 8.38511146386754, 6.213782838661589, - 4.4539145819668216, 3.0879216862145458, - 2.0707575453161375, 1.3431694604339544, - 0.8426960266006469, 0.5113869779741611, - 0.624479487346518, 1.029057065092317, - 1.6402094933940163, 2.528702657702978, - 3.7708112146549193, 5.438891513917972, - 7.587952155023488, 10.239467045177655, - 13.36499198416051, 16.873245829727516, - 20.60474036191124, 24.337403367979306, - 27.80480893920172, 30.725868775068133, - 32.84180430508374, 33.95380324486458, - 33.95380324486458, 32.84180430508374, - 30.725868775068133, 27.804808939201713, - 24.337403367979302, 20.60474036191124, - 16.873245829727516, 13.364991984160506, - 10.239467045177653, 7.587952155023479, - 5.438891513917972, 3.7708112146549193, - 2.528702657702978, 1.6402094933940148, - 1.029057065092317, 0.624479487346518, - 0.7376074103159139, 1.2154764603642687, - 1.937342540967424, 2.9867911702474954, - 4.453914581966821, 6.424176879879071, - 8.962551776440092, 12.094403296257038, - 15.786134414467007, 19.92992790359346, - 24.337403367979306, 28.746258981774883, - 32.841804305083734, 36.292030332629274, - 38.79127932048947, 40.1047230362006, - 40.1047230362006, 38.79127932048947, - 36.292030332629274, 32.84180430508372, - 28.74625898177488, 24.337403367979306, - 19.92992790359346, 15.786134414467003, - 12.094403296257036, 8.962551776440083, - 6.424176879879071, 4.453914581966821, - 2.9867911702474954, 1.9373425409674223, - 1.2154764603642687, 0.7376074103159139, - 0.8426960266006467, 1.3886481741511985, - 2.2133601677597268, 3.412326145659869, - 5.088473988397974, 7.3394440662346385, - 10.239467045177655, 13.817520620483494, - 18.035221122246448, 22.7693903557753, - 27.80480893920172, 32.841804305083734, - 37.52085134616084, 41.462638974136944, - 44.317961686599176, 45.81853473523395, - 45.81853473523395, 44.317961686599176, - 41.462638974136944, 37.52085134616083, - 32.84180430508373, 27.80480893920172, - 22.7693903557753, 18.03522112224644, - 13.81752062048349, 10.239467045177644, - 7.3394440662346385, 5.088473988397974, - 3.412326145659869, 2.2133601677597246, - 1.3886481741511985, 0.8426960266006467, - 0.9312262345416489, 1.5345338882566935, - 2.4458867606410766, 3.770811214654918, - 5.6230483142227445, 8.110496128715798, - 11.315183695191251, 15.269132987394766, - 19.929927903593466, 25.16144965029706, - 30.725868775068133, 36.292030332629274, - 41.462638974136944, 45.818534735233946, - 48.9738260075251, 50.632043141135796, - 50.632043141135796, 48.9738260075251, - 45.818534735233946, 41.46263897413693, - 36.29203033262927, 30.725868775068133, - 25.16144965029706, 19.92992790359346, - 15.269132987394762, 11.315183695191239, - 8.110496128715798, 5.6230483142227445, - 3.770811214654918, 2.4458867606410744, - 1.5345338882566935, 0.9312262345416489, - 0.9953550860502569, 1.6402094933940174, - 2.6143226391225554, 4.030487954293403, - 6.010279275929925, 8.669025068952923, - 12.094403296257042, 16.320641123326727, - 21.302401465547916, 26.894191715021453, - 32.84180430508374, 38.79127932048947, - 44.317961686599176, 48.9738260075251, - 52.34640626713388, 54.11881644684437, - 54.11881644684437, 52.34640626713388, - 48.9738260075251, 44.31796168659916, - 38.791279320489465, 32.84180430508374, - 26.894191715021453, 21.30240146554791, - 16.320641123326723, 12.09440329625703, - 8.669025068952923, 6.010279275929925, - 4.030487954293403, 2.6143226391225527, - 1.6402094933940174, 0.9953550860502569, - 1.0290570650923165, 1.6957457605469852, - 2.702841649099706, 4.166957263054249, - 6.2137828386615945, 8.962551776440094, - 12.503910749556066, 16.873245829727512, - 22.023684852551362, 27.804808939201713, - 33.95380324486458, 40.1047230362006, - 45.81853473523395, 50.632043141135796, - 54.11881644684437, 55.9512391101074, - 55.9512391101074, 54.11881644684437, - 50.632043141135796, 45.81853473523394, - 40.104723036200596, 33.95380324486458, - 27.804808939201713, 22.023684852551355, - 16.873245829727512, 12.503910749556052, - 8.962551776440094, 6.2137828386615945, - 4.166957263054249, 2.7028416490997036, - 1.6957457605469852, 1.0290570650923165, - 1.0290570650923165, 1.6957457605469852, - 2.702841649099706, 4.166957263054249, - 6.2137828386615945, 8.962551776440094, - 12.503910749556066, 16.873245829727512, - 22.023684852551362, 27.804808939201713, - 33.95380324486458, 40.1047230362006, - 45.81853473523395, 50.632043141135796, - 54.11881644684437, 55.9512391101074, - 55.9512391101074, 54.11881644684437, - 50.632043141135796, 45.81853473523394, - 40.104723036200596, 33.95380324486458, - 27.804808939201713, 22.023684852551355, - 16.873245829727512, 12.503910749556052, - 8.962551776440094, 6.2137828386615945, - 4.166957263054249, 2.7028416490997036, - 1.6957457605469852, 1.0290570650923165, - 0.9953550860502569, 1.6402094933940174, - 2.6143226391225554, 4.030487954293403, - 6.010279275929925, 8.669025068952923, - 12.094403296257042, 16.320641123326727, - 21.302401465547916, 26.894191715021453, - 32.84180430508374, 38.79127932048947, - 44.317961686599176, 48.9738260075251, - 52.34640626713388, 54.11881644684437, - 54.11881644684437, 52.34640626713388, - 48.9738260075251, 44.31796168659916, - 38.791279320489465, 32.84180430508374, - 26.894191715021453, 21.30240146554791, - 16.320641123326723, 12.09440329625703, - 8.669025068952923, 6.010279275929925, - 4.030487954293403, 2.6143226391225527, - 1.6402094933940174, 0.9953550860502569, - 0.9312262345416489, 1.5345338882566935, - 2.4458867606410766, 3.770811214654918, - 5.6230483142227445, 8.110496128715798, - 11.315183695191251, 15.269132987394766, - 19.929927903593466, 25.16144965029706, - 30.725868775068133, 36.292030332629274, - 41.462638974136944, 45.818534735233946, - 48.9738260075251, 50.632043141135796, - 50.632043141135796, 48.9738260075251, - 45.818534735233946, 41.46263897413693, - 36.29203033262927, 30.725868775068133, - 25.16144965029706, 19.92992790359346, - 15.269132987394762, 11.315183695191239, - 8.110496128715798, 5.6230483142227445, - 3.770811214654918, 2.4458867606410744, - 1.5345338882566935, 0.9312262345416489, - 0.8426960266006465, 1.388648174151198, - 2.213360167759726, 3.412326145659868, - 5.088473988397973, 7.339444066234637, - 10.239467045177651, 13.81752062048349, - 18.03522112224644, 22.769390355775293, - 27.804808939201713, 32.84180430508372, - 37.52085134616083, 41.46263897413693, - 44.31796168659916, 45.81853473523394, - 45.81853473523394, 44.31796168659916, - 41.46263897413693, 37.52085134616082, - 32.84180430508372, 27.804808939201713, - 22.769390355775293, 18.035221122246437, - 13.817520620483487, 10.23946704517764, - 7.339444066234637, 5.088473988397973, - 3.412326145659868, 2.213360167759724, - 1.388648174151198, 0.8426960266006465, - 0.7376074103159138, 1.2154764603642685, - 1.9373425409674239, 2.986791170247495, - 4.45391458196682, 6.42417687987907, - 8.96255177644009, 12.094403296257036, - 15.786134414467005, 19.929927903593455, - 24.337403367979302, 28.74625898177488, - 32.84180430508373, 36.29203033262927, - 38.791279320489465, 40.104723036200596, - 40.104723036200596, 38.791279320489465, - 36.29203033262927, 32.84180430508372, - 28.746258981774876, 24.337403367979302, - 19.929927903593455, 15.786134414467, - 12.094403296257035, 8.962551776440081, - 6.42417687987907, 4.45391458196682, - 2.986791170247495, 1.937342540967422, - 1.2154764603642685, 0.7376074103159138, - 0.624479487346518, 1.029057065092317, - 1.6402094933940163, 2.528702657702978, - 3.7708112146549193, 5.438891513917972, - 7.587952155023488, 10.239467045177655, - 13.36499198416051, 16.873245829727516, - 20.60474036191124, 24.337403367979306, - 27.80480893920172, 30.725868775068133, - 32.84180430508374, 33.95380324486458, - 33.95380324486458, 32.84180430508374, - 30.725868775068133, 27.804808939201713, - 24.337403367979302, 20.60474036191124, - 16.873245829727516, 13.364991984160506, - 10.239467045177653, 7.587952155023479, - 5.438891513917972, 3.7708112146549193, - 2.528702657702978, 1.6402094933940148, - 1.029057065092317, 0.624479487346518, - 0.5113869779741611, 0.8426960266006469, - 1.3431694604339557, 2.0707575453161375, - 3.0879216862145458, 4.4539145819668216, - 6.213782838661596, 8.385111463867542, - 10.944607468965954, 13.81752062048349, - 16.873245829727516, 19.92992790359346, - 22.7693903557753, 25.16144965029706, - 26.894191715021453, 27.804808939201713, - 27.804808939201713, 26.894191715021453, - 25.16144965029706, 22.769390355775293, - 19.929927903593455, 16.873245829727516, - 13.81752062048349, 10.94460746896595, - 8.38511146386754, 6.213782838661589, - 4.4539145819668216, 3.0879216862145458, - 2.0707575453161375, 1.3431694604339544, - 0.8426960266006469, 0.5113869779741611, - 0.40506035000019325, 0.6674842383176155, - 1.0639001679476452, 1.640209493394016, - 2.445886760641077, 3.527864957745563, - 4.921824684359942, 6.641694706032255, - 8.669025068952921, 10.94460746896595, - 13.364991984160506, 15.786134414467003, - 18.03522112224644, 19.92992790359346, - 21.30240146554791, 22.023684852551355, - 22.023684852551355, 21.30240146554791, - 19.92992790359346, 18.035221122246437, - 15.786134414467, 13.364991984160506, - 10.94460746896595, 8.66902506895292, - 6.641694706032253, 4.921824684359937, - 3.527864957745563, 2.445886760641077, - 1.640209493394016, 1.0639001679476443, - 0.6674842383176155, 0.40506035000019325, - 0.31033330285948746, 0.5113869779741612, - 0.8150974367937979, 1.256631584568099, - 1.8738938946990886, 2.7028416490997067, - 3.7708112146549184, 5.088473988397973, - 6.641694706032256, 8.38511146386754, - 10.239467045177653, 12.094403296257036, - 13.81752062048349, 15.269132987394762, - 16.320641123326723, 16.873245829727512, - 16.873245829727512, 16.320641123326723, - 15.269132987394762, 13.817520620483487, - 12.094403296257035, 10.239467045177653, - 8.38511146386754, 6.641694706032253, - 5.088473988397972, 3.7708112146549144, - 2.7028416490997067, 1.8738938946990886, - 1.256631584568099, 0.8150974367937971, - 0.5113869779741612, 0.31033330285948746, - 0.22997234561316438, 0.37896307536801666, - 0.6040275655739762, 0.9312262345416477, - 1.3886481741511962, 2.002939510961384, - 2.7943578465739445, 3.7708112146549153, - 4.9218246843599385, 6.213782838661589, - 7.587952155023479, 8.962551776440083, - 10.239467045177644, 11.315183695191239, - 12.09440329625703, 12.503910749556052, - 12.503910749556052, 12.09440329625703, - 11.315183695191239, 10.23946704517764, - 8.962551776440081, 7.587952155023479, - 6.213782838661589, 4.921824684359937, - 3.7708112146549144, 2.7943578465739414, - 2.002939510961384, 1.3886481741511962, - 0.9312262345416477, 0.6040275655739756, - 0.37896307536801666, 0.22997234561316438, - 0.16483955267999187, 0.2716331116219307, - 0.43295481224112503, 0.6674842383176153, - 0.9953550860502568, 1.4356667631130016, - 2.002939510961386, 2.702841649099707, - 3.527864957745564, 4.4539145819668216, - 5.438891513917972, 6.424176879879071, - 7.3394440662346385, 8.110496128715798, - 8.669025068952923, 8.962551776440094, - 8.962551776440094, 8.669025068952923, - 8.110496128715798, 7.339444066234637, - 6.42417687987907, 5.438891513917972, - 4.4539145819668216, 3.527864957745563, - 2.7028416490997067, 2.002939510961384, - 1.4356667631130016, 0.9953550860502568, - 0.6674842383176153, 0.4329548122411246, - 0.2716331116219307, 0.16483955267999187, - 0.1142841022428579, 0.1883246211023866, - 0.3001697785771043, 0.462770225332246, - 0.6900847555862331, 0.9953550860502568, - 1.3886481741511976, 1.8738938946990888, - 2.445886760641078, 3.0879216862145458, - 3.7708112146549193, 4.453914581966821, - 5.088473988397974, 5.6230483142227445, - 6.010279275929925, 6.2137828386615945, - 6.2137828386615945, 6.010279275929925, - 5.6230483142227445, 5.088473988397973, - 4.45391458196682, 3.7708112146549193, - 3.0879216862145458, 2.445886760641077, - 1.8738938946990886, 1.3886481741511962, - 0.9953550860502568, 0.6900847555862331, - 0.462770225332246, 0.30016977857710403, - 0.1883246211023866, 0.1142841022428579, - 0.07663881765058349, 0.12629032396046155, - 0.201293587411668, 0.31033330285948735, - 0.462770225332246, 0.6674842383176153, - 0.9312262345416487, 1.2566315845680993, - 1.6402094933940163, 2.0707575453161375, - 2.528702657702978, 2.9867911702474954, - 3.412326145659869, 3.770811214654918, - 4.030487954293403, 4.166957263054249, - 4.166957263054249, 4.030487954293403, - 3.770811214654918, 3.412326145659868, - 2.986791170247495, 2.528702657702978, - 2.0707575453161375, 1.640209493394016, - 1.256631584568099, 0.9312262345416477, - 0.6674842383176153, 0.462770225332246, - 0.31033330285948735, 0.2012935874116678, - 0.12629032396046155, 0.07663881765058349, - 0.04971075420435811, 0.08191654627828754, - 0.13056641990951587, 0.2012935874116678, - 0.30016977857710403, 0.4329548122411246, - 0.6040275655739763, 0.8150974367937973, - 1.0639001679476445, 1.3431694604339544, - 1.6402094933940148, 1.9373425409674223, - 2.2133601677597246, 2.4458867606410744, - 2.6143226391225527, 2.7028416490997036, - 2.7028416490997036, 2.6143226391225527, - 2.4458867606410744, 2.213360167759724, - 1.937342540967422, 1.6402094933940148, - 1.3431694604339544, 1.0639001679476443, - 0.8150974367937971, 0.6040275655739756, - 0.4329548122411246, 0.30016977857710403, - 0.2012935874116678, 0.13056641990951576, - 0.08191654627828754, 0.04971075420435811, - 0.03118821286615593, 0.05139392317575348, - 0.08191654627828761, 0.12629032396046155, - 0.1883246211023866, 0.2716331116219307, - 0.3789630753680171, 0.5113869779741613, - 0.6674842383176157, 0.8426960266006469, - 1.029057065092317, 1.2154764603642687, - 1.3886481741511985, 1.5345338882566935, - 1.6402094933940174, 1.6957457605469852, - 1.6957457605469852, 1.6402094933940174, - 1.5345338882566935, 1.388648174151198, - 1.2154764603642685, 1.029057065092317, - 0.8426960266006469, 0.6674842383176155, - 0.5113869779741612, 0.37896307536801666, - 0.2716331116219307, 0.1883246211023866, - 0.12629032396046155, 0.08191654627828754, - 0.05139392317575348, 0.03118821286615593, - 0.018926452033215378, 0.03118821286615593, - 0.04971075420435816, 0.07663881765058349, - 0.1142841022428579, 0.16483955267999187, - 0.22997234561316462, 0.3103333028594875, - 0.40506035000019336, 0.5113869779741611, - 0.624479487346518, 0.7376074103159139, - 0.8426960266006467, 0.9312262345416489, - 0.9953550860502569, 1.0290570650923165, - 1.0290570650923165, 0.9953550860502569, - 0.9312262345416489, 0.8426960266006465, - 0.7376074103159138, 0.624479487346518, - 0.5113869779741611, 0.40506035000019325, - 0.31033330285948746, 0.22997234561316438, - 0.16483955267999187, 0.1142841022428579, - 0.07663881765058349, 0.04971075420435811, - 0.03118821286615593, 0.018926452033215378]) + return np.array( + [ + 0.018926452033215378, + 0.03118821286615593, + 0.04971075420435816, + 0.07663881765058349, + 0.1142841022428579, + 0.16483955267999187, + 0.22997234561316462, + 0.3103333028594875, + 0.40506035000019336, + 0.5113869779741611, + 0.624479487346518, + 0.7376074103159139, + 0.8426960266006467, + 0.9312262345416489, + 0.9953550860502569, + 1.0290570650923165, + 1.0290570650923165, + 0.9953550860502569, + 0.9312262345416489, + 0.8426960266006465, + 0.7376074103159138, + 0.624479487346518, + 0.5113869779741611, + 0.40506035000019325, + 0.31033330285948746, + 0.22997234561316438, + 0.16483955267999187, + 0.1142841022428579, + 0.07663881765058349, + 0.04971075420435811, + 0.03118821286615593, + 0.018926452033215378, + 0.03118821286615593, + 0.05139392317575348, + 0.08191654627828761, + 0.12629032396046155, + 0.1883246211023866, + 0.2716331116219307, + 0.3789630753680171, + 0.5113869779741613, + 0.6674842383176157, + 0.8426960266006469, + 1.029057065092317, + 1.2154764603642687, + 1.3886481741511985, + 1.5345338882566935, + 1.6402094933940174, + 1.6957457605469852, + 1.6957457605469852, + 1.6402094933940174, + 1.5345338882566935, + 1.388648174151198, + 1.2154764603642685, + 1.029057065092317, + 0.8426960266006469, + 0.6674842383176155, + 0.5113869779741612, + 0.37896307536801666, + 0.2716331116219307, + 0.1883246211023866, + 0.12629032396046155, + 0.08191654627828754, + 0.05139392317575348, + 0.03118821286615593, + 0.04971075420435816, + 0.08191654627828761, + 0.130566419909516, + 0.201293587411668, + 0.3001697785771043, + 0.43295481224112503, + 0.6040275655739769, + 0.815097436793798, + 1.0639001679476456, + 1.3431694604339557, + 1.6402094933940163, + 1.937342540967424, + 2.2133601677597268, + 2.4458867606410766, + 2.6143226391225554, + 2.702841649099706, + 2.702841649099706, + 2.6143226391225554, + 2.4458867606410766, + 2.213360167759726, + 1.9373425409674239, + 1.6402094933940163, + 1.3431694604339557, + 1.0639001679476452, + 0.8150974367937979, + 0.6040275655739762, + 0.43295481224112503, + 0.3001697785771043, + 0.201293587411668, + 0.13056641990951587, + 0.08191654627828761, + 0.04971075420435816, + 0.07663881765058349, + 0.12629032396046155, + 0.201293587411668, + 0.31033330285948735, + 0.462770225332246, + 0.6674842383176153, + 0.9312262345416487, + 1.2566315845680993, + 1.6402094933940163, + 2.0707575453161375, + 2.528702657702978, + 2.9867911702474954, + 3.412326145659869, + 3.770811214654918, + 4.030487954293403, + 4.166957263054249, + 4.166957263054249, + 4.030487954293403, + 3.770811214654918, + 3.412326145659868, + 2.986791170247495, + 2.528702657702978, + 2.0707575453161375, + 1.640209493394016, + 1.256631584568099, + 0.9312262345416477, + 0.6674842383176153, + 0.462770225332246, + 0.31033330285948735, + 0.2012935874116678, + 0.12629032396046155, + 0.07663881765058349, + 0.1142841022428579, + 0.1883246211023866, + 0.3001697785771043, + 0.462770225332246, + 0.6900847555862331, + 0.9953550860502568, + 1.3886481741511976, + 1.8738938946990888, + 2.445886760641078, + 3.0879216862145458, + 3.7708112146549193, + 4.453914581966821, + 5.088473988397974, + 5.6230483142227445, + 6.010279275929925, + 6.2137828386615945, + 6.2137828386615945, + 6.010279275929925, + 5.6230483142227445, + 5.088473988397973, + 4.45391458196682, + 3.7708112146549193, + 3.0879216862145458, + 2.445886760641077, + 1.8738938946990886, + 1.3886481741511962, + 0.9953550860502568, + 0.6900847555862331, + 0.462770225332246, + 0.30016977857710403, + 0.1883246211023866, + 0.1142841022428579, + 0.16483955267999187, + 0.2716331116219307, + 0.43295481224112503, + 0.6674842383176153, + 0.9953550860502568, + 1.4356667631130016, + 2.002939510961386, + 2.702841649099707, + 3.527864957745564, + 4.4539145819668216, + 5.438891513917972, + 6.424176879879071, + 7.3394440662346385, + 8.110496128715798, + 8.669025068952923, + 8.962551776440094, + 8.962551776440094, + 8.669025068952923, + 8.110496128715798, + 7.339444066234637, + 6.42417687987907, + 5.438891513917972, + 4.4539145819668216, + 3.527864957745563, + 2.7028416490997067, + 2.002939510961384, + 1.4356667631130016, + 0.9953550860502568, + 0.6674842383176153, + 0.4329548122411246, + 0.2716331116219307, + 0.16483955267999187, + 0.22997234561316462, + 0.3789630753680171, + 0.6040275655739769, + 0.9312262345416487, + 1.3886481741511976, + 2.002939510961386, + 2.794357846573947, + 3.7708112146549193, + 4.921824684359944, + 6.213782838661596, + 7.587952155023488, + 8.962551776440092, + 10.239467045177655, + 11.315183695191251, + 12.094403296257042, + 12.503910749556066, + 12.503910749556066, + 12.094403296257042, + 11.315183695191251, + 10.239467045177651, + 8.96255177644009, + 7.587952155023488, + 6.213782838661596, + 4.921824684359942, + 3.7708112146549184, + 2.7943578465739445, + 2.002939510961386, + 1.3886481741511976, + 0.9312262345416487, + 0.6040275655739763, + 0.3789630753680171, + 0.22997234561316462, + 0.3103333028594875, + 0.5113869779741613, + 0.815097436793798, + 1.2566315845680993, + 1.8738938946990888, + 2.702841649099707, + 3.7708112146549193, + 5.088473988397975, + 6.641694706032257, + 8.385111463867542, + 10.239467045177655, + 12.094403296257038, + 13.817520620483494, + 15.269132987394766, + 16.320641123326727, + 16.873245829727512, + 16.873245829727512, + 16.320641123326727, + 15.269132987394766, + 13.81752062048349, + 12.094403296257036, + 10.239467045177655, + 8.385111463867542, + 6.641694706032255, + 5.088473988397973, + 3.7708112146549153, + 2.702841649099707, + 1.8738938946990888, + 1.2566315845680993, + 0.8150974367937973, + 0.5113869779741613, + 0.3103333028594875, + 0.40506035000019336, + 0.6674842383176157, + 1.0639001679476456, + 1.6402094933940163, + 2.445886760641078, + 3.527864957745564, + 4.921824684359944, + 6.641694706032257, + 8.669025068952925, + 10.944607468965954, + 13.36499198416051, + 15.786134414467007, + 18.035221122246448, + 19.929927903593466, + 21.302401465547916, + 22.023684852551362, + 22.023684852551362, + 21.302401465547916, + 19.929927903593466, + 18.03522112224644, + 15.786134414467005, + 13.36499198416051, + 10.944607468965954, + 8.669025068952921, + 6.641694706032256, + 4.9218246843599385, + 3.527864957745564, + 2.445886760641078, + 1.6402094933940163, + 1.0639001679476445, + 0.6674842383176157, + 0.40506035000019336, + 0.5113869779741611, + 0.8426960266006469, + 1.3431694604339557, + 2.0707575453161375, + 3.0879216862145458, + 4.4539145819668216, + 6.213782838661596, + 8.385111463867542, + 10.944607468965954, + 13.81752062048349, + 16.873245829727516, + 19.92992790359346, + 22.7693903557753, + 25.16144965029706, + 26.894191715021453, + 27.804808939201713, + 27.804808939201713, + 26.894191715021453, + 25.16144965029706, + 22.769390355775293, + 19.929927903593455, + 16.873245829727516, + 13.81752062048349, + 10.94460746896595, + 8.38511146386754, + 6.213782838661589, + 4.4539145819668216, + 3.0879216862145458, + 2.0707575453161375, + 1.3431694604339544, + 0.8426960266006469, + 0.5113869779741611, + 0.624479487346518, + 1.029057065092317, + 1.6402094933940163, + 2.528702657702978, + 3.7708112146549193, + 5.438891513917972, + 7.587952155023488, + 10.239467045177655, + 13.36499198416051, + 16.873245829727516, + 20.60474036191124, + 24.337403367979306, + 27.80480893920172, + 30.725868775068133, + 32.84180430508374, + 33.95380324486458, + 33.95380324486458, + 32.84180430508374, + 30.725868775068133, + 27.804808939201713, + 24.337403367979302, + 20.60474036191124, + 16.873245829727516, + 13.364991984160506, + 10.239467045177653, + 7.587952155023479, + 5.438891513917972, + 3.7708112146549193, + 2.528702657702978, + 1.6402094933940148, + 1.029057065092317, + 0.624479487346518, + 0.7376074103159139, + 1.2154764603642687, + 1.937342540967424, + 2.9867911702474954, + 4.453914581966821, + 6.424176879879071, + 8.962551776440092, + 12.094403296257038, + 15.786134414467007, + 19.92992790359346, + 24.337403367979306, + 28.746258981774883, + 32.841804305083734, + 36.292030332629274, + 38.79127932048947, + 40.1047230362006, + 40.1047230362006, + 38.79127932048947, + 36.292030332629274, + 32.84180430508372, + 28.74625898177488, + 24.337403367979306, + 19.92992790359346, + 15.786134414467003, + 12.094403296257036, + 8.962551776440083, + 6.424176879879071, + 4.453914581966821, + 2.9867911702474954, + 1.9373425409674223, + 1.2154764603642687, + 0.7376074103159139, + 0.8426960266006467, + 1.3886481741511985, + 2.2133601677597268, + 3.412326145659869, + 5.088473988397974, + 7.3394440662346385, + 10.239467045177655, + 13.817520620483494, + 18.035221122246448, + 22.7693903557753, + 27.80480893920172, + 32.841804305083734, + 37.52085134616084, + 41.462638974136944, + 44.317961686599176, + 45.81853473523395, + 45.81853473523395, + 44.317961686599176, + 41.462638974136944, + 37.52085134616083, + 32.84180430508373, + 27.80480893920172, + 22.7693903557753, + 18.03522112224644, + 13.81752062048349, + 10.239467045177644, + 7.3394440662346385, + 5.088473988397974, + 3.412326145659869, + 2.2133601677597246, + 1.3886481741511985, + 0.8426960266006467, + 0.9312262345416489, + 1.5345338882566935, + 2.4458867606410766, + 3.770811214654918, + 5.6230483142227445, + 8.110496128715798, + 11.315183695191251, + 15.269132987394766, + 19.929927903593466, + 25.16144965029706, + 30.725868775068133, + 36.292030332629274, + 41.462638974136944, + 45.818534735233946, + 48.9738260075251, + 50.632043141135796, + 50.632043141135796, + 48.9738260075251, + 45.818534735233946, + 41.46263897413693, + 36.29203033262927, + 30.725868775068133, + 25.16144965029706, + 19.92992790359346, + 15.269132987394762, + 11.315183695191239, + 8.110496128715798, + 5.6230483142227445, + 3.770811214654918, + 2.4458867606410744, + 1.5345338882566935, + 0.9312262345416489, + 0.9953550860502569, + 1.6402094933940174, + 2.6143226391225554, + 4.030487954293403, + 6.010279275929925, + 8.669025068952923, + 12.094403296257042, + 16.320641123326727, + 21.302401465547916, + 26.894191715021453, + 32.84180430508374, + 38.79127932048947, + 44.317961686599176, + 48.9738260075251, + 52.34640626713388, + 54.11881644684437, + 54.11881644684437, + 52.34640626713388, + 48.9738260075251, + 44.31796168659916, + 38.791279320489465, + 32.84180430508374, + 26.894191715021453, + 21.30240146554791, + 16.320641123326723, + 12.09440329625703, + 8.669025068952923, + 6.010279275929925, + 4.030487954293403, + 2.6143226391225527, + 1.6402094933940174, + 0.9953550860502569, + 1.0290570650923165, + 1.6957457605469852, + 2.702841649099706, + 4.166957263054249, + 6.2137828386615945, + 8.962551776440094, + 12.503910749556066, + 16.873245829727512, + 22.023684852551362, + 27.804808939201713, + 33.95380324486458, + 40.1047230362006, + 45.81853473523395, + 50.632043141135796, + 54.11881644684437, + 55.9512391101074, + 55.9512391101074, + 54.11881644684437, + 50.632043141135796, + 45.81853473523394, + 40.104723036200596, + 33.95380324486458, + 27.804808939201713, + 22.023684852551355, + 16.873245829727512, + 12.503910749556052, + 8.962551776440094, + 6.2137828386615945, + 4.166957263054249, + 2.7028416490997036, + 1.6957457605469852, + 1.0290570650923165, + 1.0290570650923165, + 1.6957457605469852, + 2.702841649099706, + 4.166957263054249, + 6.2137828386615945, + 8.962551776440094, + 12.503910749556066, + 16.873245829727512, + 22.023684852551362, + 27.804808939201713, + 33.95380324486458, + 40.1047230362006, + 45.81853473523395, + 50.632043141135796, + 54.11881644684437, + 55.9512391101074, + 55.9512391101074, + 54.11881644684437, + 50.632043141135796, + 45.81853473523394, + 40.104723036200596, + 33.95380324486458, + 27.804808939201713, + 22.023684852551355, + 16.873245829727512, + 12.503910749556052, + 8.962551776440094, + 6.2137828386615945, + 4.166957263054249, + 2.7028416490997036, + 1.6957457605469852, + 1.0290570650923165, + 0.9953550860502569, + 1.6402094933940174, + 2.6143226391225554, + 4.030487954293403, + 6.010279275929925, + 8.669025068952923, + 12.094403296257042, + 16.320641123326727, + 21.302401465547916, + 26.894191715021453, + 32.84180430508374, + 38.79127932048947, + 44.317961686599176, + 48.9738260075251, + 52.34640626713388, + 54.11881644684437, + 54.11881644684437, + 52.34640626713388, + 48.9738260075251, + 44.31796168659916, + 38.791279320489465, + 32.84180430508374, + 26.894191715021453, + 21.30240146554791, + 16.320641123326723, + 12.09440329625703, + 8.669025068952923, + 6.010279275929925, + 4.030487954293403, + 2.6143226391225527, + 1.6402094933940174, + 0.9953550860502569, + 0.9312262345416489, + 1.5345338882566935, + 2.4458867606410766, + 3.770811214654918, + 5.6230483142227445, + 8.110496128715798, + 11.315183695191251, + 15.269132987394766, + 19.929927903593466, + 25.16144965029706, + 30.725868775068133, + 36.292030332629274, + 41.462638974136944, + 45.818534735233946, + 48.9738260075251, + 50.632043141135796, + 50.632043141135796, + 48.9738260075251, + 45.818534735233946, + 41.46263897413693, + 36.29203033262927, + 30.725868775068133, + 25.16144965029706, + 19.92992790359346, + 15.269132987394762, + 11.315183695191239, + 8.110496128715798, + 5.6230483142227445, + 3.770811214654918, + 2.4458867606410744, + 1.5345338882566935, + 0.9312262345416489, + 0.8426960266006465, + 1.388648174151198, + 2.213360167759726, + 3.412326145659868, + 5.088473988397973, + 7.339444066234637, + 10.239467045177651, + 13.81752062048349, + 18.03522112224644, + 22.769390355775293, + 27.804808939201713, + 32.84180430508372, + 37.52085134616083, + 41.46263897413693, + 44.31796168659916, + 45.81853473523394, + 45.81853473523394, + 44.31796168659916, + 41.46263897413693, + 37.52085134616082, + 32.84180430508372, + 27.804808939201713, + 22.769390355775293, + 18.035221122246437, + 13.817520620483487, + 10.23946704517764, + 7.339444066234637, + 5.088473988397973, + 3.412326145659868, + 2.213360167759724, + 1.388648174151198, + 0.8426960266006465, + 0.7376074103159138, + 1.2154764603642685, + 1.9373425409674239, + 2.986791170247495, + 4.45391458196682, + 6.42417687987907, + 8.96255177644009, + 12.094403296257036, + 15.786134414467005, + 19.929927903593455, + 24.337403367979302, + 28.74625898177488, + 32.84180430508373, + 36.29203033262927, + 38.791279320489465, + 40.104723036200596, + 40.104723036200596, + 38.791279320489465, + 36.29203033262927, + 32.84180430508372, + 28.746258981774876, + 24.337403367979302, + 19.929927903593455, + 15.786134414467, + 12.094403296257035, + 8.962551776440081, + 6.42417687987907, + 4.45391458196682, + 2.986791170247495, + 1.937342540967422, + 1.2154764603642685, + 0.7376074103159138, + 0.624479487346518, + 1.029057065092317, + 1.6402094933940163, + 2.528702657702978, + 3.7708112146549193, + 5.438891513917972, + 7.587952155023488, + 10.239467045177655, + 13.36499198416051, + 16.873245829727516, + 20.60474036191124, + 24.337403367979306, + 27.80480893920172, + 30.725868775068133, + 32.84180430508374, + 33.95380324486458, + 33.95380324486458, + 32.84180430508374, + 30.725868775068133, + 27.804808939201713, + 24.337403367979302, + 20.60474036191124, + 16.873245829727516, + 13.364991984160506, + 10.239467045177653, + 7.587952155023479, + 5.438891513917972, + 3.7708112146549193, + 2.528702657702978, + 1.6402094933940148, + 1.029057065092317, + 0.624479487346518, + 0.5113869779741611, + 0.8426960266006469, + 1.3431694604339557, + 2.0707575453161375, + 3.0879216862145458, + 4.4539145819668216, + 6.213782838661596, + 8.385111463867542, + 10.944607468965954, + 13.81752062048349, + 16.873245829727516, + 19.92992790359346, + 22.7693903557753, + 25.16144965029706, + 26.894191715021453, + 27.804808939201713, + 27.804808939201713, + 26.894191715021453, + 25.16144965029706, + 22.769390355775293, + 19.929927903593455, + 16.873245829727516, + 13.81752062048349, + 10.94460746896595, + 8.38511146386754, + 6.213782838661589, + 4.4539145819668216, + 3.0879216862145458, + 2.0707575453161375, + 1.3431694604339544, + 0.8426960266006469, + 0.5113869779741611, + 0.40506035000019325, + 0.6674842383176155, + 1.0639001679476452, + 1.640209493394016, + 2.445886760641077, + 3.527864957745563, + 4.921824684359942, + 6.641694706032255, + 8.669025068952921, + 10.94460746896595, + 13.364991984160506, + 15.786134414467003, + 18.03522112224644, + 19.92992790359346, + 21.30240146554791, + 22.023684852551355, + 22.023684852551355, + 21.30240146554791, + 19.92992790359346, + 18.035221122246437, + 15.786134414467, + 13.364991984160506, + 10.94460746896595, + 8.66902506895292, + 6.641694706032253, + 4.921824684359937, + 3.527864957745563, + 2.445886760641077, + 1.640209493394016, + 1.0639001679476443, + 0.6674842383176155, + 0.40506035000019325, + 0.31033330285948746, + 0.5113869779741612, + 0.8150974367937979, + 1.256631584568099, + 1.8738938946990886, + 2.7028416490997067, + 3.7708112146549184, + 5.088473988397973, + 6.641694706032256, + 8.38511146386754, + 10.239467045177653, + 12.094403296257036, + 13.81752062048349, + 15.269132987394762, + 16.320641123326723, + 16.873245829727512, + 16.873245829727512, + 16.320641123326723, + 15.269132987394762, + 13.817520620483487, + 12.094403296257035, + 10.239467045177653, + 8.38511146386754, + 6.641694706032253, + 5.088473988397972, + 3.7708112146549144, + 2.7028416490997067, + 1.8738938946990886, + 1.256631584568099, + 0.8150974367937971, + 0.5113869779741612, + 0.31033330285948746, + 0.22997234561316438, + 0.37896307536801666, + 0.6040275655739762, + 0.9312262345416477, + 1.3886481741511962, + 2.002939510961384, + 2.7943578465739445, + 3.7708112146549153, + 4.9218246843599385, + 6.213782838661589, + 7.587952155023479, + 8.962551776440083, + 10.239467045177644, + 11.315183695191239, + 12.09440329625703, + 12.503910749556052, + 12.503910749556052, + 12.09440329625703, + 11.315183695191239, + 10.23946704517764, + 8.962551776440081, + 7.587952155023479, + 6.213782838661589, + 4.921824684359937, + 3.7708112146549144, + 2.7943578465739414, + 2.002939510961384, + 1.3886481741511962, + 0.9312262345416477, + 0.6040275655739756, + 0.37896307536801666, + 0.22997234561316438, + 0.16483955267999187, + 0.2716331116219307, + 0.43295481224112503, + 0.6674842383176153, + 0.9953550860502568, + 1.4356667631130016, + 2.002939510961386, + 2.702841649099707, + 3.527864957745564, + 4.4539145819668216, + 5.438891513917972, + 6.424176879879071, + 7.3394440662346385, + 8.110496128715798, + 8.669025068952923, + 8.962551776440094, + 8.962551776440094, + 8.669025068952923, + 8.110496128715798, + 7.339444066234637, + 6.42417687987907, + 5.438891513917972, + 4.4539145819668216, + 3.527864957745563, + 2.7028416490997067, + 2.002939510961384, + 1.4356667631130016, + 0.9953550860502568, + 0.6674842383176153, + 0.4329548122411246, + 0.2716331116219307, + 0.16483955267999187, + 0.1142841022428579, + 0.1883246211023866, + 0.3001697785771043, + 0.462770225332246, + 0.6900847555862331, + 0.9953550860502568, + 1.3886481741511976, + 1.8738938946990888, + 2.445886760641078, + 3.0879216862145458, + 3.7708112146549193, + 4.453914581966821, + 5.088473988397974, + 5.6230483142227445, + 6.010279275929925, + 6.2137828386615945, + 6.2137828386615945, + 6.010279275929925, + 5.6230483142227445, + 5.088473988397973, + 4.45391458196682, + 3.7708112146549193, + 3.0879216862145458, + 2.445886760641077, + 1.8738938946990886, + 1.3886481741511962, + 0.9953550860502568, + 0.6900847555862331, + 0.462770225332246, + 0.30016977857710403, + 0.1883246211023866, + 0.1142841022428579, + 0.07663881765058349, + 0.12629032396046155, + 0.201293587411668, + 0.31033330285948735, + 0.462770225332246, + 0.6674842383176153, + 0.9312262345416487, + 1.2566315845680993, + 1.6402094933940163, + 2.0707575453161375, + 2.528702657702978, + 2.9867911702474954, + 3.412326145659869, + 3.770811214654918, + 4.030487954293403, + 4.166957263054249, + 4.166957263054249, + 4.030487954293403, + 3.770811214654918, + 3.412326145659868, + 2.986791170247495, + 2.528702657702978, + 2.0707575453161375, + 1.640209493394016, + 1.256631584568099, + 0.9312262345416477, + 0.6674842383176153, + 0.462770225332246, + 0.31033330285948735, + 0.2012935874116678, + 0.12629032396046155, + 0.07663881765058349, + 0.04971075420435811, + 0.08191654627828754, + 0.13056641990951587, + 0.2012935874116678, + 0.30016977857710403, + 0.4329548122411246, + 0.6040275655739763, + 0.8150974367937973, + 1.0639001679476445, + 1.3431694604339544, + 1.6402094933940148, + 1.9373425409674223, + 2.2133601677597246, + 2.4458867606410744, + 2.6143226391225527, + 2.7028416490997036, + 2.7028416490997036, + 2.6143226391225527, + 2.4458867606410744, + 2.213360167759724, + 1.937342540967422, + 1.6402094933940148, + 1.3431694604339544, + 1.0639001679476443, + 0.8150974367937971, + 0.6040275655739756, + 0.4329548122411246, + 0.30016977857710403, + 0.2012935874116678, + 0.13056641990951576, + 0.08191654627828754, + 0.04971075420435811, + 0.03118821286615593, + 0.05139392317575348, + 0.08191654627828761, + 0.12629032396046155, + 0.1883246211023866, + 0.2716331116219307, + 0.3789630753680171, + 0.5113869779741613, + 0.6674842383176157, + 0.8426960266006469, + 1.029057065092317, + 1.2154764603642687, + 1.3886481741511985, + 1.5345338882566935, + 1.6402094933940174, + 1.6957457605469852, + 1.6957457605469852, + 1.6402094933940174, + 1.5345338882566935, + 1.388648174151198, + 1.2154764603642685, + 1.029057065092317, + 0.8426960266006469, + 0.6674842383176155, + 0.5113869779741612, + 0.37896307536801666, + 0.2716331116219307, + 0.1883246211023866, + 0.12629032396046155, + 0.08191654627828754, + 0.05139392317575348, + 0.03118821286615593, + 0.018926452033215378, + 0.03118821286615593, + 0.04971075420435816, + 0.07663881765058349, + 0.1142841022428579, + 0.16483955267999187, + 0.22997234561316462, + 0.3103333028594875, + 0.40506035000019336, + 0.5113869779741611, + 0.624479487346518, + 0.7376074103159139, + 0.8426960266006467, + 0.9312262345416489, + 0.9953550860502569, + 1.0290570650923165, + 1.0290570650923165, + 0.9953550860502569, + 0.9312262345416489, + 0.8426960266006465, + 0.7376074103159138, + 0.624479487346518, + 0.5113869779741611, + 0.40506035000019325, + 0.31033330285948746, + 0.22997234561316438, + 0.16483955267999187, + 0.1142841022428579, + 0.07663881765058349, + 0.04971075420435811, + 0.03118821286615593, + 0.018926452033215378, + ] + ) diff --git a/africanus/model/spectral/dask.py b/africanus/model/spectral/dask.py index 3ca1c8747..5f3e06e0c 100644 --- a/africanus/model/spectral/dask.py +++ b/africanus/model/spectral/dask.py @@ -2,8 +2,9 @@ from africanus.model.spectral.spec_model import ( - spectral_model as np_spectral_model, - SPECTRAL_MODEL_DOC) + spectral_model as np_spectral_model, + SPECTRAL_MODEL_DOC, +) from africanus.util.requirements import requires_optional try: @@ -25,17 +26,29 @@ def spectral_model(stokes, spi, ref_freq, frequencies, base=0): pol_dim = () if stokes.ndim == 1 else ("pol",) - return da.blockwise(spectral_model_wrapper, ("source", "chan",) + pol_dim, - stokes, ("source",) + pol_dim, - spi, ("source", "spi") + pol_dim, - ref_freq, ("source",), - frequencies, ("chan",), - base=base, - dtype=stokes.dtype) + return da.blockwise( + spectral_model_wrapper, + ( + "source", + "chan", + ) + + pol_dim, + stokes, + ("source",) + pol_dim, + spi, + ("source", "spi") + pol_dim, + ref_freq, + ("source",), + frequencies, + ("chan",), + base=base, + dtype=stokes.dtype, + ) try: spectral_model.__doc__ = SPECTRAL_MODEL_DOC.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" + ) except AttributeError: pass diff --git a/africanus/model/spectral/spec_model.py b/africanus/model/spectral/spec_model.py index 42d6b3cf8..30852970b 100644 --- a/africanus/model/spectral/spec_model.py +++ b/africanus/model/spectral/spec_model.py @@ -23,34 +23,31 @@ def numpy_spectral_model(stokes, spi, ref_freq, frequency, base): if isinstance(base, list): base = base + [base[-1]] * (npol - len(base)) else: - base = [base]*npol + base = [base] * npol spi_exps = np.arange(1, spi.shape[1] + 1) - spectral_model = np.empty((stokes.shape[0], frequency.shape[0], npol), - dtype=stokes.dtype) + spectral_model = np.empty( + (stokes.shape[0], frequency.shape[0], npol), dtype=stokes.dtype + ) spectral_model[:, :, :] = stokes[:, None, :] for p, b in enumerate(base): if b in ("std", 0): - freq_ratio = (frequency[None, :] / ref_freq[:, None]) - term = freq_ratio[:, None, :]**spi[:, :, p, None] + freq_ratio = frequency[None, :] / ref_freq[:, None] + term = freq_ratio[:, None, :] ** spi[:, :, p, None] spectral_model[:, :, p] *= term.prod(axis=1) elif b in ("log", 1): freq_ratio = np.log(frequency[None, :] / ref_freq[:, None]) - term = freq_ratio[:, None, :]**spi_exps[None, :, None] + term = freq_ratio[:, None, :] ** spi_exps[None, :, None] term = spi[:, :, p, None] * term - spectral_model[:, :, p] = ( - stokes[:, p, None] * np.exp(term.sum(axis=1)) - ) + spectral_model[:, :, p] = stokes[:, p, None] * np.exp(term.sum(axis=1)) elif b in ("log10", 2): freq_ratio = np.log10(frequency[None, :] / ref_freq[:, None]) - term = freq_ratio[:, None, :]**spi_exps[None, :, None] + term = freq_ratio[:, None, :] ** spi_exps[None, :, None] term = spi[:, :, p, None] * term - spectral_model[:, :, p] = ( - stokes[:, p, None] * 10**(term.sum(axis=1)) - ) + spectral_model[:, :, p] = stokes[:, p, None] * 10 ** (term.sum(axis=1)) else: raise ValueError("Invalid base %s" % base) @@ -59,9 +56,11 @@ def numpy_spectral_model(stokes, spi, ref_freq, frequency, base): def pol_getter_factory(npoldims): if npoldims == 0: + def impl(pol_shape): return 1 else: + def impl(pol_shape): npols = 1 @@ -75,9 +74,11 @@ def impl(pol_shape): def promote_base_factory(is_base_list): if is_base_list: + def impl(base, npol): return base + [base[-1]] * (npol - len(base)) else: + def impl(base, npol): return [base] * npol @@ -86,9 +87,11 @@ def impl(base, npol): def add_pol_dim_factory(have_pol_dim): if have_pol_dim: + def impl(array): return array else: + def impl(array): return array.reshape(array.shape + (1,)) @@ -106,8 +109,9 @@ def spectral_model_impl(stokes, spi, ref_freq, frequency, base=0): @overload(spectral_model_impl, jit_options=JIT_OPTIONS) def nb_spectral_model(stokes, spi, ref_freq, frequency, base=0): - arg_dtypes = tuple(np.dtype(a.dtype.name) for a - in (stokes, spi, ref_freq, frequency)) + arg_dtypes = tuple( + np.dtype(a.dtype.name) for a in (stokes, spi, ref_freq, frequency) + ) dtype = np.result_type(*arg_dtypes) if isinstance(base, types.containers.List): @@ -119,6 +123,7 @@ def nb_spectral_model(stokes, spi, ref_freq, frequency, base=0): promote_base = promote_base_factory(is_base_list) if isinstance(base, types.scalars.Integer): + def is_std(base): return base == 0 @@ -129,6 +134,7 @@ def is_log10(base): return base == 2 elif isinstance(base, types.misc.UnicodeType): + def is_std(base): return base == "std" @@ -196,12 +202,10 @@ def impl(stokes, spi, ref_freq, frequency, base=0): spec_model = 0 for si in range(0, nspi): - term = espi[s, si, p] * freq_ratio**(si + 1) + term = espi[s, si, p] * freq_ratio ** (si + 1) spec_model += term - spectral_model[s, f, p] = ( - estokes[s, p] * np.exp(spec_model) - ) + spectral_model[s, f, p] = estokes[s, p] * np.exp(spec_model) elif is_log10(b): for s in range(nsrc): @@ -212,12 +216,10 @@ def impl(stokes, spi, ref_freq, frequency, base=0): spec_model = 0 for si in range(0, nspi): - term = espi[s, si, p] * freq_ratio**(si + 1) + term = espi[s, si, p] * freq_ratio ** (si + 1) spec_model += term - spectral_model[s, f, p] = ( - estokes[s, p] * 10**spec_model - ) + spectral_model[s, f, p] = estokes[s, p] * 10**spec_model else: raise ValueError("Invalid base") @@ -228,7 +230,8 @@ def impl(stokes, spi, ref_freq, frequency, base=0): return impl -SPECTRAL_MODEL_DOC = DocstringTemplate(r""" +SPECTRAL_MODEL_DOC = DocstringTemplate( + r""" Compute a spectral model, per polarisation. .. math:: @@ -272,10 +275,12 @@ def impl(stokes, spi, ref_freq, frequency, base=0): spectral_model : $(array_type) Spectral Model of shape :code:`(source, chan)` or :code:`(source, chan, pol)`. -""") +""" +) try: spectral_model.__doc__ = SPECTRAL_MODEL_DOC.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/model/spectral/tests/test_spectral_model.py b/africanus/model/spectral/tests/test_spectral_model.py index 42b36916a..5ef5c5281 100644 --- a/africanus/model/spectral/tests/test_spectral_model.py +++ b/africanus/model/spectral/tests/test_spectral_model.py @@ -5,8 +5,7 @@ from numpy.testing import assert_array_almost_equal import pytest -from africanus.model.spectral.spec_model import (spectral_model, - numpy_spectral_model) +from africanus.model.spectral.spec_model import spectral_model, numpy_spectral_model @pytest.fixture @@ -20,7 +19,7 @@ def impl(nsrc): @pytest.fixture def frequency(): def impl(nchan): - return np.linspace(.856e9, 2*.856e9, nchan) + return np.linspace(0.856e9, 2 * 0.856e9, nchan) return impl @@ -28,7 +27,7 @@ def impl(nchan): @pytest.fixture def ref_freq(): def impl(shape): - return np.full(shape, 3*.856e9/2) + return np.full(shape, 3 * 0.856e9 / 2) return impl @@ -41,8 +40,9 @@ def impl(shape): return impl -@pytest.mark.parametrize("base", [0, 1, 2, "std", "log", "log10", - ["log", "std", "std", "std"]]) +@pytest.mark.parametrize( + "base", [0, 1, 2, "std", "log", "log10", ["log", "std", "std", "std"]] +) @pytest.mark.parametrize("npol", [0, 1, 2, 4]) def test_spectral_model_multiple_spi(flux, ref_freq, frequency, base, npol): nsrc = 10 @@ -75,13 +75,13 @@ def test_spectral_model_multiple_spi(flux, ref_freq, frequency, base, npol): assert model.flags.c_contiguous is True -@pytest.mark.parametrize("base", [0, 1, 2, "std", "log", "log10", - ["log", "std", "std", "std"]]) +@pytest.mark.parametrize( + "base", [0, 1, 2, "std", "log", "log10", ["log", "std", "std", "std"]] +) @pytest.mark.parametrize("npol", [0, 1, 2, 4]) def test_dask_spectral_model(flux, ref_freq, frequency, base, npol): da = pytest.importorskip("dask.array") - from africanus.model.spectral.spec_model import ( - spectral_model as np_spectral_model) + from africanus.model.spectral.spec_model import spectral_model as np_spectral_model from africanus.model.spectral.dask import spectral_model sc = (5, 5) @@ -118,8 +118,7 @@ def test_dask_spectral_model(flux, ref_freq, frequency, base, npol): da_ref_freq = da.from_array(ref_freq, chunks=sc) da_freq = da.from_array(freq, chunks=fc) - da_model = spectral_model(da_stokes, da_spi, - da_ref_freq, da_freq, base=base) + da_model = spectral_model(da_stokes, da_spi, da_ref_freq, da_freq, base=base) np_model = np_spectral_model(stokes, spi, ref_freq, freq, base=base) assert_array_almost_equal(da_model, np_model) diff --git a/africanus/model/spi/component_spi.py b/africanus/model/spi/component_spi.py index 55bd76afc..dae7beb98 100644 --- a/africanus/model/spi/component_spi.py +++ b/africanus/model/spi/component_spi.py @@ -9,10 +9,10 @@ @jit(nopython=True, nogil=True, cache=True) -def _fit_spi_components_impl(data, weights, freqs, freq0, out, - jac, beam, ncomps, nfreqs, - tol, maxiter, mindet): - w = freqs/freq0 +def _fit_spi_components_impl( + data, weights, freqs, freq0, out, jac, beam, ncomps, nfreqs, tol, maxiter, mindet +): + w = freqs / freq0 dof = np.maximum(w.size - 2, 1) for comp in range(ncomps): eps = 1.0 @@ -23,8 +23,8 @@ def _fit_spi_components_impl(data, weights, freqs, freq0, out, while eps > tol and k < maxiter: alphap = alphak i0p = i0k - jac[1, :] = b*w**alphak - model = i0k*jac[1, :] + jac[1, :] = b * w**alphak + model = i0k * jac[1, :] jac[0, :] = model * np.log(w) residual = data[comp] - model lik = 0.0 @@ -41,22 +41,22 @@ def _fit_spi_components_impl(data, weights, freqs, freq0, out, hess01 += jac[0, v] * weights[v] * jac[1, v] hess11 += jac[1, v] * weights[v] * jac[1, v] det = np.maximum(hess00 * hess11 - hess01**2, mindet) - alphak = alphap + (hess11 * jr0 - hess01 * jr1)/det - i0k = i0p + (-hess01 * jr0 + hess00 * jr1)/det + alphak = alphap + (hess11 * jr0 - hess01 * jr1) / det + i0k = i0p + (-hess01 * jr0 + hess00 * jr1) / det eps = np.maximum(np.abs(alphak - alphap), np.abs(i0k - i0p)) k += 1 if k == maxiter: print("Warning - max iterations exceeded for component ", comp) out[0, comp] = alphak - out[1, comp] = hess11/det * lik/dof + out[1, comp] = hess11 / det * lik / dof out[2, comp] = i0k - out[3, comp] = hess00/det * lik/dof + out[3, comp] = hess00 / det * lik / dof return out -def fit_spi_components(data, weights, freqs, freq0, - alphai=None, I0i=None, beam=None, - tol=1e-4, maxiter=100): +def fit_spi_components( + data, weights, freqs, freq0, alphai=None, I0i=None, beam=None, tol=1e-4, maxiter=100 +): ncomps, nfreqs = data.shape if beam is None: beam = np.ones(data.shape, data.dtype) @@ -73,7 +73,7 @@ def fit_spi_components(data, weights, freqs, freq0, ref_freq_idx = np.argwhere(tmp == tmp.min()).squeeze() if np.size(ref_freq_idx) > 1: ref_freq_idx = ref_freq_idx.min() - out[2, :] = data[:, ref_freq_idx]/beam[:, ref_freq_idx] + out[2, :] = data[:, ref_freq_idx] / beam[:, ref_freq_idx] if data.dtype == np.float64: mindet = 1e-12 elif data.dtype == np.float32: @@ -81,9 +81,20 @@ def fit_spi_components(data, weights, freqs, freq0, else: raise ValueError("Unsupported data type. Must be float32 of float64.") - return _fit_spi_components_impl(data, weights, freqs, freq0, out, - jac, beam, ncomps, nfreqs, - tol, maxiter, mindet) + return _fit_spi_components_impl( + data, + weights, + freqs, + freq0, + out, + jac, + beam, + ncomps, + nfreqs, + tol, + maxiter, + mindet, + ) SPI_DOCSTRING = DocstringTemplate( @@ -139,10 +150,12 @@ def fit_spi_components(data, weights, freqs, freq0, array of shape :code:`(4, comps)` The fitted components arranged as [alphas, alphavars, I0s, I0vars] - """) + """ +) try: fit_spi_components.__doc__ = SPI_DOCSTRING.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/model/spi/dask.py b/africanus/model/spi/dask.py index 55bd1945e..e1fb43f45 100644 --- a/africanus/model/spi/dask.py +++ b/africanus/model/spi/dask.py @@ -3,7 +3,8 @@ from africanus.model.spi.component_spi import SPI_DOCSTRING from africanus.model.spi.component_spi import ( - fit_spi_components as np_fit_spi_components) + fit_spi_components as np_fit_spi_components, +) from africanus.util.requirements import requires_optional @@ -15,37 +16,53 @@ opt_import_error = None -def _fit_spi_components_wrapper(data, weights, freqs, freq0, - alphai, I0i, beam, tol, maxiter): - return np_fit_spi_components(data[0], - weights[0], - freqs[0], - freq0, - alphai, - I0i, - beam[0] if beam is not None else beam, - tol=tol, - maxiter=maxiter) - - -@requires_optional('dask.array', opt_import_error) -def fit_spi_components(data, weights, freqs, freq0, - alphai=None, I0i=None, beam=None, - tol=1e-5, maxiter=100): - """ Dask wrapper fit_spi_components function """ - return blockwise(_fit_spi_components_wrapper, ("vars", "comps"), - data, ("comps", "chan"), - weights, ("chan",), - freqs, ("chan",), - freq0, None, - alphai, ("comps",) if alphai is not None else None, - I0i, ("comps",) if I0i is not None else None, - beam, ("comps", "chan") if beam is not None else None, - tol, None, - maxiter, None, - new_axes={"vars": 4}, - dtype=data.dtype) +def _fit_spi_components_wrapper( + data, weights, freqs, freq0, alphai, I0i, beam, tol, maxiter +): + return np_fit_spi_components( + data[0], + weights[0], + freqs[0], + freq0, + alphai, + I0i, + beam[0] if beam is not None else beam, + tol=tol, + maxiter=maxiter, + ) + + +@requires_optional("dask.array", opt_import_error) +def fit_spi_components( + data, weights, freqs, freq0, alphai=None, I0i=None, beam=None, tol=1e-5, maxiter=100 +): + """Dask wrapper fit_spi_components function""" + return blockwise( + _fit_spi_components_wrapper, + ("vars", "comps"), + data, + ("comps", "chan"), + weights, + ("chan",), + freqs, + ("chan",), + freq0, + None, + alphai, + ("comps",) if alphai is not None else None, + I0i, + ("comps",) if I0i is not None else None, + beam, + ("comps", "chan") if beam is not None else None, + tol, + None, + maxiter, + None, + new_axes={"vars": 4}, + dtype=data.dtype, + ) fit_spi_components.__doc__ = SPI_DOCSTRING.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) diff --git a/africanus/model/spi/examples/README.rst b/africanus/model/spi/examples/README.rst index e90654b72..b429b25c3 100644 --- a/africanus/model/spi/examples/README.rst +++ b/africanus/model/spi/examples/README.rst @@ -5,7 +5,7 @@ Fits a simple spectral index model to image cubes. Usage is as follows .. code-block:: bash - $ ./simple_spi_fitter.py --fitsmodel=/path/to/model.fits + $ ./simple_spi_fitter.py --fitsmodel=/path/to/model.fits Run @@ -18,7 +18,7 @@ only compulsary input if the beam parameters are specified. If they are not supplied the residual image cube needs to be provided as input so that these can be taken from the header. This means you either have to specify the beam parameters manually or pass in a residual cube with a header -which contains beam parameters. +which contains beam parameters. The residual is also used to determine the weights in the different imaging bands. The weights will be set as 1/rms**2 in each imaging band, given that @@ -40,4 +40,4 @@ found at git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft. Install via .. code-block:: bash - $ pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft \ No newline at end of file + $ pip3 install git+https://gitlab.mpcdf.mpg.de/mtr/pypocketfft diff --git a/africanus/model/spi/examples/simple_spi_fitter.py b/africanus/model/spi/examples/simple_spi_fitter.py index 2ac402dd2..39f05e8ed 100755 --- a/africanus/model/spi/examples/simple_spi_fitter.py +++ b/africanus/model/spi/examples/simple_spi_fitter.py @@ -9,6 +9,7 @@ from astropy.io import fits import warnings from africanus.model.spi.dask import fit_spi_components + iFs = np.fft.ifftshift Fs = np.fft.fftshift @@ -18,19 +19,21 @@ from pypocketfft import r2c, c2r def fft(x, ax, ncpu): - return r2c(x, axes=ax, forward=True, - nthreads=ncpu, inorm=0) + return r2c(x, axes=ax, forward=True, nthreads=ncpu, inorm=0) def ifft(y, ax, ncpu, lastsize): - return c2r(y, axes=ax, forward=False, lastsize=lastsize, - nthreads=args.ncpu, inorm=2) + return c2r( + y, axes=ax, forward=False, lastsize=lastsize, nthreads=args.ncpu, inorm=2 + ) except BaseException: - warnings.warn("No pypocketfft installation found. " - "FFT's will be performed in serial. " - "Install pypocketfft from " - "https://gitlab.mpcdf.mpg.de/mtr/pypocketfft " - "for optimal performance.", - ImportWarning) + warnings.warn( + "No pypocketfft installation found. " + "FFT's will be performed in serial. " + "Install pypocketfft from " + "https://gitlab.mpcdf.mpg.de/mtr/pypocketfft " + "for optimal performance.", + ImportWarning, + ) from numpy.fft import rfftn, irfftn # additional arguments will have no effect @@ -41,60 +44,60 @@ def ifft(y, ax, ncpu, lastsize): return irfftn(y, axes=ax) -def Gaussian2D(xin, yin, GaussPar=(1., 1., 0.)): +def Gaussian2D(xin, yin, GaussPar=(1.0, 1.0, 0.0)): S0, S1, PA = GaussPar PA = 90 + PA SMaj = np.max([S0, S1]) SMin = np.min([S0, S1]) - A = np.array([[1. / SMaj ** 2, 0], - [0, 1. / SMin ** 2]]) + A = np.array([[1.0 / SMaj**2, 0], [0, 1.0 / SMin**2]]) c, s, t = np.cos, np.sin, PA - R = np.array([[c(t), -s(t)], - [s(t), c(t)]]) + R = np.array([[c(t), -s(t)], [s(t), c(t)]]) A = np.dot(np.dot(R.T, A), R) sOut = xin.shape # only compute the result where necessary - extent = (5 * SMaj)**2 + extent = (5 * SMaj) ** 2 xflat = xin.squeeze() yflat = yin.squeeze() ind = np.argwhere(xflat**2 + yflat**2 <= extent).squeeze() idx = ind[:, 0] idy = ind[:, 1] x = np.array([xflat[idx, idy].ravel(), yflat[idx, idy].ravel()]) - R = np.einsum('nb,bc,cn->n', x.T, A, x) + R = np.einsum("nb,bc,cn->n", x.T, A, x) # need to adjust for the fact that GaussPar corresponds to FWHM - fwhm_conv = 2*np.sqrt(2*np.log(2)) - tmp = np.exp(-fwhm_conv*R) + fwhm_conv = 2 * np.sqrt(2 * np.log(2)) + tmp = np.exp(-fwhm_conv * R) gausskern = np.zeros_like(xflat, dtype=np.float64) gausskern[idx, idy] = tmp - return np.ascontiguousarray(gausskern.reshape(sOut), - dtype=np.float64) + return np.ascontiguousarray(gausskern.reshape(sOut), dtype=np.float64) def convolve_model(model, gausskern, args): print("Doing FFT's") # get padding _, npix_l, npix_m = model.shape - pfrac = args.padding_frac/2.0 - npad_l = int(pfrac*npix_l) - npad_m = int(pfrac*npix_m) + pfrac = args.padding_frac / 2.0 + npad_l = int(pfrac * npix_l) + npad_m = int(pfrac * npix_m) # get fast FFT sizes try: from scipy.fftpack import next_fast_len - nfft = next_fast_len(npix_l + 2*npad_l) - npad_ll = (nfft - npix_l)//2 + + nfft = next_fast_len(npix_l + 2 * npad_l) + npad_ll = (nfft - npix_l) // 2 npad_lr = nfft - npix_l - npad_ll - nfft = next_fast_len(npix_m + 2*npad_m) - npad_ml = (nfft - npix_m)//2 + nfft = next_fast_len(npix_m + 2 * npad_m) + npad_ml = (nfft - npix_m) // 2 npad_mr = nfft - npix_m - npad_ml padding = ((0, 0), (npad_ll, npad_lr), (npad_ml, npad_mr)) unpad_l = slice(npad_ll, -npad_lr) unpad_m = slice(npad_ml, -npad_mr) except BaseException: - warnings.warn("Could not determine fast fft size. " - "Install scipy for optimal performance.", - ImportWarning) + warnings.warn( + "Could not determine fast fft size. " + "Install scipy for optimal performance.", + ImportWarning, + ) padding = ((0, 0), (npad_l, npad_l), (npad_m, npad_m)) unpad_l = slice(npad_l, -npad_l) unpad_m = slice(npad_m, -npad_m) @@ -102,15 +105,16 @@ def convolve_model(model, gausskern, args): lastsize = npix_m + np.sum(padding[-1]) # get FT of convolution kernel - gausskernhat = fft(iFs(np.pad(gausskern[None], padding, mode='constant'), - axes=ax), ax, args.ncpu) + gausskernhat = fft( + iFs(np.pad(gausskern[None], padding, mode="constant"), axes=ax), ax, args.ncpu + ) # Convolve model with Gaussian kernel - convmodel = fft(iFs(np.pad(model, padding, mode='constant'), axes=ax), - ax, args.ncpu) + convmodel = fft( + iFs(np.pad(model, padding, mode="constant"), axes=ax), ax, args.ncpu + ) convmodel *= gausskernhat - return Fs(ifft(convmodel, ax, args.ncpu, lastsize), - axes=ax)[:, unpad_l, unpad_m] + return Fs(ifft(convmodel, ax, args.ncpu, lastsize), axes=ax)[:, unpad_l, unpad_m] def interpolate_beam(xx, yy, maskindices, freqs, args): @@ -123,7 +127,13 @@ def interpolate_beam(xx, yy, maskindices, freqs, args): ntime = 1 nant = 1 nband = freqs.size - parangles = np.zeros((ntime, nant,), dtype=np.float64) + parangles = np.zeros( + ( + ntime, + nant, + ), + dtype=np.float64, + ) ant_scale = np.ones((nant, nband, 2), dtype=np.float64) point_errs = np.zeros((ntime, nant, nband, 2), dtype=np.float64) @@ -132,132 +142,175 @@ def interpolate_beam(xx, yy, maskindices, freqs, args): else: print("Loading fits beam patterns from %s" % args.beammodel) from glob import glob - paths = glob(args.beammodel + '_**_**.fits') + + paths = glob(args.beammodel + "_**_**.fits") beam_hdr = None for path in paths: - if 'xx' in path or 'XX' in path or 'rr' in path or 'RR' in path: - if 're' in path: + if "xx" in path or "XX" in path or "rr" in path or "RR" in path: + if "re" in path: corr1_re = fits.getdata(path) if beam_hdr is None: beam_hdr = fits.getheader(path) - elif 'im' in path: + elif "im" in path: corr1_im = fits.getdata(path) else: raise NotImplementedError("Only re/im patterns supported") - elif 'yy' in path or 'YY' in path or 'll' in path or 'LL' in path: - if 're' in path: + elif "yy" in path or "YY" in path or "ll" in path or "LL" in path: + if "re" in path: corr2_re = fits.getdata(path) - elif 'im' in path: + elif "im" in path: corr2_im = fits.getdata(path) else: raise NotImplementedError("Only re/im patterns supported") # get Stokes I amplitude - beam_amp = (corr1_re**2 + corr1_im**2 + corr2_re**2 + corr2_im**2)/2.0 + beam_amp = (corr1_re**2 + corr1_im**2 + corr2_re**2 + corr2_im**2) / 2.0 # get cube in correct shape for interpolation code - beam_amp = np.ascontiguousarray(np.transpose(beam_amp, (1, 2, 0)) - [:, :, :, None, None]) + beam_amp = np.ascontiguousarray( + np.transpose(beam_amp, (1, 2, 0))[:, :, :, None, None] + ) # get cube info - if beam_hdr['CUNIT1'] != "DEG" and beam_hdr['CUNIT1'] != "deg": + if beam_hdr["CUNIT1"] != "DEG" and beam_hdr["CUNIT1"] != "deg": raise ValueError("Beam image units must be in degrees") - npix_l = beam_hdr['NAXIS1'] - refpix_l = beam_hdr['CRPIX1'] - delta_l = beam_hdr['CDELT1'] - l_min = (1 - refpix_l)*delta_l - l_max = (1 + npix_l - refpix_l)*delta_l + npix_l = beam_hdr["NAXIS1"] + refpix_l = beam_hdr["CRPIX1"] + delta_l = beam_hdr["CDELT1"] + l_min = (1 - refpix_l) * delta_l + l_max = (1 + npix_l - refpix_l) * delta_l - if beam_hdr['CUNIT2'] != "DEG" and beam_hdr['CUNIT2'] != "deg": + if beam_hdr["CUNIT2"] != "DEG" and beam_hdr["CUNIT2"] != "deg": raise ValueError("Beam image units must be in degrees") - npix_m = beam_hdr['NAXIS2'] - refpix_m = beam_hdr['CRPIX2'] - delta_m = beam_hdr['CDELT2'] - m_min = (1 - refpix_m)*delta_m - m_max = (1 + npix_m - refpix_m)*delta_m - - if (l_min > l_source.min() or m_min > m_source.min() or - l_max < l_source.max() or m_max < m_source.max()): + npix_m = beam_hdr["NAXIS2"] + refpix_m = beam_hdr["CRPIX2"] + delta_m = beam_hdr["CDELT2"] + m_min = (1 - refpix_m) * delta_m + m_max = (1 + npix_m - refpix_m) * delta_m + + if ( + l_min > l_source.min() + or m_min > m_source.min() + or l_max < l_source.max() + or m_max < m_source.max() + ): raise ValueError("The supplied beam is not large enough") beam_extents = np.array([[l_min, l_max], [m_min, m_max]]) # get frequencies - if beam_hdr["CTYPE3"] != 'FREQ': - raise ValueError( - "Cubes are assumed to be in format [nchan, nx, ny]") - nchan = beam_hdr['NAXIS3'] - refpix = beam_hdr['CRPIX3'] - delta = beam_hdr['CDELT3'] # assumes units are Hz - freq0 = beam_hdr['CRVAL3'] + if beam_hdr["CTYPE3"] != "FREQ": + raise ValueError("Cubes are assumed to be in format [nchan, nx, ny]") + nchan = beam_hdr["NAXIS3"] + refpix = beam_hdr["CRPIX3"] + delta = beam_hdr["CDELT3"] # assumes units are Hz + freq0 = beam_hdr["CRVAL3"] bfreqs = freq0 + np.arange(1 - refpix, 1 + nchan - refpix) * delta if bfreqs[0] > freqs[0] or bfreqs[-1] < freqs[-1]: - warnings.warn("The supplied beam does not have sufficient " - "bandwidth. Beam frequencies:") + warnings.warn( + "The supplied beam does not have sufficient " + "bandwidth. Beam frequencies:" + ) with np.printoptions(precision=2): print(bfreqs) # LB - dask probably not necessary for small problem from africanus.rime.fast_beam_cubes import beam_cube_dde - beam_source = beam_cube_dde(beam_amp, beam_extents, bfreqs, - lm_source, parangles, point_errs, - ant_scale, freqs).squeeze() + + beam_source = beam_cube_dde( + beam_amp, + beam_extents, + bfreqs, + lm_source, + parangles, + point_errs, + ant_scale, + freqs, + ).squeeze() return beam_source def create_parser(): - p = argparse.ArgumentParser(description='Simple spectral index fitting' - 'tool.', - formatter_class=argparse.RawTextHelpFormatter) + p = argparse.ArgumentParser( + description="Simple spectral index fitting" "tool.", + formatter_class=argparse.RawTextHelpFormatter, + ) p.add_argument("--fitsmodel", type=str, required=True) p.add_argument("--fitsresidual", type=str) - p.add_argument('--outfile', type=str, - help="Path to output directory. \n" - "Placed next to input model if outfile not provided.") - p.add_argument('--beampars', default=None, nargs='+', type=float, - help="Beam parameters matching FWHM of restoring beam " - "specified as emaj emin pa. \n" - "By default these are taken from the fits header " - "of the residual image.") - p.add_argument('--threshold', default=5, type=float, - help="Multiple of the rms in the residual to threshold " - "on. \n" - "Only components above threshold*rms will be fit.") - p.add_argument('--maxDR', default=100, type=float, - help="Maximum dynamic range used to determine the " - "threshold above which components need to be fit. \n" - "Only used if residual is not passed in.") - p.add_argument('--ncpu', default=0, type=int, - help="Number of threads to use. \n" - "Default of zero means use all threads") - p.add_argument('--beammodel', default=None, type=str, - help="Fits beam model to use. \n" - "It is assumed that the pattern is path_to_beam/" - "name_corr_re/im.fits. \n" - "Provide only the path up to name " - "e.g. /home/user/beams/meerkat_lband. \n" - "Patterns mathing corr are determined " - "automatically. \n" - "Only real and imaginary beam models currently " - "supported.") - p.add_argument('--output', default='aiIbc', type=str, - help="Outputs to write. Letter correspond to: \n" - "a - alpha map \n" - "i - I0 map \n" - "I - reconstructed cube form alpha and I0 \n" - "b - interpolated beam \n" - "c - restoring beam used for convolution \n" - "Default is to write all of them") - p.add_argument("--padding_frac", default=0.2, type=float, - help="Padding factor for FFT's.") + p.add_argument( + "--outfile", + type=str, + help="Path to output directory. \n" + "Placed next to input model if outfile not provided.", + ) + p.add_argument( + "--beampars", + default=None, + nargs="+", + type=float, + help="Beam parameters matching FWHM of restoring beam " + "specified as emaj emin pa. \n" + "By default these are taken from the fits header " + "of the residual image.", + ) + p.add_argument( + "--threshold", + default=5, + type=float, + help="Multiple of the rms in the residual to threshold " + "on. \n" + "Only components above threshold*rms will be fit.", + ) + p.add_argument( + "--maxDR", + default=100, + type=float, + help="Maximum dynamic range used to determine the " + "threshold above which components need to be fit. \n" + "Only used if residual is not passed in.", + ) + p.add_argument( + "--ncpu", + default=0, + type=int, + help="Number of threads to use. \n" "Default of zero means use all threads", + ) + p.add_argument( + "--beammodel", + default=None, + type=str, + help="Fits beam model to use. \n" + "It is assumed that the pattern is path_to_beam/" + "name_corr_re/im.fits. \n" + "Provide only the path up to name " + "e.g. /home/user/beams/meerkat_lband. \n" + "Patterns mathing corr are determined " + "automatically. \n" + "Only real and imaginary beam models currently " + "supported.", + ) + p.add_argument( + "--output", + default="aiIbc", + type=str, + help="Outputs to write. Letter correspond to: \n" + "a - alpha map \n" + "i - I0 map \n" + "I - reconstructed cube form alpha and I0 \n" + "b - interpolated beam \n" + "c - restoring beam used for convolution \n" + "Default is to write all of them", + ) + p.add_argument( + "--padding_frac", default=0.2, type=float, help="Padding factor for FFT's." + ) return p def main(args): - ref_hdr = fits.getheader(args.fitsresidual) if args.beampars is None: print("Attempting to take beampars from residual fits header") - emaj = ref_hdr['BMAJ1'] - emin = ref_hdr['BMIN1'] - pa = ref_hdr['BPA1'] + emaj = ref_hdr["BMAJ1"] + emin = ref_hdr["BMIN1"] + pa = ref_hdr["BPA1"] beampars = (emaj, emin, pa) else: beampars = tuple(args.beampars) @@ -265,49 +318,49 @@ def main(args): print("Using emaj = %3.2e, emin = %3.2e, PA = %3.2e" % beampars) # load images - model = np.ascontiguousarray(fits.getdata(args.fitsmodel).squeeze(), - dtype=np.float64) + model = np.ascontiguousarray( + fits.getdata(args.fitsmodel).squeeze(), dtype=np.float64 + ) mhdr = fits.getheader(args.fitsmodel) - if mhdr['CUNIT1'] != "DEG" and mhdr['CUNIT1'] != "deg": + if mhdr["CUNIT1"] != "DEG" and mhdr["CUNIT1"] != "deg": raise ValueError("Image units must be in degrees") - npix_l = mhdr['NAXIS1'] - refpix_l = mhdr['CRPIX1'] - delta_l = mhdr['CDELT1'] - l_coord = np.arange(1 - refpix_l, 1 + npix_l - refpix_l)*delta_l + npix_l = mhdr["NAXIS1"] + refpix_l = mhdr["CRPIX1"] + delta_l = mhdr["CDELT1"] + l_coord = np.arange(1 - refpix_l, 1 + npix_l - refpix_l) * delta_l - if mhdr['CUNIT2'] != "DEG" and mhdr['CUNIT2'] != "deg": + if mhdr["CUNIT2"] != "DEG" and mhdr["CUNIT2"] != "deg": raise ValueError("Image units must be in degrees") - npix_m = mhdr['NAXIS2'] - refpix_m = mhdr['CRPIX2'] - delta_m = mhdr['CDELT2'] - m_coord = np.arange(1 - refpix_m, 1 + npix_m - refpix_m)*delta_m + npix_m = mhdr["NAXIS2"] + refpix_m = mhdr["CRPIX2"] + delta_m = mhdr["CDELT2"] + m_coord = np.arange(1 - refpix_m, 1 + npix_m - refpix_m) * delta_m print("Image shape = ", (npix_l, npix_m)) # get frequencies - if mhdr["CTYPE4"] == 'FREQ': + if mhdr["CTYPE4"] == "FREQ": freq_axis = 4 - nband = mhdr['NAXIS4'] - refpix_nu = mhdr['CRPIX4'] - delta_nu = mhdr['CDELT4'] # assumes units are Hz - ref_freq = mhdr['CRVAL4'] - ncorr = mhdr['NAXIS3'] - elif mhdr["CTYPE3"] == 'FREQ': + nband = mhdr["NAXIS4"] + refpix_nu = mhdr["CRPIX4"] + delta_nu = mhdr["CDELT4"] # assumes units are Hz + ref_freq = mhdr["CRVAL4"] + ncorr = mhdr["NAXIS3"] + elif mhdr["CTYPE3"] == "FREQ": freq_axis = 3 - nband = mhdr['NAXIS3'] - refpix_nu = mhdr['CRPIX3'] - delta_nu = mhdr['CDELT3'] # assumes units are Hz - ref_freq = mhdr['CRVAL3'] - ncorr = mhdr['NAXIS4'] + nband = mhdr["NAXIS3"] + refpix_nu = mhdr["CRPIX3"] + delta_nu = mhdr["CDELT3"] # assumes units are Hz + ref_freq = mhdr["CRVAL3"] + ncorr = mhdr["NAXIS4"] else: raise ValueError("Freq axis must be 3rd or 4th") if ncorr > 1: raise ValueError("Only Stokes I cubes supported") - freqs = ref_freq + np.arange(1 - refpix_nu, - 1 + nband - refpix_nu) * delta_nu + freqs = ref_freq + np.arange(1 - refpix_nu, 1 + nband - refpix_nu) * delta_nu print("Cube frequencies:") with np.printoptions(precision=2): @@ -325,15 +378,19 @@ def main(args): if args.fitsresidual is not None: resid = fits.getdata(args.fitsresidual).squeeze().astype(np.float64) rms = np.std(resid) - rms_cube = np.std(resid.reshape(nband, npix_l*npix_m), axis=1).ravel() + rms_cube = np.std(resid.reshape(nband, npix_l * npix_m), axis=1).ravel() threshold = args.threshold * rms - print("Setting cutoff threshold as %i times the rms " - "of the residual" % args.threshold) + print( + "Setting cutoff threshold as %i times the rms " + "of the residual" % args.threshold + ) del resid else: - print("No residual provided. Setting threshold i.t.o dynamic range. " - "Max dynamic range is %i" % args.maxDR) - threshold = model.max()/args.maxDR + print( + "No residual provided. Setting threshold i.t.o dynamic range. " + "Max dynamic range is %i" % args.maxDR + ) + threshold = model.max() / args.maxDR if args.channelweights is None: rms_cube = None @@ -343,9 +400,11 @@ def main(args): minimage = np.amin(model, axis=0) maskindices = np.argwhere(minimage > threshold) if not maskindices.size: - raise ValueError("No components found above threshold. " - "Try lowering your threshold." - "Max of convolved model is %3.2e" % model.max()) + raise ValueError( + "No components found above threshold. " + "Try lowering your threshold." + "Max of convolved model is %3.2e" % model.max() + ) fitcube = model[:, maskindices[:, 0], maskindices[:, 1]].T print(xx.shape, yy.shape, maskindices.shape) @@ -359,7 +418,7 @@ def main(args): # set weights for fit if rms_cube is not None: print("Using RMS in each imaging band to determine weights.") - weights = np.where(rms_cube > 0, 1.0/rms_cube**2, 0.0) + weights = np.where(rms_cube > 0, 1.0 / rms_cube**2, 0.0) # normalise weights /= weights.max() else: @@ -367,14 +426,16 @@ def main(args): weights = np.ones(nband, dtype=np.float64) ncomps, _ = fitcube.shape - fitcube = da.from_array(fitcube.astype(np.float64), - chunks=(ncomps//args.ncpu, nband)) + fitcube = da.from_array( + fitcube.astype(np.float64), chunks=(ncomps // args.ncpu, nband) + ) weights = da.from_array(weights.astype(np.float64), chunks=(nband)) freqsdask = da.from_array(freqs.astype(np.float64), chunks=(nband)) print("Fitting %i components" % ncomps) - alpha, _, Iref, _ = fit_spi_components(fitcube, weights, freqsdask, - np.float64(ref_freq)).compute() + alpha, _, Iref, _ = fit_spi_components( + fitcube, weights, freqsdask, np.float64(ref_freq) + ).compute() print("Done. Writing output.") alphamap = np.zeros([npix_l, npix_m]) @@ -386,46 +447,82 @@ def main(args): if args.outfile is None: # find last / tmp = args.fitsmodel[::-1] - idx = tmp.find('/') + idx = tmp.find("/") if idx != -1: outfile = args.fitsmodel[0:-idx] else: - outfile = 'image-' + outfile = "image-" else: outfile = args.outfile hdu = fits.PrimaryHDU(header=mhdr) - if 'I' in args.output: + if "I" in args.output: # get the reconstructed cube - Irec_cube = i0map[None, :, :] * \ - (freqs[:, None, None]/ref_freq)**alphamap[None, :, :] + Irec_cube = ( + i0map[None, :, :] + * (freqs[:, None, None] / ref_freq) ** alphamap[None, :, :] + ) # save it if freq_axis == 3: hdu.data = Irec_cube[None, :, :, :] elif freq_axis == 4: hdu.data = Irec_cube[:, None, :, :] - name = outfile + 'Irec_cube.fits' + name = outfile + "Irec_cube.fits" hdu.writeto(name, overwrite=True) print("Wrote reconstructed cube to %s" % name) - if args.beammodel is not None and 'b' in args.output: + if args.beammodel is not None and "b" in args.output: beam_map = np.zeros((nband, npix_l, npix_m)) beam_map[:, maskindices[:, 0], maskindices[:, 1]] = beam_source.T if freq_axis == 3: hdu.data = beam_map[None, :, :, :] elif freq_axis == 4: hdu.data = beam_map[:, None, :, :] - name = outfile + 'interpolated_beam_cube.fits' + name = outfile + "interpolated_beam_cube.fits" hdu.writeto(name, overwrite=True) print("Wrote interpolated beam cube to %s" % name) - hdr_keys = ['SIMPLE', 'BITPIX', 'NAXIS', 'NAXIS1', 'NAXIS2', 'NAXIS3', - 'NAXIS4', 'BUNIT', 'BMAJ', 'BMIN', 'BPA', 'EQUINOX', 'BTYPE', - 'TELESCOP', 'OBSERVER', 'OBJECT', 'ORIGIN', 'CTYPE1', 'CTYPE2', - 'CTYPE3', 'CTYPE4', 'CRPIX1', 'CRPIX2', 'CRPIX3', 'CRPIX4', - 'CRVAL1', 'CRVAL2', 'CRVAL3', 'CRVAL4', 'CDELT1', 'CDELT2', - 'CDELT3', 'CDELT4', 'CUNIT1', 'CUNIT2', 'CUNIT3', 'CUNIT4', - 'SPECSYS', 'DATE-OBS'] + hdr_keys = [ + "SIMPLE", + "BITPIX", + "NAXIS", + "NAXIS1", + "NAXIS2", + "NAXIS3", + "NAXIS4", + "BUNIT", + "BMAJ", + "BMIN", + "BPA", + "EQUINOX", + "BTYPE", + "TELESCOP", + "OBSERVER", + "OBJECT", + "ORIGIN", + "CTYPE1", + "CTYPE2", + "CTYPE3", + "CTYPE4", + "CRPIX1", + "CRPIX2", + "CRPIX3", + "CRPIX4", + "CRVAL1", + "CRVAL2", + "CRVAL3", + "CRVAL4", + "CDELT1", + "CDELT2", + "CDELT3", + "CDELT4", + "CUNIT1", + "CUNIT2", + "CUNIT3", + "CUNIT4", + "SPECSYS", + "DATE-OBS", + ] new_hdr = {} for key in hdr_keys: @@ -441,26 +538,26 @@ def main(args): new_hdr = fits.Header(new_hdr) # save alpha map - if 'a' in args.output: + if "a" in args.output: hdu = fits.PrimaryHDU(header=new_hdr) hdu.data = alphamap - name = outfile + 'alpha.fits' + name = outfile + "alpha.fits" hdu.writeto(name, overwrite=True) print("Wrote alpha map to %s" % name) # save I0 map - if 'i' in args.output: + if "i" in args.output: hdu = fits.PrimaryHDU(header=new_hdr) hdu.data = i0map - name = outfile + 'I0.fits' + name = outfile + "I0.fits" hdu.writeto(name, overwrite=True) print("Wrote I0 map to %s" % name) # save clean beam for consistency check - if 'c' in args.output: + if "c" in args.output: hdu = fits.PrimaryHDU(header=new_hdr) hdu.data = gausskern - name = outfile + 'clean-beam.fits' + name = outfile + "clean-beam.fits" hdu.writeto(name, overwrite=True) print("Wrote clean beam to %s" % name) @@ -472,9 +569,11 @@ def main(args): if args.ncpu: from multiprocessing.pool import ThreadPool + dask.config.set(pool=ThreadPool(args.ncpu)) else: import multiprocessing + args.ncpu = multiprocessing.cpu_count() print("Using %i threads" % args.ncpu) diff --git a/africanus/model/spi/tests/test_component_spi.py b/africanus/model/spi/tests/test_component_spi.py index bb5af18b9..bdee6db7d 100644 --- a/africanus/model/spi/tests/test_component_spi.py +++ b/africanus/model/spi/tests/test_component_spi.py @@ -13,6 +13,7 @@ def test_fit_spi_components_vs_scipy(): :return: """ from africanus.model.spi import fit_spi_components + curve_fit = pytest.importorskip("scipy.optimize").curve_fit np.random.seed(123) @@ -30,12 +31,13 @@ def test_fit_spi_components_vs_scipy(): sigma = np.abs(0.25 + 0.1 * np.random.randn(nfreqs)) data = model + sigma[None, :] * np.random.randn(ncomps, nfreqs) - weights = 1.0/sigma**2 + weights = 1.0 / sigma**2 alpha1, alphavar1, I01, I0var1 = fit_spi_components( - data, weights, freqs.squeeze(), freq0, tol=1e-8) + data, weights, freqs.squeeze(), freq0, tol=1e-8 + ) def spi_func(nu, I0, alpha, beam=1.0): - return beam * I0 * nu ** alpha + return beam * I0 * nu**alpha I02 = np.zeros(ncomps) I0var2 = np.zeros(ncomps) @@ -43,12 +45,18 @@ def spi_func(nu, I0, alpha, beam=1.0): alphavar2 = np.zeros(ncomps) for i in range(ncomps): - def fit_func(nu, I0, alpha): return spi_func(nu, I0, alpha, - beam=beams[i]) - popt, pcov = curve_fit(fit_func, (freqs / freq0).squeeze(), - data[i, :], sigma=np.diag(sigma**2), - p0=np.array([1.0, -0.7]), - absolute_sigma=False) + + def fit_func(nu, I0, alpha): + return spi_func(nu, I0, alpha, beam=beams[i]) + + popt, pcov = curve_fit( + fit_func, + (freqs / freq0).squeeze(), + data[i, :], + sigma=np.diag(sigma**2), + p0=np.array([1.0, -0.7]), + absolute_sigma=False, + ) I02[i] = popt[0] I0var2[i] = pcov[0, 0] alpha2[i] = popt[1] @@ -63,6 +71,7 @@ def fit_func(nu, I0, alpha): return spi_func(nu, I0, alpha, def test_dask_fit_spi_components_vs_np(): from africanus.model.spi import fit_spi_components as np_fit_spi from africanus.model.spi.dask import fit_spi_components + da = pytest.importorskip("dask.array") np.random.seed(123) @@ -77,7 +86,7 @@ def test_dask_fit_spi_components_vs_np(): sigma = np.abs(0.25 + 0.1 * np.random.randn(nfreqs)) data = model + sigma[None, :] * np.random.randn(ncomps, nfreqs) - weights = 1.0/sigma**2 + weights = 1.0 / sigma**2 freqs = freqs.squeeze() alpha1, alphavar1, I01, I0var1 = np_fit_spi(data, weights, freqs, freq0) @@ -86,10 +95,9 @@ def test_dask_fit_spi_components_vs_np(): weights_dask = da.from_array(weights, chunks=(nfreqs)) freqs_dask = da.from_array(freqs, chunks=(nfreqs)) - alpha2, alphavar2, I02, I0var2 = fit_spi_components(data_dask, - weights_dask, - freqs_dask, - freq0).compute() + alpha2, alphavar2, I02, I0var2 = fit_spi_components( + data_dask, weights_dask, freqs_dask, freq0 + ).compute() np.testing.assert_array_almost_equal(alpha1, alpha2, decimal=6) np.testing.assert_array_almost_equal(alphavar1, alphavar2, decimal=6) diff --git a/africanus/model/wsclean/dask.py b/africanus/model/wsclean/dask.py index 8db1d9fa6..e088abde6 100644 --- a/africanus/model/wsclean/dask.py +++ b/africanus/model/wsclean/dask.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- -from africanus.model.wsclean.spec_model import (spectra as np_spectra, - SPECTRA_DOCS) +from africanus.model.wsclean.spec_model import spectra as np_spectra, SPECTRA_DOCS from africanus.util.requirements import requires_optional try: @@ -17,22 +16,29 @@ def spectra_wrapper(stokes, spi, log_si, ref_freq, frequency): return np_spectra(stokes, spi[0], log_si, ref_freq, frequency) -@requires_optional('dask.array', opt_import_error) +@requires_optional("dask.array", opt_import_error) def spectra(stokes, spi, log_si, ref_freq, frequency): corrs = tuple("corr-%d" % i for i in range(len(stokes.shape[1:]))) log_si_schema = None if isinstance(log_si, bool) else ("source",) - return da.blockwise(spectra_wrapper, ("source", "chan") + corrs, - stokes, ("source",) + corrs, - spi, ("source", "spi") + corrs, - log_si, log_si_schema, - ref_freq, ("source",), - frequency, ("chan",), - dtype=stokes.dtype) + return da.blockwise( + spectra_wrapper, + ("source", "chan") + corrs, + stokes, + ("source",) + corrs, + spi, + ("source", "spi") + corrs, + log_si, + log_si_schema, + ref_freq, + ("source",), + frequency, + ("chan",), + dtype=stokes.dtype, + ) try: - spectra.__doc__ = SPECTRA_DOCS.substitute( - array_type=":class:`dask.array.Array`") + spectra.__doc__ = SPECTRA_DOCS.substitute(array_type=":class:`dask.array.Array`") except AttributeError: pass diff --git a/africanus/model/wsclean/file_model.py b/africanus/model/wsclean/file_model.py index e8e54ce32..b278424a1 100644 --- a/africanus/model/wsclean/file_model.py +++ b/africanus/model/wsclean/file_model.py @@ -7,15 +7,13 @@ import numpy as np -hour_re = re.compile(r"(?P[+-]*)" - r"(?P\d+):" - r"(?P\d+):" - r"(?P\d+\.?\d*)") +hour_re = re.compile( + r"(?P[+-]*)" r"(?P\d+):" r"(?P\d+):" r"(?P\d+\.?\d*)" +) -deg_re = re.compile(r"(?P[+-])*" - r"(?P\d+)\." - r"(?P\d+)\." - r"(?P\d+\.?\d*)") +deg_re = re.compile( + r"(?P[+-])*" r"(?P\d+)\." r"(?P\d+)\." r"(?P\d+\.?\d*)" +) def _hour_converter(hour_str): @@ -25,10 +23,10 @@ def _hour_converter(hour_str): raise ValueError("Error parsing '%s'" % hour_str) value = float(m.group("hours")) / 24.0 - value += float(m.group("mins")) / (24.0*60.0) - value += float(m.group("secs")) / (24.0*60.0*60.0) + value += float(m.group("mins")) / (24.0 * 60.0) + value += float(m.group("secs")) / (24.0 * 60.0 * 60.0) - if m.group("sign") == '-': + if m.group("sign") == "-": value = -value return 2.0 * math.pi * value @@ -41,17 +39,17 @@ def _deg_converter(deg_str): raise ValueError(f"Error parsing '{deg_str}'") value = float(m.group("degs")) / 360.0 - value += float(m.group("mins")) / (360.0*60.0) - value += float(m.group("secs")) / (360.0*60.0*60.0) + value += float(m.group("mins")) / (360.0 * 60.0) + value += float(m.group("secs")) / (360.0 * 60.0 * 60.0) - if m.group("sign") == '-': + if m.group("sign") == "-": value = -value return 2.0 * math.pi * value def arcsec2rad(arcseconds=0.0): - return np.deg2rad(float(arcseconds) / 3600.) + return np.deg2rad(float(arcseconds) / 3600.0) def spi_converter(spi): @@ -64,26 +62,27 @@ def spi_converter(spi): _COLUMN_CONVERTERS = { - 'Name': str, - 'Type': str, - 'Ra': _hour_converter, - 'Dec': _deg_converter, - 'I': float, - 'SpectralIndex': spi_converter, - 'LogarithmicSI': lambda x: bool(x == "true"), - 'ReferenceFrequency': float, - 'MajorAxis': arcsec2rad, - 'MinorAxis': arcsec2rad, - 'Orientation': lambda x=0.0: np.deg2rad(float(x)), + "Name": str, + "Type": str, + "Ra": _hour_converter, + "Dec": _deg_converter, + "I": float, + "SpectralIndex": spi_converter, + "LogarithmicSI": lambda x: bool(x == "true"), + "ReferenceFrequency": float, + "MajorAxis": arcsec2rad, + "MinorAxis": arcsec2rad, + "Orientation": lambda x=0.0: np.deg2rad(float(x)), } # Split on commas, ignoring within [] brackets -_COMMA_SPLIT_RE = re.compile(r',\s*(?=[^\]]*(?:\[|$))') +_COMMA_SPLIT_RE = re.compile(r",\s*(?=[^\]]*(?:\[|$))") # Parse columm headers, handling possible defaults -_COL_HEADER_RE = re.compile(r"^\s*?(?P.*?)" - r"(\s*?=\s*?'(?P.*?)'\s*?){0,1}$") +_COL_HEADER_RE = re.compile( + r"^\s*?(?P.*?)" r"(\s*?=\s*?'(?P.*?)'\s*?){0,1}$" +) def _parse_col_descriptor(column_descriptor): @@ -98,7 +97,7 @@ def _parse_col_descriptor(column_descriptor): if m is None: raise ValueError(f"'{column}' is not a valid column header") - name, default = m.group('name', 'default') + name, default = m.group("name", "default") columns.append(name) defaults.append(default) @@ -110,8 +109,7 @@ def _parse_header(header): format_str, col_desc = (c.strip() for c in header.split("=", 1)) if format_str != "Format": - raise ValueError(f"'{format_str}' does not " - f"appear to be a wsclean header") + raise ValueError(f"'{format_str}' does not " f"appear to be a wsclean header") return _parse_col_descriptor(col_desc) @@ -123,8 +121,10 @@ def _parse_lines(fh, line_nr, column_names, defaults, converters): components = [c.strip() for c in re.split(_COMMA_SPLIT_RE, line)] if len(components) != len(column_names): - raise ValueError(f"line {line_nr} '{line}' should " - f"have {len(column_names)} components") + raise ValueError( + f"line {line_nr} '{line}' should " + f"have {len(column_names)} components" + ) # Iterate through each column's data it = zip(column_names, components, converters, source_data, defaults) @@ -140,7 +140,8 @@ def _parse_lines(fh, line_nr, column_names, defaults, converters): f"on line {line_nr} and no default was " f"supplied either. Attempting to " f"generate a default produced the " - f"following exception {e}") + f"following exception {e}" + ) value = default else: @@ -157,8 +158,7 @@ def _parse_lines(fh, line_nr, column_names, defaults, converters): spi_column = columns["SpectralIndex"] log_spi_column = columns["LogarithmicSI"] except KeyError as e: - raise ValueError(f"WSClean Model File missing " - f"required column {str(e)}") + raise ValueError(f"WSClean Model File missing " f"required column {str(e)}") it = zip(name_column, flux_column, spi_column, log_spi_column) @@ -167,15 +167,19 @@ def _parse_lines(fh, line_nr, column_names, defaults, converters): good = True if not math.isfinite(flux): - warnings.warn(f"Non-finite I {flux} encountered " - f"for source {name}. This source model will " - f"be zeroed.") + warnings.warn( + f"Non-finite I {flux} encountered " + f"for source {name}. This source model will " + f"be zeroed." + ) good = False if not all(map(math.isfinite, spi)): - warnings.warn(f"Non-finite SpectralIndex {spi} encountered " - f"for source {name}. This source model will " - f"be zeroed.") + warnings.warn( + f"Non-finite SpectralIndex {spi} encountered " + f"for source {name}. This source model will " + f"be zeroed." + ) good = False if good: @@ -232,7 +236,7 @@ def load(filename): try: # Search for a header until we find a non-empty string - header = '' + header = "" line_nr = 1 for headers in fh: @@ -244,8 +248,9 @@ def load(filename): line_nr += 1 if not header: - raise ValueError(f"'{filename}' does not contain " - f"a valid wsclean header") + raise ValueError( + f"'{filename}' does not contain " f"a valid wsclean header" + ) column_names, defaults = _parse_header(header) diff --git a/africanus/model/wsclean/spec_model.py b/africanus/model/wsclean/spec_model.py index ee51a1998..56d9c2357 100644 --- a/africanus/model/wsclean/spec_model.py +++ b/africanus/model/wsclean/spec_model.py @@ -7,22 +7,22 @@ def ordinary_spectral_model(I, coeffs, log_poly, ref_freq, freq): # noqa: E741 - """ Numpy ordinary polynomial implementation """ + """Numpy ordinary polynomial implementation""" coeffs_idx = np.arange(1, coeffs.shape[1] + 1) # (source, chan, coeffs-comp) term = (freq[None, :, None] / ref_freq[:, None, None]) - 1.0 - term = term**coeffs_idx[None, None, :] - term = coeffs[:, None, :]*term + term = term ** coeffs_idx[None, None, :] + term = coeffs[:, None, :] * term return I[:, None] + term.sum(axis=2) def log_spectral_model(I, coeffs, log_poly, ref_freq, freq): # noqa: E741 - """ Numpy logarithmic polynomial implementation """ + """Numpy logarithmic polynomial implementation""" coeffs_idx = np.arange(1, coeffs.shape[1] + 1) # (source, chan, coeffs-comp) term = np.log(freq[None, :, None] / ref_freq[:, None, None]) - term = term**coeffs_idx[None, None, :] - term = coeffs[:, None, :]*term + term = term ** coeffs_idx[None, None, :] + term = coeffs[:, None, :] * term return I[:, None] * np.exp(term.sum(axis=2)) @@ -33,10 +33,12 @@ def _check_log_poly_shape(coeffs, log_poly): @overload(_check_log_poly_shape) def overload_check_log_poly_shape(coeffs, log_poly): if isinstance(log_poly, types.npytypes.Array): + def impl(coeffs, log_poly): if coeffs.shape[0] != log_poly.shape[0]: raise ValueError("coeffs.shape[0] != log_poly.shape[0]") elif isinstance(log_poly, types.scalars.Boolean): + def impl(coeffs, log_poly): pass else: @@ -52,9 +54,11 @@ def _log_polynomial(log_poly, s): @overload(_log_polynomial) def overload_log_polynomial(log_poly, s): if isinstance(log_poly, types.npytypes.Array): + def impl(log_poly, s): return log_poly[s] elif isinstance(log_poly, types.scalars.Boolean): + def impl(log_poly, s): return log_poly else: @@ -74,15 +78,15 @@ def spectra_impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 @overload(spectra_impl, jit_option=JIT_OPTIONS) def nb_spectra(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 - arg_dtypes = tuple(np.dtype(a.dtype.name) for a - in (I, coeffs, ref_freq, frequency)) + arg_dtypes = tuple(np.dtype(a.dtype.name) for a in (I, coeffs, ref_freq, frequency)) dtype = np.result_type(*arg_dtypes) def impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 if not (I.shape[0] == coeffs.shape[0] == ref_freq.shape[0]): print(I.shape, coeffs.shape, ref_freq.shape) - raise ValueError("first dimensions of I, coeffs " - "and ref_freq don't match.") + raise ValueError( + "first dimensions of I, coeffs " "and ref_freq don't match." + ) _check_log_poly_shape(coeffs, log_poly) @@ -103,7 +107,7 @@ def impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 spectral_model[s, f] = 0 for c in range(ncoeffs): - term = coeffs[s, c] * np.log(nu/rf)**(c + 1) + term = coeffs[s, c] * np.log(nu / rf) ** (c + 1) spectral_model[s, f] += term spectral_model[s, f] = I[s] * np.exp(spectral_model[s, f]) @@ -116,7 +120,7 @@ def impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 for c in range(ncoeffs): term = coeffs[s, c] - term *= ((nu/rf) - 1.0)**(c + 1) + term *= ((nu / rf) - 1.0) ** (c + 1) spectral_model[s, f] += term return spectral_model @@ -124,7 +128,8 @@ def impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 return impl -SPECTRA_DOCS = DocstringTemplate(r""" +SPECTRA_DOCS = DocstringTemplate( + r""" Produces a spectral model from a polynomial expansion of a wsclean file model. Depending on how `log_poly` is set ordinary or logarithmic polynomials are used to produce @@ -172,10 +177,10 @@ def impl(I, coeffs, log_poly, ref_freq, frequency): # noqa: E741 ------- spectral_model : $(array_type) Spectral Model of shape :code:`(source, chan)` -""") +""" +) try: - spectra.__doc__ = SPECTRA_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + spectra.__doc__ = SPECTRA_DOCS.substitute(array_type=":class:`numpy.ndarray`") except AttributeError: pass diff --git a/africanus/model/wsclean/tests/conftest.py b/africanus/model/wsclean/tests/conftest.py index ebf3fce84..a371c8e7b 100644 --- a/africanus/model/wsclean/tests/conftest.py +++ b/africanus/model/wsclean/tests/conftest.py @@ -6,7 +6,7 @@ import pytest -_WSCLEAN_MODEL_FILE = (""" +_WSCLEAN_MODEL_FILE = """ Format = Name, Type, Ra, Dec, I, SpectralIndex, LogarithmicSI, ReferenceFrequency='125584411.621094', MajorAxis, MinorAxis, Orientation s0c0,POINT,-08:28:05.152,39.35.08.511,0.000748810650400475,[-0.00695379313004673,-0.0849693907803257],true,125584411.621094,,, s0c1,POINT,08:22:27.658,39.37.38.353,-0.000154968071120503,[-0.000898135869319762,0.0183710297781511],false,125584411.621094,,, @@ -16,7 +16,7 @@ s1c1,GAUSSIAN,07:51:09.24,42.32.46.177,0.000660490865128381,[0.00404869217508666,-0.011844732049232],false,125584411.621094,83.6144111272856,83.6144111272856,0 s1c2,GAUSSIAN,07:51:09.24,42.32.46.177,0.000660490865128381,[0.00404869217508666,-0.011844732049232],false,,83.6144111272856,83.6144111272856,45.0 s1c3,GAUSSIAN,07:51:09.24,42.32.46.177,nan,[nan,inf],false,,83.6144111272856,83.6144111272856,45.0 -""") # noqa +""" # noqa @pytest.fixture diff --git a/africanus/model/wsclean/tests/test_spectral_model.py b/africanus/model/wsclean/tests/test_spectral_model.py index d7ccb3de7..81a42d514 100644 --- a/africanus/model/wsclean/tests/test_spectral_model.py +++ b/africanus/model/wsclean/tests/test_spectral_model.py @@ -6,23 +6,27 @@ import pytest from africanus.model.wsclean.file_model import load -from africanus.model.wsclean.spec_model import (ordinary_spectral_model, - log_spectral_model, spectra) +from africanus.model.wsclean.spec_model import ( + ordinary_spectral_model, + log_spectral_model, + spectra, +) from africanus.model.wsclean.dask import spectra as dask_spectra @pytest.fixture def freq(): - return np.linspace(.856e9, .856e9*2, 16) + return np.linspace(0.856e9, 0.856e9 * 2, 16) @pytest.fixture def spectral_model_inputs(wsclean_model_file): sources = dict(load(wsclean_model_file)) - I, spi, log_si, ref_freq = (sources[n] for n in ("I", "SpectralIndex", - "LogarithmicSI", - "ReferenceFrequency")) + I, spi, log_si, ref_freq = ( + sources[n] + for n in ("I", "SpectralIndex", "LogarithmicSI", "ReferenceFrequency") + ) I = np.asarray(I) # noqa spi = np.asarray(spi) @@ -74,16 +78,16 @@ def test_dask_spectral_model(spectral_model_inputs, freq): spi[log_si] = np.abs(spi[log_si]) # Compute spectral model with numpy implementations - ordinary_spec_model = ordinary_spectral_model(I, spi, log_si, - ref_freq, freq) - log_spec_model = log_spectral_model(I, spi, log_si, - ref_freq, freq) + ordinary_spec_model = ordinary_spectral_model(I, spi, log_si, ref_freq, freq) + log_spec_model = log_spectral_model(I, spi, log_si, ref_freq, freq) # Choose between ordinary and log spectral index # based on log_si array - spec_model = np.where(log_si[:, None] == True, # noqa - log_spec_model, - ordinary_spec_model) + spec_model = np.where( + log_si[:, None] == True, # noqa + log_spec_model, + ordinary_spec_model, + ) # Create dask arrays src_chunks = (4, 4) diff --git a/africanus/model/wsclean/tests/test_wsclean_model_file.py b/africanus/model/wsclean/tests/test_wsclean_model_file.py index 46a37c2d3..61406ead5 100644 --- a/africanus/model/wsclean/tests/test_wsclean_model_file.py +++ b/africanus/model/wsclean/tests/test_wsclean_model_file.py @@ -10,35 +10,46 @@ def test_wsclean_model_file(wsclean_model_file): sources = dict(load(wsclean_model_file)) - (name, stype, ra, dec, I, - spi, log_si, ref_freq, - major, minor, orientation) = (sources[n] for n in ( - "Name", "Type", "Ra", "Dec", "I", - "SpectralIndex", "LogarithmicSI", - "ReferenceFrequency", - "MajorAxis", "MinorAxis", "Orientation")) + (name, stype, ra, dec, I, spi, log_si, ref_freq, major, minor, orientation) = ( + sources[n] + for n in ( + "Name", + "Type", + "Ra", + "Dec", + "I", + "SpectralIndex", + "LogarithmicSI", + "ReferenceFrequency", + "MajorAxis", + "MinorAxis", + "Orientation", + ) + ) # Seven sources - assert (len(I) == len(spi) == len(log_si) == len(ref_freq) == 8) + assert len(I) == len(spi) == len(log_si) == len(ref_freq) == 8 # Name and type read correctly assert name[0] == "s0c0" and stype[0] == "POINT" # Check ra conversion for line 0 file entry (-float, float, float) - hours, mins, secs = (-8., 28., 5.152) - expected_ra0 = -2.0 * np.pi * ( - (-hours / 24.0) + - (mins / (24.0*60.0)) + - (secs / (24.0*60.0*60.0))) + hours, mins, secs = (-8.0, 28.0, 5.152) + expected_ra0 = ( + -2.0 + * np.pi + * ((-hours / 24.0) + (mins / (24.0 * 60.0)) + (secs / (24.0 * 60.0 * 60.0))) + ) assert ra[0] == expected_ra0 # Check dec conversion for line 0 file entry - degs, mins, secs = (39., 35., 8.511) - expected_dec0 = 2.0 * np.pi * ( - (degs / 360.0) + - (mins / (360.0*60.0)) + - (secs / (360.0*60.0*60.0))) + degs, mins, secs = (39.0, 35.0, 8.511) + expected_dec0 = ( + 2.0 + * np.pi + * ((degs / 360.0) + (mins / (360.0 * 60.0)) + (secs / (360.0 * 60.0 * 60.0))) + ) assert dec[0] == expected_dec0 @@ -50,19 +61,21 @@ def test_wsclean_model_file(wsclean_model_file): # Check ra conversion for line 2 file entry (int, not float, seconds) hours, mins, secs = (8, 18, 44) - expected_ra2 = 2.0 * np.pi * ( - (hours / 24.0) + - (mins / (24.0*60.0)) + - (secs / (24.0*60.0*60.0))) + expected_ra2 = ( + 2.0 + * np.pi + * ((hours / 24.0) + (mins / (24.0 * 60.0)) + (secs / (24.0 * 60.0 * 60.0))) + ) assert ra[2] == expected_ra2 # Check dec conversion for line 2 file entry (int, not float, seconds) degs, mins, secs = (39, 38, 37) - expected_dec2 = 2.0 * np.pi * ( - (degs / 360.0) + - (mins / (360.0*60.0)) + - (secs / (360.0*60.0*60.0))) + expected_dec2 = ( + 2.0 + * np.pi + * ((degs / 360.0) + (mins / (360.0 * 60.0)) + (secs / (360.0 * 60.0 * 60.0))) + ) assert dec[2] == expected_dec2 @@ -72,10 +85,11 @@ def test_wsclean_model_file(wsclean_model_file): # Check dec conversion for line 4 file entry (+int, not float, seconds) degs, mins, secs = (+41, 47, 17.131) - expected_dec4 = 2.0 * np.pi * ( - (degs / 360.0) + - (mins / (360.0*60.0)) + - (secs / (360.0*60.0*60.0))) + expected_dec4 = ( + 2.0 + * np.pi + * ((degs / 360.0) + (mins / (360.0 * 60.0)) + (secs / (360.0 * 60.0 * 60.0))) + ) assert dec[4] == expected_dec4 diff --git a/africanus/rime/cuda/beam.cu.j2 b/africanus/rime/cuda/beam.cu.j2 index 93fb1fc06..01b265724 100644 --- a/africanus/rime/cuda/beam.cu.j2 +++ b/africanus/rime/cuda/beam.cu.j2 @@ -278,4 +278,3 @@ extern "C" __global__ void {{kernel_name}}( } } - diff --git a/africanus/rime/cuda/beam.py b/africanus/rime/cuda/beam.py index abf9c9a69..f01ae9284 100644 --- a/africanus/rime/cuda/beam.py +++ b/africanus/rime/cuda/beam.py @@ -43,24 +43,44 @@ def _generate_interp_kernel(beam_freq_map, frequencies): block = (1024, 1, 1) - code = render(kernel_name=name, - beam_nud_limit=BEAM_NUD_LIMIT, - blockdimx=block[0], - beam_freq_type=_get_typename(beam_freq_map.dtype), - freq_type=_get_typename(frequencies.dtype)) + code = render( + kernel_name=name, + beam_nud_limit=BEAM_NUD_LIMIT, + blockdimx=block[0], + beam_freq_type=_get_typename(beam_freq_map.dtype), + freq_type=_get_typename(frequencies.dtype), + ) dtype = np.result_type(beam_freq_map, frequencies) return cp.RawKernel(code, name), block, dtype -def _main_key_fn(beam, beam_lm_ext, beam_freq_map, - lm, parangles, pointing_errors, - antenna_scaling, frequencies, - dde_dims, ncorr): - return (beam.dtype, beam.ndim, beam_lm_ext.dtype, beam_freq_map.dtype, - lm.dtype, parangles.dtype, pointing_errors.dtype, - antenna_scaling.dtype, frequencies.dtype, dde_dims, ncorr) +def _main_key_fn( + beam, + beam_lm_ext, + beam_freq_map, + lm, + parangles, + pointing_errors, + antenna_scaling, + frequencies, + dde_dims, + ncorr, +): + return ( + beam.dtype, + beam.ndim, + beam_lm_ext.dtype, + beam_freq_map.dtype, + lm.dtype, + parangles.dtype, + pointing_errors.dtype, + antenna_scaling.dtype, + frequencies.dtype, + dde_dims, + ncorr, + ) # Value to use in a bit shift to recover channel from flattened @@ -69,16 +89,22 @@ def _main_key_fn(beam, beam_lm_ext, beam_freq_map, @memoize_on_key(_main_key_fn) -def _generate_main_kernel(beam, beam_lm_ext, beam_freq_map, - lm, parangles, pointing_errors, - antenna_scaling, frequencies, - dde_dims, ncorr): - +def _generate_main_kernel( + beam, + beam_lm_ext, + beam_freq_map, + lm, + parangles, + pointing_errors, + antenna_scaling, + frequencies, + dde_dims, + ncorr, +): beam_lw, beam_mh, beam_nud = beam.shape[:3] if beam_lw < 2 or beam_mh < 2 or beam_nud < 2: - raise ValueError("(beam_lw, beam_mh, beam_nud) < 2 " - "to linearly interpolate") + raise ValueError("(beam_lw, beam_mh, beam_nud) < 2 " "to linearly interpolate") # Create template render = jinja_env.get_template(str(_MAIN_TEMPLATE_PATH)).render @@ -95,54 +121,63 @@ def _generate_main_kernel(beam, beam_lm_ext, beam_freq_map, try: corr_shift = _corr_shifter[ncorr] except KeyError: - raise ValueError("Number of Correlations not in %s" - % list(_corr_shifter.keys())) + raise ValueError( + "Number of Correlations not in %s" % list(_corr_shifter.keys()) + ) - coord_type = np.result_type(beam_lm_ext, lm, parangles, - pointing_errors, antenna_scaling, - np.float32) + coord_type = np.result_type( + beam_lm_ext, lm, parangles, pointing_errors, antenna_scaling, np.float32 + ) assert coord_type in (np.float32, np.float64) - code = render(kernel_name=name, - blockdimx=block[0], - blockdimy=block[1], - blockdimz=block[2], - corr_shift=corr_shift, - ncorr=ncorr, - beam_nud_limit=BEAM_NUD_LIMIT, - # Beam type and manipulation functions - beam_type=_get_typename(beam.real.dtype), - beam_dims=beam.ndim, - make2_beam_fn=cuda_function('make2', beam.real.dtype), - beam_sqrt_fn=cuda_function('sqrt', beam.real.dtype), - beam_rsqrt_fn=cuda_function('rsqrt', beam.real.dtype), - # Coordinate type and manipulation functions - FT=_get_typename(coord_type), - floor_fn=cuda_function('floor', coord_type), - min_fn=cuda_function('min', coord_type), - max_fn=cuda_function('max', coord_type), - cos_fn=cuda_function('cos', coord_type), - sin_fn=cuda_function('sin', coord_type), - lm_ext_type=_get_typename(beam_lm_ext.dtype), - beam_freq_type=_get_typename(beam_freq_map.dtype), - lm_type=_get_typename(lm.dtype), - pa_type=_get_typename(parangles.dtype), - pe_type=_get_typename(pointing_errors.dtype), - as_type=_get_typename(antenna_scaling.dtype), - freq_type=_get_typename(frequencies.dtype), - dde_type=_get_typename(beam.real.dtype), - dde_dims=dde_dims) + code = render( + kernel_name=name, + blockdimx=block[0], + blockdimy=block[1], + blockdimz=block[2], + corr_shift=corr_shift, + ncorr=ncorr, + beam_nud_limit=BEAM_NUD_LIMIT, + # Beam type and manipulation functions + beam_type=_get_typename(beam.real.dtype), + beam_dims=beam.ndim, + make2_beam_fn=cuda_function("make2", beam.real.dtype), + beam_sqrt_fn=cuda_function("sqrt", beam.real.dtype), + beam_rsqrt_fn=cuda_function("rsqrt", beam.real.dtype), + # Coordinate type and manipulation functions + FT=_get_typename(coord_type), + floor_fn=cuda_function("floor", coord_type), + min_fn=cuda_function("min", coord_type), + max_fn=cuda_function("max", coord_type), + cos_fn=cuda_function("cos", coord_type), + sin_fn=cuda_function("sin", coord_type), + lm_ext_type=_get_typename(beam_lm_ext.dtype), + beam_freq_type=_get_typename(beam_freq_map.dtype), + lm_type=_get_typename(lm.dtype), + pa_type=_get_typename(parangles.dtype), + pe_type=_get_typename(pointing_errors.dtype), + as_type=_get_typename(antenna_scaling.dtype), + freq_type=_get_typename(frequencies.dtype), + dde_type=_get_typename(beam.real.dtype), + dde_dims=dde_dims, + ) # Complex output type return cp.RawKernel(code, name), block, dtype -@requires_optional('cupy', opt_import_error) -def beam_cube_dde(beam, beam_lm_ext, beam_freq_map, - lm, parangles, pointing_errors, - antenna_scaling, frequencies): - +@requires_optional("cupy", opt_import_error) +def beam_cube_dde( + beam, + beam_lm_ext, + beam_freq_map, + lm, + parangles, + pointing_errors, + antenna_scaling, + frequencies, +): corrs = beam.shape[3:] if beam.shape[2] >= BEAM_NUD_LIMIT: @@ -152,7 +187,7 @@ def beam_cube_dde(beam, beam_lm_ext, beam_freq_map, ntime, na = parangles.shape nchan = frequencies.shape[0] ncorr = reduce(mul, corrs, 1) - nchancorr = nchan*ncorr + nchancorr = nchan * ncorr oshape = (nsrc, ntime, na, nchan) + corrs @@ -166,14 +201,18 @@ def beam_cube_dde(beam, beam_lm_ext, beam_freq_map, ikernel, iblock, idt = _generate_interp_kernel(beam_freq_map, frequencies) # Generate main beam cube kernel - kernel, block, dtype = _generate_main_kernel(fbeam, beam_lm_ext, - beam_freq_map, - lm, parangles, - pointing_errors, - antenna_scaling, - frequencies, - len(oshape), - ncorr) + kernel, block, dtype = _generate_main_kernel( + fbeam, + beam_lm_ext, + beam_freq_map, + lm, + parangles, + pointing_errors, + antenna_scaling, + frequencies, + len(oshape), + ncorr, + ) # Call frequency interpolation kernel igrid = grids((nchan, 1, 1), iblock) freq_data = cp.empty((3, nchan), dtype=frequencies.dtype) @@ -189,10 +228,23 @@ def beam_cube_dde(beam, beam_lm_ext, beam_freq_map, grid = grids((nchancorr, na, ntime), block) try: - kernel(grid, block, (fbeam, beam_lm_ext, beam_freq_map, - lm, parangles, pointing_errors, - antenna_scaling, frequencies, freq_data, - nsrc, out)) + kernel( + grid, + block, + ( + fbeam, + beam_lm_ext, + beam_freq_map, + lm, + parangles, + pointing_errors, + antenna_scaling, + frequencies, + freq_data, + nsrc, + out, + ), + ) except CompileException: log.exception(format_code(kernel.code)) raise @@ -202,6 +254,7 @@ def beam_cube_dde(beam, beam_lm_ext, beam_freq_map, try: beam_cube_dde.__doc__ = BEAM_CUBE_DOCS.substitute( - array_type=":class:`cupy.ndarray`") + array_type=":class:`cupy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/rime/cuda/feeds.py b/africanus/rime/cuda/feeds.py index c4d03eddf..4b54acdd7 100644 --- a/africanus/rime/cuda/feeds.py +++ b/africanus/rime/cuda/feeds.py @@ -48,11 +48,13 @@ def _generate_kernel(parallactic_angles, feed_type): render = jinja_env.get_template(_TEMPLATE_PATH).render name = "feed_rotation" - code = render(kernel_name=name, - feed_type=feed_type, - sincos_fn=cuda_function('sincos', dtype), - pa_type=_get_typename(dtype), - out_type=_get_typename(dtype)) + code = render( + kernel_name=name, + feed_type=feed_type, + sincos_fn=cuda_function("sincos", dtype), + pa_type=_get_typename(dtype), + out_type=_get_typename(dtype), + ) # Complex output type out_dtype = np.result_type(dtype, np.complex64) @@ -60,8 +62,8 @@ def _generate_kernel(parallactic_angles, feed_type): @requires_optional("cupy", opt_import_error) -def feed_rotation(parallactic_angles, feed_type='linear'): - """ Cupy implementation of the feed_rotation kernel. """ +def feed_rotation(parallactic_angles, feed_type="linear"): + """Cupy implementation of the feed_rotation kernel.""" kernel, block, out_dtype = _generate_kernel(parallactic_angles, feed_type) in_shape = parallactic_angles.shape parallactic_angles = parallactic_angles.ravel() @@ -79,6 +81,7 @@ def feed_rotation(parallactic_angles, feed_type='linear'): try: feed_rotation.__doc__ = FEED_ROTATION_DOCS.substitute( - array_type=":class:`cupy.ndarray`") + array_type=":class:`cupy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/rime/cuda/phase.py b/africanus/rime/cuda/phase.py index 87beed634..346e44141 100644 --- a/africanus/rime/cuda/phase.py +++ b/africanus/rime/cuda/phase.py @@ -44,16 +44,18 @@ def _generate_kernel(lm, uvw, frequency): render = jinja_env.get_template(_TEMPLATE_PATH).render name = "phase_delay" - code = render(kernel_name=name, - lm_type=_get_typename(lm.dtype), - uvw_type=_get_typename(uvw.dtype), - freq_type=_get_typename(frequency.dtype), - out_type=_get_typename(out_dtype), - sqrt_fn=cuda_function('sqrt', lm.dtype), - sincos_fn=cuda_function('sincos', out_dtype), - minus_two_pi_over_c=minus_two_pi_over_c, - blockdimx=blockdimx, - blockdimy=blockdimy) + code = render( + kernel_name=name, + lm_type=_get_typename(lm.dtype), + uvw_type=_get_typename(uvw.dtype), + freq_type=_get_typename(frequency.dtype), + out_type=_get_typename(out_dtype), + sqrt_fn=cuda_function("sqrt", lm.dtype), + sincos_fn=cuda_function("sincos", out_dtype), + minus_two_pi_over_c=minus_two_pi_over_c, + blockdimx=blockdimx, + blockdimy=blockdimy, + ) # Complex output type out_dtype = np.result_type(out_dtype, np.complex64) @@ -64,8 +66,9 @@ def _generate_kernel(lm, uvw, frequency): def phase_delay(lm, uvw, frequency): kernel, block, out_dtype = _generate_kernel(lm, uvw, frequency) grid = grids((frequency.shape[0], uvw.shape[0], 1), block) - out = cp.empty(shape=(lm.shape[0], uvw.shape[0], frequency.shape[0]), - dtype=out_dtype) + out = cp.empty( + shape=(lm.shape[0], uvw.shape[0], frequency.shape[0]), dtype=out_dtype + ) try: kernel(grid, block, (lm, uvw, frequency, out)) @@ -78,6 +81,7 @@ def phase_delay(lm, uvw, frequency): try: phase_delay.__doc__ = PHASE_DELAY_DOCS.substitute( - array_type=':class:`cupy.ndarray`') + array_type=":class:`cupy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/rime/cuda/predict.py b/africanus/rime/cuda/predict.py index 9b28af5e6..4ca55ceef 100644 --- a/africanus/rime/cuda/predict.py +++ b/africanus/rime/cuda/predict.py @@ -8,7 +8,7 @@ import numpy as np -from africanus.rime.predict import (PREDICT_DOCS, predict_checks) +from africanus.rime.predict import PREDICT_DOCS, predict_checks from africanus.util.code import format_code, memoize_on_key from africanus.util.cuda import cuda_type, grids from africanus.util.jinja2 import jinja_env @@ -30,21 +30,38 @@ def _key_fn(*args): - """ Hash on array datatypes and rank """ - return tuple((a.dtype, a.ndim) - if isinstance(a, (np.ndarray, cp.ndarray)) - else a for a in args) + """Hash on array datatypes and rank""" + return tuple( + (a.dtype, a.ndim) if isinstance(a, (np.ndarray, cp.ndarray)) else a + for a in args + ) @memoize_on_key(_key_fn) -def _generate_kernel(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones, - corrs, out_ndim): - - tup = predict_checks(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones) +def _generate_kernel( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + corrs, + out_ndim, +): + tup = predict_checks( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ) (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = tup @@ -63,8 +80,9 @@ def _generate_kernel(time_index, antenna1, antenna2, name = "predict_vis" # Complex output type - out_dtype = np.result_type(dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones) + out_dtype = np.result_type( + dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones + ) ncorrs = reduce(mul, corrs, 1) @@ -74,38 +92,50 @@ def _generate_kernel(time_index, antenna1, antenna2, block = (blockdimx, blockdimy, 1) - code = render(kernel_name=name, blockdimx=blockdimx, blockdimy=blockdimy, - have_dde1=have_ddes1, - dde1_type=cuda_type(dde1_jones) if have_ddes1 else "int", - dde1_ndim=dde1_jones.ndim if have_ddes1 else 1, - have_dde2=have_ddes2, - dde2_type=cuda_type(dde2_jones) if have_ddes2 else "int", - dde2_ndim=dde2_jones.ndim if have_ddes2 else 1, - have_coh=have_coh, - coh_type=cuda_type(source_coh) if have_coh else "int", - coh_ndim=source_coh.ndim if have_coh else 1, - have_die1=have_dies1, - die1_type=cuda_type(die1_jones) if have_dies1 else "int", - die1_ndim=die1_jones.ndim if have_dies1 else 1, - have_base_vis=have_bvis, - base_vis_type=cuda_type(base_vis) if have_bvis else "int", - base_vis_ndim=base_vis.ndim if have_bvis else 1, - have_die2=have_dies2, - die2_type=cuda_type(die2_jones) if have_dies2 else "int", - die2_ndim=die2_jones.ndim if have_dies2 else 1, - out_type=cuda_type(out_dtype), - corrs=ncorrs, - out_ndim=out_ndim, - warp_size=32) + code = render( + kernel_name=name, + blockdimx=blockdimx, + blockdimy=blockdimy, + have_dde1=have_ddes1, + dde1_type=cuda_type(dde1_jones) if have_ddes1 else "int", + dde1_ndim=dde1_jones.ndim if have_ddes1 else 1, + have_dde2=have_ddes2, + dde2_type=cuda_type(dde2_jones) if have_ddes2 else "int", + dde2_ndim=dde2_jones.ndim if have_ddes2 else 1, + have_coh=have_coh, + coh_type=cuda_type(source_coh) if have_coh else "int", + coh_ndim=source_coh.ndim if have_coh else 1, + have_die1=have_dies1, + die1_type=cuda_type(die1_jones) if have_dies1 else "int", + die1_ndim=die1_jones.ndim if have_dies1 else 1, + have_base_vis=have_bvis, + base_vis_type=cuda_type(base_vis) if have_bvis else "int", + base_vis_ndim=base_vis.ndim if have_bvis else 1, + have_die2=have_dies2, + die2_type=cuda_type(die2_jones) if have_dies2 else "int", + die2_ndim=die2_jones.ndim if have_dies2 else 1, + out_type=cuda_type(out_dtype), + corrs=ncorrs, + out_ndim=out_ndim, + warp_size=32, + ) return cp.RawKernel(code, name), block, out_dtype @requires_optional("cupy", opt_import_error) -def predict_vis(time_index, antenna1, antenna2, - dde1_jones=None, source_coh=None, dde2_jones=None, - die1_jones=None, base_vis=None, die2_jones=None): - """ Cupy implementation of the feed_rotation kernel. """ +def predict_vis( + time_index, + antenna1, + antenna2, + dde1_jones=None, + source_coh=None, + dde2_jones=None, + die1_jones=None, + base_vis=None, + die2_jones=None, +): + """Cupy implementation of the feed_rotation kernel.""" have_ddes = dde1_jones is not None and dde2_jones is not None have_dies = die1_jones is not None and die2_jones is not None @@ -130,8 +160,9 @@ def predict_vis(time_index, antenna1, antenna2, chan = base_vis.shape[1] corrs = base_vis.shape[2:] else: - raise ValueError("Insufficient inputs supplied for determining " - "the output shape") + raise ValueError( + "Insufficient inputs supplied for determining " "the output shape" + ) ncorrs = len(corrs) @@ -164,19 +195,21 @@ def predict_vis(time_index, antenna1, antenna2, out_shape = (row, chan) + (flat_corrs,) - kernel, block, out_dtype = _generate_kernel(time_index, - antenna1, - antenna2, - dde1_jones, - source_coh, - dde2_jones, - die1_jones, - base_vis, - die2_jones, - corrs, - len(out_shape)) - - grid = grids((chan*flat_corrs, row, 1), block) + kernel, block, out_dtype = _generate_kernel( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + corrs, + len(out_shape), + ) + + grid = grids((chan * flat_corrs, row, 1), block) out = cp.empty(shape=out_shape, dtype=out_dtype) # Normalise the time index @@ -184,10 +217,18 @@ def predict_vis(time_index, antenna1, antenna2, # Normalise the time index with a device-wide reduction norm_time_index = time_index - time_index.min() - args = (norm_time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones, - out) + args = ( + norm_time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + out, + ) try: kernel(grid, block, tuple(a for a in args if a is not None)) @@ -200,10 +241,10 @@ def predict_vis(time_index, antenna1, antenna2, try: predict_vis.__doc__ = PREDICT_DOCS.substitute( - array_type=":class:`cupy.ndarray`", - get_time_index=":code:`cp.unique(time, " - "return_inverse=True)[1]`", - extra_args="", - extra_notes="") + array_type=":class:`cupy.ndarray`", + get_time_index=":code:`cp.unique(time, " "return_inverse=True)[1]`", + extra_args="", + extra_notes="", + ) except AttributeError: pass diff --git a/africanus/rime/cuda/test_shuffle.py b/africanus/rime/cuda/test_shuffle.py index 2e1acb9ba..8cd1d2da8 100644 --- a/africanus/rime/cuda/test_shuffle.py +++ b/africanus/rime/cuda/test_shuffle.py @@ -11,7 +11,8 @@ def test_cuda_shuffle_transpose(): cp = pytest.importorskip("cupy") jinja2 = pytest.importorskip("jinja2") - _TEMPLATE = jinja2.Template(""" + _TEMPLATE = jinja2.Template( + """ #include #define debug {{debug}} @@ -105,23 +106,25 @@ def test_cuda_shuffle_transpose(): output[v + {{corr}}*nvis] = values[{{corr}}]; {%- endfor %} } - """) + """ + ) nvis = 32 ncorrs = 4 dtype = np.int32 dtypes = { - np.float32: 'float', - np.float64: 'double', - np.int32: 'int', + np.float32: "float", + np.float64: "double", + np.int32: "int", } - code = _TEMPLATE.render(type=dtypes[dtype], warp_size=32, - corrs=ncorrs, debug="false") + code = _TEMPLATE.render( + type=dtypes[dtype], warp_size=32, corrs=ncorrs, debug="false" + ) kernel = cp.RawKernel(code, "kernel") - inputs = cp.arange(nvis*ncorrs, dtype=dtype).reshape(nvis, ncorrs) + inputs = cp.arange(nvis * ncorrs, dtype=dtype).reshape(nvis, ncorrs) outputs = cp.empty_like(inputs) args = (inputs, outputs) block = (256, 1, 1) @@ -133,8 +136,7 @@ def test_cuda_shuffle_transpose(): print(format_code(kernel.code)) raise - np.testing.assert_array_almost_equal(cp.asnumpy(inputs), - cp.asnumpy(outputs)) + np.testing.assert_array_almost_equal(cp.asnumpy(inputs), cp.asnumpy(outputs)) return # Dead code @@ -229,7 +231,8 @@ def test_cuda_shuffle_transpose_2(ncorrs): # https://homes.cs.washington.edu/~cdel/papers/sc13-shuffle-abstract.pdf # http://sc13.supercomputing.org/sites/default/files/PostersArchive/spost142.html - _TEMPLATE = jinja2.Template(""" + _TEMPLATE = jinja2.Template( + """ #include {%- if (corrs < 1 or (corrs.__and__(corrs - 1) != 0)) %} @@ -329,26 +332,29 @@ def test_cuda_shuffle_transpose_2(ncorrs): output[v + {{corr}}*nvis] = values[{{corr}}]; {%- endfor %} } - """) # noqa + """ + ) # noqa nvis = 32 dtype = np.int32 dtypes = { - np.float32: 'float', - np.float64: 'double', - np.int32: 'int', + np.float32: "float", + np.float64: "double", + np.int32: "int", } - code = _TEMPLATE.render(type=dtypes[dtype], - throw=throw_helper, - register_assign_cycles=register_assign_cycles, - warp_size=32, - corrs=ncorrs, - debug="false") + code = _TEMPLATE.render( + type=dtypes[dtype], + throw=throw_helper, + register_assign_cycles=register_assign_cycles, + warp_size=32, + corrs=ncorrs, + debug="false", + ) kernel = cp.RawKernel(code, "kernel") - inputs = cp.arange(nvis*ncorrs, dtype=dtype).reshape(nvis, ncorrs) + inputs = cp.arange(nvis * ncorrs, dtype=dtype).reshape(nvis, ncorrs) outputs = cp.empty_like(inputs) args = (inputs, outputs) block = (256, 1, 1) @@ -360,8 +366,7 @@ def test_cuda_shuffle_transpose_2(ncorrs): print(format_code(kernel.code)) raise - np.testing.assert_array_almost_equal(cp.asnumpy(inputs), - cp.asnumpy(outputs)) + np.testing.assert_array_almost_equal(cp.asnumpy(inputs), cp.asnumpy(outputs)) return # Dead code diff --git a/africanus/rime/cuda/tests/test_cuda_beam.py b/africanus/rime/cuda/tests/test_cuda_beam.py index 5332f3702..544073c3a 100644 --- a/africanus/rime/cuda/tests/test_cuda_beam.py +++ b/africanus/rime/cuda/tests/test_cuda_beam.py @@ -8,7 +8,7 @@ from africanus.rime import beam_cube_dde as np_beam_cube_dde from africanus.rime.cuda.beam import beam_cube_dde as cp_beam_cude_dde -cp = pytest.importorskip('cupy') +cp = pytest.importorskip("cupy") @pytest.mark.parametrize("corrs", [(2, 2), (4,), (2,), (1,)]) @@ -18,35 +18,39 @@ def test_cuda_beam(corrs): src, time, ant, chan = 20, 29, 14, 64 beam_lw = beam_mh = beam_nud = 50 - beam = (rs.normal(size=(beam_lw, beam_mh, beam_nud) + corrs) + - rs.normal(size=(beam_lw, beam_mh, beam_nud) + corrs)*1j) + beam = ( + rs.normal(size=(beam_lw, beam_mh, beam_nud) + corrs) + + rs.normal(size=(beam_lw, beam_mh, beam_nud) + corrs) * 1j + ) beam_lm_ext = np.array(([[-0.5, 0.5], [-0.5, 0.5]])) lm = rs.normal(size=(src, 2)) - 0.5 if chan == 1: - freqs = np.array([.856e9*3 / 2]) + freqs = np.array([0.856e9 * 3 / 2]) else: - freqs = np.linspace(.856e9, 2*.856e9, chan) + freqs = np.linspace(0.856e9, 2 * 0.856e9, chan) - beam_freq_map = np.linspace(.856e9, 2*.856e9, beam_nud) + beam_freq_map = np.linspace(0.856e9, 2 * 0.856e9, beam_nud) parangles = rs.normal(size=(time, ant)) point_errors = rs.normal(size=(time, ant, chan, 2)) ant_scales = rs.normal(size=(ant, chan, 2)) - np_ddes = np_beam_cube_dde(beam, beam_lm_ext, beam_freq_map, - lm, parangles, point_errors, ant_scales, - freqs) - - cp_ddes = cp_beam_cude_dde(cp.asarray(beam), - cp.asarray(beam_lm_ext), - cp.asarray(beam_freq_map), - cp.asarray(lm), - cp.asarray(parangles), - cp.asarray(point_errors), - cp.asarray(ant_scales), - cp.asarray(freqs)) + np_ddes = np_beam_cube_dde( + beam, beam_lm_ext, beam_freq_map, lm, parangles, point_errors, ant_scales, freqs + ) + + cp_ddes = cp_beam_cude_dde( + cp.asarray(beam), + cp.asarray(beam_lm_ext), + cp.asarray(beam_freq_map), + cp.asarray(lm), + cp.asarray(parangles), + cp.asarray(point_errors), + cp.asarray(ant_scales), + cp.asarray(freqs), + ) assert_array_almost_equal(np_ddes, cp.asnumpy(cp_ddes)) diff --git a/africanus/rime/cuda/tests/test_cuda_feed_rotation.py b/africanus/rime/cuda/tests/test_cuda_feed_rotation.py index 95d94819f..ac7a8f65c 100644 --- a/africanus/rime/cuda/tests/test_cuda_feed_rotation.py +++ b/africanus/rime/cuda/tests/test_cuda_feed_rotation.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("shape", [(10, 7), (8,)]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_cuda_feed_rotation(feed_type, shape, dtype): - cp = pytest.importorskip('cupy') + cp = pytest.importorskip("cupy") pa = np.random.random(shape).astype(dtype) diff --git a/africanus/rime/cuda/tests/test_cuda_phase_delay.py b/africanus/rime/cuda/tests/test_cuda_phase_delay.py index cd801a909..f2a123f44 100644 --- a/africanus/rime/cuda/tests/test_cuda_phase_delay.py +++ b/africanus/rime/cuda/tests/test_cuda_phase_delay.py @@ -8,21 +8,17 @@ from africanus.rime.cuda.phase import phase_delay as cp_phase_delay -@pytest.mark.parametrize("dtype, decimal", [ - (np.float32, 5), - (np.float64, 6) -]) +@pytest.mark.parametrize("dtype, decimal", [(np.float32, 5), (np.float64, 6)]) def test_cuda_phase_delay(dtype, decimal): - cp = pytest.importorskip('cupy') + cp = pytest.importorskip("cupy") - lm = 0.01*np.random.random((10, 2)).astype(dtype) + lm = 0.01 * np.random.random((10, 2)).astype(dtype) uvw = np.random.random((100, 3)).astype(dtype) - freq = np.linspace(.856e9, 2*.856e9, 70, dtype=dtype) + freq = np.linspace(0.856e9, 2 * 0.856e9, 70, dtype=dtype) - cp_cplx_phase = cp_phase_delay(cp.asarray(lm), - cp.asarray(uvw), - cp.asarray(freq)) + cp_cplx_phase = cp_phase_delay(cp.asarray(lm), cp.asarray(uvw), cp.asarray(freq)) np_cplx_phase = np_phase_delay(lm, uvw, freq) - np.testing.assert_array_almost_equal(cp.asnumpy(cp_cplx_phase), - np_cplx_phase, decimal=decimal) + np.testing.assert_array_almost_equal( + cp.asnumpy(cp_cplx_phase), np_cplx_phase, decimal=decimal + ) diff --git a/africanus/rime/cuda/tests/test_cuda_predict.py b/africanus/rime/cuda/tests/test_cuda_predict.py index 3faedb1f7..2d72f5c01 100644 --- a/africanus/rime/cuda/tests/test_cuda_predict.py +++ b/africanus/rime/cuda/tests/test_cuda_predict.py @@ -6,29 +6,31 @@ from africanus.rime.predict import predict_vis as np_predict_vis from africanus.rime.cuda.predict import predict_vis -from africanus.rime.tests.test_predict import (corr_shape_parametrization, - die_presence_parametrization, - dde_presence_parametrization, - chunk_parametrization, - rc) +from africanus.rime.tests.test_predict import ( + corr_shape_parametrization, + die_presence_parametrization, + dde_presence_parametrization, + chunk_parametrization, + rc, +) @corr_shape_parametrization @dde_presence_parametrization @die_presence_parametrization @chunk_parametrization -def test_cuda_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, - a1j, blj, a2j, g1j, bvis, g2j, - chunks): +def test_cuda_predict_vis( + corr_shape, idm, einsum_sig1, einsum_sig2, a1j, blj, a2j, g1j, bvis, g2j, chunks +): np.random.seed(40) - cp = pytest.importorskip('cupy') + cp = pytest.importorskip("cupy") - s = sum(chunks['source']) - t = sum(chunks['time']) - a = sum(chunks['antenna']) - c = sum(chunks['channels']) - r = sum(chunks['rows']) + s = sum(chunks["source"]) + t = sum(chunks["time"]) + a = sum(chunks["antenna"]) + c = sum(chunks["channels"]) + r = sum(chunks["rows"]) a1_jones = rc((s, t, a, c) + corr_shape) bl_jones = rc((s, r, c) + corr_shape) @@ -38,35 +40,42 @@ def test_cuda_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, g2_jones = rc((t, a, c) + corr_shape) # Add 10 to the index to test time index normalisation - time_idx = np.concatenate([np.full(rows, i+10, dtype=np.int32) - for i, rows in enumerate(chunks['rows'])]) + time_idx = np.concatenate( + [np.full(rows, i + 10, dtype=np.int32) for i, rows in enumerate(chunks["rows"])] + ) - ant1 = np.concatenate([np.random.randint(0, a, rows, dtype=np.int32) - for rows in chunks['rows']]) + ant1 = np.concatenate( + [np.random.randint(0, a, rows, dtype=np.int32) for rows in chunks["rows"]] + ) - ant2 = np.concatenate([np.random.randint(0, a, rows, dtype=np.int32) - for rows in chunks['rows']]) + ant2 = np.concatenate( + [np.random.randint(0, a, rows, dtype=np.int32) for rows in chunks["rows"]] + ) assert ant1.size == r - model_vis = predict_vis(cp.asarray(time_idx), - cp.asarray(ant1), - cp.asarray(ant2), - cp.asarray(a1_jones) if a1j else None, - cp.asarray(bl_jones) if blj else None, - cp.asarray(a2_jones) if a2j else None, - cp.asarray(g1_jones) if g1j else None, - cp.asarray(base_vis) if bvis else None, - cp.asarray(g2_jones) if g2j else None) + model_vis = predict_vis( + cp.asarray(time_idx), + cp.asarray(ant1), + cp.asarray(ant2), + cp.asarray(a1_jones) if a1j else None, + cp.asarray(bl_jones) if blj else None, + cp.asarray(a2_jones) if a2j else None, + cp.asarray(g1_jones) if g1j else None, + cp.asarray(base_vis) if bvis else None, + cp.asarray(g2_jones) if g2j else None, + ) - np_model_vis = np_predict_vis(time_idx, - ant1, - ant2, - a1_jones if a1j else None, - bl_jones if blj else None, - a2_jones if a2j else None, - g1_jones if g1j else None, - base_vis if bvis else None, - g2_jones if g2j else None) + np_model_vis = np_predict_vis( + time_idx, + ant1, + ant2, + a1_jones if a1j else None, + bl_jones if blj else None, + a2_jones if a2j else None, + g1_jones if g1j else None, + base_vis if bvis else None, + g2_jones if g2j else None, + ) np.testing.assert_array_almost_equal(cp.asnumpy(model_vis), np_model_vis) diff --git a/africanus/rime/cuda/tests/test_macros.py b/africanus/rime/cuda/tests/test_macros.py index 1387ea3ba..9c3ebc751 100644 --- a/africanus/rime/cuda/tests/test_macros.py +++ b/africanus/rime/cuda/tests/test_macros.py @@ -11,29 +11,29 @@ @pytest.mark.parametrize("ncorrs", [1, 2, 4, 8]) -@pytest.mark.parametrize("dtype", [np.int32, np.float64, np.float32, - np.complex64, np.complex128]) +@pytest.mark.parametrize( + "dtype", [np.int32, np.float64, np.float32, np.complex64, np.complex128] +) @pytest.mark.parametrize("nvis", [9, 10, 11, 32, 1025]) @pytest.mark.parametrize("debug", ["false"]) def test_cuda_inplace_warp_transpose(ncorrs, dtype, nvis, debug): - cp = pytest.importorskip('cupy') + cp = pytest.importorskip("cupy") path = pjoin("rime", "cuda", "tests", "test_warp_transpose.cu.j2") render = jinja_env.get_template(path).render dtypes = { - np.float32: 'float', - np.float64: 'double', - np.int32: 'int', - np.complex64: 'float2', - np.complex128: 'double2', + np.float32: "float", + np.float64: "double", + np.int32: "int", + np.complex64: "float2", + np.complex128: "double2", } - code = render(type=dtypes[dtype], warp_size=32, - corrs=ncorrs, debug=debug) + code = render(type=dtypes[dtype], warp_size=32, corrs=ncorrs, debug=debug) kernel = cp.RawKernel(code, "kernel") - inputs = cp.arange(nvis*ncorrs, dtype=dtype).reshape(nvis, ncorrs) + inputs = cp.arange(nvis * ncorrs, dtype=dtype).reshape(nvis, ncorrs) outputs = cp.empty_like(inputs) args = (inputs, outputs) block = (256, 1, 1) @@ -45,5 +45,4 @@ def test_cuda_inplace_warp_transpose(ncorrs, dtype, nvis, debug): print(format_code(kernel.code)) raise - np.testing.assert_array_almost_equal(cp.asnumpy(inputs), - cp.asnumpy(outputs)) + np.testing.assert_array_almost_equal(cp.asnumpy(inputs), cp.asnumpy(outputs)) diff --git a/africanus/rime/dask.py b/africanus/rime/dask.py index 25c775e3e..80b2b8f1c 100644 --- a/africanus/rime/dask.py +++ b/africanus/rime/dask.py @@ -37,7 +37,7 @@ def _phase_delay_wrap(lm, uvw, frequency, convention): @requires_optional("dask.array", da_import_error) def phase_delay(lm, uvw, frequency, convention="fourier"): - """ Dask wrapper for phase_delay function """ + """Dask wrapper for phase_delay function""" return da.core.blockwise( _phase_delay_wrap, ("source", "row", "chan"), @@ -58,7 +58,6 @@ def _parangle_wrapper(t, ap, fc, **kw): @requires_optional("dask.array", da_import_error) def parallactic_angles(times, antenna_positions, field_centre, **kwargs): - return da.core.blockwise( _parangle_wrapper, ("time", "ant"), @@ -69,7 +68,7 @@ def parallactic_angles(times, antenna_positions, field_centre, **kwargs): field_centre, ("fc",), dtype=times.dtype, - **kwargs + **kwargs, ) @@ -118,7 +117,6 @@ def transform_sources( frequency, dtype=None, ): - if dtype is None: dtype = np.float64 @@ -176,7 +174,6 @@ def beam_cube_dde( antenna_scaling, frequencies, ): - if not all(len(c) == 1 for c in beam.chunks): raise ValueError("Beam chunking unsupported") diff --git a/africanus/rime/dask_predict.py b/africanus/rime/dask_predict.py index e25b2ba13..194871207 100644 --- a/africanus/rime/dask_predict.py +++ b/africanus/rime/dask_predict.py @@ -13,11 +13,15 @@ from africanus.util.requirements import requires_optional -from africanus.rime.predict import (PREDICT_DOCS, predict_checks, - predict_vis as np_predict_vis) +from africanus.rime.predict import ( + PREDICT_DOCS, + predict_checks, + predict_vis as np_predict_vis, +) from africanus.rime.wsclean_predict import ( - WSCLEAN_PREDICT_DOCS, - wsclean_predict_main as wsclean_predict_body) + WSCLEAN_PREDICT_DOCS, + wsclean_predict_main as wsclean_predict_body, +) from africanus.model.wsclean.spec_model import spectra as wsclean_spectra @@ -69,8 +73,9 @@ def __init__( ): self.func = func self.output_indices = tuple(output_indices) - self.indices = tuple((name, tuple(ind) if ind is not None else ind) - for name, ind in indices) + self.indices = tuple( + (name, tuple(ind) if ind is not None else ind) for name, ind in indices + ) self.numblocks = numblocks if axis is None: @@ -82,12 +87,14 @@ def __init__( self.feed_index = feed_index self.axis = axis - token = tokenize(self.func, - self.output_indices, - self.indices, - self.numblocks, - self.feed_index, - self.axis) + token = tokenize( + self.func, + self.output_indices, + self.indices, + self.numblocks, + self.feed_index, + self.axis, + ) self.func_name = funcname(self.func) self.name = "-".join((self.func_name, token)) @@ -109,9 +116,7 @@ def _dict(self): dim_map = {k: i for i, k in enumerate(out_dims)} dsk = {} - int_name = "-".join((self.func_name, - "intermediate", - tokenize(self.name))) + int_name = "-".join((self.func_name, "intermediate", tokenize(self.name))) # Iterate over the output keys creating associated task for out_ind in product(*[range(dim_blocks[d]) for d in out_dims]): @@ -125,19 +130,23 @@ def _dict(self): # Otherwise feed in the result of the last operation else: - task.append((int_name,) + - # Index last reduction block - # always in first axis - (out_ind[0] - 1,) + - out_ind[1:]) + task.append( + (int_name,) + + + # Index last reduction block + # always in first axis + (out_ind[0] - 1,) + + out_ind[1:] + ) elif ind is None: # Literal arg, embed task.append(arg) else: # Derive input key from output key indices - task.append(tuple(_ind_map(arg, ind, out_ind, - dim_map, dim_blocks))) + task.append( + tuple(_ind_map(arg, ind, out_ind, dim_map, dim_blocks)) + ) # Final block if out_ind[0] == last_block: @@ -169,12 +178,24 @@ def _out_numblocks(self): return {k: v for k, v in d.items() if k in self.output_indices} -def linear_reduction(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - predict_check_tup, out_dtype): - - (have_ddes1, have_coh, have_ddes2, - have_dies1, have_bvis, have_dies2) = predict_check_tup +def linear_reduction( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + predict_check_tup, + out_dtype, +): + ( + have_ddes1, + have_coh, + have_ddes2, + have_dies1, + have_bvis, + have_dies2, + ) = predict_check_tup have_ddes = have_ddes1 and have_ddes2 @@ -185,63 +206,83 @@ def linear_reduction(time_index, antenna1, antenna2, else: raise ValueError("need ddes or source coherencies") - args = [(time_index, ("row",)), - (antenna1, ("row",)), - (antenna2, ("row",)), - (dde1_jones, ("source", "row", "ant", "chan") + cdims), - (source_coh, ("source", "row", "chan") + cdims), - (dde2_jones, ("source", "row", "ant", "chan") + cdims), - (None, None), - (None, None), - (None, None)] - - name_args = [(None, None) if a is None else - (a.name, i) if isinstance(a, da.Array) else - (a, i) for a, i in args] - - numblocks = {a.name: a.numblocks - for a, i in args - if a is not None} - - lr = LinearReduction(np_predict_vis, ("row", "chan") + cdims, - name_args, - numblocks=numblocks, - feed_index=7, - axis='source') - - graph = HighLevelGraph.from_collections(lr.name, lr, - [a for a, i in args - if a is not None]) - - chunk_map = {d: arg.chunks[i] for arg, ind in args - if arg is not None and ind is not None - for i, d in enumerate(ind)} - chunk_map['row'] = time_index.chunks[0] # Override - - chunks = tuple(chunk_map[d] for d in ('row', 'chan') + cdims) + args = [ + (time_index, ("row",)), + (antenna1, ("row",)), + (antenna2, ("row",)), + (dde1_jones, ("source", "row", "ant", "chan") + cdims), + (source_coh, ("source", "row", "chan") + cdims), + (dde2_jones, ("source", "row", "ant", "chan") + cdims), + (None, None), + (None, None), + (None, None), + ] + + name_args = [ + (None, None) + if a is None + else (a.name, i) + if isinstance(a, da.Array) + else (a, i) + for a, i in args + ] + + numblocks = {a.name: a.numblocks for a, i in args if a is not None} + + lr = LinearReduction( + np_predict_vis, + ("row", "chan") + cdims, + name_args, + numblocks=numblocks, + feed_index=7, + axis="source", + ) + + graph = HighLevelGraph.from_collections( + lr.name, lr, [a for a, i in args if a is not None] + ) + + chunk_map = { + d: arg.chunks[i] + for arg, ind in args + if arg is not None and ind is not None + for i, d in enumerate(ind) + } + chunk_map["row"] = time_index.chunks[0] # Override + + chunks = tuple(chunk_map[d] for d in ("row", "chan") + cdims) return da.Array(graph, lr.name, chunks, dtype=out_dtype) -def _predict_coh_wrapper(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - base_vis, - reduce_single_source=False): - +def _predict_coh_wrapper( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + base_vis, + reduce_single_source=False, +): if reduce_single_source: # All these arrays contract over a single 'source' chunk dde1_jones = dde1_jones[0] if dde1_jones else None source_coh = source_coh[0] if source_coh else None dde2_jones = dde2_jones[0] if dde2_jones else None - vis = np_predict_vis(time_index, antenna1, antenna2, - # dde1_jones contracts over a single 'ant' chunk - dde1_jones[0] if dde1_jones else None, - source_coh, - # dde2_jones contracts over a single 'ant' chunk - dde2_jones[0] if dde2_jones else None, - None, - base_vis, - None) + vis = np_predict_vis( + time_index, + antenna1, + antenna2, + # dde1_jones contracts over a single 'ant' chunk + dde1_jones[0] if dde1_jones else None, + source_coh, + # dde2_jones contracts over a single 'ant' chunk + dde2_jones[0] if dde2_jones else None, + None, + base_vis, + None, + ) if reduce_single_source: return vis @@ -249,26 +290,43 @@ def _predict_coh_wrapper(time_index, antenna1, antenna2, return vis[None, ...] -def _predict_dies_wrapper(time_index, antenna1, antenna2, - die1_jones, base_vis, die2_jones): - - return np_predict_vis(time_index, antenna1, antenna2, - None, - None, - None, - # die1_jones loses the 'ant' dim - die1_jones[0] if die1_jones else None, - base_vis, - # die2_jones loses the 'ant' dim - die2_jones[0] if die2_jones else None) - - -def parallel_reduction(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - predict_check_tup, out_dtype): - """ Does a standard dask tree reduction over source coherencies """ - (have_ddes1, have_coh, have_ddes2, - have_dies1, have_bvis, have_dies2) = predict_check_tup +def _predict_dies_wrapper( + time_index, antenna1, antenna2, die1_jones, base_vis, die2_jones +): + return np_predict_vis( + time_index, + antenna1, + antenna2, + None, + None, + None, + # die1_jones loses the 'ant' dim + die1_jones[0] if die1_jones else None, + base_vis, + # die2_jones loses the 'ant' dim + die2_jones[0] if die2_jones else None, + ) + + +def parallel_reduction( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + predict_check_tup, + out_dtype, +): + """Does a standard dask tree reduction over source coherencies""" + ( + have_ddes1, + have_coh, + have_ddes2, + have_dies1, + have_bvis, + have_dies2, + ) = predict_check_tup have_ddes = have_ddes1 and have_ddes2 @@ -283,33 +341,55 @@ def parallel_reduction(time_index, antenna1, antenna2, src_coh_dims = ("src", "row", "chan") + cdims coherencies = da.blockwise( - _predict_coh_wrapper, src_coh_dims, - time_index, ("row",), - antenna1, ("row",), - antenna2, ("row",), - dde1_jones, None if dde1_jones is None else ajones_dims, - source_coh, None if source_coh is None else src_coh_dims, - dde2_jones, None if dde2_jones is None else ajones_dims, - None, None, + _predict_coh_wrapper, + src_coh_dims, + time_index, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + dde1_jones, + None if dde1_jones is None else ajones_dims, + source_coh, + None if source_coh is None else src_coh_dims, + dde2_jones, + None if dde2_jones is None else ajones_dims, + None, + None, # time+row dimension chunks are equivalent but differently sized align_arrays=False, # Force row dimension to take row chunking scheme, # instead of time chunking scheme - adjust_chunks={'row': time_index.chunks[0]}, - meta=np.empty((0,)*len(src_coh_dims), dtype=out_dtype), - dtype=out_dtype) + adjust_chunks={"row": time_index.chunks[0]}, + meta=np.empty((0,) * len(src_coh_dims), dtype=out_dtype), + dtype=out_dtype, + ) return coherencies.sum(axis=0) -def apply_dies(time_index, antenna1, antenna2, - die1_jones, base_vis, die2_jones, - predict_check_tup, out_dtype): - """ Apply any Direction-Independent Effects and Base Visibilities """ +def apply_dies( + time_index, + antenna1, + antenna2, + die1_jones, + base_vis, + die2_jones, + predict_check_tup, + out_dtype, +): + """Apply any Direction-Independent Effects and Base Visibilities""" # Now apply any Direction Independent Effect Terms - (have_ddes1, have_coh, have_ddes2, - have_dies1, have_bvis, have_dies2) = predict_check_tup + ( + have_ddes1, + have_coh, + have_ddes2, + have_dies1, + have_bvis, + have_dies2, + ) = predict_check_tup have_dies = have_dies1 and have_dies2 @@ -335,78 +415,119 @@ def apply_dies(time_index, antenna1, antenna2, vis_dims = ("row", "chan") + cdims return da.blockwise( - _predict_dies_wrapper, vis_dims, - time_index, ("row",), - antenna1, ("row",), - antenna2, ("row",), - die1_jones, None if die1_jones is None else gjones_dims, - base_vis, None if base_vis is None else vis_dims, - die2_jones, None if die2_jones is None else gjones_dims, + _predict_dies_wrapper, + vis_dims, + time_index, + ("row",), + antenna1, + ("row",), + antenna2, + ("row",), + die1_jones, + None if die1_jones is None else gjones_dims, + base_vis, + None if base_vis is None else vis_dims, + die2_jones, + None if die2_jones is None else gjones_dims, # time+row dimension chunks are equivalent but differently sized align_arrays=False, # Force row dimension to take row chunking scheme, # instead of time chunking scheme - adjust_chunks={'row': time_index.chunks[0]}, - meta=np.empty((0,)*len(vis_dims), dtype=out_dtype), - dtype=out_dtype) - - -@requires_optional('dask.array', opt_import_error) -def predict_vis(time_index, antenna1, antenna2, - dde1_jones=None, source_coh=None, dde2_jones=None, - die1_jones=None, base_vis=None, die2_jones=None, - streams=None): - - predict_check_tup = predict_checks(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones) - - (have_ddes1, have_coh, have_ddes2, - have_dies1, have_bvis, have_dies2) = predict_check_tup + adjust_chunks={"row": time_index.chunks[0]}, + meta=np.empty((0,) * len(vis_dims), dtype=out_dtype), + dtype=out_dtype, + ) + + +@requires_optional("dask.array", opt_import_error) +def predict_vis( + time_index, + antenna1, + antenna2, + dde1_jones=None, + source_coh=None, + dde2_jones=None, + die1_jones=None, + base_vis=None, + die2_jones=None, + streams=None, +): + predict_check_tup = predict_checks( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ) + + ( + have_ddes1, + have_coh, + have_ddes2, + have_dies1, + have_bvis, + have_dies2, + ) = predict_check_tup have_ddes = have_ddes1 and have_ddes2 if have_ddes: if dde1_jones.shape[2] != dde1_jones.chunks[2][0]: - raise ValueError("Subdivision of antenna dimension into " - "multiple chunks is not supported.") + raise ValueError( + "Subdivision of antenna dimension into " + "multiple chunks is not supported." + ) if dde2_jones.shape[2] != dde2_jones.chunks[2][0]: - raise ValueError("Subdivision of antenna dimension into " - "multiple chunks is not supported.") + raise ValueError( + "Subdivision of antenna dimension into " + "multiple chunks is not supported." + ) if dde1_jones.chunks != dde2_jones.chunks: raise ValueError("dde1_jones.chunks != dde2_jones.chunks") if len(dde1_jones.chunks[1]) != len(time_index.chunks[0]): - raise ValueError("Number of row chunks (%s) does not equal " - "number of time chunks (%s)." % - (time_index.chunks[0], dde1_jones.chunks[1])) + raise ValueError( + "Number of row chunks (%s) does not equal " + "number of time chunks (%s)." + % (time_index.chunks[0], dde1_jones.chunks[1]) + ) have_dies = have_dies1 and have_dies2 if have_dies: if die1_jones.shape[1] != die1_jones.chunks[1][0]: - raise ValueError("Subdivision of antenna dimension into " - "multiple chunks is not supported.") + raise ValueError( + "Subdivision of antenna dimension into " + "multiple chunks is not supported." + ) if die2_jones.shape[1] != die2_jones.chunks[1][0]: - raise ValueError("Subdivision of antenna dimension into " - "multiple chunks is not supported.") + raise ValueError( + "Subdivision of antenna dimension into " + "multiple chunks is not supported." + ) if die1_jones.chunks != die2_jones.chunks: raise ValueError("die1_jones.chunks != die2_jones.chunks") if len(die1_jones.chunks[0]) != len(time_index.chunks[0]): - raise ValueError("Number of row chunks (%s) does not equal " - "number of time chunks (%s)." % - (time_index.chunks[0], die1_jones.chunks[1])) + raise ValueError( + "Number of row chunks (%s) does not equal " + "number of time chunks (%s)." + % (time_index.chunks[0], die1_jones.chunks[1]) + ) # Infer the output dtype dtype_arrays = [dde1_jones, source_coh, dde2_jones, die1_jones, die2_jones] - out_dtype = np.result_type(*(np.dtype(a.dtype.name) - for a in dtype_arrays - if a is not None)) + out_dtype = np.result_type( + *(np.dtype(a.dtype.name) for a in dtype_arrays if a is not None) + ) # Apply direction dependent effects if have_coh or have_ddes: @@ -414,23 +535,27 @@ def predict_vis(time_index, antenna1, antenna2, # the gains because coherencies are chunked over source which # must be summed and added to the (possibly present) base visibilities if streams is True: - sum_coherencies = linear_reduction(time_index, - antenna1, - antenna2, - dde1_jones, - source_coh, - dde2_jones, - predict_check_tup, - out_dtype) + sum_coherencies = linear_reduction( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + predict_check_tup, + out_dtype, + ) else: - sum_coherencies = parallel_reduction(time_index, - antenna1, - antenna2, - dde1_jones, - source_coh, - dde2_jones, - predict_check_tup, - out_dtype) + sum_coherencies = parallel_reduction( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + predict_check_tup, + out_dtype, + ) else: assert have_dies or have_bvis sum_coherencies = None @@ -444,57 +569,91 @@ def predict_vis(time_index, antenna1, antenna2, if not have_bvis: # Set base_vis = summed coherencies base_vis = sum_coherencies - predict_check_tup = (have_ddes1, have_coh, have_ddes2, - have_dies1, True, have_dies2) + predict_check_tup = ( + have_ddes1, + have_coh, + have_ddes2, + have_dies1, + True, + have_dies2, + ) else: base_vis += sum_coherencies # Apply direction independent effects - return apply_dies(time_index, antenna1, antenna2, - die1_jones, base_vis, die2_jones, - predict_check_tup, out_dtype) + return apply_dies( + time_index, + antenna1, + antenna2, + die1_jones, + base_vis, + die2_jones, + predict_check_tup, + out_dtype, + ) def wsclean_spectrum_wrapper(flux, coeffs, log_poly, ref_freq, frequency): return wsclean_spectra(flux, coeffs[0], log_poly, ref_freq, frequency) -def wsclean_body_wrapper(uvw, lm, source_type, gauss_shape, - frequency, spectrum, dtype_): - return wsclean_predict_body(uvw[0], lm[0], source_type, - gauss_shape[0], frequency, spectrum, - dtype_)[None, :] - - -@requires_optional('dask.array', opt_import_error) -def wsclean_predict(uvw, lm, source_type, flux, coeffs, - log_poly, ref_freq, gauss_shape, frequency): - spectrum_dtype = np.result_type(*(a.dtype for a in (flux, coeffs, - log_poly, ref_freq, - frequency))) - - spectrum = da.blockwise(wsclean_spectrum_wrapper, ("source", "chan"), - flux, ("source",), - coeffs, ("source", "comp"), - log_poly, ("source",), - ref_freq, ("source",), - frequency, ("chan",), - dtype=spectrum_dtype) - - out_dtype = np.result_type(uvw.dtype, lm.dtype, frequency.dtype, - spectrum.dtype, np.complex64) - - vis = da.blockwise(wsclean_body_wrapper, ("source", "row", "chan", "corr"), - uvw, ("row", "uvw"), - lm, ("source", "lm"), - source_type, ("source",), - gauss_shape, ("source", "gauss"), - frequency, ("chan",), - spectrum, ("source", "chan"), - out_dtype, None, - adjust_chunks={"source": 1}, - new_axes={"corr": 1}, - dtype=out_dtype) +def wsclean_body_wrapper( + uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype_ +): + return wsclean_predict_body( + uvw[0], lm[0], source_type, gauss_shape[0], frequency, spectrum, dtype_ + )[None, :] + + +@requires_optional("dask.array", opt_import_error) +def wsclean_predict( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency +): + spectrum_dtype = np.result_type( + *(a.dtype for a in (flux, coeffs, log_poly, ref_freq, frequency)) + ) + + spectrum = da.blockwise( + wsclean_spectrum_wrapper, + ("source", "chan"), + flux, + ("source",), + coeffs, + ("source", "comp"), + log_poly, + ("source",), + ref_freq, + ("source",), + frequency, + ("chan",), + dtype=spectrum_dtype, + ) + + out_dtype = np.result_type( + uvw.dtype, lm.dtype, frequency.dtype, spectrum.dtype, np.complex64 + ) + + vis = da.blockwise( + wsclean_body_wrapper, + ("source", "row", "chan", "corr"), + uvw, + ("row", "uvw"), + lm, + ("source", "lm"), + source_type, + ("source",), + gauss_shape, + ("source", "gauss"), + frequency, + ("chan",), + spectrum, + ("source", "chan"), + out_dtype, + None, + adjust_chunks={"source": 1}, + new_axes={"corr": 1}, + dtype=out_dtype, + ) return vis.sum(axis=0) @@ -568,14 +727,16 @@ def wsclean_predict(uvw, lm, source_type, flux, coeffs, try: predict_vis.__doc__ = PREDICT_DOCS.substitute( - array_type=":class:`dask.array.Array`", - get_time_index=":code:`time.map_blocks(" - "lambda a: np.unique(a, " - "return_inverse=True)[1])`", - extra_args=EXTRA_DASK_ARGS, - extra_notes=EXTRA_DASK_NOTES) + array_type=":class:`dask.array.Array`", + get_time_index=":code:`time.map_blocks(" + "lambda a: np.unique(a, " + "return_inverse=True)[1])`", + extra_args=EXTRA_DASK_ARGS, + extra_notes=EXTRA_DASK_NOTES, + ) except AttributeError: pass wsclean_predict.__doc__ = WSCLEAN_PREDICT_DOCS.substitute( - array_type=":class:`dask.array.Array`") + array_type=":class:`dask.array.Array`" +) diff --git a/africanus/rime/examples/predict.py b/africanus/rime/examples/predict.py index 7d2592660..e9ed01e60 100644 --- a/africanus/rime/examples/predict.py +++ b/africanus/rime/examples/predict.py @@ -24,20 +24,25 @@ from africanus.util.beams import beam_filenames, beam_grids from africanus.coordinates.dask import radec_to_lm -from africanus.rime.dask import (phase_delay, predict_vis, parallactic_angles, - beam_cube_dde, feed_rotation) +from africanus.rime.dask import ( + phase_delay, + predict_vis, + parallactic_angles, + beam_cube_dde, + feed_rotation, +) from africanus.model.coherency.dask import convert from africanus.model.spectral.dask import spectral_model from africanus.model.shape.dask import gaussian as gaussian_shape from africanus.util.requirements import requires_optional -_einsum_corr_indices = 'ijkl' +_einsum_corr_indices = "ijkl" def _brightness_schema(corrs, index): if corrs == 4: - return "sf" + _einsum_corr_indices[index:index + 2], index + 1 + return "sf" + _einsum_corr_indices[index : index + 2], index + 1 else: return "sfi", index @@ -62,10 +67,10 @@ def _bl_jones_output_schema(corrs, index): _rime_term_map = { - 'brightness': _brightness_schema, - 'phase_delay': _phase_delay_schema, - 'spi': _spi_schema, - 'gauss_shape': _gauss_shape_schema, + "brightness": _brightness_schema, + "phase_delay": _phase_delay_schema, + "spi": _spi_schema, + "gauss_shape": _gauss_shape_schema, } @@ -87,12 +92,14 @@ def corr_schema(pol): corr_types = pol.CORR_TYPE.data[0] if corrs == 4: - return [[corr_types[0], corr_types[1]], - [corr_types[2], corr_types[3]]] # (2, 2) shape + return [ + [corr_types[0], corr_types[1]], + [corr_types[2], corr_types[3]], + ] # (2, 2) shape elif corrs == 2: - return [corr_types[0], corr_types[1]] # (2, ) shape + return [corr_types[0], corr_types[1]] # (2, ) shape elif corrs == 1: - return [corr_types[0]] # (1, ) shape + return [corr_types[0]] # (1, ) shape else: raise ValueError("corrs %d not in (1, 2, 4)" % corrs) @@ -116,9 +123,10 @@ def baseline_jones_multiply(corrs, *args): input_einsum_schemas.append(einsum_schema) if not len(einsum_schema) == array.ndim: - raise ValueError("%s len(%s) == %d != %s.ndim" - % (name, einsum_schema, - len(einsum_schema), array.shape)) + raise ValueError( + "%s len(%s) == %d != %s.ndim" + % (name, einsum_schema, len(einsum_schema), array.shape) + ) output_schema = _bl_jones_output_schema(corrs, corr_index) schema = ",".join(input_einsum_schemas) + output_schema @@ -135,17 +143,20 @@ def create_parser(): p.add_argument("-b", "--beam", default=None) p.add_argument("-l", "--l-axis", default="L") p.add_argument("-m", "--m-axis", default="M") - p.add_argument("-iuvw", "--invert-uvw", action="store_true", - help="Invert UVW coordinates. Useful if we want " - "compare our visibilities against MeqTrees") + p.add_argument( + "-iuvw", + "--invert-uvw", + action="store_true", + help="Invert UVW coordinates. Useful if we want " + "compare our visibilities against MeqTrees", + ) return p @lru_cache(maxsize=16) def load_beams(beam_file_schema, corr_types, l_axis, m_axis): - class FITSFile(object): - """ Exists so that fits file is closed when last ref is gc'd """ + """Exists so that fits file is closed when last ref is gc'd""" def __init__(self, filename): self.hdul = hdul = fits.open(filename) @@ -177,12 +188,16 @@ def __init__(self, filename): raise ValueError("BEAM FITS Header Files differ") # Map FITS header type to NumPy type - BITPIX_MAP = {8: np.dtype('uint8').type, 16: np.dtype('int16').type, - 32: np.dtype('int32').type, -32: np.dtype('float32').type, - -64: np.dtype('float64').type} + BITPIX_MAP = { + 8: np.dtype("uint8").type, + 16: np.dtype("int16").type, + 32: np.dtype("int32").type, + -32: np.dtype("float32").type, + -64: np.dtype("float64").type, + } header = flat_headers[0] - bitpix = header['BITPIX'] + bitpix = header["BITPIX"] try: dtype = BITPIX_MAP[bitpix] @@ -191,13 +206,14 @@ def __init__(self, filename): else: dtype = np.result_type(dtype, np.complex64) - if not header['NAXIS'] == 3: - raise ValueError("FITS must have exactly three axes. " - "L or X, M or Y and FREQ. NAXIS != 3") + if not header["NAXIS"] == 3: + raise ValueError( + "FITS must have exactly three axes. " "L or X, M or Y and FREQ. NAXIS != 3" + ) - (l_ax, l_grid), (m_ax, m_grid), (nu_ax, nu_grid) = beam_grids(header, - l_axis, - m_axis) + (l_ax, l_grid), (m_ax, m_grid), (nu_ax, nu_grid) = beam_grids( + header, l_axis, m_axis + ) # Shape of each correlation shape = (l_grid.shape[0], m_grid.shape[0], nu_grid.shape[0]) @@ -207,16 +223,15 @@ def __init__(self, filename): def _load_correlation(re, im, ax): # Read real and imaginary for each correlation - return (re.hdul[0].data.transpose(ax) + - im.hdul[0].data.transpose(ax)*1j) + return re.hdul[0].data.transpose(ax) + im.hdul[0].data.transpose(ax) * 1j # Create delayed loads of the beam beam_loader = dask.delayed(_load_correlation) - beam_corrs = [beam_loader(re, im, ax) - for c, (corr, (re, im)) in enumerate(beam_files)] - beam_corrs = [da.from_delayed(bc, shape=shape, dtype=dtype) - for bc in beam_corrs] + beam_corrs = [ + beam_loader(re, im, ax) for c, (corr, (re, im)) in enumerate(beam_files) + ] + beam_corrs = [da.from_delayed(bc, shape=shape, dtype=dtype) for bc in beam_corrs] # Stack correlations and rechunk to one great big block beam = da.stack(beam_corrs, axis=3) @@ -272,8 +287,7 @@ def parse_sky_model(filename, chunks): U = source.flux.U V = source.flux.V - spectrum = (getattr(source, "spectrum", _empty_spectrum) - or _empty_spectrum) + spectrum = getattr(source, "spectrum", _empty_spectrum) or _empty_spectrum try: # Extract reference frequency @@ -284,7 +298,7 @@ def parse_sky_model(filename, chunks): try: # Extract SPI for I. # Zero Q, U and V to get 1 on the exponential - spi = [[spectrum.spi]*4] + spi = [[spectrum.spi] * 4] except AttributeError: # Default I SPI to -0.7 spi = [[0, 0, 0, 0]] @@ -309,25 +323,26 @@ def parse_sky_model(filename, chunks): raise ValueError("Unknown source morphology %s" % typecode) Point = namedtuple("Point", ["radec", "stokes", "spi", "ref_freq"]) - Gauss = namedtuple("Gauss", ["radec", "stokes", "spi", "ref_freq", - "shape"]) + Gauss = namedtuple("Gauss", ["radec", "stokes", "spi", "ref_freq", "shape"]) source_data = {} if len(point_radec) > 0: - source_data['point'] = Point( - da.from_array(point_radec, chunks=(chunks, -1)), - da.from_array(point_stokes, chunks=(chunks, -1)), - da.from_array(point_spi, chunks=(chunks, 1, -1)), - da.from_array(point_ref_freq, chunks=chunks)) + source_data["point"] = Point( + da.from_array(point_radec, chunks=(chunks, -1)), + da.from_array(point_stokes, chunks=(chunks, -1)), + da.from_array(point_spi, chunks=(chunks, 1, -1)), + da.from_array(point_ref_freq, chunks=chunks), + ) if len(gauss_radec) > 0: - source_data['gauss'] = Gauss( - da.from_array(gauss_radec, chunks=(chunks, -1)), - da.from_array(gauss_stokes, chunks=(chunks, -1)), - da.from_array(gauss_spi, chunks=(chunks, 1, -1)), - da.from_array(gauss_ref_freq, chunks=chunks), - da.from_array(gauss_shape, chunks=(chunks, -1))) + source_data["gauss"] = Gauss( + da.from_array(gauss_radec, chunks=(chunks, -1)), + da.from_array(gauss_stokes, chunks=(chunks, -1)), + da.from_array(gauss_spi, chunks=(chunks, 1, -1)), + da.from_array(gauss_ref_freq, chunks=chunks), + da.from_array(gauss_shape, chunks=(chunks, -1)), + ) return source_data @@ -345,9 +360,16 @@ def support_tables(args): {name: dataset} """ - n = {k: '::'.join((args.ms, k)) for k - in ("ANTENNA", "DATA_DESCRIPTION", "FIELD", - "SPECTRAL_WINDOW", "POLARIZATION")} + n = { + k: "::".join((args.ms, k)) + for k in ( + "ANTENNA", + "DATA_DESCRIPTION", + "FIELD", + "SPECTRAL_WINDOW", + "POLARIZATION", + ) + } # All rows at once lazy_tables = {"ANTENNA": xds_from_table(n["ANTENNA"])} @@ -356,12 +378,9 @@ def support_tables(args): # Fixed shape rows "DATA_DESCRIPTION": xds_from_table(n["DATA_DESCRIPTION"]), # Variably shaped, need a dataset per row - "FIELD": xds_from_table(n["FIELD"], - group_cols="__row__"), - "SPECTRAL_WINDOW": xds_from_table(n["SPECTRAL_WINDOW"], - group_cols="__row__"), - "POLARIZATION": xds_from_table(n["POLARIZATION"], - group_cols="__row__"), + "FIELD": xds_from_table(n["FIELD"], group_cols="__row__"), + "SPECTRAL_WINDOW": xds_from_table(n["SPECTRAL_WINDOW"], group_cols="__row__"), + "POLARIZATION": xds_from_table(n["POLARIZATION"], group_cols="__row__"), } lazy_tables.update(dask.compute(compute_tables)[0]) @@ -369,14 +388,14 @@ def support_tables(args): def _zero_pes(parangles, frequency, dtype_): - """ Create zeroed pointing errors """ + """Create zeroed pointing errors""" ntime, na = parangles.shape nchan = frequency.shape[0] return np.zeros((ntime, na, nchan, 2), dtype=dtype_) def _unity_ant_scales(parangles, frequency, dtype_): - """ Create zeroed antenna scalings """ + """Create zeroed antenna scalings""" _, na = parangles[0].shape nchan = frequency.shape[0] return np.ones((na, nchan, 2), dtype=dtype_) @@ -392,20 +411,21 @@ def dde_factory(args, ms, ant, field, pol, lm, utime, frequency): if not len(corr_type) == 4: raise ValueError("Need four correlations for DDEs") - parangles = parallactic_angles(utime, ant.POSITION.data, - field.PHASE_DIR.data[0][0]) + parangles = parallactic_angles(utime, ant.POSITION.data, field.PHASE_DIR.data[0][0]) corr_type_set = set(corr_type) if corr_type_set.issubset(set([9, 10, 11, 12])): - pol_type = 'linear' + pol_type = "linear" elif corr_type_set.issubset(set([5, 6, 7, 8])): - pol_type = 'circular' + pol_type = "circular" else: - raise ValueError("Cannot determine polarisation type " - "from correlations %s. Constructing " - "a feed rotation matrix will not be " - "possible." % (corr_type,)) + raise ValueError( + "Cannot determine polarisation type " + "from correlations %s. Constructing " + "a feed rotation matrix will not be " + "possible." % (corr_type,) + ) # Construct feed rotation feed_rot = feed_rotation(parangles, pol_type) @@ -413,38 +433,46 @@ def dde_factory(args, ms, ant, field, pol, lm, utime, frequency): dtype = np.result_type(parangles, frequency) # Create zeroed pointing errors - zpe = da.blockwise(_zero_pes, ("time", "ant", "chan", "comp"), - parangles, ("time", "ant"), - frequency, ("chan",), - dtype, None, - new_axes={"comp": 2}, - dtype=dtype) + zpe = da.blockwise( + _zero_pes, + ("time", "ant", "chan", "comp"), + parangles, + ("time", "ant"), + frequency, + ("chan",), + dtype, + None, + new_axes={"comp": 2}, + dtype=dtype, + ) # Created zeroed antenna scaling factors - zas = da.blockwise(_unity_ant_scales, ("ant", "chan", "comp"), - parangles, ("time", "ant"), - frequency, ("chan",), - dtype, None, - new_axes={"comp": 2}, - dtype=dtype) + zas = da.blockwise( + _unity_ant_scales, + ("ant", "chan", "comp"), + parangles, + ("time", "ant"), + frequency, + ("chan",), + dtype, + None, + new_axes={"comp": 2}, + dtype=dtype, + ) # Load the beam information - beam, lm_ext, freq_map = load_beams(args.beam, corr_type, - args.l_axis, args.m_axis) + beam, lm_ext, freq_map = load_beams(args.beam, corr_type, args.l_axis, args.m_axis) # Introduce the correlation axis beam = beam.reshape(beam.shape[:3] + (2, 2)) - beam_dde = beam_cube_dde(beam, lm_ext, freq_map, lm, parangles, - zpe, zas, - frequency) + beam_dde = beam_cube_dde(beam, lm_ext, freq_map, lm, parangles, zpe, zas, frequency) # Multiply the beam by the feed rotation to form the DDE term return da.einsum("stafij,tajk->stafik", beam_dde, feed_rot) -def vis_factory(args, source_type, sky_model, - ms, ant, field, spw, pol): +def vis_factory(args, source_type, sky_model, ms, ant, field, spw, pol): try: source = sky_model[source_type] except KeyError: @@ -463,14 +491,11 @@ def vis_factory(args, source_type, sky_model, # (source, spi, corrs) # Apply spectral mode to stokes parameters - stokes = spectral_model(source.stokes, - source.spi, - source.ref_freq, - frequency, - base=0) + stokes = spectral_model( + source.stokes, source.spi, source.ref_freq, frequency, base=0 + ) - brightness = convert(stokes, ["I", "Q", "U", "V"], - corr_schema(pol)) + brightness = convert(stokes, ["I", "Q", "U", "V"], corr_schema(pol)) bl_jones_args = ["phase_delay", phase] @@ -484,26 +509,25 @@ def vis_factory(args, source_type, sky_model, # Unique times and time index for each row chunk # The index is not global meta = np.empty((0,), dtype=tuple) - utime_inv = ms.TIME.data.map_blocks(np.unique, return_inverse=True, - meta=meta, dtype=tuple) + utime_inv = ms.TIME.data.map_blocks( + np.unique, return_inverse=True, meta=meta, dtype=tuple + ) # Need unique times for parallactic angles nan_chunks = (tuple(np.nan for _ in utime_inv.chunks[0]),) - utime = utime_inv.map_blocks(getitem, 0, - chunks=nan_chunks, - dtype=ms.TIME.dtype) + utime = utime_inv.map_blocks(getitem, 0, chunks=nan_chunks, dtype=ms.TIME.dtype) time_idx = utime_inv.map_blocks(getitem, 1, dtype=np.int32) jones = baseline_jones_multiply(corrs, *bl_jones_args) dde = dde_factory(args, ms, ant, field, pol, lm, utime, frequency) - return predict_vis(time_idx, ms.ANTENNA1.data, ms.ANTENNA2.data, - dde, jones, dde, None, None, None) + return predict_vis( + time_idx, ms.ANTENNA1.data, ms.ANTENNA2.data, dde, jones, dde, None, None, None + ) -@requires_optional("dask.array", "Tigger", - "daskms", opt_import_error) +@requires_optional("dask.array", "Tigger", "daskms", opt_import_error) def predict(args): # Convert source data into dask arrays sky_model = parse_sky_model(args.sky_model, args.model_chunks) @@ -521,15 +545,16 @@ def predict(args): writes = [] # Construct a graph for each DATA_DESC_ID - for xds in xds_from_ms(args.ms, - columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"], - group_cols=["FIELD_ID", "DATA_DESC_ID"], - chunks={"row": args.row_chunks}): - + for xds in xds_from_ms( + args.ms, + columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"], + group_cols=["FIELD_ID", "DATA_DESC_ID"], + chunks={"row": args.row_chunks}, + ): # Perform subtable joins ant = ant_ds[0] - field = field_ds[xds.attrs['FIELD_ID']] - ddid = ddid_ds[xds.attrs['DATA_DESC_ID']] + field = field_ds[xds.attrs["FIELD_ID"]] + ddid = ddid_ds[xds.attrs["DATA_DESC_ID"]] spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol = pol_ds[ddid.POLARIZATION_ID.data[0]] @@ -537,9 +562,10 @@ def predict(args): corrs = pol.NUM_CORR.data[0] # Generate visibility expressions for each source type - source_vis = [vis_factory(args, stype, sky_model, - xds, ant, field, spw, pol) - for stype in sky_model.keys()] + source_vis = [ + vis_factory(args, stype, sky_model, xds, ant, field, spw, pol) + for stype in sky_model.keys() + ] # Sum visibilities together vis = sum(source_vis) @@ -551,7 +577,7 @@ def predict(args): # Assign visibilities to MODEL_DATA array on the dataset xds = xds.assign(MODEL_DATA=(("row", "chan", "corr"), vis)) # Create a write to the table - write = xds_to_table(xds, args.ms, ['MODEL_DATA']) + write = xds_to_table(xds, args.ms, ["MODEL_DATA"]) # Add to the list of writes writes.append(write) diff --git a/africanus/rime/examples/predict_shapelet.py b/africanus/rime/examples/predict_shapelet.py index b6a157310..d005d67d6 100644 --- a/africanus/rime/examples/predict_shapelet.py +++ b/africanus/rime/examples/predict_shapelet.py @@ -39,7 +39,7 @@ def _brightness_schema(corrs, index): if corrs == 4: - return "sf" + _einsum_corr_indices[index: index + 2], index + 1 + return "sf" + _einsum_corr_indices[index : index + 2], index + 1 else: return "sfi", index @@ -204,9 +204,7 @@ def parse_sky_model(filename, chunks): U = source.flux.U V = source.flux.V - spectrum = ( - getattr(source, "spectrum", _empty_spectrum) or _empty_spectrum - ) + spectrum = getattr(source, "spectrum", _empty_spectrum) or _empty_spectrum try: # Extract reference frequency ref_freq = spectrum.freq0 @@ -253,9 +251,7 @@ def parse_sky_model(filename, chunks): raise ValueError("Unknown source morphology %s" % typecode) Point = namedtuple("Point", ["radec", "stokes", "spi", "ref_freq"]) - Gauss = namedtuple( - "Gauss", ["radec", "stokes", "spi", "ref_freq", "shape"] - ) + Gauss = namedtuple("Gauss", ["radec", "stokes", "spi", "ref_freq", "shape"]) Shapelet = namedtuple( "Shapelet", ["radec", "stokes", "spi", "ref_freq", "beta", "coeffs"] ) @@ -322,12 +318,8 @@ def support_tables(args): "DATA_DESCRIPTION": xds_from_table(n["DATA_DESCRIPTION"]), # Variably shaped, need a dataset per row "FIELD": xds_from_table(n["FIELD"], group_cols="__row__"), - "SPECTRAL_WINDOW": xds_from_table( - n["SPECTRAL_WINDOW"], group_cols="__row__" - ), - "POLARIZATION": xds_from_table( - n["POLARIZATION"], group_cols="__row__" - ), + "SPECTRAL_WINDOW": xds_from_table(n["SPECTRAL_WINDOW"], group_cols="__row__"), + "POLARIZATION": xds_from_table(n["POLARIZATION"], group_cols="__row__"), } lazy_tables.update(dask.compute(compute_tables)[0]) @@ -335,23 +327,21 @@ def support_tables(args): def _zero_pes(parangles, frequency, dtype_): - """ Create zeroed pointing errors """ + """Create zeroed pointing errors""" ntime, na = parangles.shape nchan = frequency.shape[0] return np.zeros((ntime, na, nchan, 2), dtype=dtype_) def _unity_ant_scales(parangles, frequency, dtype_): - """ Create zeroed antenna scalings """ + """Create zeroed antenna scalings""" _, na = parangles[0].shape nchan = frequency.shape[0] return np.ones((na, nchan, 2), dtype=dtype_) -def zernike_factory( - args, ms, ant, field, pol, lm, utime, frequency, jon, nrow=None -): - """ Generate a primary beam DDE using Zernike polynomials """ +def zernike_factory(args, ms, ant, field, pol, lm, utime, frequency, jon, nrow=None): + """Generate a primary beam DDE using Zernike polynomials""" if not args.zernike: return None @@ -382,9 +372,7 @@ def zernike_factory( pointing_errors = da.from_array( np.zeros((ntime, na, nchan, 2)), chunks=(time_chunks, na, nchan, 2) ) - antenna_scaling = da.from_array( - np.ones((na, nchan, 2)), chunks=(na, nchan, 2) - ) + antenna_scaling = da.from_array(np.ones((na, nchan, 2)), chunks=(na, nchan, 2)) parangles = da.from_array( parallactic_angles( np.array(utime)[:ntime], @@ -419,9 +407,7 @@ def zernike_factory( # Call Zernike_dde dde_r = zernike_dde( - da.from_array( - zernike_coords, chunks=(3, nsrc, time_chunks, na, nchan) - ), + da.from_array(zernike_coords, chunks=(3, nsrc, time_chunks, na, nchan)), da.from_array(coeffs_r, chunks=coeffs_r.shape), da.from_array(noll_index_r, chunks=noll_index_r.shape), parangles, @@ -430,9 +416,7 @@ def zernike_factory( pointing_errors, ) dde_i = zernike_dde( - da.from_array( - zernike_coords, chunks=(3, nsrc, time_chunks, na, nchan) - ), + da.from_array(zernike_coords, chunks=(3, nsrc, time_chunks, na, nchan)), da.from_array(coeffs_i, chunks=coeffs_i.shape), da.from_array(noll_index_i, chunks=noll_index_i.shape), parangles, @@ -489,9 +473,7 @@ def vis_factory(args, source_type, sky_model, ms, ant, field, spw, pol): # Need unique times for parallactic angles nan_chunks = (tuple(np.nan for _ in utime_inv.chunks[0]),) - utime = utime_inv.map_blocks( - getitem, 0, chunks=nan_chunks, dtype=ms.TIME.dtype - ) + utime = utime_inv.map_blocks(getitem, 0, chunks=nan_chunks, dtype=ms.TIME.dtype) time_idx = utime_inv.map_blocks(getitem, 1, dtype=np.int32) @@ -572,7 +554,6 @@ def predict(args): group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={"row": args.row_chunks}, ): - # Perform subtable joins ant = ant_ds[0] field = field_ds[xds.attrs["FIELD_ID"]] diff --git a/africanus/rime/examples/tests/cmp_codex_vs_meq.py b/africanus/rime/examples/tests/cmp_codex_vs_meq.py index ccd2f7448..8299fcfe7 100644 --- a/africanus/rime/examples/tests/cmp_codex_vs_meq.py +++ b/africanus/rime/examples/tests/cmp_codex_vs_meq.py @@ -11,8 +11,7 @@ import numpy as np -from africanus.rime.examples.predict import (predict, - create_parser as predict_parser) +from africanus.rime.examples.predict import predict, create_parser as predict_parser from africanus.util.requirements import requires_optional from africanus.testing.beam_factory import beam_factory @@ -24,7 +23,7 @@ opt_import_error = None -@requires_optional('pyrap.tables', opt_import_error) +@requires_optional("pyrap.tables", opt_import_error) def inspect_polarisation_type(args): linear_corr_types = set([9, 10, 11, 12]) circular_corr_types = set([5, 6, 7, 8]) @@ -41,41 +40,44 @@ def inspect_polarisation_type(args): elif discovered_corr_types.issubset(circular_corr_types): return "circular" - raise ValueError("MS Correlation types are not wholly " - "linear or circular: %s" % discovered_corr_types) + raise ValueError( + "MS Correlation types are not wholly " + "linear or circular: %s" % discovered_corr_types + ) def cmp_script_factory(args, pol_type): beam_pattern = args.beam.replace("$", r"\$") - return ["python", - "cmp_codex_vs_meq.py", - args.ms, - "-sm " + args.sky_model, - '-b "' + beam_pattern + '"', - "--run-predict"] + return [ + "python", + "cmp_codex_vs_meq.py", + args.ms, + "-sm " + args.sky_model, + '-b "' + beam_pattern + '"', + "--run-predict", + ] def meqtrees_command_factory(args, pol_type): # Directory in which meqtree-related files are read/written - meq_dir = 'meqtrees' + meq_dir = "meqtrees" # Scripts - meqpipe = 'meqtree-pipeliner.py' + meqpipe = "meqtree-pipeliner.py" # Meqtree profile and script - cfg_file = pjoin(meq_dir, 'tdlconf.profiles') - sim_script = pjoin(meq_dir, 'turbo-sim.py') + cfg_file = pjoin(meq_dir, "tdlconf.profiles") + sim_script = pjoin(meq_dir, "turbo-sim.py") meqtrees_vis_column = "CORRECTED_DATA" - if pol_type == 'linear': - cfg_section = '-'.join(('codex', 'compare', 'linear')) - elif pol_type == 'circular': - cfg_section = '-'.join(('codex', 'compare', 'circular')) + if pol_type == "linear": + cfg_section = "-".join(("codex", "compare", "linear")) + elif pol_type == "circular": + cfg_section = "-".join(("codex", "compare", "circular")) else: - raise ValueError("pol_type %s not in ('circular', 'linear')" - % pol_type) + raise ValueError("pol_type %s not in ('circular', 'linear')" % pol_type) # $ is a special pattern is most shells, escape it beam_pattern = args.beam.replace("$", r"\$") @@ -86,34 +88,35 @@ def meqtrees_command_factory(args, pol_type): cmd_list = [ # Meqtree Pipeline script - '$(which {})'.format(meqpipe), + "$(which {})".format(meqpipe), # Configuration File - '-c', cfg_file, + "-c", + cfg_file, # Configuration section '"[{section}]"'.format(section=cfg_section), # Measurement Set - 'ms_sel.msname={ms}'.format(ms=args.ms), + "ms_sel.msname={ms}".format(ms=args.ms), # Tigger sky file - 'tiggerlsm.filename={sm}'.format(sm=args.sky_model), + "tiggerlsm.filename={sm}".format(sm=args.sky_model), # Output column - 'ms_sel.output_column={c}'.format(c=meqtrees_vis_column), + "ms_sel.output_column={c}".format(c=meqtrees_vis_column), # Enable the beam? - 'me.e_enable={e}'.format(e=1), + "me.e_enable={e}".format(e=1), # Enable feed rotation - 'me.l_enable={e}'.format(e=1), + "me.l_enable={e}".format(e=1), # Beam FITS file pattern 'pybeams_fits.filename_pattern="{p}"'.format(p=beam_pattern), # FITS L and M AXIS - 'pybeams_fits.l_axis={lax}'.format(lax=args.l_axis), - 'pybeams_fits.m_axis={max}'.format(max=args.m_axis), + "pybeams_fits.l_axis={lax}".format(lax=args.l_axis), + "pybeams_fits.m_axis={max}".format(max=args.m_axis), sim_script, - '=simulate' + "=simulate", ] return cmd_list -@requires_optional('pyrap.tables', opt_import_error) +@requires_optional("pyrap.tables", opt_import_error) def compare_columns(args, codex_column, meqtrees_column): with pt.table(args.ms) as T: codex_vis = T.getcol(codex_column) @@ -131,12 +134,15 @@ def compare_columns(args, codex_column, meqtrees_column): print("Codex Africanus visibilities agrees with MeqTrees") return True - bad_vis_file = 'bad_visibilities.txt' + bad_vis_file = "bad_visibilities.txt" # Some visibilities differ, do some analysis - print("Codex Africanus differs from MeqTrees by {nc}/{t} " - "visibilities. Writing them out to '{bvf}'" - .format(nc=problems[0].size, t=not_close.size, bvf=bad_vis_file)) + print( + "Codex Africanus differs from MeqTrees by {nc}/{t} " + "visibilities. Writing them out to '{bvf}'".format( + nc=problems[0].size, t=not_close.size, bvf=bad_vis_file + ) + ) mb_problems = codex_vis[problems] meq_problems = meqtrees_vis[problems] @@ -151,11 +157,14 @@ def compare_columns(args, codex_column, meqtrees_column): it = itertools.islice(it, 0, 1000, 1) # Write out the problematic visibilities to file - with open(bad_vis_file, 'w') as f: + with open(bad_vis_file, "w") as f: for i, (p, mb, meq, d, amp) in it: - f.write("{i} {t} Codex Africanus: {mb} MeqTrees: {meq} " - "Difference {d} Absolute Difference {ad} \n" - .format(i=i, t=p, mb=mb, meq=meq, d=d, ad=amp)) + f.write( + "{i} {t} Codex Africanus: {mb} MeqTrees: {meq} " + "Difference {d} Absolute Difference {ad} \n".format( + i=i, t=p, mb=mb, meq=meq, d=d, ad=amp + ) + ) return False @@ -164,8 +173,7 @@ def create_beams(schema, pol_type): td = tempfile.mkdtemp(prefix="beams-") path = Path(td, schema) - filenames = beam_factory(polarisation_type=pol_type, - schema=path, npix=257) + filenames = beam_factory(polarisation_type=pol_type, schema=path, npix=257) return path, filenames @@ -196,25 +204,25 @@ def compare(): nrow = min(row_chunk, nrows - r) exemplar = T.getcol("MODEL_DATA", startrow=r, nrow=nrow) - T.putcol("MODEL_DATA", np.zeros_like(exemplar), - startrow=r, nrow=nrow) - T.putcol("CORRECTED_DATA", np.zeros_like(exemplar), - startrow=r, nrow=nrow) + T.putcol("MODEL_DATA", np.zeros_like(exemplar), startrow=r, nrow=nrow) + T.putcol( + "CORRECTED_DATA", np.zeros_like(exemplar), startrow=r, nrow=nrow + ) pol_type = inspect_polarisation_type(args) - beam_path, filenames = create_beams("beams_$(corr)_$(reim).fits", - pol_type) + beam_path, filenames = create_beams("beams_$(corr)_$(reim).fits", pol_type) args.beam = str(beam_path) meq_cmd = " ".join(meqtrees_command_factory(args, pol_type)) cmp_cmd = " ".join(cmp_script_factory(args, pol_type)) - print("\nRUN THE FOLLOWING COMMAND IN A SEPARATE " - "MEQTREES ENVIRONMENT, PREFERABLY BUILT FROM SOURCE\n" - "https://github.com/ska-sa/meqtrees/wiki/BuildFromSource" - "\n\n%s\n\n\n" % meq_cmd) + print( + "\nRUN THE FOLLOWING COMMAND IN A SEPARATE " + "MEQTREES ENVIRONMENT, PREFERABLY BUILT FROM SOURCE\n" + "https://github.com/ska-sa/meqtrees/wiki/BuildFromSource" + "\n\n%s\n\n\n" % meq_cmd + ) - print("\nTHEN RUN THIS IN THE CURRENT ENVIRONMENT" - "\n\n%s\n\n\n" % cmp_cmd) + print("\nTHEN RUN THIS IN THE CURRENT ENVIRONMENT" "\n\n%s\n\n\n" % cmp_cmd) return True diff --git a/africanus/rime/fast_beam_cubes.py b/africanus/rime/fast_beam_cubes.py index e3ceb6b3b..881beb9d6 100644 --- a/africanus/rime/fast_beam_cubes.py +++ b/africanus/rime/fast_beam_cubes.py @@ -11,8 +11,7 @@ def freq_grid_interp(frequency, beam_freq_map): # Interpolated grid coordinate beam_nud = beam_freq_map.shape[0] - freq_data = np.empty((frequency.shape[0], 3), - dtype=frequency.dtype) + freq_data = np.empty((frequency.shape[0], 3), dtype=frequency.dtype) for f in range(frequency.shape[0]): freq = frequency[f] @@ -56,10 +55,16 @@ def freq_grid_interp(frequency, beam_freq_map): @njit(nogil=True, cache=True) -def beam_cube_dde(beam, beam_lm_extents, beam_freq_map, - lm, parallactic_angles, point_errors, antenna_scaling, - frequency): - +def beam_cube_dde( + beam, + beam_lm_extents, + beam_freq_map, + lm, + parallactic_angles, + point_errors, + antenna_scaling, + frequency, +): nsrc = lm.shape[0] ntime, nants = parallactic_angles.shape nchan = frequency.shape[0] @@ -70,7 +75,7 @@ def beam_cube_dde(beam, beam_lm_extents, beam_freq_map, raise ValueError("beam_lw, beam_mh and beam_nud must be >= 2") # Flatten correlations - ncorrs = reduce(lambda x, y: x*y, corrs, 1) + ncorrs = reduce(lambda x, y: x * y, corrs, 1) lower_l, upper_l = beam_lm_extents[0] lower_m, upper_m = beam_lm_extents[1] @@ -130,16 +135,16 @@ def beam_cube_dde(beam, beam_lm_extents, beam_freq_map, tm = sm + point_errors[t, a, f, 1] # Rotate lm coordinate angle - vl = tl*cos_pa - tm*sin_pa - vm = tl*sin_pa + tm*cos_pa + vl = tl * cos_pa - tm * sin_pa + vm = tl * sin_pa + tm * cos_pa # Scale by antenna scaling vl *= antenna_scaling[a, f, 0] vm *= antenna_scaling[a, f, 1] # Shift into the cube coordinate system - vl = lscale*(vl - lower_l) - vm = mscale*(vm - lower_m) + vl = lscale * (vl - lower_l) + vm = mscale * (vm - lower_m) # Clamp the coordinates to the edges of the cube vl = max(zero, min(vl, lmaxf)) @@ -163,28 +168,28 @@ def beam_cube_dde(beam, beam_lm_extents, beam_freq_map, # Accumulate lower cube correlations beam_scratch[:] = fbeam[gl0, gm0, gc0, :] - weight = (one - ld)*(one - md)*nud + weight = (one - ld) * (one - md) * nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) corr_sum[c] += weight * beam_scratch[c] beam_scratch[:] = fbeam[gl1, gm0, gc0, :] - weight = ld*(one - md)*nud + weight = ld * (one - md) * nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) corr_sum[c] += weight * beam_scratch[c] beam_scratch[:] = fbeam[gl0, gm1, gc0, :] - weight = (one - ld)*md*nud + weight = (one - ld) * md * nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) corr_sum[c] += weight * beam_scratch[c] beam_scratch[:] = fbeam[gl1, gm1, gc0, :] - weight = ld*md*nud + weight = ld * md * nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) @@ -192,28 +197,28 @@ def beam_cube_dde(beam, beam_lm_extents, beam_freq_map, # Accumulate upper cube correlations beam_scratch[:] = fbeam[gl0, gm0, gc1, :] - weight = (one - ld)*(one - md)*inv_nud + weight = (one - ld) * (one - md) * inv_nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) corr_sum[c] += weight * beam_scratch[c] beam_scratch[:] = fbeam[gl1, gm0, gc1, :] - weight = ld*(one - md)*inv_nud + weight = ld * (one - md) * inv_nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) corr_sum[c] += weight * beam_scratch[c] beam_scratch[:] = fbeam[gl0, gm1, gc1, :] - weight = (one - ld)*md*inv_nud + weight = (one - ld) * md * inv_nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) corr_sum[c] += weight * beam_scratch[c] beam_scratch[:] = fbeam[gl1, gm1, gc1, :] - weight = ld*md*inv_nud + weight = ld * md * inv_nud for c in range(ncorrs): absc_sum[c] += weight * np.abs(beam_scratch[c]) @@ -284,11 +289,13 @@ def beam_cube_dde(beam, beam_lm_extents, beam_freq_map, ddes : $(array_type) Direction Dependent Effects of shape :code:`(source, time, ant, chan, corr, corr)` - """) + """ +) try: beam_cube_dde.__doc__ = BEAM_CUBE_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/rime/feeds.py b/africanus/rime/feeds.py index 8002071f5..684a28374 100644 --- a/africanus/rime/feeds.py +++ b/africanus/rime/feeds.py @@ -38,8 +38,8 @@ def _nb_feed_rotation(parallactic_angles, feed_type, feed_rotation): feed_rotation.real[i, 0, 0] = pa_cos feed_rotation.imag[i, 0, 0] = -pa_sin - feed_rotation[i, 0, 1] = 0.0 + 0.0*1j - feed_rotation[i, 1, 0] = 0.0 + 0.0*1j + feed_rotation[i, 0, 1] = 0.0 + 0.0 * 1j + feed_rotation[i, 1, 0] = 0.0 + 0.0 * 1j feed_rotation.real[i, 1, 1] = pa_cos feed_rotation.imag[i, 1, 1] = pa_sin else: @@ -48,10 +48,10 @@ def _nb_feed_rotation(parallactic_angles, feed_type, feed_rotation): return feed_rotation.reshape(shape + (2, 2)) -def feed_rotation(parallactic_angles, feed_type='linear'): - if feed_type == 'linear': +def feed_rotation(parallactic_angles, feed_type="linear"): + if feed_type == "linear": poltype = 0 - elif feed_type == 'circular': + elif feed_type == "circular": poltype = 1 else: raise ValueError("Invalid feed_type '%s'" % feed_type) @@ -61,9 +61,10 @@ def feed_rotation(parallactic_angles, feed_type='linear'): elif parallactic_angles.dtype == np.float64: dtype = np.complex128 else: - raise ValueError("parallactic_angles has " - "none-floating point type %s" - % parallactic_angles.dtype) + raise ValueError( + "parallactic_angles has " + "none-floating point type %s" % parallactic_angles.dtype + ) # Create result array with flattened parangles shape = (reduce(mul, parallactic_angles.shape),) + (2, 2) @@ -72,7 +73,8 @@ def feed_rotation(parallactic_angles, feed_type='linear'): return _nb_feed_rotation(parallactic_angles, poltype, result) -FEED_ROTATION_DOCS = DocstringTemplate(r""" +FEED_ROTATION_DOCS = DocstringTemplate( + r""" Computes the 2x2 feed rotation (L) matrix from the ``parallactic_angles``. @@ -102,10 +104,12 @@ def feed_rotation(parallactic_angles, feed_type='linear'): ------- feed_matrix : $(array_type) Feed rotation matrix of shape :code:`(pa0, pa1,...,pan,2,2)` -""") +""" +) try: feed_rotation.__doc__ = FEED_ROTATION_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/rime/jax/phase.py b/africanus/rime/jax/phase.py index 5898a5de2..9087ace08 100644 --- a/africanus/rime/jax/phase.py +++ b/africanus/rime/jax/phase.py @@ -12,7 +12,7 @@ from africanus.util.requirements import requires_optional -@requires_optional('jax', opt_import_error) +@requires_optional("jax", opt_import_error) def phase_delay(lm, uvw, frequency): one = lm.dtype.type(1.0) neg_two_pi_over_c = lm.dtype.type(minus_two_pi_over_c) @@ -26,8 +26,6 @@ def phase_delay(lm, uvw, frequency): n = jnp.sqrt(one - l**2 - m**2) - one - real_phase = (neg_two_pi_over_c * - (l * u + m * v + n * w) * - frequency[None, None, :]) + real_phase = neg_two_pi_over_c * (l * u + m * v + n * w) * frequency[None, None, :] - return jnp.exp(jnp.complex64(1j)*real_phase) + return jnp.exp(jnp.complex64(1j) * real_phase) diff --git a/africanus/rime/jax/tests/test_jax_phase_delay.py b/africanus/rime/jax/tests/test_jax_phase_delay.py index 7b188129e..52d875f08 100644 --- a/africanus/rime/jax/tests/test_jax_phase_delay.py +++ b/africanus/rime/jax/tests/test_jax_phase_delay.py @@ -10,13 +10,13 @@ @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_jax_phase_delay(dtype): - jax = pytest.importorskip('jax') + jax = pytest.importorskip("jax") np.random.seed(0) uvw = np.random.random(size=(100, 3)).astype(dtype) - lm = np.random.random(size=(10, 2)).astype(dtype)*0.001 - frequency = np.linspace(.856e9, .856e9*2, 64).astype(dtype) + lm = np.random.random(size=(10, 2)).astype(dtype) * 0.001 + frequency = np.linspace(0.856e9, 0.856e9 * 2, 64).astype(dtype) # Compute complex phase np_complex_phase = np_phase_delay(lm, uvw, frequency) diff --git a/africanus/rime/parangles.py b/africanus/rime/parangles.py index ff207c9d8..589a126c1 100644 --- a/africanus/rime/parangles.py +++ b/africanus/rime/parangles.py @@ -3,25 +3,22 @@ import warnings -from .parangles_astropy import (have_astropy_parangles, - astropy_parallactic_angles) -from .parangles_casa import (have_casa_parangles, - casa_parallactic_angles) +from .parangles_astropy import have_astropy_parangles, astropy_parallactic_angles +from .parangles_casa import have_casa_parangles, casa_parallactic_angles _discovered_backends = [] if have_astropy_parangles: - _discovered_backends.append('astropy') + _discovered_backends.append("astropy") if have_casa_parangles: - _discovered_backends.append('casa') + _discovered_backends.append("casa") -_standard_backends = set(['casa', 'astropy', 'test']) +_standard_backends = set(["casa", "astropy", "test"]) -def parallactic_angles(times, antenna_positions, field_centre, - backend='casa'): +def parallactic_angles(times, antenna_positions, field_centre, backend="casa"): """ Computes parallactic angles per timestep for the given reference antenna position and field centre. @@ -52,24 +49,20 @@ def parallactic_angles(times, antenna_positions, field_centre, Parallactic angles of shape :code:`(time,ant)` """ if backend not in _standard_backends: - raise ValueError("'%s' is not one of the " - "standard backends '%s'" - % (backend, _standard_backends)) + raise ValueError( + "'%s' is not one of the " + "standard backends '%s'" % (backend, _standard_backends) + ) if not field_centre.shape == (2,): - raise ValueError("Invalid field_centre shape %s" % - (field_centre.shape,)) + raise ValueError("Invalid field_centre shape %s" % (field_centre.shape,)) - if backend == 'astropy': - warnings.warn('astropy backend currently returns the incorrect values') - return astropy_parallactic_angles(times, - antenna_positions, - field_centre) - elif backend == 'casa': - return casa_parallactic_angles(times, - antenna_positions, - field_centre) - elif backend == 'test': - return times[:, None]*(antenna_positions.sum(axis=1)[None, :]) + if backend == "astropy": + warnings.warn("astropy backend currently returns the incorrect values") + return astropy_parallactic_angles(times, antenna_positions, field_centre) + elif backend == "casa": + return casa_parallactic_angles(times, antenna_positions, field_centre) + elif backend == "test": + return times[:, None] * (antenna_positions.sum(axis=1)[None, :]) else: raise ValueError("Invalid backend %s" % backend) diff --git a/africanus/rime/parangles_astropy.py b/africanus/rime/parangles_astropy.py index 2b0ea4e64..d91ef8a21 100644 --- a/africanus/rime/parangles_astropy.py +++ b/africanus/rime/parangles_astropy.py @@ -4,8 +4,7 @@ from africanus.util.requirements import requires_optional try: - from astropy.coordinates import (EarthLocation, SkyCoord, - AltAz, CIRS) + from astropy.coordinates import EarthLocation, SkyCoord, AltAz, CIRS from astropy.time import Time from astropy import units except ImportError as e: @@ -16,7 +15,7 @@ have_astropy_parangles = True -@requires_optional('astropy', astropy_import_error) +@requires_optional("astropy", astropy_import_error) def astropy_parallactic_angles(times, antenna_positions, field_centre): """ Computes parallactic angles per timestep for the given @@ -26,12 +25,11 @@ def astropy_parallactic_angles(times, antenna_positions, field_centre): fc = field_centre # Convert from MJD second to MJD - times = Time(times / 86400.00, format='mjd', scale='utc') + times = Time(times / 86400.00, format="mjd", scale="utc") - ap = EarthLocation.from_geocentric( - ap[:, 0], ap[:, 1], ap[:, 2], unit='m') - fc = SkyCoord(ra=fc[0], dec=fc[1], unit=units.rad, frame='fk5') - pole = SkyCoord(ra=0, dec=90, unit=units.deg, frame='fk5') + ap = EarthLocation.from_geocentric(ap[:, 0], ap[:, 1], ap[:, 2], unit="m") + fc = SkyCoord(ra=fc[0], dec=fc[1], unit=units.rad, frame="fk5") + pole = SkyCoord(ra=0, dec=90, unit=units.deg, frame="fk5") cirs_frame = CIRS(obstime=times) pole_cirs = pole.transform_to(cirs_frame) diff --git a/africanus/rime/parangles_casa.py b/africanus/rime/parangles_casa.py index 3df7fd8c8..aef808cde 100644 --- a/africanus/rime/parangles_casa.py +++ b/africanus/rime/parangles_casa.py @@ -20,9 +20,10 @@ _thread_local = threading.local() -@requires_optional('pyrap.measures', 'pyrap.quanta', casa_import_error) -def casa_parallactic_angles(times, antenna_positions, field_centre, - zenith_frame='AZEL'): +@requires_optional("pyrap.measures", "pyrap.quanta", casa_import_error) +def casa_parallactic_angles( + times, antenna_positions, field_centre, zenith_frame="AZEL" +): """ Computes parallactic angles per timestep for the given reference antenna position and field centre. @@ -35,26 +36,28 @@ def casa_parallactic_angles(times, antenna_positions, field_centre, _thread_local.meas_serv = meas_serv = pyrap.measures.measures() # Create direction measure for the zenith - zenith = meas_serv.direction(zenith_frame, '0deg', '90deg') + zenith = meas_serv.direction(zenith_frame, "0deg", "90deg") # Create position measures for each antenna - reference_positions = [meas_serv.position( - 'itrf', - *(pq.quantity(x, 'm') for x in pos)) - for pos in antenna_positions] + reference_positions = [ + meas_serv.position("itrf", *(pq.quantity(x, "m") for x in pos)) + for pos in antenna_positions + ] # Compute field centre in radians - fc_rad = meas_serv.direction('J2000', *(pq.quantity(f, 'rad') - for f in field_centre)) - - return np.asarray([ - # Set current time as the reference frame - meas_serv.do_frame(meas_serv.epoch("UTC", pq.quantity(t, "s"))) - and - [ # Set antenna position as the reference frame - meas_serv.do_frame(rp) - and - meas_serv.posangle(fc_rad, zenith).get_value("rad") - for rp in reference_positions + fc_rad = meas_serv.direction( + "J2000", *(pq.quantity(f, "rad") for f in field_centre) + ) + + return np.asarray( + [ + # Set current time as the reference frame + meas_serv.do_frame(meas_serv.epoch("UTC", pq.quantity(t, "s"))) + and [ # Set antenna position as the reference frame + meas_serv.do_frame(rp) + and meas_serv.posangle(fc_rad, zenith).get_value("rad") + for rp in reference_positions + ] + for t in times ] - for t in times]) + ) diff --git a/africanus/rime/phase.py b/africanus/rime/phase.py index 8fce52a7e..576194341 100644 --- a/africanus/rime/phase.py +++ b/africanus/rime/phase.py @@ -9,26 +9,26 @@ @njit(**JIT_OPTIONS) -def phase_delay(lm, uvw, frequency, convention='fourier'): +def phase_delay(lm, uvw, frequency, convention="fourier"): return phase_delay_impl(lm, uvw, frequency, convention=convention) -def phase_delay_impl(lm, uvw, frequency, convention='fourier'): +def phase_delay_impl(lm, uvw, frequency, convention="fourier"): raise NotImplementedError @overload(phase_delay_impl, jit_options=JIT_OPTIONS) -def nb_phase_delay(lm, uvw, frequency, convention='fourier'): +def nb_phase_delay(lm, uvw, frequency, convention="fourier"): # Bake constants in with the correct type one = lm.dtype(1.0) zero = lm.dtype(0.0) neg_two_pi_over_c = lm.dtype(minus_two_pi_over_c) out_dtype = infer_complex_dtype(lm, uvw, frequency) - def _phase_delay_impl(lm, uvw, frequency, convention='fourier'): - if convention == 'fourier': + def _phase_delay_impl(lm, uvw, frequency, convention="fourier"): + if convention == "fourier": constant = neg_two_pi_over_c - elif convention == 'casa': + elif convention == "casa": constant = -neg_two_pi_over_c else: raise ValueError("convention not in ('fourier', 'casa')") @@ -104,10 +104,12 @@ def _phase_delay_impl(lm, uvw, frequency, convention='fourier'): ------- complex_phase : $(array_type) complex of shape :code:`(source, row, chan)` - """) + """ +) try: phase_delay.__doc__ = PHASE_DELAY_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" + ) except AttributeError: pass diff --git a/africanus/rime/predict.py b/africanus/rime/predict.py index cc7e13e13..f8a4a7929 100644 --- a/africanus/rime/predict.py +++ b/africanus/rime/predict.py @@ -4,8 +4,7 @@ import numpy as np from africanus.util.docs import DocstringTemplate -from africanus.util.numba import (is_numba_type_none, JIT_OPTIONS, - njit, overload) +from africanus.util.numba import is_numba_type_none, JIT_OPTIONS, njit, overload JONES_NOT_PRESENT = 0 @@ -51,8 +50,7 @@ def _get_jones_types(name, numba_ndarray_type, corr_1_dims, corr_2_dims): elif numba_ndarray_type.ndim == corr_2_dims: return JONES_2X2 else: - raise ValueError("%s.ndim not in (%d, %d)" % - (name, corr_1_dims, corr_2_dims)) + raise ValueError("%s.ndim not in (%d, %d)" % (name, corr_1_dims, corr_2_dims)) def jones_mul_factory(have_ddes, have_coh, jones_type, accumulate): @@ -91,6 +89,7 @@ def jones_mul_factory(have_ddes, have_coh, jones_type, accumulate): if have_coh and have_ddes: if jones_type == JONES_1_OR_2: + def jones_mul(a1j, blj, a2j, jout): for c in range(jout.shape[0]): if accumulate: @@ -99,6 +98,7 @@ def jones_mul(a1j, blj, a2j, jout): jout[c] = a1j[c] * blj[c] * np.conj(a2j[c]) elif jones_type == JONES_2X2: + def jones_mul(a1j, blj, a2j, jout): a2_xx_H = np.conj(a2j[0, 0]) a2_xy_H = np.conj(a2j[0, 1]) @@ -125,6 +125,7 @@ def jones_mul(a1j, blj, a2j, jout): raise ex elif have_ddes and not have_coh: if jones_type == JONES_1_OR_2: + def jones_mul(a1j, a2j, jout): for c in range(jout.shape[0]): if accumulate: @@ -133,6 +134,7 @@ def jones_mul(a1j, a2j, jout): jout[c] = a1j[c] * np.conj(a2j[c]) elif jones_type == JONES_2X2: + def jones_mul(a1j, a2j, jout): a2_xx_H = np.conj(a2j[0, 0]) a2_xy_H = np.conj(a2j[0, 1]) @@ -153,6 +155,7 @@ def jones_mul(a1j, a2j, jout): raise ex elif not have_ddes and have_coh: if jones_type == JONES_1_OR_2: + def jones_mul(blj, jout): for c in range(jout.shape[0]): if accumulate: @@ -163,6 +166,7 @@ def jones_mul(blj, jout): jout[c] = blj[c] elif jones_type == JONES_2X2: + def jones_mul(blj, jout): if accumulate: jout[0, 0] += blj[0, 0] @@ -187,10 +191,11 @@ def jones_mul(): def sum_coherencies_factory(have_ddes, have_coh, jones_type): - """ Factory function generating a function that sums coherencies """ + """Factory function generating a function that sums coherencies""" jones_mul = jones_mul_factory(have_ddes, have_coh, jones_type, True) if have_ddes and have_coh: + def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): for s in range(a1j.shape[0]): for r in range(time.shape[0]): @@ -199,12 +204,15 @@ def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): a2 = ant2[r] for f in range(a1j.shape[3]): - jones_mul(a1j[s, ti, a1, f], - blj[s, r, f], - a2j[s, ti, a2, f], - cout[r, f]) + jones_mul( + a1j[s, ti, a1, f], + blj[s, r, f], + a2j[s, ti, a2, f], + cout[r, f], + ) elif have_ddes and not have_coh: + def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): for s in range(a1j.shape[0]): for r in range(time.shape[0]): @@ -213,12 +221,11 @@ def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): a2 = ant2[r] for f in range(a1j.shape[3]): - jones_mul(a1j[s, ti, a1, f], - a2j[s, ti, a2, f], - cout[r, f]) + jones_mul(a1j[s, ti, a1, f], a2j[s, ti, a2, f], cout[r, f]) elif not have_ddes and have_coh: if jones_type == JONES_2X2: + def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): for s in range(blj.shape[0]): for r in range(blj.shape[1]): @@ -227,6 +234,7 @@ def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): for c2 in range(blj.shape[4]): cout[r, f, c1, c2] += blj[s, r, f, c1, c2] else: + def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): # TODO(sjperkins): Without this, these loops # produce an incorrect value @@ -245,48 +253,82 @@ def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, cout): def output_factory(have_ddes, have_coh, have_dies, have_base_vis, out_dtype): - """ Factory function generating a function that creates function output """ + """Factory function generating a function that creates function output""" if have_ddes: - def output(time_index, dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones): + + def output( + time_index, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ): row = time_index.shape[0] chan = dde1_jones.shape[3] corrs = dde1_jones.shape[4:] return np.zeros((row, chan) + corrs, dtype=out_dtype) elif have_coh: - def output(time_index, dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones): + + def output( + time_index, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ): row = time_index.shape[0] chan = source_coh.shape[2] corrs = source_coh.shape[3:] return np.zeros((row, chan) + corrs, dtype=out_dtype) elif have_dies: - def output(time_index, dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones): + + def output( + time_index, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ): row = time_index.shape[0] chan = die1_jones.shape[2] corrs = die1_jones.shape[3:] return np.zeros((row, chan) + corrs, dtype=out_dtype) elif have_base_vis: - def output(time_index, dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones): + + def output( + time_index, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ): row = time_index.shape[0] chan = base_vis.shape[1] corrs = base_vis.shape[2:] return np.zeros((row, chan) + corrs, dtype=out_dtype) else: - raise ValueError("Insufficient inputs were supplied " - "for determining the output shape") + raise ValueError( + "Insufficient inputs were supplied " "for determining the output shape" + ) # TODO(sjperkins) # perhaps inline="always" on resolution of # https://github.com/numba/numba/issues/4691 - return njit(nogil=True, inline='never')(output) + return njit(nogil=True, inline="never")(output) def add_coh_factory(have_bvis): if have_bvis: + def add_coh(base_vis, add_coh_cout): add_coh_cout += base_vis else: @@ -307,9 +349,8 @@ def apply_dies_factory(have_dies, jones_type): jones_mul = jones_mul_factory(have_dies, True, jones_type, False) if have_dies: - def apply_dies(time, ant1, ant2, - die1_jones, die2_jones, - tmin, dies_out): + + def apply_dies(time, ant1, ant2, die1_jones, die2_jones, tmin, dies_out): # Iterate over rows for r in range(time.shape[0]): ti = time[r] - tmin @@ -318,13 +359,15 @@ def apply_dies(time, ant1, ant2, # Iterate over channels for c in range(dies_out.shape[1]): - jones_mul(die1_jones[ti, a1, c], dies_out[r, c], - die2_jones[ti, a2, c], dies_out[r, c]) + jones_mul( + die1_jones[ti, a1, c], + dies_out[r, c], + die2_jones[ti, a2, c], + dies_out[r, c], + ) else: # noop - def apply_dies(time, ant1, ant2, - die1_jones, die2_jones, - tmin, dies_out): + def apply_dies(time, ant1, ant2, die1_jones, die2_jones, tmin, dies_out): pass return njit(nogil=True, inline="always")(apply_dies) @@ -334,11 +377,18 @@ def _default_none_check(arg): return arg is not None -def predict_checks(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones, - none_check=_default_none_check): - +def predict_checks( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + none_check=_default_none_check, +): have_ddes1 = none_check(dde1_jones) have_coh = none_check(source_coh) have_ddes2 = none_check(dde2_jones) @@ -351,12 +401,10 @@ def predict_checks(time_index, antenna1, antenna2, assert antenna2.ndim == 1 if have_ddes1 ^ have_ddes2: - raise ValueError("Both dde1_jones and dde2_jones " - "must be present or absent") + raise ValueError("Both dde1_jones and dde2_jones " "must be present or absent") if have_dies1 ^ have_dies2: - raise ValueError("Both die1_jones and die2_jones " - "must be present or absent") + raise ValueError("Both die1_jones and die2_jones " "must be present or absent") have_ddes = have_ddes1 and have_ddes2 have_dies = have_dies1 and have_dies2 @@ -389,7 +437,7 @@ def predict_checks(time_index, antenna1, antenna2, if have_ddes: ndim = dde1_jones.ndim - expected_sizes.append([ndim, ndim - 1, ndim - 2, ndim - 1]), + (expected_sizes.append([ndim, ndim - 1, ndim - 2, ndim - 1]),) if have_coh: ndim = source_coh.ndim @@ -404,54 +452,96 @@ def predict_checks(time_index, antenna1, antenna2, expected_sizes.append([ndim + 2, ndim + 1, ndim, ndim + 1]) if not all(expected_sizes[0] == s for s in expected_sizes[1:]): - raise ValueError("One of the following pre-conditions is broken " - "(missing values are ignored):\n" - "dde_jones{1,2}.ndim == source_coh.ndim + 1\n" - "dde_jones{1,2}.ndim == base_vis.ndim + 2\n" - "dde_jones{1,2}.ndim == die_jones{1,2}.ndim + 1") + raise ValueError( + "One of the following pre-conditions is broken " + "(missing values are ignored):\n" + "dde_jones{1,2}.ndim == source_coh.ndim + 1\n" + "dde_jones{1,2}.ndim == base_vis.ndim + 2\n" + "dde_jones{1,2}.ndim == die_jones{1,2}.ndim + 1" + ) - return (have_ddes1, have_coh, have_ddes2, - have_dies1, have_bvis, have_dies2) + return (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) @njit(**JIT_OPTIONS) -def predict_vis(time_index, antenna1, antenna2, - dde1_jones=None, source_coh=None, dde2_jones=None, - die1_jones=None, base_vis=None, die2_jones=None): - return predict_vis_impl(time_index, antenna1, antenna2, - dde1_jones=dde1_jones, - source_coh=source_coh, - dde2_jones=dde2_jones, - die1_jones=die1_jones, - base_vis=base_vis, - die2_jones=die2_jones) - - -def predict_vis_impl(time_index, antenna1, antenna2, - dde1_jones=None, source_coh=None, dde2_jones=None, - die1_jones=None, base_vis=None, die2_jones=None): +def predict_vis( + time_index, + antenna1, + antenna2, + dde1_jones=None, + source_coh=None, + dde2_jones=None, + die1_jones=None, + base_vis=None, + die2_jones=None, +): + return predict_vis_impl( + time_index, + antenna1, + antenna2, + dde1_jones=dde1_jones, + source_coh=source_coh, + dde2_jones=dde2_jones, + die1_jones=die1_jones, + base_vis=base_vis, + die2_jones=die2_jones, + ) + + +def predict_vis_impl( + time_index, + antenna1, + antenna2, + dde1_jones=None, + source_coh=None, + dde2_jones=None, + die1_jones=None, + base_vis=None, + die2_jones=None, +): raise NotImplementedError @overload(predict_vis_impl, jit_options=JIT_OPTIONS) -def nb_predict_vis(time_index, antenna1, antenna2, - dde1_jones=None, source_coh=None, dde2_jones=None, - die1_jones=None, base_vis=None, die2_jones=None): - - tup = predict_checks(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones, - lambda x: not is_numba_type_none(x)) +def nb_predict_vis( + time_index, + antenna1, + antenna2, + dde1_jones=None, + source_coh=None, + dde2_jones=None, + die1_jones=None, + base_vis=None, + die2_jones=None, +): + tup = predict_checks( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + lambda x: not is_numba_type_none(x), + ) (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = tup # Infer the output dtype - dtype_arrays = (dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones) - - out_dtype = np.result_type(*(np.dtype(a.dtype.name) - for a in dtype_arrays - if not is_numba_type_none(a))) + dtype_arrays = ( + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ) + + out_dtype = np.result_type( + *(np.dtype(a.dtype.name) for a in dtype_arrays if not is_numba_type_none(a)) + ) jones_types = [ _get_jones_types("dde1_jones", dde1_jones, 5, 6), @@ -459,7 +549,8 @@ def nb_predict_vis(time_index, antenna1, antenna2, _get_jones_types("dde2_jones", dde2_jones, 5, 6), _get_jones_types("die1_jones", die1_jones, 4, 5), _get_jones_types("base_vis", base_vis, 3, 4), - _get_jones_types("die2_jones", die2_jones, 4, 5)] + _get_jones_types("die2_jones", die2_jones, 4, 5), + ] ptypes = [t for t in jones_types if t != JONES_NOT_PRESENT] @@ -475,35 +566,53 @@ def nb_predict_vis(time_index, antenna1, antenna2, have_dies = have_dies1 and have_dies2 # Create functions that we will use inside our predict function - out_fn = output_factory(have_ddes, have_coh, - have_dies, have_bvis, out_dtype) + out_fn = output_factory(have_ddes, have_coh, have_dies, have_bvis, out_dtype) sum_coh_fn = sum_coherencies_factory(have_ddes, have_coh, jones_type) apply_dies_fn = apply_dies_factory(have_dies, jones_type) add_coh_fn = add_coh_factory(have_bvis) - def _predict_vis_fn(time_index, antenna1, antenna2, - dde1_jones=None, source_coh=None, dde2_jones=None, - die1_jones=None, base_vis=None, die2_jones=None): - + def _predict_vis_fn( + time_index, + antenna1, + antenna2, + dde1_jones=None, + source_coh=None, + dde2_jones=None, + die1_jones=None, + base_vis=None, + die2_jones=None, + ): # Get the output shape - out = out_fn(time_index, dde1_jones, source_coh, dde2_jones, - die1_jones, base_vis, die2_jones) + out = out_fn( + time_index, + dde1_jones, + source_coh, + dde2_jones, + die1_jones, + base_vis, + die2_jones, + ) # Minimum time index, used to normalise within function tmin = time_index.min() # Sum coherencies if any - sum_coh_fn(time_index, antenna1, antenna2, - dde1_jones, source_coh, dde2_jones, - tmin, out) + sum_coh_fn( + time_index, + antenna1, + antenna2, + dde1_jones, + source_coh, + dde2_jones, + tmin, + out, + ) # Add base visibilities to the output, if any add_coh_fn(base_vis, out) # Apply direction independent effects, if any - apply_dies_fn(time_index, antenna1, antenna2, - die1_jones, die2_jones, - tmin, out) + apply_dies_fn(time_index, antenna1, antenna2, die1_jones, die2_jones, tmin, out) return out @@ -511,32 +620,37 @@ def _predict_vis_fn(time_index, antenna1, antenna2, @njit(**JIT_OPTIONS) -def apply_gains(time_index, antenna1, antenna2, - die1_jones, corrupted_vis, die2_jones): - return apply_gains_impl(time_index, antenna1, antenna2, - die1_jones, corrupted_vis, die2_jones) +def apply_gains(time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones): + return apply_gains_impl( + time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones + ) -def apply_gains_impl(time_index, antenna1, antenna2, - die1_jones, corrupted_vis, die2_jones): +def apply_gains_impl( + time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones +): raise NotImplementedError @overload(apply_gains_impl, jit_options=JIT_OPTIONS) -def nb_apply_gains(time_index, antenna1, antenna2, - die1_jones, corrupted_vis, die2_jones): - - def impl(time_index, antenna1, antenna2, - die1_jones, corrupted_vis, die2_jones): - return predict_vis(time_index, antenna1, antenna2, - die1_jones=die1_jones, - base_vis=corrupted_vis, - die2_jones=die2_jones) +def nb_apply_gains( + time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones +): + def impl(time_index, antenna1, antenna2, die1_jones, corrupted_vis, die2_jones): + return predict_vis( + time_index, + antenna1, + antenna2, + die1_jones=die1_jones, + base_vis=corrupted_vis, + die2_jones=die2_jones, + ) return impl -PREDICT_DOCS = DocstringTemplate(r""" +PREDICT_DOCS = DocstringTemplate( + r""" Multiply Jones terms together to form model visibilities according to the following formula: @@ -622,21 +736,23 @@ def impl(time_index, antenna1, antenna2, ------- visibilities : $(array_type) Model visibilities of shape :code:`(row,chan,corr_1,corr_2)` -""") +""" +) try: predict_vis.__doc__ = PREDICT_DOCS.substitute( - array_type=":class:`numpy.ndarray`", - get_time_index=":code:`np.unique(time, " - "return_inverse=True)[1]`", - extra_args="", - extra_notes="") + array_type=":class:`numpy.ndarray`", + get_time_index=":code:`np.unique(time, " "return_inverse=True)[1]`", + extra_args="", + extra_notes="", + ) except AttributeError: pass -APPLY_GAINS_DOCS = DocstringTemplate(r""" +APPLY_GAINS_DOCS = DocstringTemplate( + r""" Apply gains to corrupted visibilities in order to recover the true visibilities. @@ -676,11 +792,13 @@ def impl(time_index, antenna1, antenna2, ------- true_vis : $(array_type) True visibilities of shape :code:`(row,chan,corr_1,corr_2)` -""") +""" +) try: apply_gains.__doc__ = APPLY_GAINS_DOCS.substitute( - array_type=":class:`numpy.ndarray`", - wrapper_func=":func:`~africanus.rime.predict_vis`") + array_type=":class:`numpy.ndarray`", + wrapper_func=":func:`~africanus.rime.predict_vis`", + ) except AttributeError: pass diff --git a/africanus/rime/tests/conftest.py b/africanus/rime/tests/conftest.py index 576635e7d..5ad193fc5 100644 --- a/africanus/rime/tests/conftest.py +++ b/africanus/rime/tests/conftest.py @@ -10,20 +10,23 @@ @pytest.fixture def wsrt_ants(): - """ Westerbork antenna positions """ - return np.array([ - [3828763.10544699, 442449.10566454, 5064923.00777], - [3828746.54957258, 442592.13950824, 5064923.00792], - [3828729.99081359, 442735.17696417, 5064923.00829], - [3828713.43109885, 442878.2118934, 5064923.00436], - [3828696.86994428, 443021.24917264, 5064923.00397], - [3828680.31391933, 443164.28596862, 5064923.00035], - [3828663.75159173, 443307.32138056, 5064923.00204], - [3828647.19342757, 443450.35604638, 5064923.0023], - [3828630.63486201, 443593.39226634, 5064922.99755], - [3828614.07606798, 443736.42941621, 5064923.], - [3828609.94224429, 443772.19450029, 5064922.99868], - [3828601.66208572, 443843.71178407, 5064922.99963], - [3828460.92418735, 445059.52053929, 5064922.99071], - [3828452.64716351, 445131.03744105, 5064922.98793]], - dtype=np.float64) + """Westerbork antenna positions""" + return np.array( + [ + [3828763.10544699, 442449.10566454, 5064923.00777], + [3828746.54957258, 442592.13950824, 5064923.00792], + [3828729.99081359, 442735.17696417, 5064923.00829], + [3828713.43109885, 442878.2118934, 5064923.00436], + [3828696.86994428, 443021.24917264, 5064923.00397], + [3828680.31391933, 443164.28596862, 5064923.00035], + [3828663.75159173, 443307.32138056, 5064923.00204], + [3828647.19342757, 443450.35604638, 5064923.0023], + [3828630.63486201, 443593.39226634, 5064922.99755], + [3828614.07606798, 443736.42941621, 5064923.0], + [3828609.94224429, 443772.19450029, 5064922.99868], + [3828601.66208572, 443843.71178407, 5064922.99963], + [3828460.92418735, 445059.52053929, 5064922.99071], + [3828452.64716351, 445131.03744105, 5064922.98793], + ], + dtype=np.float64, + ) diff --git a/africanus/rime/tests/test_fast_beams.py b/africanus/rime/tests/test_fast_beams.py index 4645bebf0..9de7450b7 100644 --- a/africanus/rime/tests/test_fast_beams.py +++ b/africanus/rime/tests/test_fast_beams.py @@ -11,18 +11,18 @@ def rf(*a, **kw): def rc(*a, **kw): - return rf(*a, **kw) + 1j*rf(*a, **kw) + return rf(*a, **kw) + 1j * rf(*a, **kw) @pytest.fixture def beam_freq_map(): - return np.array([.5, .56, .7, .91, 1.0]) + return np.array([0.5, 0.56, 0.7, 0.91, 1.0]) @pytest.fixture def beam_freq_map_montblanc(): - """ Montblanc doesn't handle values outside the cube in the same way """ - return np.array([.4, .56, .7, .91, 1.1]) + """Montblanc doesn't handle values outside the cube in the same way""" + return np.array([0.4, 0.56, 0.7, 0.91, 1.1]) @pytest.fixture @@ -37,15 +37,15 @@ def freqs(): 4. One value (1.1) above the beam freq range """ - return np.array([.4, .5, .6, .7, .8, .9, 1.0, 1.1]) + return np.array([0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1]) def test_fast_beam_small(): - """ Small beam test, interpolation of one soure at [0.1, 0.1] """ + """Small beam test, interpolation of one soure at [0.1, 0.1]""" np.random.seed(42) # One frequency, to the lower side of the beam frequency map - freq = np.asarray([.3]) + freq = np.asarray([0.3]) beam_freq_map = np.asarray([0.0, 1.0]) beam_lw = 2 @@ -66,9 +66,16 @@ def test_fast_beam_small(): beam = rc((beam_lw, beam_mh, beam_nud, 1)) - ddes = beam_cube_dde(beam, beam_extents, beam_freq_map, - lm, parangles, point_errors, antenna_scaling, - freq) + ddes = beam_cube_dde( + beam, + beam_extents, + beam_freq_map, + lm, + parangles, + point_errors, + antenna_scaling, + freq, + ) # Pen and paper the expected value lower_l = beam_extents[0, 0] @@ -82,15 +89,15 @@ def test_fast_beam_small(): abs_sum = np.zeros((1,), dtype=beam.real.dtype) # Weights of the sample at each grid point - wt0 = (1 - ld)*(1 - md)*(1 - chd) - wt1 = ld*(1 - md)*(1 - chd) - wt2 = (1 - ld)*md*(1 - chd) - wt3 = ld*md*(1 - chd) + wt0 = (1 - ld) * (1 - md) * (1 - chd) + wt1 = ld * (1 - md) * (1 - chd) + wt2 = (1 - ld) * md * (1 - chd) + wt3 = ld * md * (1 - chd) - wt4 = (1 - ld)*(1 - md)*chd - wt5 = ld*(1 - md)*chd - wt6 = (1 - ld)*md*chd - wt7 = ld*md*chd + wt4 = (1 - ld) * (1 - md) * chd + wt5 = ld * (1 - md) * chd + wt6 = (1 - ld) * md * chd + wt7 = ld * md * chd # Sum lower channel correlations corr_sum[:] += wt0 * beam[0, 0, 0, 0] @@ -116,7 +123,7 @@ def test_fast_beam_small(): corr_sum *= abs_sum / np.abs(corr_sum) - assert_array_almost_equal([[[[[0.470255+0.4786j]]]]], ddes) + assert_array_almost_equal([[[[[0.470255 + 0.4786j]]]]], ddes) assert_array_almost_equal(ddes.squeeze(), corr_sum.squeeze()) @@ -129,8 +136,7 @@ def test_grid_interpolate(freqs, beam_freq_map): # Frequencies (first -- 0.8 and last -- 1.1) # outside the beam result in scaling, - assert_array_almost_equal(freq_scale, - [0.8, 1., 1., 1., 1., 1., 1., 1.1]) + assert_array_almost_equal(freq_scale, [0.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1]) # Frequencies outside the beam are snapped to 0 if below # and beam_nud - 2 if above. # Frequencies on the edges are similarly snapped @@ -139,14 +145,7 @@ def test_grid_interpolate(freqs, beam_freq_map): # frequency is snapped to the first and last beam freq value # if outside (first and last values) # Third frequency value is also exactly on a grid point - exp_diff = [1., - 1., - 0.71428571, - 1., - 0.52380952, - 0.04761905, - 0., - 0.] + exp_diff = [1.0, 1.0, 0.71428571, 1.0, 0.52380952, 0.04761905, 0.0, 0.0] assert_array_almost_equal(fgrid_diff, exp_diff) @@ -176,35 +175,48 @@ def test_dask_fast_beams(freqs, beam_freq_map): antenna_scaling = np.random.random(size=(ants, chans, 2)) # Make random values more representative - lm = (lm - 0.5)*0.0001 # Shift lm to around the centre - parangles *= np.pi / 12 # parangles to 15 degrees max - point_errors *= 0.001 # Pointing errors - antenna_scaling *= 0.0001 # Antenna scaling + lm = (lm - 0.5) * 0.0001 # Shift lm to around the centre + parangles *= np.pi / 12 # parangles to 15 degrees max + point_errors *= 0.001 # Pointing errors + antenna_scaling *= 0.0001 # Antenna scaling # Beam variables beam = rc((beam_lw, beam_mh, beam_nud, 2, 2)) beam_lm_extents = np.asarray([[-1.0, 1.0], [-1.0, 1.0]]) # Compute numba ddes - ddes = beam_cube_dde(beam, beam_lm_extents, beam_freq_map, - lm, parangles, point_errors, antenna_scaling, - freqs) + ddes = beam_cube_dde( + beam, + beam_lm_extents, + beam_freq_map, + lm, + parangles, + point_errors, + antenna_scaling, + freqs, + ) # Create dask arrays da_beam = da.from_array(beam, chunks=beam.shape) da_beam_freq_map = da.from_array(beam_freq_map, chunks=beam_freq_map.shape) da_lm = da.from_array(lm, chunks=(src_c, 2)) da_parangles = da.from_array(parangles, chunks=(time_c, ants_c)) - da_point_errors = da.from_array(point_errors, - chunks=(time_c, ants_c, chan_c, 2)) + da_point_errors = da.from_array(point_errors, chunks=(time_c, ants_c, chan_c, 2)) da_ant_scale = da.from_array(antenna_scaling, chunks=(ants_c, chan_c, 2)) da_extents = da.from_array(beam_lm_extents, chunks=beam_lm_extents.shape) da_freqs = da.from_array(freqs, chunks=chan_c) # dask ddes - da_ddes = dask_beam_cube_dde(da_beam, da_extents, da_beam_freq_map, - da_lm, da_parangles, da_point_errors, - da_ant_scale, da_freqs) + da_ddes = dask_beam_cube_dde( + da_beam, + da_extents, + da_beam_freq_map, + da_lm, + da_parangles, + da_point_errors, + da_ant_scale, + da_freqs, + ) # Should be strictly equal assert_array_equal(da_ddes.compute(), ddes) @@ -212,7 +224,7 @@ def test_dask_fast_beams(freqs, beam_freq_map): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_fast_beams_vs_montblanc(freqs, beam_freq_map_montblanc, dtype): - """ Test that the numba beam matches montblanc implementation """ + """Test that the numba beam matches montblanc implementation""" mb_tf_mod = pytest.importorskip("montblanc.impl.rime.tensorflow") tf = pytest.importorskip("tensorflow") @@ -237,41 +249,69 @@ def test_fast_beams_vs_montblanc(freqs, beam_freq_map_montblanc, dtype): antenna_scaling = np.random.random(size=(ants, chans, 2)).astype(dtype) # Make random values more representative - lm = (lm - 0.5)*0.0001 # Shift lm to around the centre - parangles *= np.pi / 12 # parangles to 15 degrees max - point_errors *= 0.001 # Pointing errors - antenna_scaling *= 0.0001 # Antenna scaling + lm = (lm - 0.5) * 0.0001 # Shift lm to around the centre + parangles *= np.pi / 12 # parangles to 15 degrees max + point_errors *= 0.001 # Pointing errors + antenna_scaling *= 0.0001 # Antenna scaling # Beam variables beam = rc((beam_lw, beam_mh, beam_nud, 2, 2)).astype(ctype) beam_lm_extents = np.asarray([[-1.0, 1.0], [-1.0, 1.0]]).astype(dtype) - ddes = beam_cube_dde(beam, beam_lm_extents, beam_freq_map, - lm, parangles, point_errors, antenna_scaling, - freqs) + ddes = beam_cube_dde( + beam, + beam_lm_extents, + beam_freq_map, + lm, + parangles, + point_errors, + antenna_scaling, + freqs, + ) assert ddes.shape == (src, time, ants, chans, 2, 2) rime = mb_tf_mod.load_tf_lib() # Montblanc beam extent format is different - mb_beam_extents = np.array([beam_lm_extents[0, 0], - beam_lm_extents[1, 0], - beam_freq_map[0], - beam_lm_extents[0, 1], - beam_lm_extents[1, 1], - beam_freq_map[-1]], dtype=dtype) + mb_beam_extents = np.array( + [ + beam_lm_extents[0, 0], + beam_lm_extents[1, 0], + beam_freq_map[0], + beam_lm_extents[0, 1], + beam_lm_extents[1, 1], + beam_freq_map[-1], + ], + dtype=dtype, + ) # Montblanc wants flattened correlations mb_beam = beam.reshape(beam.shape[:3] + (-1,)) - np_args = [lm, freqs, point_errors, antenna_scaling, - np.sin(parangles), np.cos(parangles), - mb_beam_extents, beam_freq_map, mb_beam] + np_args = [ + lm, + freqs, + point_errors, + antenna_scaling, + np.sin(parangles), + np.cos(parangles), + mb_beam_extents, + beam_freq_map, + mb_beam, + ] # Argument string name list - arg_names = ["lm", "frequency", "point_errors", "antenna_scaling", - "parallactic_angle_sin", "parallactic_angle_cos", - "beam_extents", "beam_freq_map", "e_beam"] + arg_names = [ + "lm", + "frequency", + "point_errors", + "antenna_scaling", + "parallactic_angle_sin", + "parallactic_angle_cos", + "beam_extents", + "beam_freq_map", + "e_beam", + ] # Constructor tensorflow variables tf_args = [tf.Variable(v, name=n) for v, n in zip(np_args, arg_names)] diff --git a/africanus/rime/tests/test_parangles.py b/africanus/rime/tests/test_parangles.py index c78080f12..43d3c75b1 100644 --- a/africanus/rime/tests/test_parangles.py +++ b/africanus/rime/tests/test_parangles.py @@ -4,8 +4,8 @@ from africanus.rime.parangles import _discovered_backends -no_casa = 'casa' not in _discovered_backends -no_astropy = 'astropy' not in _discovered_backends +no_casa = "casa" not in _discovered_backends +no_astropy = "astropy" not in _discovered_backends def _julian_day(year, month, day): @@ -27,9 +27,14 @@ def _julian_day(year, month, day): # Formula below from # http://scienceworld.wolfram.com/astronomy/JulianDate.html # Also agrees with https://gist.github.com/jiffyclub/1294443 - return (367*year - int(7*(year + int((month+9)/12))/4) - - int((3*(int(year + (month - 9)/7)/100)+1)/4) - + int(275*month/9) + day + 1721028.5) + return ( + 367 * year + - int(7 * (year + int((month + 9) / 12)) / 4) + - int((3 * (int(year + (month - 9) / 7) / 100) + 1) / 4) + + int(275 * month / 9) + + day + + 1721028.5 + ) def _modified_julian_date(year, month, day): @@ -58,25 +63,31 @@ def _observation_endpoints(year, month, day, hour_duration): in Modified Julian Date seconds """ start = _modified_julian_date(year, month, day) - end = start + hour_duration / 24. + end = start + hour_duration / 24.0 # Convert to seconds - start *= 86400. - end *= 86400. + start *= 86400.0 + end *= 86400.0 return (start, end) @pytest.mark.flaky(min_passes=1, max_runs=3) -@pytest.mark.parametrize('backend', [ - 'test', - pytest.param('casa', marks=pytest.mark.skipif( - no_casa, - reason='python-casascore not installed')), - pytest.param('astropy', marks=pytest.mark.skipif( - no_astropy, - reason="astropy not installed"))]) -@pytest.mark.parametrize('observation', [(2018, 1, 1, 4)]) +@pytest.mark.parametrize( + "backend", + [ + "test", + pytest.param( + "casa", + marks=pytest.mark.skipif(no_casa, reason="python-casascore not installed"), + ), + pytest.param( + "astropy", + marks=pytest.mark.skipif(no_astropy, reason="astropy not installed"), + ), + ], +) +@pytest.mark.parametrize("observation", [(2018, 1, 1, 4)]) def test_parallactic_angles(observation, wsrt_ants, backend): import numpy as np from africanus.rime import parallactic_angles @@ -91,13 +102,14 @@ def test_parallactic_angles(observation, wsrt_ants, backend): @pytest.mark.flaky(min_passes=1, max_runs=3) -@pytest.mark.skipif(no_casa or no_astropy, - reason="Neither python-casacore or astropy installed") +@pytest.mark.skipif( + no_casa or no_astropy, reason="Neither python-casacore or astropy installed" +) # Parametrize on observation length and error tolerance -@pytest.mark.parametrize('obs_and_tol', [ - ((2018, 1, 1, 4), "10s"), - ((2018, 2, 20, 8), "10s"), - ((2018, 11, 2, 4), "10s")]) +@pytest.mark.parametrize( + "obs_and_tol", + [((2018, 1, 1, 4), "10s"), ((2018, 2, 20, 8), "10s"), ((2018, 11, 2, 4), "10s")], +) def test_compare_astropy_and_casa(obs_and_tol, wsrt_ants): """ Compare astropy and python-casacore parallactic angle implementations. @@ -115,32 +127,38 @@ def test_compare_astropy_and_casa(obs_and_tol, wsrt_ants): time = np.linspace(start, end, 5) ant = wsrt_ants[:4, :] - fc = np.array([0., 1.04719755], dtype=np.float64) + fc = np.array([0.0, 1.04719755], dtype=np.float64) astro_pa = astropy_parallactic_angles(time, ant, fc) - casa_pa = casa_parallactic_angles(time, ant, fc, zenith_frame='AZELGEO') + casa_pa = casa_parallactic_angles(time, ant, fc, zenith_frame="AZELGEO") # Convert to angle degrees - astro_pa = Angle(astro_pa, unit=units.deg).wrap_at(180*units.deg) - casa_pa = Angle(casa_pa*units.rad, unit=units.deg).wrap_at(180*units.deg) + astro_pa = Angle(astro_pa, unit=units.deg).wrap_at(180 * units.deg) + casa_pa = Angle(casa_pa * units.rad, unit=units.deg).wrap_at(180 * units.deg) # Difference in degrees, wrapped at 180 - diff = np.abs((astro_pa - casa_pa).wrap_at(180*units.deg)) + diff = np.abs((astro_pa - casa_pa).wrap_at(180 * units.deg)) assert np.all(np.abs(diff) < Angle(rtol)) @pytest.mark.flaky(min_passes=1, max_runs=3) -@pytest.mark.parametrize('backend', [ - 'test', - pytest.param('casa', marks=pytest.mark.skipif( - no_casa, - reason='python-casascore not installed')), - pytest.param('astropy', marks=pytest.mark.skipif( - no_astropy, - reason="astropy not installed"))]) -@pytest.mark.parametrize('observation', [(2018, 1, 1, 4)]) +@pytest.mark.parametrize( + "backend", + [ + "test", + pytest.param( + "casa", + marks=pytest.mark.skipif(no_casa, reason="python-casascore not installed"), + ), + pytest.param( + "astropy", + marks=pytest.mark.skipif(no_astropy, reason="astropy not installed"), + ), + ], +) +@pytest.mark.parametrize("observation", [(2018, 1, 1, 4)]) def test_dask_parallactic_angles(observation, wsrt_ants, backend): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") from africanus.rime import parallactic_angles as np_parangle from africanus.rime.dask import parallactic_angles as da_parangle diff --git a/africanus/rime/tests/test_predict.py b/africanus/rime/tests/test_predict.py index 041715162..b1e919a4a 100644 --- a/africanus/rime/tests/test_predict.py +++ b/africanus/rime/tests/test_predict.py @@ -14,54 +14,65 @@ def rf(*a, **kw): def rc(*a, **kw): - return rf(*a, **kw) + 1j*rf(*a, **kw) - - -chunk_parametrization = pytest.mark.parametrize("chunks", [ - { - 'source': (2, 3, 4, 2, 2, 2, 2, 2, 2), - 'time': (2, 1, 1), - 'rows': (4, 4, 2), - 'antenna': (4,), - 'channels': (3, 2), - }]) + return rf(*a, **kw) + 1j * rf(*a, **kw) + + +chunk_parametrization = pytest.mark.parametrize( + "chunks", + [ + { + "source": (2, 3, 4, 2, 2, 2, 2, 2, 2), + "time": (2, 1, 1), + "rows": (4, 4, 2), + "antenna": (4,), + "channels": (3, 2), + } + ], +) corr_shape_parametrization = pytest.mark.parametrize( - 'corr_shape, idm, einsum_sig1, einsum_sig2', [ + "corr_shape, idm, einsum_sig1, einsum_sig2", + [ ((1,), (1,), "srci,srci,srci->rci", "rci,rci,rci->rci"), ((2,), (1, 1), "srci,srci,srci->rci", "rci,rci,rci->rci"), - ((2, 2), ((1, 0), (0, 1)), - "srcij,srcjk,srclk->rcil", "rcij,rcjk,rclk->rcil") - ]) - - -dde_presence_parametrization = pytest.mark.parametrize('a1j,blj,a2j', [ - [True, True, True], - [True, False, True], - [False, True, False], -]) - -die_presence_parametrization = pytest.mark.parametrize('g1j,bvis,g2j', [ - [True, True, True], - [True, False, True], - [False, True, False], -]) + ((2, 2), ((1, 0), (0, 1)), "srcij,srcjk,srclk->rcil", "rcij,rcjk,rclk->rcil"), + ], +) + + +dde_presence_parametrization = pytest.mark.parametrize( + "a1j,blj,a2j", + [ + [True, True, True], + [True, False, True], + [False, True, False], + ], +) + +die_presence_parametrization = pytest.mark.parametrize( + "g1j,bvis,g2j", + [ + [True, True, True], + [True, False, True], + [False, True, False], + ], +) @corr_shape_parametrization @dde_presence_parametrization @die_presence_parametrization @chunk_parametrization -def test_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, - a1j, blj, a2j, g1j, bvis, g2j, - chunks): +def test_predict_vis( + corr_shape, idm, einsum_sig1, einsum_sig2, a1j, blj, a2j, g1j, bvis, g2j, chunks +): from africanus.rime.predict import predict_vis - s = sum(chunks['source']) - t = sum(chunks['time']) - a = sum(chunks['antenna']) - c = sum(chunks['channels']) - r = sum(chunks['rows']) + s = sum(chunks["source"]) + t = sum(chunks["time"]) + a = sum(chunks["antenna"]) + c = sum(chunks["channels"]) + r = sum(chunks["rows"]) a1_jones = rc((s, t, a, c) + corr_shape) bl_jones = rc((s, r, c) + corr_shape) @@ -77,13 +88,17 @@ def test_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, assert ant1.size == r - model_vis = predict_vis(time_idx, ant1, ant2, - a1_jones if a1j else None, - bl_jones if blj else None, - a2_jones if a2j else None, - g1_jones if g1j else None, - base_vis if bvis else None, - g2_jones if g2j else None) + model_vis = predict_vis( + time_idx, + ant1, + ant2, + a1_jones if a1j else None, + bl_jones if blj else None, + a2_jones if a2j else None, + g1_jones if g1j else None, + base_vis if bvis else None, + g2_jones if g2j else None, + ) assert model_vis.shape == (r, c) + corr_shape @@ -115,11 +130,10 @@ def _id(array): @dde_presence_parametrization @die_presence_parametrization @chunk_parametrization -def test_dask_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, - a1j, blj, a2j, g1j, bvis, g2j, - chunks): - - da = pytest.importorskip('dask.array') +def test_dask_predict_vis( + corr_shape, idm, einsum_sig1, einsum_sig2, a1j, blj, a2j, g1j, bvis, g2j, chunks +): + da = pytest.importorskip("dask.array") import numpy as np import dask @@ -127,18 +141,18 @@ def test_dask_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, from africanus.rime.dask import predict_vis # chunk sizes - sc = chunks['source'] - tc = chunks['time'] - rrc = chunks['rows'] - ac = chunks['antenna'] - cc = chunks['channels'] + sc = chunks["source"] + tc = chunks["time"] + rrc = chunks["rows"] + ac = chunks["antenna"] + cc = chunks["channels"] # dimension sizes - s = sum(sc) # sources - t = sum(tc) # times - a = sum(ac) # antennas - c = sum(cc) # channels - r = sum(rrc) # rows + s = sum(sc) # sources + t = sum(tc) # times + a = sum(ac) # antennas + c = sum(cc) # channels + r = sum(rrc) # rows a1_jones = rc((s, t, a, c) + corr_shape) a2_jones = rc((s, t, a, c) + corr_shape) @@ -154,13 +168,17 @@ def test_dask_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, assert ant1.size == r - np_model_vis = np_predict_vis(time_idx, ant1, ant2, - a1_jones if a1j else None, - bl_jones if blj else None, - a2_jones if a2j else None, - g1_jones if g1j else None, - base_vis if bvis else None, - g2_jones if g2j else None) + np_model_vis = np_predict_vis( + time_idx, + ant1, + ant2, + a1_jones if a1j else None, + bl_jones if blj else None, + a2_jones if a2j else None, + g1_jones if g1j else None, + base_vis if bvis else None, + g2_jones if g2j else None, + ) da_time_idx = da.from_array(time_idx, chunks=rrc) da_ant1 = da.from_array(ant1, chunks=rrc) @@ -173,19 +191,22 @@ def test_dask_predict_vis(corr_shape, idm, einsum_sig1, einsum_sig2, da_base_vis = da.from_array(base_vis, chunks=(rrc, cc) + corr_shape) da_g2_jones = da.from_array(g2_jones, chunks=(tc, ac, cc) + corr_shape) - args = (da_time_idx, da_ant1, da_ant2, - da_a1_jones if a1j else None, - da_bl_jones if blj else None, - da_a2_jones if a2j else None, - da_g1_jones if g1j else None, - da_base_vis if bvis else None, - da_g2_jones if g2j else None) + args = ( + da_time_idx, + da_ant1, + da_ant2, + da_a1_jones if a1j else None, + da_bl_jones if blj else None, + da_a2_jones if a2j else None, + da_g1_jones if g1j else None, + da_base_vis if bvis else None, + da_g2_jones if g2j else None, + ) stream_model_vis = predict_vis(*args, streams=True) fan_model_vis = predict_vis(*args, streams=False) - stream_model_vis, fan_model_vis = dask.compute(stream_model_vis, - fan_model_vis) + stream_model_vis, fan_model_vis = dask.compute(stream_model_vis, fan_model_vis) assert_array_almost_equal(fan_model_vis, np_model_vis) assert_array_almost_equal(stream_model_vis, fan_model_vis) diff --git a/africanus/rime/tests/test_rime.py b/africanus/rime/tests/test_rime.py index eac29d4f8..c7ee8c4ff 100644 --- a/africanus/rime/tests/test_rime.py +++ b/africanus/rime/tests/test_rime.py @@ -13,19 +13,16 @@ def rf(*a, **kw): def rc(*a, **kw): - return rf(*a, **kw) + 1j*rf(*a, **kw) + return rf(*a, **kw) + 1j * rf(*a, **kw) -@pytest.mark.parametrize("convention, sign", [ - ('fourier', 1), - ('casa', -1) -]) +@pytest.mark.parametrize("convention, sign", [("fourier", 1), ("casa", -1)]) def test_phase_delay(convention, sign): from africanus.rime import phase_delay uvw = np.random.random(size=(100, 3)) lm = np.random.random(size=(10, 2)) - frequency = np.linspace(.856e9, .856e9*2, 64, endpoint=True) + frequency = np.linspace(0.856e9, 0.856e9 * 2, 64, endpoint=True) from africanus.constants import minus_two_pi_over_c @@ -46,8 +43,8 @@ def test_phase_delay(convention, sign): # Test singular value vs a point in the output n = np.sqrt(1.0 - l**2 - m**2) - 1.0 - phase = sign*minus_two_pi_over_c*(u*l + v*m + w*n)*freq - assert np.all(np.exp(1j*phase) == complex_phase[lm_i, uvw_i, freq_i]) + phase = sign * minus_two_pi_over_c * (u * l + v * m + w * n) * freq + assert np.all(np.exp(1j * phase) == complex_phase[lm_i, uvw_i, freq_i]) def test_feed_rotation(): @@ -58,37 +55,36 @@ def test_feed_rotation(): pa_sin = np.sin(parangles) pa_cos = np.cos(parangles) - fr = feed_rotation(parangles, feed_type='linear') + fr = feed_rotation(parangles, feed_type="linear") np_expr = np.stack([pa_cos, pa_sin, -pa_sin, pa_cos], axis=2) assert np.allclose(fr, np_expr.reshape(10, 5, 2, 2)) - fr = feed_rotation(parangles, feed_type='circular') + fr = feed_rotation(parangles, feed_type="circular") zeros = np.zeros_like(pa_sin) - np_expr = np.stack([pa_cos - 1j*pa_sin, zeros, - zeros, pa_cos + 1j*pa_sin], axis=2) + np_expr = np.stack( + [pa_cos - 1j * pa_sin, zeros, zeros, pa_cos + 1j * pa_sin], axis=2 + ) assert np.allclose(fr, np_expr.reshape(10, 5, 2, 2)) -@pytest.mark.parametrize("convention, sign", [ - ('fourier', 1), - ('casa', -1) -]) +@pytest.mark.parametrize("convention, sign", [("fourier", 1), ("casa", -1)]) def test_dask_phase_delay(convention, sign): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") from africanus.rime import phase_delay as np_phase_delay from africanus.rime.dask import phase_delay as dask_phase_delay # So that 1 > 1 - l**2 - m**2 >= 0 - lm = np.random.random(size=(10, 2))*0.01 + lm = np.random.random(size=(10, 2)) * 0.01 uvw = np.random.random(size=(100, 3)) - frequency = np.linspace(.856e9, .856e9*2, 64, endpoint=True) + frequency = np.linspace(0.856e9, 0.856e9 * 2, 64, endpoint=True) dask_lm = da.from_array(lm, chunks=(5, 2)) dask_uvw = da.from_array(uvw, chunks=(25, 3)) dask_frequency = da.from_array(frequency, chunks=16) - dask_phase = dask_phase_delay(dask_lm, dask_uvw, dask_frequency, - convention=convention) + dask_phase = dask_phase_delay( + dask_lm, dask_uvw, dask_frequency, convention=convention + ) np_phase = np_phase_delay(lm, uvw, frequency, convention=convention) # Should agree completely @@ -96,7 +92,7 @@ def test_dask_phase_delay(convention, sign): def test_dask_feed_rotation(): - da = pytest.importorskip('dask.array') + da = pytest.importorskip("dask.array") import numpy as np from africanus.rime import feed_rotation as np_feed_rotation from africanus.rime.dask import feed_rotation @@ -104,8 +100,8 @@ def test_dask_feed_rotation(): parangles = np.random.random((10, 5)) dask_parangles = da.from_array(parangles, chunks=(5, (2, 3))) - np_fr = np_feed_rotation(parangles, feed_type='linear') - assert np.all(np_fr == feed_rotation(dask_parangles, feed_type='linear')) + np_fr = np_feed_rotation(parangles, feed_type="linear") + assert np.all(np_fr == feed_rotation(dask_parangles, feed_type="linear")) - np_fr = np_feed_rotation(parangles, feed_type='circular') - assert np.all(np_fr == feed_rotation(dask_parangles, feed_type='circular')) + np_fr = np_feed_rotation(parangles, feed_type="circular") + assert np.all(np_fr == feed_rotation(dask_parangles, feed_type="circular")) diff --git a/africanus/rime/tests/test_wsclean_predict.py b/africanus/rime/tests/test_wsclean_predict.py index 1092feb8b..234681023 100644 --- a/africanus/rime/tests/test_wsclean_predict.py +++ b/africanus/rime/tests/test_wsclean_predict.py @@ -9,21 +9,25 @@ from africanus.model.wsclean.spec_model import spectra from africanus.rime.wsclean_predict import wsclean_predict -chunk_parametrization = pytest.mark.parametrize("chunks", [ - { - 'source': (2, 3, 4, 2, 2, 2, 2, 2, 2), - 'time': (2, 1, 1), - 'rows': (4, 4, 2), - 'antenna': (4,), - 'channels': (3, 2), - }]) +chunk_parametrization = pytest.mark.parametrize( + "chunks", + [ + { + "source": (2, 3, 4, 2, 2, 2, 2, 2, 2), + "time": (2, 1, 1), + "rows": (4, 4, 2), + "antenna": (4,), + "channels": (3, 2), + } + ], +) @chunk_parametrization def test_wsclean_predict(chunks): - row = sum(chunks['rows']) - src = sum(chunks['source']) - chan = sum(chunks['channels']) + row = sum(chunks["rows"]) + src = sum(chunks["source"]) + chan = sum(chunks["channels"]) rs = np.random.RandomState(42) source_sel = rs.randint(0, 2, src).astype(np.bool_) @@ -31,22 +35,23 @@ def test_wsclean_predict(chunks): gauss_shape = rs.normal(size=(src, 3)) uvw = rs.normal(size=(row, 3)) - lm = rs.normal(size=(src, 2))*1e-5 + lm = rs.normal(size=(src, 2)) * 1e-5 flux = rs.normal(size=src) coeffs = rs.normal(size=(src, 2)) log_poly = rs.randint(0, 2, src, dtype=np.bool_) flux[log_poly] = np.abs(flux[log_poly]) coeffs[log_poly] = np.abs(coeffs[log_poly]) - freq = np.linspace(.856e9, 2*.856e9, chan) + freq = np.linspace(0.856e9, 2 * 0.856e9, chan) ref_freq = np.full(src, freq[freq.shape[0] // 2]) # WSClean visibilities - vis = wsclean_predict(uvw, lm, source_type, flux, coeffs, - log_poly, ref_freq, gauss_shape, freq) + vis = wsclean_predict( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, freq + ) # Compute it another way. Note the CASA coordinate convention # used by wsclean - phase = phase_delay(lm, uvw, freq, convention='casa') + phase = phase_delay(lm, uvw, freq, convention="casa") spectrum = spectra(flux, coeffs, log_poly, ref_freq, freq) shape = gaussian(uvw, freq, gauss_shape) # point sources don't' contribute to the shape @@ -60,12 +65,11 @@ def test_wsclean_predict(chunks): def test_dask_wsclean_predict(chunks): da = pytest.importorskip("dask.array") - from africanus.rime.dask_predict import ( - wsclean_predict as dask_wsclean_predict) + from africanus.rime.dask_predict import wsclean_predict as dask_wsclean_predict - row = sum(chunks['rows']) - src = sum(chunks['source']) - chan = sum(chunks['channels']) + row = sum(chunks["rows"]) + src = sum(chunks["source"]) + chan = sum(chunks["channels"]) rs = np.random.RandomState(42) source_sel = rs.randint(0, 2, src).astype(np.bool_) @@ -74,31 +78,38 @@ def test_dask_wsclean_predict(chunks): gauss_shape = rs.normal(size=(src, 3)) uvw = rs.normal(size=(row, 3)) - lm = rs.normal(size=(src, 2))*1e-5 + lm = rs.normal(size=(src, 2)) * 1e-5 flux = rs.normal(size=src) coeffs = rs.normal(size=(src, 2)) log_poly = rs.randint(0, 2, src, dtype=np.bool_) flux[log_poly] = np.abs(flux[log_poly]) coeffs[log_poly] = np.abs(coeffs[log_poly]) - freq = np.linspace(.856e9, 2*.856e9, chan) + freq = np.linspace(0.856e9, 2 * 0.856e9, chan) ref_freq = np.full(src, freq[freq.shape[0] // 2]) - da_uvw = da.from_array(uvw, chunks=(chunks['rows'], 3)) - da_lm = da.from_array(lm, chunks=(chunks['source'], 2)) - da_source_type = da.from_array(source_type, chunks=chunks['source']) - da_gauss_shape = da.from_array(gauss_shape, chunks=(chunks['source'], 3)) - da_flux = da.from_array(flux, chunks=chunks['source']) - da_coeffs = da.from_array(coeffs, chunks=(chunks['source'], 2)) - da_log_poly = da.from_array(log_poly, chunks=chunks['source']) - da_ref_freq = da.from_array(ref_freq, chunks=chunks['source']) + da_uvw = da.from_array(uvw, chunks=(chunks["rows"], 3)) + da_lm = da.from_array(lm, chunks=(chunks["source"], 2)) + da_source_type = da.from_array(source_type, chunks=chunks["source"]) + da_gauss_shape = da.from_array(gauss_shape, chunks=(chunks["source"], 3)) + da_flux = da.from_array(flux, chunks=chunks["source"]) + da_coeffs = da.from_array(coeffs, chunks=(chunks["source"], 2)) + da_log_poly = da.from_array(log_poly, chunks=chunks["source"]) + da_ref_freq = da.from_array(ref_freq, chunks=chunks["source"]) da_freq = da.from_array(freq) - vis = wsclean_predict(uvw, lm, source_type, flux, - coeffs, log_poly, ref_freq, - gauss_shape, freq) - da_vis = dask_wsclean_predict(da_uvw, da_lm, da_source_type, - da_flux, da_coeffs, - da_log_poly, da_ref_freq, - da_gauss_shape, da_freq) + vis = wsclean_predict( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, freq + ) + da_vis = dask_wsclean_predict( + da_uvw, + da_lm, + da_source_type, + da_flux, + da_coeffs, + da_log_poly, + da_ref_freq, + da_gauss_shape, + da_freq, + ) assert_almost_equal(vis, da_vis) diff --git a/africanus/rime/tests/test_zernike.py b/africanus/rime/tests/test_zernike.py index 488669a1c..7291a92d4 100644 --- a/africanus/rime/tests/test_zernike.py +++ b/africanus/rime/tests/test_zernike.py @@ -3,11 +3,11 @@ def test_zernike_func_xx_corr(coeff_xx, noll_index_xx, eidos_data_xx): - """ Tests reconstruction of xx correlation against eidos """ + """Tests reconstruction of xx correlation against eidos""" from africanus.rime import zernike_dde npix = 17 - nsrc = npix ** 2 + nsrc = npix**2 ntime = 1 na = 1 nchan = 1 @@ -54,11 +54,11 @@ def test_zernike_func_xx_corr(coeff_xx, noll_index_xx, eidos_data_xx): def test_zernike_func_xy_corr(coeff_xy, noll_index_xy, eidos_data_xy): - """ Tests reconstruction of xy correlation against eidos """ + """Tests reconstruction of xy correlation against eidos""" from africanus.rime import zernike_dde npix = 17 - nsrc = npix ** 2 + nsrc = npix**2 ntime = 1 na = 1 nchan = 1 @@ -105,11 +105,11 @@ def test_zernike_func_xy_corr(coeff_xy, noll_index_xy, eidos_data_xy): def test_zernike_func_yx_corr(coeff_yx, noll_index_yx, eidos_data_yx): - """ Tests reconstruction of yx correlation against eidos """ + """Tests reconstruction of yx correlation against eidos""" from africanus.rime import zernike_dde npix = 17 - nsrc = npix ** 2 + nsrc = npix**2 ntime = 1 na = 1 nchan = 1 @@ -156,11 +156,11 @@ def test_zernike_func_yx_corr(coeff_yx, noll_index_yx, eidos_data_yx): def test_zernike_func_yy_corr(coeff_yy, noll_index_yy, eidos_data_yy): - """ Tests reconstruction of yy correlation against eidos """ + """Tests reconstruction of yy correlation against eidos""" from africanus.rime import zernike_dde npix = 17 - nsrc = npix ** 2 + nsrc = npix**2 ntime = 1 na = 1 nchan = 1 @@ -207,11 +207,11 @@ def test_zernike_func_yy_corr(coeff_yy, noll_index_yy, eidos_data_yy): def test_zernike_multiple_dims(coeff_xx, noll_index_xx): - """ Tests that we can call zernike_dde with multiple dimensions """ + """Tests that we can call zernike_dde with multiple dimensions""" from africanus.rime import zernike_dde as np_zernike_dde npix = 17 - nsrc = npix ** 2 + nsrc = npix**2 ntime = 10 na = 7 nchan = 8 @@ -258,14 +258,14 @@ def test_zernike_multiple_dims(coeff_xx, noll_index_xx): def test_dask_zernike(coeff_xx, noll_index_xx): - """ Tests that dask zernike_dde agrees with numpy zernike_dde """ + """Tests that dask zernike_dde agrees with numpy zernike_dde""" da = pytest.importorskip("dask.array") from africanus.rime.dask import zernike_dde from africanus.rime import zernike_dde as np_zernike_dde npix = 17 - nsrc = npix ** 2 + nsrc = npix**2 ntime = 10 na = 7 nchan = 8 @@ -317,9 +317,7 @@ def test_dask_zernike(coeff_xx, noll_index_xx): coords = da.from_array(coords, (3, npix, time_c, ant_c, chan_c)) coeffs = da.from_array(coeffs, (ant_c, chan_c, corr1, corr2, npoly)) - noll_indices = da.from_array( - noll_indices, (ant_c, chan_c, corr1, corr2, npoly) - ) + noll_indices = da.from_array(noll_indices, (ant_c, chan_c, corr1, corr2, npoly)) parallactic_angles = da.from_array(parallactic_angles) frequency_scaling = da.from_array(frequency_scaling) @@ -424,9 +422,7 @@ def coeff_yy(): @pytest.fixture def noll_index_xx(): - return np.array( - [10, 3, 21, 36, 0, 55, 16, 28, 37, 46, 23, 6, 15, 2, 5, 7, 57] - ) + return np.array([10, 3, 21, 36, 0, 55, 16, 28, 37, 46, 23, 6, 15, 2, 5, 7, 57]) @pytest.fixture @@ -441,9 +437,7 @@ def noll_index_yx(): @pytest.fixture def noll_index_yy(): - return np.array( - [10, 3, 21, 36, 0, 55, 28, 16, 11, 23, 37, 46, 6, 2, 15, 5, 29] - ) + return np.array([10, 3, 21, 36, 0, 55, 28, 16, 11, 23, 37, 46, 6, 2, 15, 5, 29]) @pytest.fixture diff --git a/africanus/rime/transform.py b/africanus/rime/transform.py index eed0e5526..6e9686f6d 100644 --- a/africanus/rime/transform.py +++ b/africanus/rime/transform.py @@ -10,8 +10,9 @@ @jit(nopython=True, nogil=True, cache=True) -def _nb_transform_sources(lm, parallactic_angles, pointing_errors, - antenna_scaling, frequency, coords): +def _nb_transform_sources( + lm, parallactic_angles, pointing_errors, antenna_scaling, frequency, coords +): """ numba implementation of :func:`~africanus.rime.transform_sources` @@ -27,8 +28,8 @@ def _nb_transform_sources(lm, parallactic_angles, pointing_errors, l, m = lm[s] # Rotate source coordinate by parallactic angle - l = l*pa_cos - m*pa_sin # noqa - m = l*pa_sin + m*pa_cos + l = l * pa_cos - m * pa_sin # noqa + m = l * pa_sin + m * pa_cos # Add pointing errors l += pointing_errors[t, a, 0] # noqa @@ -36,15 +37,16 @@ def _nb_transform_sources(lm, parallactic_angles, pointing_errors, # Scale by antenna scaling factors for c in range(nchan): - coords[0, s, t, a, c] = l*antenna_scaling[a, c] - coords[1, s, t, a, c] = m*antenna_scaling[a, c] + coords[0, s, t, a, c] = l * antenna_scaling[a, c] + coords[1, s, t, a, c] = m * antenna_scaling[a, c] coords[2, s, t, a, c] = frequency[c] return coords -def transform_sources(lm, parallactic_angles, pointing_errors, - antenna_scaling, frequency, dtype=None): +def transform_sources( + lm, parallactic_angles, pointing_errors, antenna_scaling, frequency, dtype=None +): """ Creates beam sampling coordinates suitable for use in :func:`~africanus.rime.beam_cube_dde` by: @@ -92,5 +94,6 @@ def transform_sources(lm, parallactic_angles, pointing_errors, dtype = np.float64 if dtype is None else dtype coords = np.empty((3, nsrc, ntime, na, nchan), dtype=dtype) - return _nb_transform_sources(lm, parallactic_angles, pointing_errors, - antenna_scaling, frequency, coords) + return _nb_transform_sources( + lm, parallactic_angles, pointing_errors, antenna_scaling, frequency, coords + ) diff --git a/africanus/rime/wsclean_predict.py b/africanus/rime/wsclean_predict.py index c490dbc07..fc86ccfde 100644 --- a/africanus/rime/wsclean_predict.py +++ b/africanus/rime/wsclean_predict.py @@ -9,9 +9,7 @@ @njit(**JIT_OPTIONS) -def wsclean_predict_main(uvw, lm, source_type, gauss_shape, - frequency, spectrum, dtype): - +def wsclean_predict_main(uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype): fwhm = 2.0 * np.sqrt(2.0 * np.log(2.0)) fwhminv = 1.0 / fwhm gauss_scale = fwhminv * np.sqrt(2.0) * np.pi / lightspeed @@ -30,7 +28,7 @@ def wsclean_predict_main(uvw, lm, source_type, gauss_shape, for s in range(nsrc): l = lm[s, 0] # noqa m = lm[s, 1] - n = np.sqrt(n1 - l*l - m*m) - n1 + n = np.sqrt(n1 - l * l - m * m) - n1 if source_type[s] == "POINT": for r in range(nrow): @@ -38,14 +36,14 @@ def wsclean_predict_main(uvw, lm, source_type, gauss_shape, v = uvw[r, 1] w = uvw[r, 2] - real_phase = two_pi_over_c*(u*l + v*m + w*n) + real_phase = two_pi_over_c * (u * l + v * m + w * n) for f in range(nchan): p = real_phase * frequency[f] re = np.cos(p) * spectrum[s, f] im = np.sin(p) * spectrum[s, f] - vis[r, f, 0] += re + im*1j + vis[r, f, 0] += re + im * 1j elif source_type[s] == "GAUSSIAN": emaj, emin, angle = gauss_shape[s] @@ -60,11 +58,11 @@ def wsclean_predict_main(uvw, lm, source_type, gauss_shape, w = uvw[r, 2] # Compute phase term - real_phase = two_pi_over_c*(u*l + v*m + w*n) + real_phase = two_pi_over_c * (u * l + v * m + w * n) # Gaussian shape term bits - u1 = (u*em - v*el)*er - v1 = u*el + v*em + u1 = (u * em - v * el) * er + v1 = u * el + v * em for f in range(nchan): p = real_phase * frequency[f] @@ -78,43 +76,50 @@ def wsclean_predict_main(uvw, lm, source_type, gauss_shape, re *= shape im *= shape - vis[r, f, 0] += re + im*1j + vis[r, f, 0] += re + im * 1j else: - raise ValueError("source_type must be " - "POINT or GAUSSIAN") + raise ValueError("source_type must be " "POINT or GAUSSIAN") return vis @njit(**JIT_OPTIONS) -def wsclean_predict(uvw, lm, source_type, flux, coeffs, - log_poly, ref_freq, gauss_shape, frequency): - return wsclean_predict_impl(uvw, lm, source_type, flux, coeffs, - log_poly, ref_freq, gauss_shape, frequency) +def wsclean_predict( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency +): + return wsclean_predict_impl( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency + ) -def wsclean_predict_impl(uvw, lm, source_type, flux, coeffs, - log_poly, ref_freq, gauss_shape, frequency): +def wsclean_predict_impl( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency +): raise NotImplementedError @overload(wsclean_predict_impl, jit_options=JIT_OPTIONS) -def nb_wsclean_predict(uvw, lm, source_type, flux, coeffs, - log_poly, ref_freq, gauss_shape, frequency): - arg_dtypes = tuple(np.dtype(a.dtype.name) for a - in (uvw, lm, flux, coeffs, ref_freq, frequency)) +def nb_wsclean_predict( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency +): + arg_dtypes = tuple( + np.dtype(a.dtype.name) for a in (uvw, lm, flux, coeffs, ref_freq, frequency) + ) dtype = np.result_type(np.complex64, *arg_dtypes) - def impl(uvw, lm, source_type, flux, coeffs, log_poly, - ref_freq, gauss_shape, frequency): + def impl( + uvw, lm, source_type, flux, coeffs, log_poly, ref_freq, gauss_shape, frequency + ): spectrum = spectra(flux, coeffs, log_poly, ref_freq, frequency) - return wsclean_predict_main(uvw, lm, source_type, gauss_shape, - frequency, spectrum, dtype) + return wsclean_predict_main( + uvw, lm, source_type, gauss_shape, frequency, spectrum, dtype + ) return impl -WSCLEAN_PREDICT_DOCS = DocstringTemplate(""" +WSCLEAN_PREDICT_DOCS = DocstringTemplate( + """ Predict visibilities from a `WSClean sky model `_. @@ -155,7 +160,9 @@ def impl(uvw, lm, source_type, flux, coeffs, log_poly, ------- visibilities : $(array_type) Complex visibilities of shape :code:`(row, chan, 1)` -""") +""" +) wsclean_predict.__doc__ = WSCLEAN_PREDICT_DOCS.substitute( - array_type=":class:`numpy.ndarray`") + array_type=":class:`numpy.ndarray`" +) diff --git a/africanus/rime/zernike.py b/africanus/rime/zernike.py index 98e3e96f9..3a88a6c67 100644 --- a/africanus/rime/zernike.py +++ b/africanus/rime/zernike.py @@ -53,9 +53,7 @@ def zernike(j, rho, phi): @jit(nogil=True, nopython=True, cache=True) def _convert_coords(l_coords, m_coords): - rho, phi = ((l_coords ** 2 + m_coords ** 2) ** 0.5), np.arctan2( - l_coords, m_coords - ) + rho, phi = ((l_coords**2 + m_coords**2) ** 0.5), np.arctan2(l_coords, m_coords) return rho, phi @@ -118,7 +116,7 @@ def zernike_dde( antenna_scaling, pointing_errors, ): - """ Wrapper for :func:`nb_zernike_dde` """ + """Wrapper for :func:`nb_zernike_dde`""" _, sources, times, ants, chans = coords.shape # ant, chan, corr_1, ..., corr_n, poly corr_shape = coeffs.shape[2:-1] diff --git a/africanus/testing/beam_factory.py b/africanus/testing/beam_factory.py index 653a6d6a7..25ae16494 100644 --- a/africanus/testing/beam_factory.py +++ b/africanus/testing/beam_factory.py @@ -33,41 +33,43 @@ } -@requires_optional('astropy.io.fits', opt_import_error) -def beam_factory(polarisation_type='linear', - frequency=None, - npix=257, - dtype=np.float64, - schema=DEFAULT_SCHEMA, - overwrite=True): - """ Generate a MeqTrees compliant beam cube """ +@requires_optional("astropy.io.fits", opt_import_error) +def beam_factory( + polarisation_type="linear", + frequency=None, + npix=257, + dtype=np.float64, + schema=DEFAULT_SCHEMA, + overwrite=True, +): + """Generate a MeqTrees compliant beam cube""" if npix % 2 != 1: raise ValueError("npix '%d' must be odd" % npix) # MeerKAT l-band, 64 channels if frequency is None: - frequency = np.linspace(.856e9, .856e9*2, 64, - endpoint=True, dtype=np.float64) + frequency = np.linspace( + 0.856e9, 0.856e9 * 2, 64, endpoint=True, dtype=np.float64 + ) # Generate a linear space of grid frequencies - gfrequency = np.linspace(frequency[0], frequency[-1], - 33, dtype=np.float64) + gfrequency = np.linspace(frequency[0], frequency[-1], 33, dtype=np.float64) bandwidth = gfrequency[-1] - frequency[0] - bandwidth_delta = bandwidth / gfrequency.shape[0]-1 + bandwidth_delta = bandwidth / gfrequency.shape[0] - 1 - if polarisation_type == 'linear': + if polarisation_type == "linear": CORR = LINEAR_CORRELATIONS - elif polarisation_type == 'circular': + elif polarisation_type == "circular": CORR = CIRCULAR_CORRELATIONS else: raise ValueError("Invalid polarisation_type %s" % polarisation_type) extent_deg = 3.0 - coords = np.linspace(-extent_deg/2, extent_deg/2, npix, endpoint=True) + coords = np.linspace(-extent_deg / 2, extent_deg / 2, npix, endpoint=True) - crpix = 1 + npix // 2 # Reference pixel (FORTRAN) - crval = coords[crpix - 1] # Reference value + crpix = 1 + npix // 2 # Reference pixel (FORTRAN) + crval = coords[crpix - 1] # Reference value cdelt = extent_deg / (npix - 1) # Delta # List of key values of the form: @@ -79,56 +81,61 @@ def beam_factory(polarisation_type='linear', # We put them in a list so that they are added to the # FITS header in the correct order axis1 = [ - ("CTYPE", ('X', "points right on the sky")), - ("CUNIT", ('DEG', 'degrees')), + ("CTYPE", ("X", "points right on the sky")), + ("CUNIT", ("DEG", "degrees")), ("NAXIS", (npix, "number of X")), ("CRPIX", (crpix, "reference pixel (one relative)")), ("CRVAL", (crval, "degrees")), - ("CDELT", (cdelt, "degrees"))] + ("CDELT", (cdelt, "degrees")), + ] axis2 = [ - ("CTYPE", ('Y', "points up on the sky")), - ("CUNIT", ('DEG', 'degrees')), + ("CTYPE", ("Y", "points up on the sky")), + ("CUNIT", ("DEG", "degrees")), ("NAXIS", (npix, "number of Y")), ("CRPIX", (crpix, "reference pixel (one relative)")), ("CRVAL", (crval, "degrees")), - ("CDELT", (cdelt, "degrees"))] + ("CDELT", (cdelt, "degrees")), + ] axis3 = [ - ("CTYPE", ('FREQ', )), + ("CTYPE", ("FREQ",)), ("CUNIT", None), ("NAXIS", (gfrequency.shape[0], "number of FREQ")), ("CRPIX", (1, "reference frequency position")), ("CRVAL", (gfrequency[0], "reference frequency")), - ("CDELT", (bandwidth_delta, "frequency step in Hz"))] + ("CDELT", (bandwidth_delta, "frequency step in Hz")), + ] axes = [axis1, axis2, axis3] metadata = [ - ('SIMPLE', True), - ('BITPIX', BITPIX_MAP[dtype]), - ('NAXIS', len(axes)), - ('OBSERVER', "Astronomer McAstronomerFace"), - ('ORIGIN', "Artificial"), - ('TELESCOP', "Telescope"), - ('OBJECT', 'beam'), - ('EQUINOX', 2000.0), + ("SIMPLE", True), + ("BITPIX", BITPIX_MAP[dtype]), + ("NAXIS", len(axes)), + ("OBSERVER", "Astronomer McAstronomerFace"), + ("ORIGIN", "Artificial"), + ("TELESCOP", "Telescope"), + ("OBJECT", "beam"), + ("EQUINOX", 2000.0), ] # Create header and set metadata header = fits.Header(metadata) # Now set the key value entries for each axis - ax_info = [('%s%d' % (k, a),) + vt - for a, axis_data in enumerate(axes, 1) - for k, vt in axis_data - if vt is not None] + ax_info = [ + ("%s%d" % (k, a),) + vt + for a, axis_data in enumerate(axes, 1) + for k, vt in axis_data + if vt is not None + ] header.update(ax_info) # Now setup the GFREQS # Jitter them randomly, except for the endpoints - frequency_jitter = np.random.random(size=gfrequency.shape)-0.5 - frequency_jitter *= 0.1*bandwidth_delta + frequency_jitter = np.random.random(size=gfrequency.shape) - 0.5 + frequency_jitter *= 0.1 * bandwidth_delta frequency_jitter[0] = frequency_jitter[-1] = 0.0 gfrequency += frequency_jitter @@ -136,16 +143,16 @@ def beam_factory(polarisation_type='linear', assert np.all(np.diff(gfrequency) >= 0.0) for i, gfreq in enumerate(gfrequency, 1): - header['GFREQ%d' % i] = gfreq + header["GFREQ%d" % i] = gfreq # Figure out the beam filenames from the schema filenames = beam_filenames(str(schema), CORR) # Westerbork beam model coords = np.deg2rad(coords) - r = np.sqrt(coords[None, :, None]**2 + coords[None, None, :]**2) + r = np.sqrt(coords[None, :, None] ** 2 + coords[None, None, :] ** 2) fq = gfrequency[:, None, None] - beam = np.cos(np.minimum(65*fq*1e-9*r, 1.0881))**3 + beam = np.cos(np.minimum(65 * fq * 1e-9 * r, 1.0881)) ** 3 for filename in [f for ri_pair in filenames.values() for f in ri_pair]: primary_hdu = fits.PrimaryHDU(beam, header=header) diff --git a/africanus/testing/tests/test_beam_factory.py b/africanus/testing/tests/test_beam_factory.py index 15da64ddc..8f1c3144a 100644 --- a/africanus/testing/tests/test_beam_factory.py +++ b/africanus/testing/tests/test_beam_factory.py @@ -8,12 +8,10 @@ @pytest.mark.parametrize("pol_type", ["linear", "circular"]) def test_beam_factory(tmp_path, pol_type): - fits = pytest.importorskip('astropy.io.fits') + fits = pytest.importorskip("astropy.io.fits") schema = tmp_path / "test_beam_$(corr)_$(reim).fits" - filenames = beam_factory(schema=schema, - npix=15, - polarisation_type=pol_type) + filenames = beam_factory(schema=schema, npix=15, polarisation_type=pol_type) for corr, (re_file, im_file) in filenames.items(): with fits.open(re_file), fits.open(im_file): diff --git a/africanus/util/beams.py b/africanus/util/beams.py index 8fffeff45..91c56946d 100644 --- a/africanus/util/beams.py +++ b/africanus/util/beams.py @@ -17,22 +17,21 @@ class FitsAxes(object): def __init__(self, header=None): # Create an zero-dimensional object if no header supplied - self._ndims = ndims = 0 if header is None else header['NAXIS'] + self._ndims = ndims = 0 if header is None else header["NAXIS"] # Extract header information for each dimension - axr = list(range(1, ndims+1)) - self._naxis = [header.get('NAXIS%d' % n) for n in axr] - self._ctype = [header.get('CTYPE%d' % n, n).strip() for n in axr] - self._crval = [header.get('CRVAL%d' % n, 0) for n in axr] + axr = list(range(1, ndims + 1)) + self._naxis = [header.get("NAXIS%d" % n) for n in axr] + self._ctype = [header.get("CTYPE%d" % n, n).strip() for n in axr] + self._crval = [header.get("CRVAL%d" % n, 0) for n in axr] # Convert right pixel from FORTRAN to C indexing - self._crpix = [header['CRPIX%d' % n]-1 for n in axr] - self._cdelt = [header.get('CDELT%d' % n, 1) for n in axr] - self._cunit = [header.get('CUNIT%d' % n, '').strip().upper() - for n in axr] + self._crpix = [header["CRPIX%d" % n] - 1 for n in axr] + self._cdelt = [header.get("CDELT%d" % n, 1) for n in axr] + self._cunit = [header.get("CUNIT%d" % n, "").strip().upper() for n in axr] def axis_and_sign(ax_str, default=None): - """ Extract axis and sign from given axis string """ + """Extract axis and sign from given axis string""" if not ax_str: if default: return default, 1.0 @@ -42,7 +41,7 @@ def axis_and_sign(ax_str, default=None): if not isinstance(ax_str, str): raise TypeError("ax_str must be a string") - return (ax_str[1:], -1.0) if ax_str[0] == '-' else (ax_str, 1.0) + return (ax_str[1:], -1.0) if ax_str[0] == "-" else (ax_str, 1.0) class BeamAxes(FitsAxes): @@ -70,32 +69,39 @@ def __init__(self, header=None): super(BeamAxes, self).__init__(header) # Check for custom irregular grid format. # Currently only implemented for FREQ dimension. - irregular_grid = [np.asarray( - [header.get('G%s%d' % (self._ctype[i], j), None) - for j in range(1, self._naxis[i]+1)]) - for i in range(self._ndims)] + irregular_grid = [ + np.asarray( + [ + header.get("G%s%d" % (self._ctype[i], j), None) + for j in range(1, self._naxis[i] + 1) + ] + ) + for i in range(self._ndims) + ] # Irregular grids are only valid if values exist for all grid points - self._irreg = [all(x is not None for x in irregular_grid[i]) - for i in range(self._ndims)] + self._irreg = [ + all(x is not None for x in irregular_grid[i]) for i in range(self._ndims) + ] def _regular_grid(i): - """ Construct a regular grid from a FitsAxes object and index """ + """Construct a regular grid from a FitsAxes object and index""" R = np.arange(0.0, float(self._naxis[i])) - return (R - self._crpix[i])*self._cdelt[i] + self._crval[i] + return (R - self._crpix[i]) * self._cdelt[i] + self._crval[i] - self._grid = [None]*self._ndims + self._grid = [None] * self._ndims for i in range(self._ndims): # Convert any degree axes to radians - if self._cunit[i] == 'DEG': - self._cunit[i] = 'RAD' + if self._cunit[i] == "DEG": + self._cunit[i] = "RAD" self._crval[i] = np.deg2rad(self._crval[i]) self._cdelt[i] = np.deg2rad(self._cdelt[i]) # Set up the grid - self._grid[i] = (_regular_grid(i) if not self._irreg[i] - else irregular_grid[i]) + self._grid[i] = ( + _regular_grid(i) if not self._irreg[i] else irregular_grid[i] + ) @property def ndims(self): @@ -174,9 +180,9 @@ def beam_grids(header, l_axis=None, m_axis=None): # Find the relevant axes for i in range(beam_axes.ndims): - if beam_axes.ctype[i].upper() in ('L', 'X', 'PX'): + if beam_axes.ctype[i].upper() in ("L", "X", "PX"): l = i # noqa - elif beam_axes.ctype[i].upper() in ('M', 'Y', 'PY'): + elif beam_axes.ctype[i].upper() in ("M", "Y", "PY"): m = i elif beam_axes.ctype[i] == "FREQ": freq = i @@ -200,7 +206,7 @@ def beam_grids(header, l_axis=None, m_axis=None): m_grid = beam_axes.grid[m] * m_sign freq_grid = beam_axes.grid[freq] - return ((l+1, l_grid), (m+1, m_grid), (freq+1, freq_grid)) + return ((l + 1, l_grid), (m + 1, m_grid), (freq + 1, freq_grid)) class FitsFilenameTemplate(string.Template): @@ -209,6 +215,7 @@ class FitsFilenameTemplate(string.Template): with a $(identifier) braced pattern expected by FITS beam filename schema """ + pattern = r""" %(delim)s(?: (?P%(delim)s) | # Escape sequence of two delimiters @@ -216,13 +223,15 @@ class FitsFilenameTemplate(string.Template): \((?P%(id)s)\) | # delimiter and a braced identifier (?P) # Other ill-formed delimiter exprs ) - """ % {'delim': re.escape(string.Template.delimiter), - 'id': string.Template.idpattern} + """ % { + "delim": re.escape(string.Template.delimiter), + "id": string.Template.idpattern, + } -CIRCULAR_CORRELATIONS = ('rr', 'rl', 'lr', 'll') -LINEAR_CORRELATIONS = ('xx', 'xy', 'yx', 'yy') -REIM = ('re', 'im') +CIRCULAR_CORRELATIONS = ("rr", "rl", "lr", "ll") +LINEAR_CORRELATIONS = ("xx", "xy", "yx", "yy") +REIM = ("re", "im") def _re_im_filenames(corr, template): @@ -230,17 +239,17 @@ def _re_im_filenames(corr, template): for ri in REIM: try: - filename = template.substitute(corr=corr.lower(), - CORR=corr.upper(), - reim=ri.lower(), - REIM=ri.upper()) + filename = template.substitute( + corr=corr.lower(), CORR=corr.upper(), reim=ri.lower(), REIM=ri.upper() + ) except KeyError: - raise ValueError("Invalid filename schema '%s'. " - "FITS Beam filename schemas " - "must follow forms such as " - "'beam_$(corr)_$(reim).fits' or " - "'beam_$(CORR)_$(REIM).fits." - % template.template) + raise ValueError( + "Invalid filename schema '%s'. " + "FITS Beam filename schemas " + "must follow forms such as " + "'beam_$(corr)_$(reim).fits' or " + "'beam_$(CORR)_$(REIM).fits." % template.template + ) else: filenames.append(filename) @@ -301,5 +310,4 @@ def beam_filenames(filename_schema, corr_types): else: corr_names.append(corr_name.lower()) - return OrderedDict((c, _re_im_filenames(c, template)) - for c in corr_names) + return OrderedDict((c, _re_im_filenames(c, template)) for c in corr_names) diff --git a/africanus/util/casa_types.py b/africanus/util/casa_types.py index 94948fb4a..d840b154b 100644 --- a/africanus/util/casa_types.py +++ b/africanus/util/casa_types.py @@ -34,7 +34,8 @@ "Plinear", "PFtotal", "PFlinear", - "Pangle"] + "Pangle", +] """ List of stokes types as defined in Measurement Set 2.0 and Stokes.h in casacore: diff --git a/africanus/util/cmdline.py b/africanus/util/cmdline.py index c8e719c66..46ed130ff 100644 --- a/africanus/util/cmdline.py +++ b/africanus/util/cmdline.py @@ -6,7 +6,7 @@ import builtins # builtin function whitelist -_BUILTIN_WHITELIST = frozenset(['slice']) +_BUILTIN_WHITELIST = frozenset(["slice"]) _missing = _BUILTIN_WHITELIST.difference(dir(builtins)) if len(_missing) > 0: raise ValueError("'%s' are not valid builtin functions.'" % list(_missing)) @@ -50,10 +50,11 @@ def _eval_value(stmt_value): func_name = stmt_value.func.id if func_name not in _BUILTIN_WHITELIST: - raise ValueError("Function '%s' in '%s' is not builtin. " - "Available builtins: '%s'" % - (func_name, assign_str, - list(_BUILTIN_WHITELIST))) + raise ValueError( + "Function '%s' in '%s' is not builtin. " + "Available builtins: '%s'" + % (func_name, assign_str, list(_BUILTIN_WHITELIST)) + ) # Recursively pass arguments through this same function if stmt_value.args is not None: @@ -63,8 +64,7 @@ def _eval_value(stmt_value): # Recursively pass keyword arguments through this same function if stmt_value.keywords is not None: - kwargs = {kw.arg: _eval_value(kw.value) for kw - in stmt_value.keywords} + kwargs = {kw.arg: _eval_value(kw.value) for kw in stmt_value.keywords} else: kwargs = {} @@ -77,12 +77,14 @@ def _eval_value(stmt_value): variables = {} # Parse the assignment string - stmts = ast.parse(assign_str, mode='single').body + stmts = ast.parse(assign_str, mode="single").body for i, stmt in enumerate(stmts): if not isinstance(stmt, ast.Assign): - raise ValueError("Statement %d in '%s' is not a " - "variable assignment." % (i, assign_str)) + raise ValueError( + "Statement %d in '%s' is not a " + "variable assignment." % (i, assign_str) + ) # Evaluate assignment lhs values = _eval_value(stmt.value) @@ -99,9 +101,11 @@ def _eval_value(stmt_value): # Require all tuple/list elements to be variable names, # although anything else is probably a syntax error if not all(isinstance(e, ast.Name) for e in target.elts): - raise ValueError("Tuple unpacking in assignment %d " - "in expression '%s' failed as not all " - "tuple contents are variable names.") + raise ValueError( + "Tuple unpacking in assignment %d " + "in expression '%s' failed as not all " + "tuple contents are variable names." + ) # Promote for zip and length checking if not isinstance(values, (tuple, list)): @@ -110,17 +114,20 @@ def _eval_value(stmt_value): elements = values if not len(target.elts) == len(elements): - raise ValueError("Unpacking '%s' into a tuple/list in " - "assignment %d of expression '%s' " - "failed. The number of tuple elements " - "did not match the number of values." - % (values, i, assign_str)) + raise ValueError( + "Unpacking '%s' into a tuple/list in " + "assignment %d of expression '%s' " + "failed. The number of tuple elements " + "did not match the number of values." % (values, i, assign_str) + ) # Unpack for variable, value in zip(target.elts, elements): variables[variable.id] = value else: - raise TypeError("'%s' types are not supported" - "as assignment targets." % type(target)) + raise TypeError( + "'%s' types are not supported" + "as assignment targets." % type(target) + ) return variables diff --git a/africanus/util/code.py b/africanus/util/code.py index 3b3d78c72..37eb91e86 100644 --- a/africanus/util/code.py +++ b/africanus/util/code.py @@ -37,10 +37,9 @@ def format_code(code): str Code prefixed with line numbers """ - lines = [''] - lines.extend(["%-5d %s" % (i, l) for i, l - in enumerate(code.split('\n'), 1)]) - return '\n'.join(lines) + lines = [""] + lines.extend(["%-5d %s" % (i, l) for i, l in enumerate(code.split("\n"), 1)]) + return "\n".join(lines) class memoize_on_key(object): diff --git a/africanus/util/cub.py b/africanus/util/cub.py index 14268f129..6658e8aa8 100644 --- a/africanus/util/cub.py +++ b/africanus/util/cub.py @@ -18,14 +18,14 @@ from africanus.util.files import sha_hash_file _cub_dir = pjoin(include_dir, "cub") -_cub_url = 'https://github.com/NVlabs/cub/archive/1.8.0.zip' -_cub_sha_hash = '836f523a34c32a7e99fba36b30abfe7a68d41d4b' -_cub_version_str = 'Current release: v1.8.0 (02/16/2018)' +_cub_url = "https://github.com/NVlabs/cub/archive/1.8.0.zip" +_cub_sha_hash = "836f523a34c32a7e99fba36b30abfe7a68d41d4b" +_cub_version_str = "Current release: v1.8.0 (02/16/2018)" _cub_version = "1.8.0" -_cub_zip_dir = 'cub-' + _cub_version +_cub_zip_dir = "cub-" + _cub_version _cub_download_filename = "cub-" + _cub_version + ".zip" -_cub_header = pjoin(_cub_dir, 'cub', 'cub.cuh') -_cub_readme = pjoin(_cub_dir, 'README.md') +_cub_header = pjoin(_cub_dir, "cub", "cub.cuh") +_cub_readme = pjoin(_cub_dir, "README.md") _cub_new_unzipped_path = _cub_dir @@ -50,28 +50,25 @@ def download_cub(archive_file): def is_cub_installed(readme_filename, header_filename, cub_version_str): # Check if the cub.h exists - if (not os.path.exists(header_filename) or - not os.path.isfile(header_filename)): - + if not os.path.exists(header_filename) or not os.path.isfile(header_filename): reason = "CUB header '{}' does not exist".format(header_filename) return (False, reason) # Check if the README.md exists - if (not os.path.exists(readme_filename) or - not os.path.isfile(readme_filename)): - + if not os.path.exists(readme_filename) or not os.path.isfile(readme_filename): reason = "CUB readme '{}' does not exist".format(readme_filename) return (False, reason) # Search for the version string, returning True if found - with open(readme_filename, 'r') as f: + with open(readme_filename, "r") as f: for line in f: if line.find(cub_version_str) != -1: return (True, "") # Nothing found! reason = "CUB version string '{}' not found in '{}'".format( - cub_version_str, readme_filename) + cub_version_str, readme_filename + ) return (False, reason) @@ -91,22 +88,21 @@ def _install_cub(): sha_hash = download_cub(archive) # Compare against our supplied hash if _cub_sha_hash != sha_hash: - msg = ('Hash of file %s downloaded from %s ' - 'is %s and does not match the expected ' - 'hash of %s.') % ( - _cub_download_filename, _cub_url, - _cub_sha_hash, sha_hash) + msg = ( + "Hash of file %s downloaded from %s " + "is %s and does not match the expected " + "hash of %s." + ) % (_cub_download_filename, _cub_url, _cub_sha_hash, sha_hash) raise InstallCubException(msg) # Unzip into include/cub - with ZipFile(archive, 'r') as zip_file: + with ZipFile(archive, "r") as zip_file: # Remove any existing install try: shutil.rmtree(_cub_dir, ignore_errors=True) except Exception as e: - raise InstallCubException("Removing %s failed\n%s" % ( - _cub_dir, str(e))) + raise InstallCubException("Removing %s failed\n%s" % (_cub_dir, str(e))) try: # Unzip into temporary directory @@ -118,16 +114,14 @@ def _install_cub(): # Move shutil.move(unzip_path, _cub_dir) except Exception as e: - raise InstallCubException("Extracting %s failed\n%s" % ( - archive, str(e))) + raise InstallCubException("Extracting %s failed\n%s" % (archive, str(e))) finally: shutil.rmtree(tmpdir, ignore_errors=True) log.info("NVIDIA cub archive unzipped into '%s'" % _cub_dir) # Final check on installation - there, reason = is_cub_installed(_cub_readme, _cub_header, - _cub_version_str) + there, reason = is_cub_installed(_cub_readme, _cub_header, _cub_version_str) if not there: raise InstallCubException(reason) @@ -136,8 +130,7 @@ def _install_cub(): _cub_install_lock = Lock() with _cub_install_lock: - _cub_installed, _ = is_cub_installed(_cub_readme, _cub_header, - _cub_version_str) + _cub_installed, _ = is_cub_installed(_cub_readme, _cub_header, _cub_version_str) def cub_dir(): diff --git a/africanus/util/cuda.py b/africanus/util/cuda.py index 0e3437106..065a08620 100644 --- a/africanus/util/cuda.py +++ b/africanus/util/cuda.py @@ -23,51 +23,49 @@ cuda_fns = { np.dtype(np.float32): { - 'abs': 'fabsf', - 'cos': 'cosf', - 'floor': 'floorf', - 'make2': 'make_float2', - 'max': 'fmaxf', - 'min': 'fminf', - 'rsqrt': 'rsqrtf', - 'sqrt': 'sqrtf', - 'sin': 'sinf', - 'sincos': 'sincosf', - 'sincospi': 'sincospif', + "abs": "fabsf", + "cos": "cosf", + "floor": "floorf", + "make2": "make_float2", + "max": "fmaxf", + "min": "fminf", + "rsqrt": "rsqrtf", + "sqrt": "sqrtf", + "sin": "sinf", + "sincos": "sincosf", + "sincospi": "sincospif", }, np.dtype(np.float64): { - 'abs': 'fabs', - 'cos': 'cos', - 'floor': 'floor', - 'make2': 'make_double2', - 'max': 'fmax', - 'min': 'fmin', - 'rsqrt': 'rsqrt', - 'sin': 'sin', - 'sincos': 'sincos', - 'sincospi': 'sincospi', - 'sqrt': 'sqrt', + "abs": "fabs", + "cos": "cos", + "floor": "floor", + "make2": "make_double2", + "max": "fmax", + "min": "fmin", + "rsqrt": "rsqrt", + "sin": "sin", + "sincos": "sincos", + "sincospi": "sincospi", + "sqrt": "sqrt", }, } numpy_to_cuda_type_map = { - np.dtype('int8'): "char", - np.dtype('uint8'): "unsigned char", - np.dtype('int16'): "short", - np.dtype('uint16'): "unsigned short", - np.dtype('int32'): "int", - np.dtype('uint32'): "unsigned int", - np.dtype('float32'): "float", - np.dtype('float64'): "double", - np.dtype('complex64'): "float2", - np.dtype('complex128'): "double2" + np.dtype("int8"): "char", + np.dtype("uint8"): "unsigned char", + np.dtype("int16"): "short", + np.dtype("uint16"): "unsigned short", + np.dtype("int32"): "int", + np.dtype("uint32"): "unsigned int", + np.dtype("float32"): "float", + np.dtype("float64"): "double", + np.dtype("complex64"): "float2", + np.dtype("complex128"): "double2", } # Also map the types -numpy_to_cuda_type_map.update({k.type: v - for k, v - in numpy_to_cuda_type_map.items()}) +numpy_to_cuda_type_map.update({k.type: v for k, v in numpy_to_cuda_type_map.items()}) def grids(dims, blocks): @@ -85,14 +83,18 @@ def grids(dims, blocks): `(x, y, z)` grid size tuple """ if not len(dims) == 3: - raise ValueError("dims must be an (x, y, z) tuple. " - "CUDA dimension ordering is inverted compared " - "to NumPy") + raise ValueError( + "dims must be an (x, y, z) tuple. " + "CUDA dimension ordering is inverted compared " + "to NumPy" + ) if not len(blocks) == 3: - raise ValueError("blocks must be an (x, y, z) tuple. " - "CUDA dimension ordering is inverted compared " - "to NumPy") + raise ValueError( + "blocks must be an (x, y, z) tuple. " + "CUDA dimension ordering is inverted compared " + "to NumPy" + ) return tuple((d + b - 1) // b for d, b in zip(dims, blocks)) diff --git a/africanus/util/dask_util.py b/africanus/util/dask_util.py index 457a40c40..8006df520 100644 --- a/africanus/util/dask_util.py +++ b/africanus/util/dask_util.py @@ -66,14 +66,14 @@ def __iadd__(self, other): return self def __add__(self, other): - return TaskData(self.completed + other.completed, - self.total + other.total, - self.time_sum + other.time_sum) + return TaskData( + self.completed + other.completed, + self.total + other.total, + self.time_sum + other.time_sum, + ) def __repr__(self): - return "TaskData(%s, %s, %s)" % (self.completed, - self.total, - self.time_sum) + return "TaskData(%s, %s, %s)" % (self.completed, self.total, self.time_sum) __str__ = __repr__ @@ -115,9 +115,12 @@ def update_bar(elapsed, prev_completed, prev_estimated, pb): percent = int(100 * fraction) msg = "\r[{0:{1}.{1}}] | {2}% Complete (Estimate) | {3} / ~{4}".format( - bar, pb._width, percent, - format_time(elapsed), - "???" if estimated == 0.0 else format_time(estimated)) + bar, + pb._width, + percent, + format_time(elapsed), + "???" if estimated == 0.0 else format_time(estimated), + ) with suppress(ValueError): pb._file.write(msg) @@ -135,10 +138,9 @@ def timer_func(pb): prev_estimated = 0.0 if elapsed > pb._minimum: - prev_completed, prev_estimated = update_bar(elapsed, - prev_completed, - prev_estimated, - pb) + prev_completed, prev_estimated = update_bar( + elapsed, prev_completed, prev_estimated, pb + ) time.sleep(pb._dt) @@ -179,6 +181,7 @@ class EstimatingProgressBar(Callback): Update resolution in seconds, default is 1.0 seconds. """ + @requires_optional("dask", opt_import_err) def __init__(self, minimum=0, width=42, dt=1.0, out=default_out): if out is None: diff --git a/africanus/util/docs.py b/africanus/util/docs.py index aa0a18935..007509f0f 100644 --- a/africanus/util/docs.py +++ b/africanus/util/docs.py @@ -30,7 +30,7 @@ def doc_tuple_to_str(doc_tuple, replacements=None): if replacements is not None: fields = (mod_docs(f, replacements) for f in fields) - return ''.join(fields) + return "".join(fields) class DefaultOut(object): @@ -48,6 +48,7 @@ class DocstringTemplate(Template): Overrides the ${identifer} braced pattern in the string Template with a $(identifier) braced pattern """ + pattern = r""" %(delim)s(?: (?P%(delim)s) | # Escape sequence of two delimiters @@ -55,5 +56,4 @@ class DocstringTemplate(Template): \((?P%(id)s)\) | # delimiter and a braced identifier (?P) # Other ill-formed delimiter exprs ) - """ % {'delim': re.escape(Template.delimiter), - 'id': Template.idpattern} + """ % {"delim": re.escape(Template.delimiter), "id": Template.idpattern} diff --git a/africanus/util/files.py b/africanus/util/files.py index 748e6a390..d83b7d8cd 100644 --- a/africanus/util/files.py +++ b/africanus/util/files.py @@ -5,11 +5,11 @@ def sha_hash_file(filename): - """ Compute the SHA1 hash of filename """ + """Compute the SHA1 hash of filename""" hash_sha = sha1() - with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(1024*1024), b""): + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): hash_sha.update(chunk) return hash_sha.hexdigest() diff --git a/africanus/util/jinja2.py b/africanus/util/jinja2.py index 5c30e0877..347c87ee9 100644 --- a/africanus/util/jinja2.py +++ b/africanus/util/jinja2.py @@ -76,6 +76,7 @@ class FakeEnvironment(object): Fake jinja2 environment, for which attribute/dict type access will fail """ + @requires_optional("jinja2") def __getitem__(self, key): raise NotImplementedError() @@ -107,19 +108,19 @@ def _jinja2_env_factory(): except ImportError: return FakeEnvironment() - loader = PackageLoader('africanus', '.') - autoescape = select_autoescape(['j2', 'cu.j2']) - env = Environment(loader=loader, - autoescape=autoescape, - extensions=['jinja2.ext.do']) + loader = PackageLoader("africanus", ".") + autoescape = select_autoescape(["j2", "cu.j2"]) + env = Environment( + loader=loader, autoescape=autoescape, extensions=["jinja2.ext.do"] + ) # TODO(sjperkins) # Find a better way to set globals # perhaps search the package tree for e.g. # `jinja2_setup`.py files, whose contents # are inspected and assigned into the globals dict - env.globals['register_assign_cycles'] = register_assign_cycles - env.globals['throw'] = throw_helper + env.globals["register_assign_cycles"] = register_assign_cycles + env.globals["throw"] = throw_helper return env diff --git a/africanus/util/numba.py b/africanus/util/numba.py index ada203069..17083e971 100644 --- a/africanus/util/numba.py +++ b/africanus/util/numba.py @@ -50,5 +50,6 @@ def is_numba_type_none(arg): boolean True if the type represents None """ - return (isinstance(arg, types.misc.NoneType) or - (isinstance(arg, types.misc.Omitted) and arg.value is None)) + return isinstance(arg, types.misc.NoneType) or ( + isinstance(arg, types.misc.Omitted) and arg.value is None + ) diff --git a/africanus/util/nvcc.py b/africanus/util/nvcc.py index 28f5f47f6..638c2dc7b 100644 --- a/africanus/util/nvcc.py +++ b/africanus/util/nvcc.py @@ -34,11 +34,11 @@ def get_path(key): - return os.environ.get(key, '').split(os.pathsep) + return os.environ.get(key, "").split(os.pathsep) def search_on_path(filenames): - for p in get_path('PATH'): + for p in get_path("PATH"): for filename in filenames: full = os.path.join(p, filename) if os.path.exists(full): @@ -46,15 +46,15 @@ def search_on_path(filenames): return None -PLATFORM_DARWIN = sys.platform.startswith('darwin') -PLATFORM_LINUX = sys.platform.startswith('linux') -PLATFORM_WIN32 = sys.platform.startswith('win32') +PLATFORM_DARWIN = sys.platform.startswith("darwin") +PLATFORM_LINUX = sys.platform.startswith("linux") +PLATFORM_WIN32 = sys.platform.startswith("win32") minimum_cuda_version = 8000 minimum_cudnn_version = 5000 maximum_cudnn_version = 7999 -_cuda_path = 'NOT_INITIALIZED' +_cuda_path = "NOT_INITIALIZED" _compiler_base_options = None _cuda_info = None @@ -73,29 +73,30 @@ def get_cuda_path(): # Use a magic word to represent the cache not filled because None is a # valid return value. - if _cuda_path != 'NOT_INITIALIZED': + if _cuda_path != "NOT_INITIALIZED": return _cuda_path - nvcc_path = search_on_path(('nvcc', 'nvcc.exe')) + nvcc_path = search_on_path(("nvcc", "nvcc.exe")) cuda_path_default = None if nvcc_path is None: - log.warn('nvcc not in path. Please set path to nvcc.') + log.warn("nvcc not in path. Please set path to nvcc.") else: cuda_path_default = os.path.normpath( - os.path.join(os.path.dirname(nvcc_path), '..')) + os.path.join(os.path.dirname(nvcc_path), "..") + ) - cuda_path = os.environ.get('CUDA_PATH', '') # Nvidia default on Windows + cuda_path = os.environ.get("CUDA_PATH", "") # Nvidia default on Windows if len(cuda_path) > 0 and cuda_path != cuda_path_default: - log.warn('nvcc path != CUDA_PATH') - log.warn('nvcc path: %s' % cuda_path_default) - log.warn('CUDA_PATH: %s' % cuda_path) + log.warn("nvcc path != CUDA_PATH") + log.warn("nvcc path: %s" % cuda_path_default) + log.warn("CUDA_PATH: %s" % cuda_path) if os.path.exists(cuda_path): _cuda_path = cuda_path elif cuda_path_default is not None: _cuda_path = cuda_path_default - elif os.path.exists('/usr/local/cuda'): - _cuda_path = '/usr/local/cuda' + elif os.path.exists("/usr/local/cuda"): + _cuda_path = "/usr/local/cuda" else: _cuda_path = None @@ -103,7 +104,7 @@ def get_cuda_path(): def get_nvcc_path(): - nvcc = os.environ.get('NVCC', None) + nvcc = os.environ.get("NVCC", None) if nvcc: return distutils.split_quoted(nvcc) @@ -112,9 +113,9 @@ def get_nvcc_path(): return None if PLATFORM_WIN32: - nvcc_bin = 'bin/nvcc.exe' + nvcc_bin = "bin/nvcc.exe" else: - nvcc_bin = 'bin/nvcc' + nvcc_bin = "bin/nvcc" nvcc_path = os.path.join(cuda_path, nvcc_bin) if os.path.exists(nvcc_path): @@ -131,29 +132,29 @@ def get_compiler_setting(): define_macros = [] if cuda_path: - include_dirs.append(os.path.join(cuda_path, 'include')) + include_dirs.append(os.path.join(cuda_path, "include")) if PLATFORM_WIN32: - library_dirs.append(os.path.join(cuda_path, 'bin')) - library_dirs.append(os.path.join(cuda_path, 'lib', 'x64')) + library_dirs.append(os.path.join(cuda_path, "bin")) + library_dirs.append(os.path.join(cuda_path, "lib", "x64")) else: - library_dirs.append(os.path.join(cuda_path, 'lib64')) - library_dirs.append(os.path.join(cuda_path, 'lib')) + library_dirs.append(os.path.join(cuda_path, "lib64")) + library_dirs.append(os.path.join(cuda_path, "lib")) if PLATFORM_DARWIN: - library_dirs.append('/usr/local/cuda/lib') + library_dirs.append("/usr/local/cuda/lib") if PLATFORM_WIN32: - nvtoolsext_path = os.environ.get('NVTOOLSEXT_PATH', '') + nvtoolsext_path = os.environ.get("NVTOOLSEXT_PATH", "") if os.path.exists(nvtoolsext_path): - include_dirs.append(os.path.join(nvtoolsext_path, 'include')) - library_dirs.append(os.path.join(nvtoolsext_path, 'lib', 'x64')) + include_dirs.append(os.path.join(nvtoolsext_path, "include")) + library_dirs.append(os.path.join(nvtoolsext_path, "lib", "x64")) else: - define_macros.append(('CUPY_NO_NVTX', '1')) + define_macros.append(("CUPY_NO_NVTX", "1")) return { - 'include_dirs': include_dirs, - 'library_dirs': library_dirs, - 'define_macros': define_macros, - 'language': 'c++', + "include_dirs": include_dirs, + "library_dirs": library_dirs, + "define_macros": define_macros, + "language": "c++", } @@ -180,9 +181,7 @@ def _match_output_lines(output_lines, regexs): def get_compiler_base_options(): - """Returns base options for nvcc compiler. - - """ + """Returns base options for nvcc compiler.""" global _compiler_base_options if _compiler_base_options is None: _compiler_base_options = _get_compiler_base_options() @@ -195,35 +194,37 @@ def _get_compiler_base_options(): # and try to compose base options according to it. nvcc_path = get_nvcc_path() with _tempdir() as temp_dir: - test_cu_path = os.path.join(temp_dir, 'test.cu') - test_out_path = os.path.join(temp_dir, 'test.out') - with open(test_cu_path, 'w') as f: - f.write('int main() { return 0; }') + test_cu_path = os.path.join(temp_dir, "test.cu") + test_out_path = os.path.join(temp_dir, "test.out") + with open(test_cu_path, "w") as f: + f.write("int main() { return 0; }") proc = subprocess.Popen( - nvcc_path + ['-o', test_out_path, test_cu_path], + nvcc_path + ["-o", test_out_path, test_cu_path], stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + stderr=subprocess.PIPE, + ) stdoutdata, stderrdata = proc.communicate() - stderrlines = stderrdata.split(b'\n') + stderrlines = stderrdata.split(b"\n") if proc.returncode != 0: - # No supported host compiler matches = _match_output_lines( stderrlines, [ - b'^ERROR: No supported gcc/g\\+\\+ host compiler found, ' - b'but .* is available.$', - b'^ *Use \'nvcc (.*)\' to use that instead.$', - ]) + b"^ERROR: No supported gcc/g\\+\\+ host compiler found, " + b"but .* is available.$", + b"^ *Use 'nvcc (.*)' to use that instead.$", + ], + ) if matches is not None: base_opts = matches[1].group(1) - base_opts = base_opts.decode('utf8').split(' ') + base_opts = base_opts.decode("utf8").split(" ") return base_opts # Unknown error raise RuntimeError( - 'Encountered unknown error while testing nvcc:\n' + - stderrdata.decode('utf8')) + "Encountered unknown error while testing nvcc:\n" + + stderrdata.decode("utf8") + ) return [] @@ -231,7 +232,7 @@ def _get_compiler_base_options(): def _get_cuda_info(): nvcc_path = get_nvcc_path() - code = ''' + code = """ #include #include int main(int argc, char* argv[]) { @@ -272,31 +273,32 @@ def _get_cuda_info(): return 0; } - ''' # noqa + """ # noqa with _tempdir() as temp_dir: - test_cu_path = os.path.join(temp_dir, 'test.cu') - test_out_path = os.path.join(temp_dir, 'test.out') + test_cu_path = os.path.join(temp_dir, "test.cu") + test_out_path = os.path.join(temp_dir, "test.out") - with open(test_cu_path, 'w') as f: + with open(test_cu_path, "w") as f: f.write(code) proc = subprocess.Popen( - nvcc_path + ['-o', test_out_path, test_cu_path], + nvcc_path + ["-o", test_out_path, test_cu_path], stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + stderr=subprocess.PIPE, + ) stdoutdata, stderrdata = proc.communicate() if proc.returncode != 0: - raise RuntimeError("Cannot determine " - "compute architecture {0}" - .format(stderrdata)) + raise RuntimeError( + "Cannot determine " "compute architecture {0}".format(stderrdata) + ) try: out = subprocess.check_output(test_out_path) except Exception as e: - msg = 'Cannot execute a stub file.\nOriginal error: {0}'.format(e) + msg = "Cannot execute a stub file.\nOriginal error: {0}".format(e) raise Exception(msg) return ast.literal_eval(out) @@ -317,70 +319,84 @@ def _format_cuda_version(version): def get_cuda_version(formatted=False): """Return CUDA Toolkit version cached in check_cuda_version().""" - _cuda_version = get_cuda_info()['cuda_version'] + _cuda_version = get_cuda_info()["cuda_version"] if _cuda_version < minimum_cuda_version: - raise ValueError('CUDA version is too old: %d' - 'CUDA v7.0 or newer is required' % _cuda_version) + raise ValueError( + "CUDA version is too old: %d" + "CUDA v7.0 or newer is required" % _cuda_version + ) return str(_cuda_version) if formatted else _cuda_version def get_gencode_options(): - return ["--generate-code=arch=compute_{a},code=sm_{a}".format( - a=dev['major']*10 + dev['minor']) - for dev in get_cuda_info()['devices']] + return [ + "--generate-code=arch=compute_{a},code=sm_{a}".format( + a=dev["major"] * 10 + dev["minor"] + ) + for dev in get_cuda_info()["devices"] + ] class _UnixCCompiler(unixccompiler.UnixCCompiler): src_extensions = list(unixccompiler.UnixCCompiler.src_extensions) - src_extensions.append('.cu') + src_extensions.append(".cu") def _compile(self, obj, src, ext, cc_args, extra_postargs, pp_opts): # For sources other than CUDA C ones, just call the super class method. - if os.path.splitext(src)[1] != '.cu': + if os.path.splitext(src)[1] != ".cu": return unixccompiler.UnixCCompiler._compile( - self, obj, src, ext, cc_args, extra_postargs, pp_opts) + self, obj, src, ext, cc_args, extra_postargs, pp_opts + ) # For CUDA C source files, compile them with NVCC. _compiler_so = self.compiler_so try: nvcc_path = get_nvcc_path() base_opts = get_compiler_base_options() - self.set_executable('compiler_so', nvcc_path) + self.set_executable("compiler_so", nvcc_path) cuda_version = get_cuda_version() # noqa: triggers cuda inspection - postargs = get_gencode_options() + [ - '-O2', '--compiler-options="-fPIC"'] + postargs = get_gencode_options() + ["-O2", '--compiler-options="-fPIC"'] postargs += extra_postargs # print('NVCC options:', postargs) return unixccompiler.UnixCCompiler._compile( - self, obj, src, ext, base_opts + cc_args, postargs, pp_opts) + self, obj, src, ext, base_opts + cc_args, postargs, pp_opts + ) finally: self.compiler_so = _compiler_so class _MSVCCompiler(msvccompiler.MSVCCompiler): - _cu_extensions = ['.cu'] + _cu_extensions = [".cu"] src_extensions = list(unixccompiler.UnixCCompiler.src_extensions) src_extensions.extend(_cu_extensions) - def _compile_cu(self, sources, output_dir=None, macros=None, - include_dirs=None, debug=0, extra_preargs=None, - extra_postargs=None, depends=None): + def _compile_cu( + self, + sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None, + ): # Compile CUDA C files, mainly derived from UnixCCompiler._compile(). - macros, objects, extra_postargs, pp_opts, _build = \ - self._setup_compile(output_dir, macros, include_dirs, sources, - depends, extra_postargs) + macros, objects, extra_postargs, pp_opts, _build = self._setup_compile( + output_dir, macros, include_dirs, sources, depends, extra_postargs + ) compiler_so = get_nvcc_path() cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) cuda_version = get_cuda_version() # noqa: triggers cuda inspection - postargs = get_gencode_options() + ['-O2'] - postargs += ['-Xcompiler', '/MD'] + postargs = get_gencode_options() + ["-O2"] + postargs += ["-Xcompiler", "/MD"] postargs += extra_postargs # print('NVCC options:', postargs) @@ -390,7 +406,7 @@ def _compile_cu(self, sources, output_dir=None, macros=None, except KeyError: continue try: - self.spawn(compiler_so + cc_args + [src, '-o', obj] + postargs) + self.spawn(compiler_so + cc_args + [src, "-o", obj] + postargs) except errors.DistutilsExecError as e: raise errors.CompileError(str(e)) @@ -401,14 +417,13 @@ def compile(self, sources, **kwargs): cu_sources = [] other_sources = [] for source in sources: - if os.path.splitext(source)[1] == '.cu': + if os.path.splitext(source)[1] == ".cu": cu_sources.append(source) else: other_sources.append(source) # Compile source files other than CUDA C ones. - other_objects = msvccompiler.MSVCCompiler.compile( - self, other_sources, **kwargs) + other_objects = msvccompiler.MSVCCompiler.compile(self, other_sources, **kwargs) # Compile CUDA C sources. cu_objects = self._compile_cu(cu_sources, **kwargs) @@ -446,7 +461,7 @@ def stdchannel_redirected(stdchannel, dest_filename): try: oldstdchannel = os.dup(stdchannel.fileno()) - dest_file = open(dest_filename, 'w') + dest_file = open(dest_filename, "w") os.dup2(dest_file.fileno(), stdchannel.fileno()) yield @@ -458,22 +473,21 @@ def stdchannel_redirected(stdchannel, dest_filename): @requires_optional("cupy", cupy_import_error) -def compile_using_nvcc(source, options=None, arch=None, filename='kern.cu'): +def compile_using_nvcc(source, options=None, arch=None, filename="kern.cu"): options = options or [] if arch is None: cuda_info = get_cuda_info() - arch = min([dev['major']*10 + dev['minor'] - for dev in cuda_info['devices']]) + arch = min([dev["major"] * 10 + dev["minor"] for dev in cuda_info["devices"]]) cc = get_compiler() settings = get_compiler_setting() arch = "--generate-code=arch=compute_{a},code=sm_{a}".format(a=arch) - options += ['-cubin'] + options += ["-cubin"] cupy_path = resource_filename("cupy", pjoin("core", "include")) - settings['include_dirs'].append(cupy_path) + settings["include_dirs"].append(cupy_path) with _tempdir() as tmpdir: tmpfile = pjoin(tmpdir, filename) @@ -485,21 +499,25 @@ def compile_using_nvcc(source, options=None, arch=None, filename='kern.cu'): stderr_file = pjoin(tmpdir, "stderr.txt") with stdchannel_redirected(sys.stderr, stderr_file): - objects = cc.compile([tmpfile], - include_dirs=settings['include_dirs'], - macros=settings['define_macros'], - extra_postargs=options) + objects = cc.compile( + [tmpfile], + include_dirs=settings["include_dirs"], + macros=settings["define_macros"], + extra_postargs=options, + ) except errors.CompileError as e: with open(stderr_file, "r") as f: errs = f.read() - lines = ["The following source code", - format_code(source), - "", - "created the following compilation errors", - "", - errs.strip(), - str(e).strip()] + lines = [ + "The following source code", + format_code(source), + "", + "created the following compilation errors", + "", + errs.strip(), + str(e).strip(), + ] ex = errors.CompileError("\n".join(lines)) raise (ex, None, sys.exc_info()[2]) diff --git a/africanus/util/patterns.py b/africanus/util/patterns.py index 9035cf193..fbbff5653 100644 --- a/africanus/util/patterns.py +++ b/africanus/util/patterns.py @@ -11,18 +11,16 @@ def freeze(arg): - """ Recursively generates a hashable object from arg """ + """Recursively generates a hashable object from arg""" if isinstance(arg, set): return tuple(map(freeze, sorted(arg))) elif isinstance(arg, (tuple, list)): return tuple(map(freeze, arg)) elif isinstance(arg, (dict, OrderedDict)): - return frozenset((freeze(k), freeze(v)) for k, v - in sorted(arg.items())) + return frozenset((freeze(k), freeze(v)) for k, v in sorted(arg.items())) elif isinstance(arg, ndarray): if arg.nbytes > 10: - warn(f"freezing ndarray of size {arg.nbytes} " - f" is probably inefficient") + warn(f"freezing ndarray of size {arg.nbytes} " f" is probably inefficient") return freeze(arg.tolist()) else: return arg @@ -58,6 +56,7 @@ def __init__(self, *args, **kw): Instantiation of object instances is thread-safe. """ + MISSING = object() def __init__(self, *args, **kwargs): @@ -67,18 +66,23 @@ def __init__(self, *args, **kwargs): def __call__(cls, *args, **kwargs): signature = inspect.signature(cls.__init__) - positional_in_kwargs = [p.name for p in signature.parameters.values() - if p.kind == p.POSITIONAL_OR_KEYWORD - and p.default == p.empty - and p.name in kwargs] + positional_in_kwargs = [ + p.name + for p in signature.parameters.values() + if p.kind == p.POSITIONAL_OR_KEYWORD + and p.default == p.empty + and p.name in kwargs + ] if positional_in_kwargs: - warn(f"Positional arguments {positional_in_kwargs} were " - f"supplied as keyword arguments to " - f"{cls.__init__}{signature}. " - f"This may create separate Multiton instances " - f"for what is intended to be a unique set of " - f"arguments.") + warn( + f"Positional arguments {positional_in_kwargs} were " + f"supplied as keyword arguments to " + f"{cls.__init__}{signature}. " + f"This may create separate Multiton instances " + f"for what is intended to be a unique set of " + f"arguments." + ) key = freeze(args + (kwargs if kwargs else Multiton.MISSING,)) @@ -233,7 +237,7 @@ def _read(file_proxy, file_range): "__lazy_kwargs__", "__lazy_object__", "__lazy_lock__", - "__weakref__" + "__weakref__", ) __lazy_members__ = set(__slots__) @@ -242,8 +246,7 @@ def __init__(self, fn, *args, **kwargs): ex = ValueError("fn must be a callable or a tuple of two callables") if isinstance(fn, tuple): - if (len(fn) != 2 or not callable(fn[0]) - or (fn[1] and not callable(fn[1]))): + if len(fn) != 2 or not callable(fn[0]) or (fn[1] and not callable(fn[1])): raise ex self.__lazy_fn__, self.__lazy_finaliser__ = fn @@ -258,18 +261,19 @@ def __init__(self, fn, *args, **kwargs): def __lazy_eq__(self, other): return ( - isinstance(other, LazyProxy) and - self.__lazy_fn__ == other.__lazy_fn__ and - self.__lazy_finaliser__ == other.__lazy_finaliser__ and - self.__lazy_args__ == other.__lazy_args__ and - self.__lazy_kwargs__ == other.__lazy_kwargs__) + isinstance(other, LazyProxy) + and self.__lazy_fn__ == other.__lazy_fn__ + and self.__lazy_finaliser__ == other.__lazy_finaliser__ + and self.__lazy_args__ == other.__lazy_args__ + and self.__lazy_kwargs__ == other.__lazy_kwargs__ + ) def __lazy_hash__(self): return ( self.__lazy_fn__, self.__lazy_finaliser__, self.__lazy_args__, - frozenset(self.__lazy_kwargs__.items()) + frozenset(self.__lazy_kwargs__.items()), ).__hash__() @classmethod @@ -305,7 +309,8 @@ def __lazy_raise_on_invalid_frames__(cls, frame, depth=10): if frame.f_code in INVALID_LAZY_CONTEXTS: raise InvalidLazyContext( f"Attempted to create a LazyObject within a call " - f"to {frame.f_code.co_name}") + f"to {frame.f_code.co_name}" + ) depth -= 1 frame = frame.f_back @@ -338,15 +343,14 @@ def __lazy_obj_from_args__(cls, proxy): cls.__lazy_raise_on_invalid_frames__(inspect.currentframe()) # Create __lazy_object__ - lazy_obj = proxy.__lazy_fn__(*proxy.__lazy_args__, - **proxy.__lazy_kwargs__) + lazy_obj = proxy.__lazy_fn__( + *proxy.__lazy_args__, **proxy.__lazy_kwargs__ + ) proxy.__lazy_object__ = lazy_obj # Create finaliser if provided if proxy.__lazy_finaliser__: - weakref.finalize(proxy, - proxy.__lazy_finaliser__, - lazy_obj) + weakref.finalize(proxy, proxy.__lazy_finaliser__, lazy_obj) return lazy_obj @@ -370,10 +374,18 @@ def __delattr__(self, name): return object.__delattr__(self.__lazy_object__, name) def __reduce__(self): - return (self.__lazy_from_args__, - (((self.__lazy_fn__, self.__lazy_finaliser__) - if self.__lazy_finaliser__ else self.__lazy_fn__), - self.__lazy_args__, self.__lazy_kwargs__)) + return ( + self.__lazy_from_args__, + ( + ( + (self.__lazy_fn__, self.__lazy_finaliser__) + if self.__lazy_finaliser__ + else self.__lazy_fn__ + ), + self.__lazy_args__, + self.__lazy_kwargs__, + ), + ) class LazyProxyMultiton(LazyProxy, metaclass=Multiton): diff --git a/africanus/util/requirements.py b/africanus/util/requirements.py index a3664bcd5..79f28ef51 100644 --- a/africanus/util/requirements.py +++ b/africanus/util/requirements.py @@ -12,12 +12,16 @@ def _missing_packages(fn, packages, import_errors): if len(import_errors) > 0: import_err_str = "\n".join((str(e) for e in import_errors)) - return ("%s requires installation of " - "the following packages: %s.\n" - "%s" % (fn, packages, import_err_str)) + return "%s requires installation of " "the following packages: %s.\n" "%s" % ( + fn, + packages, + import_err_str, + ) else: - return ("%s requires installation of the following packages: %s. " - % (fn, tuple(packages))) + return "%s requires installation of the following packages: %s. " % ( + fn, + tuple(packages), + ) class MissingPackageException(Exception): @@ -70,6 +74,7 @@ def function(*args, **kwargs): """ # Return a bare wrapper if we're on readthedocs if on_rtd(): + def _function_decorator(fn): def _wrapper(*args, **kwargs): pass @@ -106,18 +111,21 @@ def _wrapper(*args, **kwargs): honour_pytest_marker = False # Just wrong else: - raise TypeError("requirements must be " - "None, strings or ImportErrors. " - "Received %s" % requirement) + raise TypeError( + "requirements must be " + "None, strings or ImportErrors. " + "Received %s" % requirement + ) # Requested requirement import succeeded, but there were user # import errors that we now re-raise if have_requirements and len(import_errors) > 0: - raise ImportError("Successfully imported %s " - "but the following user-supplied " - "ImportErrors ocurred: \n%s" % - (actual_imports, - '\n'.join((str(e) for e in import_errors)))) + raise ImportError( + "Successfully imported %s " + "but the following user-supplied " + "ImportErrors ocurred: \n%s" + % (actual_imports, "\n".join((str(e) for e in import_errors))) + ) def _function_decorator(fn): # We have requirements, return the original function @@ -126,24 +134,28 @@ def _function_decorator(fn): # We don't have requirements, produce a failing wrapper def _wrapper(*args, **kwargs): - """ Empty docstring """ + """Empty docstring""" # We're running test cases if honour_pytest_marker and in_pytest(): try: import pytest except ImportError as e: - raise ImportError("Marked as in a pytest " - "test case, but pytest cannot " - "be imported! %s" % str(e)) + raise ImportError( + "Marked as in a pytest " + "test case, but pytest cannot " + "be imported! %s" % str(e) + ) else: msg = _missing_packages( - fn.__name__, missing_requirements, import_errors) + fn.__name__, missing_requirements, import_errors + ) pytest.skip(msg) # Raise the exception else: msg = _missing_packages( - fn.__name__, missing_requirements, import_errors) + fn.__name__, missing_requirements, import_errors + ) raise MissingPackageException(msg) return decorate(fn, _wrapper) diff --git a/africanus/util/testing.py b/africanus/util/testing.py index 807b4bb55..56642b624 100644 --- a/africanus/util/testing.py +++ b/africanus/util/testing.py @@ -7,7 +7,7 @@ from threading import Lock -__run_marker = {'in_pytest': False} +__run_marker = {"in_pytest": False} __run_marker_lock = Lock() @@ -18,15 +18,15 @@ def in_pytest(): - """ Return True if we're marked as executing inside pytest """ + """Return True if we're marked as executing inside pytest""" with __run_marker_lock: - return __run_marker['in_pytest'] + return __run_marker["in_pytest"] def mark_in_pytest(in_pytest=True): - """ Mark if we're in a pytest run """ + """Mark if we're in a pytest run""" if type(in_pytest) is not bool: - raise TypeError('in_pytest %s is not a boolean' % in_pytest) + raise TypeError("in_pytest %s is not a boolean" % in_pytest) with __run_marker_lock: - __run_marker['in_pytest'] = in_pytest + __run_marker["in_pytest"] = in_pytest diff --git a/africanus/util/tests/test_beam_utils.py b/africanus/util/tests/test_beam_utils.py index f2d54cbd4..79bdfebd6 100644 --- a/africanus/util/tests/test_beam_utils.py +++ b/africanus/util/tests/test_beam_utils.py @@ -10,71 +10,71 @@ @pytest.fixture def fits_header(): return { - "SIMPLE": 'T', # / conforms to FITS standard + "SIMPLE": "T", # / conforms to FITS standard "BITPIX": -64, # / array data type - "NAXIS": 3, # / number of array dimensions - "NAXIS1": 513, - "NAXIS2": 513, - "NAXIS3": 33, - "EXTEND": 'T', - "DATE": '2015-05-20 12:40:12.507624', - "DATE-OB": '2015-05-20 12:40:12.507624', - "ORIGIN": 'SOMEONE ', - "TELESCO": 'VLA ', - "OBJECT": 'beam ', - "EQUINOX": 2000.0, - "CTYPE1": 'L ', # points right on the sky - "CUNIT1": 'DEG ', - "CDELT1": 0.011082, # degrees - "CRPIX1": 257, # reference pixel (one relative) - "CRVAL1": 0.0110828777007, - "CTYPE2": 'M ', # points up on the sky - "CUNIT2": 'DEG ', - "CDELT2": 0.011082, # degrees - "CRPIX2": 257, # reference pixel (one relative) - "CRVAL2": -2.14349358381E-07, - "CTYPE3": 'FREQ ', - "CDELT3": 1008000.0, # frequency step in Hz - "CRPIX3": 1, # reference frequency postion - "CRVAL3": 1400256000.0, # reference frequency - "CTYPE4": 'STOKES ', - "CDELT4": 1, - "CRPIX4": 1, + "NAXIS": 3, # / number of array dimensions + "NAXIS1": 513, + "NAXIS2": 513, + "NAXIS3": 33, + "EXTEND": "T", + "DATE": "2015-05-20 12:40:12.507624", + "DATE-OB": "2015-05-20 12:40:12.507624", + "ORIGIN": "SOMEONE ", + "TELESCO": "VLA ", + "OBJECT": "beam ", + "EQUINOX": 2000.0, + "CTYPE1": "L ", # points right on the sky + "CUNIT1": "DEG ", + "CDELT1": 0.011082, # degrees + "CRPIX1": 257, # reference pixel (one relative) + "CRVAL1": 0.0110828777007, + "CTYPE2": "M ", # points up on the sky + "CUNIT2": "DEG ", + "CDELT2": 0.011082, # degrees + "CRPIX2": 257, # reference pixel (one relative) + "CRVAL2": -2.14349358381e-07, + "CTYPE3": "FREQ ", + "CDELT3": 1008000.0, # frequency step in Hz + "CRPIX3": 1, # reference frequency postion + "CRVAL3": 1400256000.0, # reference frequency + "CTYPE4": "STOKES ", + "CDELT4": 1, + "CRPIX4": 1, "CRVAL4": -5, - "GFREQ1": 1400256000.0, - "GFREQ2": 1401267006.481463, - "GFREQ3": 1402322911.080775, - "GFREQ4": 1403413869.993157, - "GFREQ5": 1404446534.122004, - "GFREQ6": 1405431839.039557, - "GFREQ7": 1406450580.210605, - "GFREQ8": 1407565986.781461, - "GFREQ9": 1408540601.110557, - "GFREQ10": 1409590690.509872, - "GFREQ11": 1410635261.125197, - "GFREQ12": 1411713397.984036, - "GFREQ13": 1412731853.361315, - "GFREQ14": 1413826544.202757, - "GFREQ15": 1414823303.16869, - "GFREQ16": 1415817968.786441, - "GFREQ17": 1416889091.051286, - "GFREQ18": 1417937927.157403, - "GFREQ19": 1419010194.848117, - "GFREQ20": 1420027703.693506, - "GFREQ21": 1421107695.319375, - "GFREQ22": 1422148567.69773, - "GFREQ23": 1423184370.515572, - "GFREQ24": 1424165878.168865, - "GFREQ25": 1425208894.904767, - "GFREQ26": 1426298839.860366, - "GFREQ27": 1427265196.336215, - "GFREQ28": 1428354727.177189, - "GFREQ29": 1429435689.132821, - "GFREQ30": 1430380674.10678, - "GFREQ31": 1431456384.211675, - "GFREQ32": 1432512000.0, - "GFREQ33": 1432456789.0, # Last GFREQ hard-coded to - # something non-linear + "GFREQ1": 1400256000.0, + "GFREQ2": 1401267006.481463, + "GFREQ3": 1402322911.080775, + "GFREQ4": 1403413869.993157, + "GFREQ5": 1404446534.122004, + "GFREQ6": 1405431839.039557, + "GFREQ7": 1406450580.210605, + "GFREQ8": 1407565986.781461, + "GFREQ9": 1408540601.110557, + "GFREQ10": 1409590690.509872, + "GFREQ11": 1410635261.125197, + "GFREQ12": 1411713397.984036, + "GFREQ13": 1412731853.361315, + "GFREQ14": 1413826544.202757, + "GFREQ15": 1414823303.16869, + "GFREQ16": 1415817968.786441, + "GFREQ17": 1416889091.051286, + "GFREQ18": 1417937927.157403, + "GFREQ19": 1419010194.848117, + "GFREQ20": 1420027703.693506, + "GFREQ21": 1421107695.319375, + "GFREQ22": 1422148567.69773, + "GFREQ23": 1423184370.515572, + "GFREQ24": 1424165878.168865, + "GFREQ25": 1425208894.904767, + "GFREQ26": 1426298839.860366, + "GFREQ27": 1427265196.336215, + "GFREQ28": 1428354727.177189, + "GFREQ29": 1429435689.132821, + "GFREQ30": 1430380674.10678, + "GFREQ31": 1431456384.211675, + "GFREQ32": 1432512000.0, + "GFREQ33": 1432456789.0, # Last GFREQ hard-coded to + # something non-linear } @@ -84,34 +84,35 @@ def test_fits_axes(fits_header): beam_axes = BeamAxes(fits_header) # L axis converted to radian - assert beam_axes.ctype[0] == fits_header['CTYPE1'].strip() == 'L' - assert fits_header['CUNIT1'].strip() == 'DEG' - assert beam_axes.cunit[0] == 'RAD' - assert beam_axes.crval[0] == np.deg2rad(fits_header['CRVAL1']) - assert beam_axes.cdelt[0] == np.deg2rad(fits_header['CDELT1']) + assert beam_axes.ctype[0] == fits_header["CTYPE1"].strip() == "L" + assert fits_header["CUNIT1"].strip() == "DEG" + assert beam_axes.cunit[0] == "RAD" + assert beam_axes.crval[0] == np.deg2rad(fits_header["CRVAL1"]) + assert beam_axes.cdelt[0] == np.deg2rad(fits_header["CDELT1"]) # M axis converted to radian and sign flipped - assert fits_header['CTYPE2'].strip() == 'M' - assert beam_axes.ctype[1] == 'M' - assert fits_header['CUNIT2'].strip() == 'DEG' - assert beam_axes.cunit[1] == 'RAD' - assert beam_axes.crval[1] == np.deg2rad(fits_header['CRVAL2']) - assert beam_axes.cdelt[1] == np.deg2rad(fits_header['CDELT2']) + assert fits_header["CTYPE2"].strip() == "M" + assert beam_axes.ctype[1] == "M" + assert fits_header["CUNIT2"].strip() == "DEG" + assert beam_axes.cunit[1] == "RAD" + assert beam_axes.crval[1] == np.deg2rad(fits_header["CRVAL2"]) + assert beam_axes.cdelt[1] == np.deg2rad(fits_header["CDELT2"]) # GFREQS used for the frequency grid - gfreqs = [fits_header.get('GFREQ%d' % (i+1)) for i - in range(fits_header['NAXIS3'])] + gfreqs = [ + fits_header.get("GFREQ%d" % (i + 1)) for i in range(fits_header["NAXIS3"]) + ] assert_array_almost_equal(beam_axes.grid[2], np.asarray(gfreqs)) # Now remove a GFREQ, forcing usage of the regular frequency grid fits_header = fits_header.copy() - del fits_header['GFREQ30'] + del fits_header["GFREQ30"] beam_axes = BeamAxes(fits_header) R = np.arange(beam_axes.naxis[2]) - g = (R - beam_axes.crpix[2])*beam_axes.cdelt[2] + beam_axes.crval[2] + g = (R - beam_axes.crpix[2]) * beam_axes.cdelt[2] + beam_axes.crval[2] assert_array_equal(g, beam_axes.grid[2]) @@ -124,42 +125,44 @@ def test_beam_grids(fits_header, header_l, header_m, l_axis, m_axis): from africanus.util.beams import beam_grids, axis_and_sign hdr = fits_header - hdr['CTYPE1'] = header_l - hdr['CTYPE2'] = header_m + hdr["CTYPE1"] = header_l + hdr["CTYPE2"] = header_m l_ax, l_sgn = axis_and_sign(l_axis, "L") m_ax, m_sgn = axis_and_sign(m_axis, "M") # Extract l, m and frequency axes and grids - (l, l_grid), (m, m_grid), (freq, freq_grid) = beam_grids(fits_header, - l_axis, m_axis) + (l, l_grid), (m, m_grid), (freq, freq_grid) = beam_grids( + fits_header, l_axis, m_axis + ) # Check expected L - assert hdr['CTYPE%d' % l] == header_l - crval = hdr['CRVAL%d' % l] - cdelt = hdr['CDELT%d' % l] - crpix = hdr['CRPIX%d' % l] - 1 # C-indexing - R = np.arange(0.0, float(hdr['NAXIS%d' % l])) + assert hdr["CTYPE%d" % l] == header_l + crval = hdr["CRVAL%d" % l] + cdelt = hdr["CDELT%d" % l] + crpix = hdr["CRPIX%d" % l] - 1 # C-indexing + R = np.arange(0.0, float(hdr["NAXIS%d" % l])) - exp_l = (R - crpix)*cdelt + crval + exp_l = (R - crpix) * cdelt + crval exp_l = np.deg2rad(exp_l) * l_sgn assert_array_almost_equal(exp_l, l_grid) - assert hdr['CTYPE%d' % m] == header_m - crval = hdr['CRVAL%d' % m] - cdelt = hdr['CDELT%d' % m] - crpix = hdr['CRPIX%d' % m] - 1 # C-indexing - R = np.arange(0.0, float(hdr['NAXIS%d' % m])) + assert hdr["CTYPE%d" % m] == header_m + crval = hdr["CRVAL%d" % m] + cdelt = hdr["CDELT%d" % m] + crpix = hdr["CRPIX%d" % m] - 1 # C-indexing + R = np.arange(0.0, float(hdr["NAXIS%d" % m])) - exp_m = (R - crpix)*cdelt + crval + exp_m = (R - crpix) * cdelt + crval exp_m = np.deg2rad(exp_m) * m_sgn assert_array_almost_equal(exp_m, m_grid) # GFREQS used for the frequency grid - gfreqs = [fits_header.get('GFREQ%d' % (i+1)) for i - in range(fits_header['NAXIS3'])] + gfreqs = [ + fits_header.get("GFREQ%d" % (i + 1)) for i in range(fits_header["NAXIS3"]) + ] assert_array_almost_equal(freq_grid, gfreqs) @@ -168,31 +171,31 @@ def test_beam_filenames(): from africanus.util.beams import beam_filenames assert beam_filenames("beam_$(corr)_$(reim).fits", [9, 10, 11, 12]) == { - 'xx': ['beam_xx_re.fits', 'beam_xx_im.fits'], - 'xy': ['beam_xy_re.fits', 'beam_xy_im.fits'], - 'yx': ['beam_yx_re.fits', 'beam_yx_im.fits'], - 'yy': ['beam_yy_re.fits', 'beam_yy_im.fits'] + "xx": ["beam_xx_re.fits", "beam_xx_im.fits"], + "xy": ["beam_xy_re.fits", "beam_xy_im.fits"], + "yx": ["beam_yx_re.fits", "beam_yx_im.fits"], + "yy": ["beam_yy_re.fits", "beam_yy_im.fits"], } assert beam_filenames("beam_$(corr)_$(reim).fits", [5, 6, 7, 8]) == { - 'rr': ['beam_rr_re.fits', 'beam_rr_im.fits'], - 'rl': ['beam_rl_re.fits', 'beam_rl_im.fits'], - 'lr': ['beam_lr_re.fits', 'beam_lr_im.fits'], - 'll': ['beam_ll_re.fits', 'beam_ll_im.fits'] + "rr": ["beam_rr_re.fits", "beam_rr_im.fits"], + "rl": ["beam_rl_re.fits", "beam_rl_im.fits"], + "lr": ["beam_lr_re.fits", "beam_lr_im.fits"], + "ll": ["beam_ll_re.fits", "beam_ll_im.fits"], } assert beam_filenames("beam_$(CORR)_$(reim).fits", [9, 10, 11, 12]) == { - 'xx': ['beam_XX_re.fits', 'beam_XX_im.fits'], - 'xy': ['beam_XY_re.fits', 'beam_XY_im.fits'], - 'yx': ['beam_YX_re.fits', 'beam_YX_im.fits'], - 'yy': ['beam_YY_re.fits', 'beam_YY_im.fits'] + "xx": ["beam_XX_re.fits", "beam_XX_im.fits"], + "xy": ["beam_XY_re.fits", "beam_XY_im.fits"], + "yx": ["beam_YX_re.fits", "beam_YX_im.fits"], + "yy": ["beam_YY_re.fits", "beam_YY_im.fits"], } assert beam_filenames("beam_$(corr)_$(REIM).fits", [9, 10, 11, 12]) == { - 'xx': ['beam_xx_RE.fits', 'beam_xx_IM.fits'], - 'xy': ['beam_xy_RE.fits', 'beam_xy_IM.fits'], - 'yx': ['beam_yx_RE.fits', 'beam_yx_IM.fits'], - 'yy': ['beam_yy_RE.fits', 'beam_yy_IM.fits'] + "xx": ["beam_xx_RE.fits", "beam_xx_IM.fits"], + "xy": ["beam_xy_RE.fits", "beam_xy_IM.fits"], + "yx": ["beam_yx_RE.fits", "beam_yx_IM.fits"], + "yy": ["beam_yy_RE.fits", "beam_yy_IM.fits"], } @@ -213,14 +216,12 @@ def test_inverse_interp(): grid = np.arange(values.size) initial = np.stack((values, grid)) - interp = interp1d(values, grid, bounds_error=False, - fill_value='extrapolate') + interp = interp1d(values, grid, bounds_error=False, fill_value="extrapolate") assert np.all(initial == np.stack((values, interp(values)))) # Monotonically increasing values = np.flipud(values) assert np.all(np.diff(values) > 0) initial = np.stack((values, grid)) - interp = interp1d(values, grid, bounds_error=False, - fill_value='extrapolate') + interp = interp1d(values, grid, bounds_error=False, fill_value="extrapolate") assert np.all(initial == np.stack((values, interp(values)))) diff --git a/africanus/util/tests/test_nvcc_compiler.py b/africanus/util/tests/test_nvcc_compiler.py index 5c251950c..90c6c501a 100644 --- a/africanus/util/tests/test_nvcc_compiler.py +++ b/africanus/util/tests/test_nvcc_compiler.py @@ -9,7 +9,7 @@ def test_nvcc_compiler(tmpdir): from africanus.util.nvcc import compile_using_nvcc - cp = pytest.importorskip('cupy') + cp = pytest.importorskip("cupy") code = """ #include @@ -26,7 +26,7 @@ def test_nvcc_compiler(tmpdir): } """ - mod = compile_using_nvcc(code, options=['-I ' + cub_dir()]) + mod = compile_using_nvcc(code, options=["-I " + cub_dir()]) kernel = mod.get_function("kernel") inputs = cp.arange(1024, dtype=cp.int32) outputs = cp.empty_like(inputs) diff --git a/africanus/util/tests/test_patterns.py b/africanus/util/tests/test_patterns.py index 618dbc38e..7ee2b4294 100644 --- a/africanus/util/tests/test_patterns.py +++ b/africanus/util/tests/test_patterns.py @@ -4,8 +4,7 @@ import pytest -from africanus.util.patterns import ( - Multiton, LazyProxy, LazyProxyMultiton) +from africanus.util.patterns import Multiton, LazyProxy, LazyProxyMultiton class DummyResource: diff --git a/africanus/util/tests/test_progress_bar.py b/africanus/util/tests/test_progress_bar.py index da3939123..c068d7923 100644 --- a/africanus/util/tests/test_progress_bar.py +++ b/africanus/util/tests/test_progress_bar.py @@ -17,11 +17,11 @@ def test_progress_bar(): assert " 1m 0s" == format_time(60) assert " 1m 1s" == format_time(61) - assert " 2h 6m" == format_time(2*60*60 + 6*60) - assert " 2h 6m" == format_time(2*60*60 + 6*60 + 59) - assert " 2h 7m" == format_time(2*60*60 + 7*60) - assert " 2h 7m" == format_time(2*60*60 + 7*60 + 1) + assert " 2h 6m" == format_time(2 * 60 * 60 + 6 * 60) + assert " 2h 6m" == format_time(2 * 60 * 60 + 6 * 60 + 59) + assert " 2h 7m" == format_time(2 * 60 * 60 + 7 * 60) + assert " 2h 7m" == format_time(2 * 60 * 60 + 7 * 60 + 1) - assert " 5d 2h" == format_time(5*60*60*24 + 2*60*60 + 500) + assert " 5d 2h" == format_time(5 * 60 * 60 * 24 + 2 * 60 * 60 + 500) - assert " 5w 2d" == format_time(5*60*60*24*7 + 2*60*60*24 + 500) + assert " 5w 2d" == format_time(5 * 60 * 60 * 24 * 7 + 2 * 60 * 60 * 24 + 500) diff --git a/africanus/util/tests/test_requirements.py b/africanus/util/tests/test_requirements.py index fe4459da5..3e323e2da 100644 --- a/africanus/util/tests/test_requirements.py +++ b/africanus/util/tests/test_requirements.py @@ -6,25 +6,25 @@ import pytest -from africanus.util.requirements import (requires_optional, - MissingPackageException) +from africanus.util.requirements import requires_optional, MissingPackageException from africanus.util.testing import force_missing_pkg_exception as force_tag def test_requires_optional_missing_import(): - @requires_optional('sys', 'bob', force_tag) + @requires_optional("sys", "bob", force_tag) def f(*args, **kwargs): pass with pytest.raises(MissingPackageException) as e: f(1, a=2) - assert ("f requires installation of the following packages: ('bob',)." - in str(e.value)) + assert "f requires installation of the following packages: ('bob',)." in str( + e.value + ) def test_requires_optional_pass_import_error(): - assert 'clearly_missing_and_nonexistent_package' not in sys.modules + assert "clearly_missing_and_nonexistent_package" not in sys.modules try: import clearly_missing_and_nonexistent_package # noqa @@ -34,7 +34,8 @@ def test_requires_optional_pass_import_error(): me = None with pytest.raises(ImportError) as e: - @requires_optional('sys', 'os', me, force_tag) + + @requires_optional("sys", "os", me, force_tag) def f(*args, **kwargs): pass diff --git a/africanus/util/tests/test_util_shape.py b/africanus/util/tests/test_util_shape.py index 8fe829545..5f1aaaebc 100644 --- a/africanus/util/tests/test_util_shape.py +++ b/africanus/util/tests/test_util_shape.py @@ -11,14 +11,17 @@ def test_corr_shape(): from africanus.util.shapes import corr_shape for i in range(10): - assert corr_shape(i, 'flat') == (i,) + assert corr_shape(i, "flat") == (i,) - assert corr_shape(1, 'matrix') == (1,) - assert corr_shape(2, 'matrix') == (2,) - assert corr_shape(4, 'matrix') == (2, 2,) + assert corr_shape(1, "matrix") == (1,) + assert corr_shape(2, "matrix") == (2,) + assert corr_shape(4, "matrix") == ( + 2, + 2, + ) with pytest.raises(ValueError, match=r"ncorr not in \(1, 2, 4\)"): - corr_shape(3, 'matrix') + corr_shape(3, "matrix") def test_aggregate_chunks(): diff --git a/africanus/util/trove.py b/africanus/util/trove.py index 37354297a..daa69c1fc 100644 --- a/africanus/util/trove.py +++ b/africanus/util/trove.py @@ -18,14 +18,14 @@ from africanus.util.files import sha_hash_file _trove_dir = pjoin(include_dir, "trove") -_trove_url = 'https://github.com/bryancatanzaro/trove/archive/master.zip' -_trove_sha_hash = 'f0bfdfb347fdfe5aca20b0357360dc793ad30cd3' -_trove_version_str = 'Current release: v1.8.0 (02/16/2018)' +_trove_url = "https://github.com/bryancatanzaro/trove/archive/master.zip" +_trove_sha_hash = "f0bfdfb347fdfe5aca20b0357360dc793ad30cd3" +_trove_version_str = "Current release: v1.8.0 (02/16/2018)" _trove_version = "master" -_trove_zip_dir = 'trove-' + _trove_version +_trove_zip_dir = "trove-" + _trove_version _trove_download_filename = "trove-" + _trove_version + ".zip" -_trove_header = pjoin(_trove_dir, 'trove', 'trove.cuh') -_trove_readme = pjoin(_trove_dir, 'README.md') +_trove_header = pjoin(_trove_dir, "trove", "trove.cuh") +_trove_readme = pjoin(_trove_dir, "README.md") _trove_new_unzipped_path = _trove_dir @@ -50,9 +50,7 @@ def download_trove(archive_file): def is_trove_installed(readme_filename): # Check if the README.md exists - if (not os.path.exists(readme_filename) or - not os.path.isfile(readme_filename)): - + if not os.path.exists(readme_filename) or not os.path.isfile(readme_filename): reason = "trove readme '{}' does not exist".format(readme_filename) return (False, reason) @@ -74,22 +72,21 @@ def _install_trove(): sha_hash = download_trove(archive) # Compare against our supplied hash if _trove_sha_hash != sha_hash: - msg = ('Hash of file %s downloaded from %s ' - 'is %s and does not match the expected ' - 'hash of %s.') % ( - _trove_download_filename, _trove_url, - sha_hash, _trove_sha_hash) + msg = ( + "Hash of file %s downloaded from %s " + "is %s and does not match the expected " + "hash of %s." + ) % (_trove_download_filename, _trove_url, sha_hash, _trove_sha_hash) raise InstallTroveException(msg) # Unzip into include/trove - with ZipFile(archive, 'r') as zip_file: + with ZipFile(archive, "r") as zip_file: # Remove any existing install try: shutil.rmtree(_trove_dir, ignore_errors=True) except Exception as e: - raise InstallTroveException("Removing %s failed\n%s" % ( - _trove_dir, str(e))) + raise InstallTroveException("Removing %s failed\n%s" % (_trove_dir, str(e))) try: # Unzip into temporary directory @@ -101,8 +98,7 @@ def _install_trove(): # Move shutil.move(unzip_path, _trove_dir) except Exception as e: - raise InstallTroveException("Extracting %s failed\n%s" % ( - archive, str(e))) + raise InstallTroveException("Extracting %s failed\n%s" % (archive, str(e))) finally: shutil.rmtree(tmpdir, ignore_errors=True) diff --git a/africanus/util/type_inference.py b/africanus/util/type_inference.py index e1e8e7073..e0b70711d 100644 --- a/africanus/util/type_inference.py +++ b/africanus/util/type_inference.py @@ -1,4 +1,3 @@ - # -*- coding: utf-8 -*- @@ -23,5 +22,5 @@ def _numpy_dtype(arg): def infer_complex_dtype(*args): - """ Infer complex datatype from arg inputs """ + """Infer complex datatype from arg inputs""" return np.result_type(np.complex64, *(_numpy_dtype(a) for a in args)) diff --git a/docs/api.rst b/docs/api.rst index 42f473bdf..cb20ea73b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -16,4 +16,4 @@ API linalg-api.rst gps-api.rst - experimental.rst \ No newline at end of file + experimental.rst diff --git a/docs/averaging-api.rst b/docs/averaging-api.rst index f3ec52386..68439e403 100644 --- a/docs/averaging-api.rst +++ b/docs/averaging-api.rst @@ -215,4 +215,3 @@ Dask .. autofunction:: time_and_channel .. autofunction:: bda - diff --git a/docs/calibration-api.rst b/docs/calibration-api.rst index ff977da4c..3836b7da3 100644 --- a/docs/calibration-api.rst +++ b/docs/calibration-api.rst @@ -48,7 +48,7 @@ scenario is determined from the shapes of the input gains and the input data. This module also provides a number of utilities which -are useful for calibration. +are useful for calibration. Utils +++++ @@ -63,7 +63,7 @@ Numpy residual_vis correct_vis compute_and_corrupt_vis - + .. autofunction:: corrupt_vis .. autofunction:: residual_vis @@ -80,7 +80,7 @@ Dask residual_vis correct_vis compute_and_corrupt_vis - + .. autofunction:: corrupt_vis .. autofunction:: residual_vis diff --git a/docs/conf.py b/docs/conf.py index b42d54327..e630ae6aa 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,7 +23,7 @@ import os import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) import sphinx_rtd_theme import africanus @@ -36,32 +36,34 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.viewcode', - 'sphinx.ext.mathjax', - 'sphinx.ext.intersphinx', - 'sphinx.ext.extlinks', - 'numpydoc'] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.viewcode", + "sphinx.ext.mathjax", + "sphinx.ext.intersphinx", + "sphinx.ext.extlinks", + "numpydoc", +] -autodoc_mock_imports = ['numpy', 'numba'] +autodoc_mock_imports = ["numpy", "numba"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Codex Africanus' -copyright = u"2018, Simon Perkins" -author = u"Simon Perkins" +project = "Codex Africanus" +copyright = "2018, Simon Perkins" +author = "Simon Perkins" # The version info for the project you're documenting, acts as replacement # for |version| and |release|, also used in various other places throughout @@ -82,10 +84,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -96,7 +98,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a # theme further. For a list of options available for each theme, see the @@ -107,13 +109,13 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # -- Options for HTMLHelp output --------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'africanusdoc' +htmlhelp_basename = "africanusdoc" # -- Options for LaTeX output ------------------------------------------ @@ -122,15 +124,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -140,9 +139,13 @@ # (source start file, target name, title, author, documentclass # [howto, manual, or own class]). latex_documents = [ - (master_doc, 'africanus.tex', - u'Codex Africanus Documentation', - u'Simon Perkins', 'manual'), + ( + master_doc, + "africanus.tex", + "Codex Africanus Documentation", + "Simon Perkins", + "manual", + ), ] @@ -150,11 +153,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'africanus', - u'Codex Africanus Documentation', - [author], 1) -] +man_pages = [(master_doc, "africanus", "Codex Africanus Documentation", [author], 1)] numpydoc_class_members_toctree = False @@ -164,24 +163,27 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'africanus', - u'Codex Africanus Documentation', - author, - 'africanus', - 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "africanus", + "Codex Africanus Documentation", + author, + "africanus", + "One line description of project.", + "Miscellaneous", + ), ] extlinks = { - 'issue': ('https://github.com/ska-sa/codex-africanus/issues/%s', 'GH#'), - 'pr': ('https://github.com/ska-sa/codex-africanus/pull/%s', 'GH#') + "issue": ("https://github.com/ska-sa/codex-africanus/issues/%s", "GH#"), + "pr": ("https://github.com/ska-sa/codex-africanus/pull/%s", "GH#"), } intersphinx_mapping = { - 'cupy': ('https://docs-cupy.chainer.org/en/latest/', None), - 'dask': ('https://dask.pydata.org/en/latest/', None), - 'numba': ('https://numba.pydata.org/numba-doc/dev/', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'python': ('https://docs.python.org/3/', None), - 'scipy': ('https://docs.scipy.org/doc/scipy/', None), + "cupy": ("https://docs-cupy.chainer.org/en/latest/", None), + "dask": ("https://dask.pydata.org/en/latest/", None), + "numba": ("https://numba.pydata.org/numba-doc/dev/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "python": ("https://docs.python.org/3/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), } diff --git a/docs/dft-api.rst b/docs/dft-api.rst index d8ec470ef..ee7fd20f4 100644 --- a/docs/dft-api.rst +++ b/docs/dft-api.rst @@ -88,4 +88,3 @@ Dask .. autofunction:: im_to_vis .. autofunction:: vis_to_im - diff --git a/docs/gps-api.rst b/docs/gps-api.rst index c59ae692d..ad10b9f18 100644 --- a/docs/gps-api.rst +++ b/docs/gps-api.rst @@ -16,4 +16,4 @@ Numpy exponential_squared .. autofunction:: abs_diff -.. autofunction:: exponential_squared \ No newline at end of file +.. autofunction:: exponential_squared diff --git a/docs/linalg-api.rst b/docs/linalg-api.rst index 700c3e3ab..bbdb1fe80 100644 --- a/docs/linalg-api.rst +++ b/docs/linalg-api.rst @@ -14,9 +14,9 @@ as a kronecker matrix of the individual matrices i.e. .. math:: K = K_0 \\otimes K_1 \\otimes K_2 \\otimes \\cdots - + Matrices which exhibit this structure can exploit -properties of the kronecker product to avoid +properties of the kronecker product to avoid explicitly expanding the matrix :math:`K`. This module implements some common linear algebra operations which leverages this property for diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..488c6524c --- /dev/null +++ b/ruff.toml @@ -0,0 +1,15 @@ +exclude = ["turbo-sim.py"] +line-length = 88 +target-version = "py310" + +select = [ + # flake8-builtins + "A", + # flake8-bugbear + "B", + # isort + "I001", + "I002", + # tidy imports + "TID" +] diff --git a/setup.cfg b/setup.cfg index 5e08770b1..44ee33bda 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,4 +23,3 @@ collect_ignore = ['setup.py'] [pycodestyle] ignore = E121,E123,E126,E133,E226,E241,E242,E704,W503,W504,E741 -exclude = turbo-sim.py diff --git a/setup.py b/setup.py index fce7cadfd..0a0767677 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ # astropy breaks with numpy 1.15.3 # https://github.com/astropy/astropy/issues/7943 "numpy >= 1.14.0, != 1.15.3", - "numba >= 0.53.1" + "numba >= 0.53.1", ] extras_require = { @@ -31,7 +31,12 @@ "astropy": ["astropy >= 4.0"], "python-casacore": ["python-casacore >= 3.4.0, != 3.5.0"], "ducc0": ["ducc0 >= 0.9.0"], - "testing": ["pytest", "flaky", "pytest-flake8 >= 1.0.6", "flake8 >= 4.0.0, < 5.0.0"], + "testing": [ + "pytest", + "flaky", + "pytest-flake8 >= 1.0.6", + "flake8 >= 4.0.0, < 5.0.0", + ], } with open(str(Path("africanus", "install", "extras_require.py")), "w") as f: