diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 46eb5775..aab5bc05 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8, 3.9, "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 @@ -26,7 +26,6 @@ jobs: ${{ runner.os }}-pip-${{ runner.os }}- - name: Install dependencies run: | - python -m pip install --upgrade pip pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi if [ -f optional_requirements.txt ]; then pip install -r optional_requirements.txt; fi diff --git a/docs/source/loading_data/index.rst b/docs/source/loading_data/index.rst index f0e3b6c9..a2af6f5b 100644 --- a/docs/source/loading_data/index.rst +++ b/docs/source/loading_data/index.rst @@ -4,9 +4,9 @@ Loading Data The main purpose of :mod:`swiftsimio` is to load data. This section will tell you all about four main objects: -+ :obj:`swiftsimio.reader.SWIFTUnits`, responsible for creating a correspondence between ++ :obj:`swiftsimio.metadata.objects.SWIFTUnits`, responsible for creating a correspondence between the SWIFT units and :mod:`unyt` objects. -+ :obj:`swiftsimio.reader.SWIFTMetadata`, responsible for loading any required information ++ :obj:`swiftsimio.metadata.objects.SWIFTMetadata`, responsible for loading any required information from the SWIFT headers into python-readable data. + :obj:`swiftsimio.reader.SWIFTDataset`, responsible for holding all particle data, and keeping track of the above two objects. @@ -47,8 +47,8 @@ notebook, and you will see that it contains several sub-objects: simulation. + ``data.dark_matter``, likewise containing information about the dark matter particles in the simulation. -+ ``data.metadata``, an instance of :obj:`swiftsimio.reader.SWIFTMetadata` -+ ``data.units``, an instance of :obj:`swiftsimio.reader.SWIFTUnits` ++ ``data.metadata``, an instance of :obj:`swiftsimio.metadata.objects.SWIFTSnapshotMetadata` ++ ``data.units``, an instance of :obj:`swiftsimio.metadata.objects.SWIFTUnits` Using metadata -------------- @@ -268,3 +268,32 @@ in SWIFT will be automatically read. data = sw.load( "extra_test.hdf5", ) + + +Halo Catalogues +--------------- + +SWIFT-compatible halo catalogues, such as those written with SOAP, can be +loaded entirely transparently with ``swiftsimio``. It is generally possible +to use all of the functionality (masking, visualisation, etc.) that is used +with snapshots with these files, assuming the files conform to the +correct metadata standard. + +An example SOAP file is available at +``http://virgodb.cosma.dur.ac.uk/swift-webstorage/IOExamples/soap_example +.hdf5`` + +You can load SOAP files as follows: + +.. code-block:: python + + from swiftsimio import load + + catalogue = load("soap_example.hdf5") + + print(catalogue.spherical_overdensity_200_mean.total_mass) + + # >>> [ 591. 328.5 361. 553. 530. 507. 795. + # 574. 489.5 233.75 0. 1406. 367.5 2308. + # ... + # 0. 534. 0. 191.75 1450. 600. 290. ] 10000000000.0*Msun (Physical) diff --git a/pyproject.toml b/pyproject.toml index 6a7f55f9..8407d060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,27 +12,27 @@ packages = [ "swiftsimio.metadata.particle", "swiftsimio.metadata.unit", "swiftsimio.metadata.writer", + "swiftsimio.metadata.soap", "swiftsimio.visualisation", "swiftsimio.visualisation.projection_backends", "swiftsimio.visualisation.slice_backends", "swiftsimio.visualisation.tools", - "swiftsimio.visualisation.smoothing_length" + "swiftsimio.visualisation.smoothing_length", ] [project] name = "swiftsimio" -version="8.0.1" +version="9.0.0" authors = [ { name="Josh Borrow", email="josh@joshborrow.com" }, ] -description="SWIFTsim (swift.dur.ac.uk) i/o routines for python." +description="SWIFTsim (swiftsim.com) i/o routines for python." readme = "README.md" -requires-python = ">3.8.0" +requires-python = ">3.10.0" classifiers = [ - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)", "Operating System :: OS Independent", ] diff --git a/swiftsimio/__init__.py b/swiftsimio/__init__.py index a8208757..f9fa573e 100644 --- a/swiftsimio/__init__.py +++ b/swiftsimio/__init__.py @@ -1,5 +1,5 @@ from .reader import * -from .writer import SWIFTWriterDataset +from .snapshot_writer import SWIFTSnapshotWriter from .masks import SWIFTMask from .statistics import SWIFTStatisticsFile from .__version__ import __version__ @@ -75,7 +75,7 @@ def mask(filename, spatial_only=True) -> SWIFTMask: """ units = SWIFTUnits(filename) - metadata = SWIFTMetadata(filename, units) + metadata = metadata_discriminator(filename, units) return SWIFTMask(metadata=metadata, spatial_only=spatial_only) @@ -109,5 +109,4 @@ def load_statistics(filename) -> SWIFTStatisticsFile: return SWIFTStatisticsFile(filename=filename) -# Rename this object to something simpler. -Writer = SWIFTWriterDataset +Writer = SWIFTSnapshotWriter diff --git a/swiftsimio/masks.py b/swiftsimio/masks.py index 10b0eb19..b3f406ca 100644 --- a/swiftsimio/masks.py +++ b/swiftsimio/masks.py @@ -3,6 +3,8 @@ snapshots. """ +import warnings + import unyt import h5py @@ -48,6 +50,11 @@ def __init__(self, metadata: SWIFTMetadata, spatial_only=True): self.units = metadata.units self.spatial_only = spatial_only + if not self.metadata.masking_valid: + raise NotImplementedError( + f"Masking not supported for {self.metadata.output_type} filetype" + ) + if self.metadata.partial_snapshot: raise InvalidSnapshot( "You cannot use masks on partial snapshots. Please use the virtual " @@ -65,9 +72,11 @@ def _generate_empty_masks(self): types. """ - for ptype in self.metadata.present_particle_names: + for group_name in self.metadata.present_group_names: setattr( - self, ptype, np.ones(getattr(self.metadata, f"n_{ptype}"), dtype=bool) + self, + group_name, + np.ones(getattr(self.metadata, f"n_{group_name}"), dtype=bool), ) return @@ -100,12 +109,15 @@ def _unpack_cell_metadata(self): # contain at least one of each type of particle). sort = None - for ptype, pname in zip( - self.metadata.present_particle_types, self.metadata.present_particle_names + for group, group_name in zip( + self.metadata.present_groups, self.metadata.present_group_names ): - part_type = f"PartType{ptype}" - counts = count_handle[part_type][:] - offsets = offset_handle[part_type][:] + if self.metadata.shared_cell_counts is None: + counts = count_handle[group][:] + offsets = offset_handle[group][:] + else: + counts = count_handle[self.metadata.shared_cell_counts][:] + offsets = offset_handle[self.metadata.shared_cell_counts][:] # When using MPI, we cannot assume that these are sorted. if sort is None: @@ -113,8 +125,8 @@ def _unpack_cell_metadata(self): # types if some datasets do not have particles in a cell! sort = np.argsort(offsets) - self.offsets[pname] = offsets[sort] - self.counts[pname] = counts[sort] + self.offsets[group_name] = offsets[sort] + self.counts[group_name] = counts[sort] # Also need to sort centers in the same way self.centers = unyt.unyt_array(centers_handle[:][sort], units=self.units.length) @@ -128,7 +140,7 @@ def _unpack_cell_metadata(self): def constrain_mask( self, - ptype: str, + group_name: str, quantity: str, lower: unyt.array.unyt_quantity, upper: unyt.array.unyt_quantity, @@ -139,13 +151,13 @@ def constrain_mask( We update the mask such that - lower < ptype.quantity <= upper + lower < group_name.quantity <= upper The quantities must have units attached. Parameters ---------- - ptype : str + group_name : str particle type quantity : str @@ -169,23 +181,17 @@ def constrain_mask( print("Please re-initialise the SWIFTMask object with spatial_only=False") return - current_mask = getattr(self, ptype) + current_mask = getattr(self, group_name) - particle_metadata = getattr(self.metadata, f"{ptype}_properties") + group_metadata = getattr(self.metadata, f"{group_name}_properties") unit_dict = { - k: v - for k, v in zip( - particle_metadata.field_names, particle_metadata.field_units - ) + k: v for k, v in zip(group_metadata.field_names, group_metadata.field_units) } unit = unit_dict[quantity] handle_dict = { - k: v - for k, v in zip( - particle_metadata.field_names, particle_metadata.field_paths - ) + k: v for k, v in zip(group_metadata.field_names, group_metadata.field_paths) } handle = handle_dict[quantity] @@ -203,7 +209,7 @@ def constrain_mask( current_mask[current_mask] = new_mask - setattr(self, ptype, current_mask) + setattr(self, group_name, current_mask) return @@ -282,7 +288,7 @@ def _generate_cell_mask(self, restrict): return cell_mask - def _update_spatial_mask(self, restrict, ptype: str, cell_mask: np.array): + def _update_spatial_mask(self, restrict, group_name: str, cell_mask: np.array): """ Updates the particle mask using the cell mask. @@ -296,7 +302,7 @@ def _update_spatial_mask(self, restrict, ptype: str, cell_mask: np.array): restrict : list currently unused - ptype : str + group_name : str particle type to update cell_mask : np.array @@ -304,20 +310,20 @@ def _update_spatial_mask(self, restrict, ptype: str, cell_mask: np.array): """ if self.spatial_only: - counts = self.counts[ptype][cell_mask] - offsets = self.offsets[ptype][cell_mask] + counts = self.counts[group_name][cell_mask] + offsets = self.offsets[group_name][cell_mask] this_mask = [[o, c + o] for c, o in zip(counts, offsets)] - setattr(self, ptype, np.array(this_mask)) - setattr(self, f"{ptype}_size", np.sum(counts)) + setattr(self, group_name, np.array(this_mask)) + setattr(self, f"{group_name}_size", np.sum(counts)) else: - counts = self.counts[ptype][~cell_mask] - offsets = self.offsets[ptype][~cell_mask] + counts = self.counts[group_name][~cell_mask] + offsets = self.offsets[group_name][~cell_mask] # We must do the whole boolean mask business. - this_mask = getattr(self, ptype) + this_mask = getattr(self, group_name) for count, offset in zip(counts, offsets): this_mask[offset : count + offset] = False @@ -367,8 +373,8 @@ def constrain_spatial(self, restrict, intersect: bool = False): # we just make a new mask self.cell_mask = self._generate_cell_mask(restrict) - for ptype in self.metadata.present_particle_names: - self._update_spatial_mask(restrict, ptype, self.cell_mask) + for group_name in self.metadata.present_group_names: + self._update_spatial_mask(restrict, group_name, self.cell_mask) return @@ -391,19 +397,38 @@ def convert_masks_to_ranges(self): # Use the accelerate.ranges_from_array function to convert # This into a set of ranges. - for ptype in self.metadata.present_particle_names: + for group_name in self.metadata.present_group_names: setattr( self, - ptype, + group_name, # Because it nests things in a list for some reason. - np.where(getattr(self, ptype))[0], + np.where(getattr(self, group_name))[0], ) - setattr(self, f"{ptype}_size", getattr(self, ptype).size) + setattr(self, f"{group_name}_size", getattr(self, group_name).size) - for ptype in self.metadata.present_particle_names: - setattr(self, ptype, ranges_from_array(getattr(self, ptype))) + for group_name in self.metadata.present_group_names: + setattr(self, group_name, ranges_from_array(getattr(self, group_name))) + + return + def constrain_index(self, index: int): + """ + Constrain the mask to a single row. + + Intended for use with SOAP catalogues, mask to read only a single row. + + Parameters + ---------- + index : int + The index of the row to select. + """ + if not self.metadata.filetype == "SOAP": + warnings.warn("Not masking a SOAP catalogue, nothing constrained.") + return + for group_name in self.metadata.present_group_names: + setattr(self, group_name, np.array([[index, index + 1]])) + setattr(self, f"{group_name}_size", 1) return def get_masked_counts_offsets(self) -> (Dict[str, np.array], Dict[str, np.array]): diff --git a/swiftsimio/metadata/__init__.py b/swiftsimio/metadata/__init__.py index 1d08b2fc..83ecd4b3 100644 --- a/swiftsimio/metadata/__init__.py +++ b/swiftsimio/metadata/__init__.py @@ -6,6 +6,8 @@ from .particle import particle_types from .particle import particle_fields +from .soap import soap_types + from .unit import unit_types from .unit import unit_fields diff --git a/swiftsimio/metadata/metadata/metadata_fields.py b/swiftsimio/metadata/metadata/metadata_fields.py index 3e8f1d47..182a3907 100644 --- a/swiftsimio/metadata/metadata/metadata_fields.py +++ b/swiftsimio/metadata/metadata/metadata_fields.py @@ -23,8 +23,12 @@ header_unpack_arrays = { "BoxSize": "boxsize", "NumPart_ThisFile": "num_part", + "NumGroup_ThisFile": "num_group", + "NumSubhalos_ThisFile": "num_subhalo", "CanHaveTypes": "has_type", "NumFilesPerSnapshot": "num_files_per_snapshot", + "OutputType": "output_type", + "SubhaloTypes": "subhalo_types", } # Some of these 'arrays' are really types of mass table, so unpack diff --git a/swiftsimio/metadata/objects.py b/swiftsimio/metadata/objects.py new file mode 100644 index 00000000..e3bd3b00 --- /dev/null +++ b/swiftsimio/metadata/objects.py @@ -0,0 +1,1301 @@ +""" +Objects describing the metadata in SWIFTsimIO files. There is a main +abstract class, ``SWIFTMetadata``, that contains the required base +methods to correctly represent the internal representation of an +HDF5 file to what SWIFTsimIO expects to be able to unpack into the +object notation (e.g. PartType0/Coordinates -> gas.coordinates). +""" + + +import numpy as np +import unyt + +import h5py +from swiftsimio.conversions import swift_cosmology_to_astropy +from swiftsimio import metadata +from swiftsimio.objects import cosmo_array, cosmo_factor, a +from abc import ABC, abstractmethod + +import re +import warnings + +from datetime import datetime +from pathlib import Path + +from typing import List, Optional + + +class SWIFTMetadata(ABC): + """ + An abstract base class for all SWIFT-related file metadata. + """ + + # Underlying path to the file that this metadata is associated with. + filename: str + # The units object associated with this file. All SWIFT metadata objects + # must use this units system. + units: "SWIFTUnits" + # The header dictionary which will later be unpackaged according to the + # metadata fields. + header: dict + # Whether this type of file can be masked or not (this is a fixed parameter + # that should probably not be changed at run-time). + masking_valid: bool = False + # Whether this file uses shared metadata cell counts for all particle types + # (as is the case in SOAP) or whether each type (e.g. Gas, Dark Matter, etc.) + # has its own top-level cell grid counts. + shared_cell_counts: str | None = None + + @abstractmethod + def __init__(self, filename, units: "SWIFTUnits"): + raise NotImplementedError + + @property + def handle(self): + # Handle, which is shared with units. Units handles + # file opening and closing. + return self.units.handle + + def load_groups(self): + """ + Loads the groups and metadata into objects: + + metadata._properties + + This contains eight arrays, + + metadata._properties.field_names + metadata._properties.field_paths + metadata._properties.field_units + metadata._properties.field_cosmologies + metadata._properties.field_descriptions + metadata._properties.field_compressions + metadata._properties.field_physicals + metadata._properties.field_valid_transforms + + As well as some more information about the group. + """ + + for group, name in zip(self.present_groups, self.present_group_names): + filetype_metadata = SWIFTGroupMetadata( + group=group, + group_name=name, + metadata=self, + scale_factor=self.scale_factor, + ) + setattr(self, f"{name}_properties", filetype_metadata) + + return + + def get_metadata(self): + """ + Loads the metadata as specified in metadata.metadata_fields. + """ + + for field, name in metadata.metadata_fields.metadata_fields_to_read.items(): + try: + setattr(self, name, dict(self.handle[field].attrs)) + except KeyError: + setattr(self, name, None) + + return + + def postprocess_header(self): + """ + Some minor postprocessing on the header to local variables. + """ + + # These are just read straight in to variables + header_unpack_arrays_units = metadata.metadata_fields.generate_units_header_unpack_arrays( + m=self.units.mass, + l=self.units.length, + t=self.units.time, + I=self.units.current, + T=self.units.temperature, + ) + + for field, name in metadata.metadata_fields.header_unpack_arrays.items(): + try: + if name in header_unpack_arrays_units.keys(): + setattr( + self, + name, + unyt.unyt_array( + self.header[field], units=header_unpack_arrays_units[name] + ), + ) + # This is required or we automatically get everything in CGS! + getattr(self, name).convert_to_units( + header_unpack_arrays_units[name] + ) + else: + # Must not have any units! Oh well. + setattr(self, name, self.header[field]) + except KeyError: + # Must not be present, just skip it + continue + + # Now unpack the 'mass table' type items: + for field, name in metadata.metadata_fields.header_unpack_mass_tables.items(): + try: + setattr( + self, + name, + MassTable( + base_mass_table=self.header[field], mass_units=self.units.mass + ), + ) + except KeyError: + setattr( + self, + name, + MassTable( + base_mass_table=np.zeros( + len(metadata.particle_types.particle_name_underscores) + ), + mass_units=self.units.mass, + ), + ) + + # These must be unpacked as 'real' strings (i.e. converted to utf-8) + + for field, name in metadata.metadata_fields.header_unpack_string.items(): + try: + # Deal with h5py's quirkiness that fixed-sized and variable-sized + # strings are read as strings or bytes + # See: https://github.com/h5py/h5py/issues/2172 + raw = self.header[field] + try: + string = raw.decode("utf-8") + except AttributeError: + string = raw + setattr(self, name, string) + except KeyError: + # Must not be present, just skip it + setattr(self, name, "") + + # These must be unpacked as they are stored as length-1 arrays + + header_unpack_float_units = metadata.metadata_fields.generate_units_header_unpack_single_float( + m=self.units.mass, + l=self.units.length, + t=self.units.time, + I=self.units.current, + T=self.units.temperature, + ) + + for field, names in metadata.metadata_fields.header_unpack_single_float.items(): + try: + if isinstance(names, list): + # Sometimes we store a list in case we have multiple names, for example + # Redshift -> metadata.redshift AND metadata.z. Can't just do the iteration + # because we may loop over the letters in the string. + for variable in names: + if variable in header_unpack_float_units.keys(): + # We have an associated unit! + unit = header_unpack_float_units[variable] + setattr( + self, + variable, + unyt.unyt_quantity(self.header[field][0], units=unit), + ) + else: + # No unit + setattr(self, variable, self.header[field][0]) + else: + # We can just check for the unit and set the attribute + variable = names + if variable in header_unpack_float_units.keys(): + # We have an associated unit! + unit = header_unpack_float_units[variable] + setattr( + self, + variable, + unyt.unyt_quantity(self.header[field][0], units=unit), + ) + else: + # No unit + setattr(self, variable, self.header[field][0]) + except KeyError: + # Must not be present, just skip it + continue + + # These are special cases, sorry! + # Date and time of snapshot dump + try: + try: + # Try and decode bytes, otherwise save raw string + snapshot_date = self.header.get( + "SnapshotDate", self.header.get("Snapshot date", b"") + ).decode("utf-8") + except AttributeError: + snapshot_date = self.header.get( + "SnapshotDate", self.header.get("Snapshot date", "") + ) + try: + self.snapshot_date = datetime.strptime( + snapshot_date, "%H:%M:%S %Y-%m-%d %Z" + ) + except ValueError: + # Backwards compatibility; this was used previously due to simplicity + # but is not portable between regions. So if you ran a simulation on + # a British (en_GB) machine, and then tried to read on a Dutch + # machine (nl_NL), this would _not_ work because %c is different. + try: + self.snapshot_date = datetime.strptime(snapshot_date, "%c\n") + except ValueError: + # Oh dear this has gone _very_wrong. Let's just keep it as a string. + self.snapshot_date = snapshot_date + except KeyError: + # Old file + pass + + # get photon group edges RT dataset from the SubgridScheme group + try: + self.photon_group_edges = ( + self.handle["SubgridScheme/PhotonGroupEdges"][:] / self.units.time + ) + except KeyError: + self.photon_group_edges = None + + # get reduced speed of light RT dataset from the SubgridScheme group + try: + self.reduced_lightspeed = ( + self.handle["SubgridScheme/ReducedLightspeed"][0] + * self.units.length + / self.units.time + ) + except KeyError: + self.reduced_lightspeed = None + + # Store these separately as self.n_gas = number of gas particles for example + for (part_number, (_, part_name)) in enumerate( + metadata.particle_types.particle_name_underscores.items() + ): + try: + setattr(self, f"n_{part_name}", self.num_part[part_number]) + except IndexError: + # Backwards compatibility; mass/number table can change size. + setattr(self, f"n_{part_name}", 0) + + # Need to unpack the gas gamma for cosmology + try: + self.gas_gamma = self.hydro_scheme["Adiabatic index"] + except (KeyError, TypeError): + # We can set a default and print a message whenever we require this value + self.gas_gamma = None + + try: + self.a = self.scale_factor + except AttributeError: + # These must always be present for the initialisation of cosmology properties + self.a = 1.0 + self.scale_factor = 1.0 + + return + + def extract_cosmology(self): + """ + Creates an astropy.cosmology object from the internal cosmology system. + + This will be saved as ``self.cosmology``. + """ + + if self.cosmology_raw is not None: + cosmo = self.cosmology_raw + else: + cosmo = {"Cosmological run": 0} + + if cosmo.get("Cosmological run", 0): + self.cosmology = swift_cosmology_to_astropy(cosmo, units=self.units) + else: + self.cosmology = None + + return + + @property + @abstractmethod + def present_groups(self) -> list[str]: + """ + A property giving the present particle groups in the file to be unpackaged + into top-level properties. For instance, in a regular snapshot, this would be + ["PartType0", "PartType1", "PartType4", ...]. In SOAP, this would be + ["SO/200_crit", "SO/200_mean", ...], i.e. one per aperture. + """ + raise NotImplementedError + + @property + @abstractmethod + def present_group_names(self) -> list[str]: + """ + A property giving the mapping for the names in ``present_groups`` to what the + objects are called on the SWIFTsimIO objects. For instance, in a regular snapshot, + this would be ["gas", "dark_matter", "stars", ...]. In SOAP, this would be + ["spherical_overdensity_200_crit", ...]. + """ + raise NotImplementedError + + @property + def partial_snapshot(self) -> bool: + """ + A property defining whether this is a partial snapshot (e.g. a `.0.hdf5` file) or + a full/virtual snapsoht covering all particles. This must be computed at run-time. + """ + return False + + @staticmethod + @abstractmethod + def get_nice_name(group: str) -> str: + """ + Converts the group name to a 'nice name' (i.e. for printing) for the SWIFTsimIO objects. + """ + raise NotImplementedError + + +class MassTable(object): + """ + Extracts a mass table to local variables based on the + particle type names. + """ + + def __init__(self, base_mass_table: np.array, mass_units: unyt.unyt_quantity): + """ + Parameters + ---------- + + base_mass_table : np.array + Mass table of the same length as the number of particle types. + + mass_units : unyt_quantity + Base mass units for the simulation. + """ + + # TODO: Extract these names from the files themselves if possible. + + for index, name in metadata.particle_types.particle_name_underscores.items(): + try: + setattr( + self, + name, + unyt.unyt_quantity(base_mass_table[index], units=mass_units), + ) + except IndexError: + # Backwards compatible. + setattr(self, name, None) + + return + + def __str__(self): + return f"Mass table for {' '.join(metadata.particle_types.particle_name_underscores.values())}" + + def __repr__(self): + return self.__str__() + + +class MappingTable(object): + """ + A mapping table from one named column instance to the other. + Initially designed for the mapping between dust and elements. + """ + + def __init__( + self, + data: np.ndarray, + named_columns_x: List[str], + named_columns_y: List[str], + named_columns_x_name: str, + named_columns_y_name: str, + ): + """ + Parameters + ---------- + + data: np.ndarray + The data array providing the mapping between the named + columns. Should be of size N x M, where N is the number + of elements in ``named_columns_x`` and M the number + of elements in ``named_columns_y``. + + named_columns_x: List[str] + The names of the columns in the first axis. + + named_columns_y: List[str] + The names of the columns in the second axis. + + named_columns_x_name: str + The name of the first mapping. + + named_columns_y_name: str + The name of the second mapping. + """ + + self.data = data + self.named_columns_x = named_columns_x + self.named_columns_y = named_columns_y + self.named_columns_x_name = named_columns_x_name + self.named_columns_y_name = named_columns_y_name + + for x, name_x in enumerate(named_columns_x): + for y, name_y in enumerate(named_columns_y): + setattr(self, f"{name_x.lower()}_to_{name_y.lower()}", data[x][y]) + + return + + def __str__(self): + return ( + f"Mapping table from {self.named_columns_x_name} to " + f"{self.named_columns_y_name}, containing {len(self.data)} " + f"by {len(self.data[0])} elements." + ) + + def __repr__(self): + return f"{self.__str__()}. Raw data: " "\n" f"{self.data}." + + +class SWIFTGroupMetadata(object): + """ + Object that contains the metadata for one hdf5 group. + + This, for instance, could be part type 0, or 'gas'. This will load in + the names of all datasets, their units, possible named fields, + and their cosmology, and present them for use in the actual i/o routines. + + Methods + ------- + load_metadata(self): + Loads the required metadata. + load_field_names(self): + Loads in the field names. + load_field_units(self): + Loads in the units from each dataset. + load_field_descriptions(self): + Loads in descriptions of the fields for each dataset. + load_field_compressions(self): + Loads in compressions of the fields for each dataset. + load_cosmology(self): + Loads in the field cosmologies. + load_physical(self): + Loads in whether the field is saved as comoving or physical. + load_valid_transforms(self): + Loads in whether the field can be converted to comoving. + load_named_columns(self): + Loads the named column data for relevant fields. + """ + + def __init__( + self, + group: str, + group_name: str, + metadata: "SWIFTMetadata", + scale_factor: float, + ): + """ + Constructor for SWIFTGroupMetadata class + + Parameters + ---------- + group: str + the name of the group in the hdf5 file + group_name : str + the corresponding group name for swiftsimio + metadata : SWIFTMetadata + the snapshot metadata + scale_factor : float + the snapshot scale factor + """ + self.group = group + self.group_name = group_name + self.metadata = metadata + self.units = metadata.units + self.scale_factor = scale_factor + + self.filename = metadata.filename + + self.load_metadata() + + return + + def __str__(self): + return f"Metadata class for {self.group} ({self.group_name})" + + def __repr__(self): + return self.__str__() + + def load_metadata(self): + """ + Loads the required metadata. + + This includes loading the field names, units and descriptions, as well as the + cosmology metadata and any custom named columns + """ + + self.load_field_names() + self.load_field_units() + self.load_field_descriptions() + self.load_field_compressions() + self.load_cosmology() + self.load_physical() + self.load_valid_transforms() + self.load_named_columns() + + def load_field_names(self): + """ + Loads in only the field names. + """ + + # regular expression for camel case to snake case + # https://stackoverflow.com/a/1176023 + def convert(name): + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + # Skip fields which are groups themselves + self.field_paths = [] + self.field_names = [] + for item in self.metadata.handle[f"{self.group}"].keys(): + # Skip fields which are groups themselves + if f"{self.group}/{item}" not in self.metadata.present_groups: + self.field_paths.append(f"{self.group}/{item}") + self.field_names.append(convert(item)) + + return + + def load_field_units(self): + """ + Loads in the units from each dataset. + """ + + unit_dict = { + "I": self.units.current, + "L": self.units.length, + "M": self.units.mass, + "T": self.units.temperature, + "t": self.units.time, + } + + def get_units(unit_attribute): + units = 1.0 + + for exponent, unit in unit_dict.items(): + # We store the 'unit exponent' in the SWIFT metadata. This corresponds + # to the power we need to raise each unit to, to return the correct units + try: + # Need to check if the exponent is 0 manually because of float precision + unit_exponent = unit_attribute[f"U_{exponent} exponent"][0] + if unit_exponent != 0.0: + units *= unit ** unit_exponent + except KeyError: + # Can't load that data! + # We should probably warn the user here... + pass + + # Deal with case where we _really_ have a dimensionless quantity. Comparing with + # 1.0 doesn't work, beacause in these cases unyt reverts to a floating point + # comparison. + try: + units.units + except AttributeError: + units = None + + return units + + self.field_units = [ + get_units(self.metadata.handle[x].attrs) for x in self.field_paths + ] + + return + + def load_field_descriptions(self): + """ + Loads in the text descriptions of the fields for each dataset. + For SOAP filetypes a description of the mask is included. + """ + + def get_desc(dataset): + try: + description = dataset.attrs["Description"].decode("utf-8") + except AttributeError: + # Description is saved as a string not bytes + description = dataset.attrs["Description"] + except KeyError: + # Can't load description! + description = "No description available" + + is_masked = dataset.attrs.get("Masked", False) + if not is_masked: + return description + " Not masked." + + mask_datasets = dataset.attrs["Mask Datasets"] + mask_threshold = dataset.attrs["Mask Threshold"] + if len(mask_datasets) == 1: + mask_str = f" Only computed for objects with {mask_datasets[0]} >= {mask_threshold}." + else: + mask_str = f' Only computed for objects where {" + ".join(mask_datasets)} >= {mask_threshold}.' + return description + mask_str + + self.field_descriptions = [ + get_desc(self.metadata.handle[x]) for x in self.field_paths + ] + + return + + def load_field_compressions(self): + """ + Loads in the string describing the compression filters of the fields for each dataset. + """ + + def get_comp(dataset): + try: + # SOAP catalogues can be compressed/uncompressed + is_compressed = dataset.attrs["Is Compressed"] + except KeyError: + is_compressed = True + try: + comp = dataset.attrs["Lossy compression filter"].decode("utf-8") + except AttributeError: + # Compression is saved as str not bytes + comp = dataset.attrs["Lossy compression filter"] + except KeyError: + # Can't load compression string! + comp = "No compression info available" + + return comp if is_compressed else "Not compressed." + + self.field_compressions = [ + get_comp(self.metadata.handle[x]) for x in self.field_paths + ] + + return + + def load_cosmology(self): + """ + Loads in the field cosmologies. + """ + + current_scale_factor = self.scale_factor + + def get_cosmo(dataset): + try: + cosmo_exponent = dataset.attrs["a-scale exponent"][0] + except: + # Can't load, 'graceful' fallback. + cosmo_exponent = 0.0 + + a_factor_this_dataset = a ** cosmo_exponent + + return cosmo_factor(a_factor_this_dataset, current_scale_factor) + + self.field_cosmologies = [ + get_cosmo(self.metadata.handle[x]) for x in self.field_paths + ] + + return + + def load_physical(self): + """ + Loads in whether the field is saved as comoving or physical. + """ + + def get_physical(dataset): + try: + physical = dataset.attrs["Value stored as physical"][0] == 1 + except: + physical = False + return physical + + self.field_physicals = [ + get_physical(self.metadata.handle[x]) for x in self.field_paths + ] + + return + + def load_valid_transforms(self): + """ + Loads in whether the field can be converted to comoving. + """ + + def get_valid_transform(dataset): + try: + valid_transform = ( + dataset.attrs["Property can be converted to comoving"][0] == 1 + ) + except: + valid_transform = True + return valid_transform + + self.field_valid_transforms = [ + get_valid_transform(self.metadata.handle[x]) for x in self.field_paths + ] + + return + + def load_named_columns(self): + """ + Loads the named column data for relevant fields. + """ + + named_columns = {} + + for field in self.field_paths: + property_name = field.split("/")[-1] + + # Not all datasets have named columns + named_columns_metadata = getattr(self.metadata, "named_columns", {}) + + if property_name in named_columns_metadata.keys(): + field_names = self.metadata.named_columns[property_name] + + # Now need to make a decision on capitalisation. If we have a set of + # words with only one capital in them, then it's likely that they are + # element names or something similar, so they should be lower case. + # If on average we have many more capitals, then they are likely to be + # ionized fractions (e.g. HeII) and so we want to leave them with their + # original capitalisation. + + num_capitals = lambda x: sum(1 for c in x if c.isupper()) + mean_num_capitals = sum(map(num_capitals, field_names)) / len( + field_names + ) + + if mean_num_capitals < 1.01: + # Decapitalise them as they are likely individual element names + formatted_field_names = [x.lower() for x in field_names] + else: + formatted_field_names = field_names + + named_columns[field] = formatted_field_names + else: + named_columns[field] = None + + self.named_columns = named_columns + + return + + +class SWIFTUnits(object): + """ + Generates a unyt system that can then be used with the SWIFT data. + + These give the unit mass, length, time, current, and temperature as + unyt unit variables in simulation units. I.e. you can take any value + that you get out of the code and multiply it by the appropriate values + to get it 'unyt-ified' with the correct units. + + Attributes + ---------- + mass : float + unit for mass used + length : float + unit for length used + time : float + unit for time used + current : float + unit for current used + temperature : float + unit for temperature used + + """ + + def __init__(self, filename: Path, handle: Optional[h5py.File] = None): + """ + SWIFTUnits constructor + + Sets filename for file to read units from and gets unit dictionary + + Parameters + ---------- + + filename : Path + Name of file to read units from + + handle: h5py.File, optional + The h5py file handle, optional. Will open a new handle with the + filename if required. + + """ + self.filename = filename + self._handle = handle + + self.get_unit_dictionary() + + return + + @property + def handle(self): + """ + Property that gets the file handle, which can be shared + with other objects for efficiency reasons. + """ + if isinstance(self._handle, h5py.File): + # Can be open or closed, let's test. + try: + file = self._handle.file + + return self._handle + except ValueError: + # This will be the case if there is no active file handle + pass + + self._handle = h5py.File(self.filename, "r") + + return self._handle + + def get_unit_dictionary(self): + """ + Store unit data and metadata + + Length 1 arrays are used to store the unit data. This dictionary + also contains the metadata information that connects the unyt + objects to the names that are stored in the SWIFT snapshots. + """ + + self.units = { + name: unyt.unyt_quantity( + value[0], units=metadata.unit_types.unit_names_to_unyt[name] + ) + for name, value in self.handle["Units"].attrs.items() + } + + # We now unpack this into variables. + self.mass = metadata.unit_types.find_nearest_base_unit( + self.units["Unit mass in cgs (U_M)"], "mass" + ) + self.length = metadata.unit_types.find_nearest_base_unit( + self.units["Unit length in cgs (U_L)"], "length" + ) + self.time = metadata.unit_types.find_nearest_base_unit( + self.units["Unit time in cgs (U_t)"], "time" + ) + self.current = metadata.unit_types.find_nearest_base_unit( + self.units["Unit current in cgs (U_I)"], "current" + ) + self.temperature = metadata.unit_types.find_nearest_base_unit( + self.units["Unit temperature in cgs (U_T)"], "temperature" + ) + + def __del__(self): + if isinstance(self._handle, h5py.File): + self._handle.close() + + +def metadata_discriminator(filename: str, units: SWIFTUnits) -> "SWIFTMetadata": + """ + Discriminates between the different types of metadata objects read from SWIFT-compatible + files. + + Parameters + ---------- + + filename : str + Name of the file to read metadata from + + units : SWIFTUnits + The units object associated with the file + + + Returns + ------- + + SWIFTMetadata + The appropriate metadata object for the file type + """ + # Old snapshots did not have this attribute, so we need to default to FullVolume + file_type = units.handle["Header"].attrs.get("OutputType", "FullVolume") + + if isinstance(file_type, bytes): + file_type = file_type.decode("utf-8") + + if file_type in ["FullVolume"]: + return SWIFTSnapshotMetadata(filename, units) + elif file_type in ["SOAP"]: + return SWIFTSOAPMetadata(filename, units) + elif file_type in ["FOF"]: + return SWIFTFOFMetadata(filename, units) + else: + raise ValueError(f"File type {file_type} not recognised.") + + +class SWIFTSnapshotMetadata(SWIFTMetadata): + """ + SWIFT Metadata for a snapshot-style file containing particle + information. For more documentation, see the main :cls:`SWIFTMetadata` + class. + """ + + masking_valid: bool = True + + def __init__(self, filename, units: SWIFTUnits): + """ + Constructor for SWIFTMetadata object + + Parameters + ---------- + + filename : str + name of file to read from + + units : SWIFTUnits + the units being used + """ + self.filename = filename + self.units = units + + self.get_metadata() + self.get_named_column_metadata() + self.get_mapping_metadata() + + self.postprocess_header() + + self.load_groups() + self.extract_cosmology() + + # After we've loaded all this metadata, we can safely release the file handle. + self.handle.close() + + return + + def get_named_column_metadata(self): + """ + Loads the custom named column metadata (if it exists) from + SubgridScheme/NamedColumns. + """ + + try: + data = self.handle["SubgridScheme/NamedColumns"] + + self.named_columns = { + k: [x.decode("utf-8") for x in data[k][:]] for k in data.keys() + } + except KeyError: + self.named_columns = {} + + return + + def get_mapping_metadata(self): + """ + Gets the mappings based on the named columns (must have already been read), + from the form: + + SubgridScheme/{X}To{Y}Mapping. + + Includes a hack of `Dust` -> `Grains` that will be deprecated. + """ + + try: + possible_keys = self.handle["SubgridScheme"].keys() + + available_keys = [key for key in possible_keys if key.endswith("Mapping")] + available_data = [ + self.handle[f"SubgridScheme/{key}"][:] for key in available_keys + ] + except KeyError: + available_keys = [] + available_data = [] + + # Keys have form {X}To{Y}Mapping + + # regular expression for camel case to snake case + # https://stackoverflow.com/a/1176023 + def convert(name): + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + regex = r"([a-zA-Z]*)To([a-zA-Z]*)Mapping" + compiled = re.compile(regex) + + for key, data in zip(available_keys, available_data): + match = compiled.match(key) + snake_case = convert(key) + + if match: + x = match.group(1) + y = match.group(2) + + if x == "Grain": + warnings.warn( + "Use of the GrainToElementMapping is deprecated, please use a newer " + "version of SWIFT to run this simulation.", + DeprecationWarning, + ) + + x = "Dust" + + named_column_name_x = [ + key for key in self.named_columns.keys() if key.startswith(x) + ][0] + named_column_name_y = [ + key for key in self.named_columns.keys() if key.startswith(y) + ][0] + + setattr( + self, + snake_case, + MappingTable( + data=data, + named_columns_x=self.named_columns[named_column_name_x], + named_columns_y=self.named_columns[named_column_name_y], + named_columns_x_name=named_column_name_x, + named_columns_y_name=named_column_name_y, + ), + ) + + return + + @property + def present_groups(self): + """ + The groups containing datasets that are present in the file. + """ + types = np.where(np.array(getattr(self, "has_type", self.num_part)) != 0)[0] + return [f"PartType{i}" for i in types] + + @property + def present_group_names(self): + """ + The names of the groups that we want to expose. + """ + + return [ + metadata.particle_types.particle_name_underscores[x] + for x in self.present_groups + ] + + @property + def code_info(self) -> str: + """ + Gets a nicely printed set of code information with: + + Name (Git Branch) + Git Revision + Git Date + """ + + def get_string(x): + return self.code[x].decode("utf-8") + + output = ( + f"{get_string('Code')} ({get_string('Git Branch')})\n" + f"{get_string('Git Revision')}\n" + f"{get_string('Git Date')}" + ) + + return output + + @property + def compiler_info(self) -> str: + """ + Gets information about the compiler and formats it as: + + Compiler Name (Compiler Version) + MPI library + """ + + def get_string(x): + return self.code[x].decode("utf-8") + + output = ( + f"{get_string('Compiler Name')} ({get_string('Compiler Version')})\n" + f"{get_string('MPI library')}" + ) + + return output + + @property + def library_info(self) -> str: + """ + Gets information about the libraries used and formats it as: + + FFTW vFFTW library version + GSL vGSL library version + HDF5 vHDF5 library version + """ + + def get_string(x): + return self.code[f"{x} library version"].decode("utf-8") + + output = ( + f"FFTW v{get_string('FFTW')}\n" + f"GSL v{get_string('GSL')}\n" + f"HDF5 v{get_string('HDF5')}" + ) + + return output + + @property + def hydro_info(self) -> str: + r""" + Gets information about the hydro scheme and formats it as: + + Scheme + Kernel function in DimensionD + $\eta$ = Kernel eta (Kernel target N_ngb $N_{ngb}$) + $C_{\rm CFL}$ = CFL parameter + """ + + def get_float(x): + return "{:4.2f}".format(self.hydro_scheme[x][0]) + + def get_int(x): + return int(self.hydro_scheme[x][0]) + + def get_string(x): + return self.hydro_scheme[x].decode("utf-8") + + output = ( + f"{get_string('Scheme')}\n" + f"{get_string('Kernel function')} in {get_int('Dimension')}D\n" + rf"$\eta$ = {get_float('Kernel eta')} " + rf"({get_float('Kernel target N_ngb')} $N_{{ngb}}$)" + "\n" + rf"$C_{{\rm CFL}}$ = {get_float('CFL parameter')}" + ) + + return output + + @property + def viscosity_info(self) -> str: + r""" + Gets information about the viscosity scheme and formats it as: + + Viscosity Model + $\alpha_{V, 0}$ = Alpha viscosity, $\ell_V$ = Viscosity decay length [internal units], $\beta_V$ = Beta viscosity + Alpha viscosity (min) < $\alpha_V$ < Alpha viscosity (max) + """ + + def get_float(x): + return "{:4.2f}".format(self.hydro_scheme[x][0]) + + def get_string(x): + return self.hydro_scheme[x].decode("utf-8") + + output = ( + f"{get_string('Viscosity Model')}\n" + rf"$\alpha_{{V, 0}}$ = {get_float('Alpha viscosity')}, " + rf"$\ell_V$ = {get_float('Viscosity decay length [internal units]')}, " + rf"$\beta_V$ = {get_float('Beta viscosity')}" + "\n" + rf"{get_float('Alpha viscosity (min)')} < $\alpha_V$ < {get_float('Alpha viscosity (max)')}" + ) + + return output + + @property + def diffusion_info(self) -> str: + """ + Gets information about the diffusion scheme and formats it as: + + $\alpha_{D, 0}$ = Diffusion alpha, $\beta_D$ = Diffusion beta + Diffusion alpha (min) < $\alpha_D$ < Diffusion alpha (max) + """ + + def get_float(x): + return "{:4.2f}".format(self.hydro_scheme[x][0]) + + output = ( + rf"$\alpha_{{D, 0}}$ = {get_float('Diffusion alpha')}, " + rf"$\beta_D$ = {get_float('Diffusion beta')}" + "\n" + rf"${get_float('Diffusion alpha (min)')} < " + rf"\alpha_D < {get_float('Diffusion alpha (max)')}$" + ) + + return output + + @property + def partial_snapshot(self) -> bool: + """ + Whether or not this snapshot is partial (e.g. a "x.0.hdf5" file), or + a file describing an entire snapshot. + """ + + # Partial snapshots have num_files_per_snapshot set to 1. Virtual snapshots + # collating multiple sub-snapshots together have num_files_per_snapshot = 1. + + return self.num_files_per_snapshot > 1 + + @staticmethod + def get_nice_name(group): + return metadata.particle_types.particle_name_class[group] + + +class SWIFTFOFMetadata(SWIFTMetadata): + """ + SWIFT Metadata for a snapshot-style file containing particle + information. For more documentation, see the main :cls:`SWIFTMetadata` + class. + """ + + def __init__(self, filename: str, units: SWIFTUnits): + self.filename = filename + self.units = units + + self.get_metadata() + self.postprocess_header() + + self.load_groups() + + # After we've loaded all this metadata, we can safely release the file handle. + self.handle.close() + + return + + @property + def present_groups(self): + """ + The groups containing datasets that are present in the file. + """ + return ["Groups"] + + @property + def present_group_names(self): + """ + The names of the groups that we want to expose. + """ + return ["fof_groups"] + + @staticmethod + def get_nice_name(group): + return "FOFGroups" + + +class SWIFTSOAPMetadata(SWIFTMetadata): + """ + SWIFT Metadata for a snapshot-style file containing particle + information. For more documentation, see the main :cls:`SWIFTMetadata` + class. + """ + + masking_valid: bool = True + shared_cell_counts: str = "Subhalos" + + def __init__(self, filename: str, units: SWIFTUnits): + self.filename = filename + self.units = units + + self.get_metadata() + self.postprocess_header() + + self.load_groups() + + # After we've loaded all this metadata, we can safely release the file handle. + self.handle.close() + + return + + @property + def present_groups(self): + """ + The groups containing datasets that are present in the file. + """ + return self.subhalo_types + + @property + def present_group_names(self): + """ + The names of the groups that we want to expose. + """ + return [ + metadata.soap_types.get_soap_name_underscore(x) for x in self.present_groups + ] + + @staticmethod + def get_nice_name(group): + return metadata.soap_types.get_soap_name_nice(group) diff --git a/swiftsimio/metadata/particle/particle_types.py b/swiftsimio/metadata/particle/particle_types.py index f2b10a26..dd9eb150 100644 --- a/swiftsimio/metadata/particle/particle_types.py +++ b/swiftsimio/metadata/particle/particle_types.py @@ -4,31 +4,31 @@ # Describes the conversion of particle types to names particle_name_underscores = { - 0: "gas", - 1: "dark_matter", - 2: "boundary", - 3: "sinks", - 4: "stars", - 5: "black_holes", - 6: "neutrinos", + "PartType0": "gas", + "PartType1": "dark_matter", + "PartType2": "boundary", + "PartType3": "sinks", + "PartType4": "stars", + "PartType5": "black_holes", + "PartType6": "neutrinos", } particle_name_class = { - 0: "Gas", - 1: "DarkMatter", - 2: "Boundary", - 3: "Sinks", - 4: "Stars", - 5: "BlackHoles", - 6: "Neutrinos", + "PartType0": "Gas", + "PartType1": "DarkMatter", + "PartType2": "Boundary", + "PartType3": "Sinks", + "PartType4": "Stars", + "PartType5": "BlackHoles", + "PartType6": "Neutrinos", } particle_name_text = { - 0: "Gas", - 1: "Dark Matter", - 2: "Boundary", - 3: "Sinks", - 4: "Stars", - 5: "Black Holes", - 6: "Neutrinos", + "PartType0": "Gas", + "PartType1": "Dark Matter", + "PartType2": "Boundary", + "PartType3": "Sinks", + "PartType4": "Stars", + "PartType5": "Black Holes", + "PartType6": "Neutrinos", } diff --git a/swiftsimio/metadata/soap/__init__.py b/swiftsimio/metadata/soap/__init__.py new file mode 100644 index 00000000..343739f2 --- /dev/null +++ b/swiftsimio/metadata/soap/__init__.py @@ -0,0 +1 @@ +from .soap_types import * diff --git a/swiftsimio/metadata/soap/soap_types.py b/swiftsimio/metadata/soap/soap_types.py new file mode 100644 index 00000000..cecb2730 --- /dev/null +++ b/swiftsimio/metadata/soap/soap_types.py @@ -0,0 +1,32 @@ +""" +Includes the fancy names. +""" + +# Describes the conversion of hdf5 groups to names +def get_soap_name_underscore(group: str) -> str: + soap_name_underscores = { + "BoundSubhalo": "bound_subhalo", + "InputHalos": "input_halos", + "InclusiveSphere": "inclusive_sphere", + "ExclusiveSphere": "exclusive_sphere", + "SO": "spherical_overdensity", + "SOAP": "soap", + "ProjectedAperture": "projected_aperture", + } + split_name = group.split("/") + split_name[0] = soap_name_underscores[split_name[0]] + return "_".join(name.lower() for name in split_name) + + +def get_soap_name_nice(group: str) -> str: + soap_name_nice = { + "BoundSubhalo": "BoundSubhalo", + "InputHalos": "InputHalos", + "InclusiveSphere": "InclusiveSphere", + "ExclusiveSphere": "ExclusiveSphere", + "SO": "SphericalOverdensity", + "SOAP": "SOAP", + "ProjectedAperture": "ProjectedAperture", + } + split_name = group.split("/") + return "".join(name.capitalize() for name in split_name) diff --git a/swiftsimio/objects.py b/swiftsimio/objects.py index 6b78fa9d..2bbe3248 100644 --- a/swiftsimio/objects.py +++ b/swiftsimio/objects.py @@ -96,10 +96,10 @@ heaviside, matmul, ) -from numpy.core.umath import _ones_like +from numpy._core.umath import _ones_like try: - from numpy.core.umath import clip + from numpy._core.umath import clip except ImportError: clip = None @@ -107,6 +107,11 @@ a = sympy.symbols("a") +class InvalidConversionError(Exception): + def __init__(self, message="Could not convert to comoving coordinates"): + self.message = message + + def _propagate_cosmo_array_attributes(func): def wrapped(self, *args, **kwargs): ret = func(self, *args, **kwargs) @@ -116,6 +121,8 @@ def wrapped(self, *args, **kwargs): ret.cosmo_factor = self.cosmo_factor if hasattr(self, "comoving"): ret.comoving = self.comoving + if hasattr(self, "valid_transform"): + ret.valid_transform = self.valid_transform return ret return wrapped @@ -591,7 +598,7 @@ class cosmo_array(unyt_array): Cosmology array class. This inherits from the unyt.unyt_array, and adds - three variables: compression, cosmo_factor, and comoving. + four variables: compression, cosmo_factor, comoving, and valid_transform. Data is assumed to be comoving when passed to the object but you can override this by setting the latter flag to be False. @@ -615,8 +622,12 @@ class cosmo_array(unyt_array): String describing any compression that was applied to this array in the hdf5 file. + valid_transform: bool + if True then the array can be converted from physical to comoving units + """ + # TODO: _cosmo_factor_ufunc_registry = { add: _preserve_cosmo_factor, subtract: _preserve_cosmo_factor, @@ -718,6 +729,7 @@ def __new__( name=None, cosmo_factor=None, comoving=True, + valid_transform=True, compression=None, ): """ @@ -753,6 +765,8 @@ def __new__( coordinates comoving : bool flag to indicate whether using comoving coordinates + valid_transform : bool + flag to indicate whether this array can be converted to comoving compression : string description of the compression filters that were applied to that array in the hdf5 file @@ -800,6 +814,15 @@ def __new__( obj.cosmo_factor = cosmo_factor obj.comoving = comoving obj.compression = compression + obj.valid_transform = valid_transform + if not obj.valid_transform: + assert ( + not obj.comoving + ), "Cosmo arrays without a valid transform to comoving units must be physical" + if obj.comoving: + assert ( + obj.valid_transform + ), "Comoving Cosmo arrays must be able to be transformed to physical" return obj @@ -810,6 +833,7 @@ def __array_finalize__(self, obj): self.cosmo_factor = getattr(obj, "cosmo_factor", None) self.comoving = getattr(obj, "comoving", True) self.compression = getattr(obj, "compression", None) + self.valid_transform = getattr(obj, "valid_transform", True) def __str__(self): if self.comoving: @@ -819,6 +843,15 @@ def __str__(self): return super().__str__() + " " + comoving_str + def __repr__(self): + if self.comoving: + comoving_str = ", comoving=True)" + else: + comoving_str = ", comoving=False)" + + # Remove final parenthesis and append comoving flag + return super().__repr__()[:-1] + comoving_str + def __reduce__(self): """ Pickle reduction method @@ -828,7 +861,9 @@ def __reduce__(self): """ np_ret = super(cosmo_array, self).__reduce__() obj_state = np_ret[2] - cosmo_state = (((self.cosmo_factor, self.comoving),) + obj_state[:],) + cosmo_state = ( + ((self.cosmo_factor, self.comoving, self.valid_transform),) + obj_state[:], + ) new_ret = np_ret[:2] + cosmo_state + np_ret[3:] return new_ret @@ -840,7 +875,7 @@ def __setstate__(self, state): state and pass the rest to unyt_array.__setstate__. """ super(cosmo_array, self).__setstate__(state[1:]) - self.cosmo_factor, self.comoving = state[0] + self.cosmo_factor, self.comoving, self.valid_transform = state[0] # Wrap functions that return copies of cosmo_arrays so that our # attributes get passed through: @@ -879,12 +914,13 @@ def convert_to_comoving(self) -> None: """ if self.comoving: return - else: - # Best to just modify values as otherwise we're just going to have - # to do a convert_to_units anyway. - values = self.d - values /= self.cosmo_factor.a_factor - self.comoving = True + if not self.valid_transform: + raise InvalidConversionError + # Best to just modify values as otherwise we're just going to have + # to do a convert_to_units anyway. + values = self.d + values /= self.cosmo_factor.a_factor + self.comoving = True def convert_to_physical(self) -> None: """ @@ -922,6 +958,8 @@ def to_comoving(self): cosmo_array copy of cosmo_array in comoving units """ + if not self.valid_transform: + raise InvalidConversionError copied_data = self.in_units(self.units, cosmo_factor=self.cosmo_factor) copied_data.convert_to_comoving() @@ -947,7 +985,13 @@ def compatible_with_physical(self): @classmethod def from_astropy( - cls, arr, unit_registry=None, comoving=True, cosmo_factor=None, compression=None + cls, + arr, + unit_registry=None, + comoving=True, + cosmo_factor=None, + compression=None, + valid_transform=True, ): """ Convert an AstroPy "Quantity" to a cosmo_array. @@ -967,6 +1011,8 @@ def from_astropy( compression : string String describing any compression that was applied to this array in the hdf5 file. + valid_transform : bool + flag to indicate whether this array can be converted to comoving Example ------- @@ -979,12 +1025,19 @@ def from_astropy( obj.comoving = comoving obj.cosmo_factor = cosmo_factor obj.compression = compression + obj.valid_trasform = valid_transform return obj @classmethod def from_pint( - cls, arr, unit_registry=None, comoving=True, cosmo_factor=None, compression=None + cls, + arr, + unit_registry=None, + comoving=True, + cosmo_factor=None, + compression=None, + valid_transform=True, ): """ Convert a Pint "Quantity" to a cosmo_array. @@ -1004,6 +1057,8 @@ def from_pint( compression : string String describing any compression that was applied to this array in the hdf5 file. + valid_transform : bool + flag to indicate whether this array can be converted to comoving Examples -------- @@ -1022,9 +1077,11 @@ def from_pint( obj.comoving = comoving obj.cosmo_factor = cosmo_factor obj.compression = compression + obj.valid_trasform = valid_transform return obj + # TODO: def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): cms = [ (hasattr(inp, "comoving"), getattr(inp, "comoving", None)) for inp in inputs diff --git a/swiftsimio/reader.py b/swiftsimio/reader.py index ba6be8c3..809b3817 100644 --- a/swiftsimio/reader.py +++ b/swiftsimio/reader.py @@ -4,17 +4,23 @@ + SWIFTUnits, which is a unit system that can be queried for units (and converts arrays to relevant unyt arrays when read from the HDF5 file) + SWIFTMetadata, which contains all of the metadata from the file -+ __SWIFTParticleDataset, which contains particle information but should never be ++ __SWIFTGroupDataset, which contains particle information but should never be directly accessed. Use generate_dataset to create one of these. The reasoning here is that properties can only be added to the class afterwards, and not directly in an _instance_ of the class. + SWIFTDataset, a container class for all of the above. """ -from swiftsimio import metadata + from swiftsimio.accelerated import read_ranges_from_file from swiftsimio.objects import cosmo_array, cosmo_factor, a -from swiftsimio.conversions import swift_cosmology_to_astropy + +from swiftsimio.metadata.objects import ( + metadata_discriminator, + SWIFTUnits, + SWIFTGroupMetadata, + SWIFTMetadata, +) import re import h5py @@ -28,1033 +34,6 @@ from typing import Union, Callable, List, Optional -class MassTable(object): - """ - Extracts a mass table to local variables based on the - particle type names. - """ - - def __init__(self, base_mass_table: np.array, mass_units: unyt.unyt_quantity): - """ - Parameters - ---------- - - base_mass_table : np.array - Mass table of the same length as the number of particle types. - - mass_units : unyt_quantity - Base mass units for the simulation. - """ - - for index, name in metadata.particle_types.particle_name_underscores.items(): - try: - setattr( - self, - name, - unyt.unyt_quantity(base_mass_table[index], units=mass_units), - ) - except IndexError: - # Backwards compatible. - setattr(self, name, None) - - return - - def __str__(self): - return f"Mass table for {' '.join(metadata.particle_types.particle_name_underscores.values())}" - - def __repr__(self): - return self.__str__() - - -class MappingTable(object): - """ - A mapping table from one named column instance to the other. - Initially designed for the mapping between dust and elements. - """ - - def __init__( - self, - data: np.ndarray, - named_columns_x: List[str], - named_columns_y: List[str], - named_columns_x_name: str, - named_columns_y_name: str, - ): - """ - Parameters - ---------- - - data: np.ndarray - The data array providing the mapping between the named - columns. Should be of size N x M, where N is the number - of elements in ``named_columns_x`` and M the number - of elements in ``named_columns_y``. - - named_columns_x: List[str] - The names of the columns in the first axis. - - named_columns_y: List[str] - The names of the columns in the second axis. - - named_columns_x_name: str - The name of the first mapping. - - named_columns_y_name: str - The name of the second mapping. - """ - - self.data = data - self.named_columns_x = named_columns_x - self.named_columns_y = named_columns_y - self.named_columns_x_name = named_columns_x_name - self.named_columns_y_name = named_columns_y_name - - for x, name_x in enumerate(named_columns_x): - for y, name_y in enumerate(named_columns_y): - setattr(self, f"{name_x.lower()}_to_{name_y.lower()}", data[x][y]) - - return - - def __str__(self): - return ( - f"Mapping table from {self.named_columns_x_name} to " - f"{self.named_columns_y_name}, containing {len(self.data)} " - f"by {len(self.data[0])} elements." - ) - - def __repr__(self): - return f"{self.__str__()}. Raw data: " "\n" f"{self.data}." - - -class SWIFTUnits(object): - """ - Generates a unyt system that can then be used with the SWIFT data. - - These give the unit mass, length, time, current, and temperature as - unyt unit variables in simulation units. I.e. you can take any value - that you get out of the code and multiply it by the appropriate values - to get it 'unyt-ified' with the correct units. - - Attributes - ---------- - mass : float - unit for mass used - length : float - unit for length used - time : float - unit for time used - current : float - unit for current used - temperature : float - unit for temperature used - - """ - - def __init__(self, filename: Path, handle: Optional[h5py.File] = None): - """ - SWIFTUnits constructor - - Sets filename for file to read units from and gets unit dictionary - - Parameters - ---------- - - filename : Path - Name of file to read units from - - handle: h5py.File, optional - The h5py file handle, optional. Will open a new handle with the - filename if required. - - """ - self.filename = filename - self._handle = handle - - self.get_unit_dictionary() - - return - - @property - def handle(self): - """ - Property that gets the file handle, which can be shared - with other objects for efficiency reasons. - """ - if isinstance(self._handle, h5py.File): - # Can be open or closed, let's test. - try: - file = self._handle.file - - return self._handle - except ValueError: - # This will be the case if there is no active file handle - pass - - self._handle = h5py.File(self.filename, "r") - - return self._handle - - def get_unit_dictionary(self): - """ - Store unit data and metadata - - Length 1 arrays are used to store the unit data. This dictionary - also contains the metadata information that connects the unyt - objects to the names that are stored in the SWIFT snapshots. - """ - - self.units = { - name: unyt.unyt_quantity( - value[0], units=metadata.unit_types.unit_names_to_unyt[name] - ) - for name, value in self.handle["Units"].attrs.items() - } - - # We now unpack this into variables. - self.mass = metadata.unit_types.find_nearest_base_unit( - self.units["Unit mass in cgs (U_M)"], "mass" - ) - self.length = metadata.unit_types.find_nearest_base_unit( - self.units["Unit length in cgs (U_L)"], "length" - ) - self.time = metadata.unit_types.find_nearest_base_unit( - self.units["Unit time in cgs (U_t)"], "time" - ) - self.current = metadata.unit_types.find_nearest_base_unit( - self.units["Unit current in cgs (U_I)"], "current" - ) - self.temperature = metadata.unit_types.find_nearest_base_unit( - self.units["Unit temperature in cgs (U_T)"], "temperature" - ) - - def __del__(self): - if isinstance(self._handle, h5py.File): - self._handle.close() - - -class SWIFTMetadata(object): - """ - Loads all metadata (apart from Units, those are handled by SWIFTUnits) - into dictionaries. - - This also does some extra parsing on some well-used metadata. - """ - - # Name of the file that has been read from - filename: str - # Unit instance associated with this file - units: SWIFTUnits - # Header dictionary, metadata about snapshot. - header: dict - - def __init__(self, filename, units: SWIFTUnits): - """ - Constructor for SWIFTMetadata object - - Parameters - ---------- - - filename : str - name of file to read from - - units : SWIFTUnits - the units being used - """ - self.filename = filename - self.units = units - - self.get_metadata() - self.get_named_column_metadata() - self.get_mapping_metadata() - - self.postprocess_header() - - self.load_particle_types() - self.extract_cosmology() - - # After we've loaded all this metadata, we can safely release the file handle. - self.handle.close() - - return - - @property - def handle(self): - # Handle, which is shared with units. Units handles - # file opening and closing. - return self.units.handle - - def get_metadata(self): - """ - Loads the metadata as specified in metadata.metadata_fields. - """ - - for field, name in metadata.metadata_fields.metadata_fields_to_read.items(): - try: - setattr(self, name, dict(self.handle[field].attrs)) - except KeyError: - setattr(self, name, None) - - return - - def get_named_column_metadata(self): - """ - Loads the custom named column metadata (if it exists) from - SubgridScheme/NamedColumns. - """ - - try: - data = self.handle["SubgridScheme/NamedColumns"] - - self.named_columns = { - k: [x.decode("utf-8") for x in data[k][:]] for k in data.keys() - } - except KeyError: - self.named_columns = {} - - return - - def get_mapping_metadata(self): - """ - Gets the mappings based on the named columns (must have already been read), - from the form: - - SubgridScheme/{X}To{Y}Mapping. - - Includes a hack of `Dust` -> `Grains` that will be deprecated. - """ - - try: - possible_keys = self.handle["SubgridScheme"].keys() - - available_keys = [key for key in possible_keys if key.endswith("Mapping")] - available_data = [ - self.handle[f"SubgridScheme/{key}"][:] for key in available_keys - ] - except KeyError: - available_keys = [] - available_data = [] - - # Keys have form {X}To{Y}Mapping - - # regular expression for camel case to snake case - # https://stackoverflow.com/a/1176023 - def convert(name): - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() - - regex = r"([a-zA-Z]*)To([a-zA-Z]*)Mapping" - compiled = re.compile(regex) - - for key, data in zip(available_keys, available_data): - match = compiled.match(key) - snake_case = convert(key) - - if match: - x = match.group(1) - y = match.group(2) - - if x == "Grain": - warnings.warn( - "Use of the GrainToElementMapping is deprecated, please use a newer " - "version of SWIFT to run this simulation.", - DeprecationWarning, - ) - - x = "Dust" - - named_column_name_x = [ - key for key in self.named_columns.keys() if key.startswith(x) - ][0] - named_column_name_y = [ - key for key in self.named_columns.keys() if key.startswith(y) - ][0] - - setattr( - self, - snake_case, - MappingTable( - data=data, - named_columns_x=self.named_columns[named_column_name_x], - named_columns_y=self.named_columns[named_column_name_y], - named_columns_x_name=named_column_name_x, - named_columns_y_name=named_column_name_y, - ), - ) - - return - - def postprocess_header(self): - """ - Some minor postprocessing on the header to local variables. - """ - - # These are just read straight in to variables - header_unpack_arrays_units = metadata.metadata_fields.generate_units_header_unpack_arrays( - m=self.units.mass, - l=self.units.length, - t=self.units.time, - I=self.units.current, - T=self.units.temperature, - ) - - for field, name in metadata.metadata_fields.header_unpack_arrays.items(): - try: - if name in header_unpack_arrays_units.keys(): - setattr( - self, - name, - unyt.unyt_array( - self.header[field], units=header_unpack_arrays_units[name] - ), - ) - # This is required or we automatically get everything in CGS! - getattr(self, name).convert_to_units( - header_unpack_arrays_units[name] - ) - else: - # Must not have any units! Oh well. - setattr(self, name, self.header[field]) - except KeyError: - # Must not be present, just skip it - continue - - # Now unpack the 'mass table' type items: - for field, name in metadata.metadata_fields.header_unpack_mass_tables.items(): - try: - setattr( - self, - name, - MassTable( - base_mass_table=self.header[field], mass_units=self.units.mass - ), - ) - except KeyError: - setattr( - self, - name, - MassTable( - base_mass_table=np.zeros( - len(metadata.particle_types.particle_name_underscores) - ), - mass_units=self.units.mass, - ), - ) - - # These must be unpacked as 'real' strings (i.e. converted to utf-8) - - for field, name in metadata.metadata_fields.header_unpack_string.items(): - try: - # Deal with h5py's quirkiness that fixed-sized and variable-sized - # strings are read as strings or bytes - # See: https://github.com/h5py/h5py/issues/2172 - raw = self.header[field] - try: - string = raw.decode("utf-8") - except AttributeError: - string = raw - setattr(self, name, string) - except KeyError: - # Must not be present, just skip it - setattr(self, name, "") - - # These must be unpacked as they are stored as length-1 arrays - - header_unpack_float_units = metadata.metadata_fields.generate_units_header_unpack_single_float( - m=self.units.mass, - l=self.units.length, - t=self.units.time, - I=self.units.current, - T=self.units.temperature, - ) - - for field, names in metadata.metadata_fields.header_unpack_single_float.items(): - try: - if isinstance(names, list): - # Sometimes we store a list in case we have multiple names, for example - # Redshift -> metadata.redshift AND metadata.z. Can't just do the iteration - # because we may loop over the letters in the string. - for variable in names: - if variable in header_unpack_float_units.keys(): - # We have an associated unit! - unit = header_unpack_float_units[variable] - setattr( - self, - variable, - unyt.unyt_quantity(self.header[field][0], units=unit), - ) - else: - # No unit - setattr(self, variable, self.header[field][0]) - else: - # We can just check for the unit and set the attribute - variable = names - if variable in header_unpack_float_units.keys(): - # We have an associated unit! - unit = header_unpack_float_units[variable] - setattr( - self, - variable, - unyt.unyt_quantity(self.header[field][0], units=unit), - ) - else: - # No unit - setattr(self, variable, self.header[field][0]) - except KeyError: - # Must not be present, just skip it - continue - - # These are special cases, sorry! - # Date and time of snapshot dump - try: - try: - self.snapshot_date = datetime.strptime( - self.header.get( - "SnapshotDate", self.header.get("Snapshot date", b"") - ).decode("utf-8"), - "%H:%M:%S %Y-%m-%d %Z", - ) - except ValueError: - # Backwards compatibility; this was used previously due to simplicity - # but is not portable between regions. So if you ran a simulation on - # a British (en_GB) machine, and then tried to read on a Dutch - # machine (nl_NL), this would _not_ work because %c is different. - try: - self.snapshot_date = datetime.strptime( - self.header.get( - "SnapshotDate", self.header.get("Snapshot date", b"") - ).decode("utf-8"), - "%c\n", - ) - except ValueError: - # Oh dear this has gone _very_wrong. Let's just keep it as a string. - self.snapshot_date = self.header.get( - "SnapshotDate", self.header.get("Snapshot date", b"") - ).decode("utf-8") - except KeyError: - # Old file - pass - - # get photon group edges RT dataset from the SubgridScheme group - try: - self.photon_group_edges = ( - self.handle["SubgridScheme/PhotonGroupEdges"][:] / self.units.time - ) - except KeyError: - self.photon_group_edges = None - - # get reduced speed of light RT dataset from the SubgridScheme group - try: - self.reduced_lightspeed = ( - self.handle["SubgridScheme/ReducedLightspeed"][0] - * self.units.length - / self.units.time - ) - except KeyError: - self.reduced_lightspeed = None - - # Store these separately as self.n_gas = number of gas particles for example - for ( - part_number, - part_name, - ) in metadata.particle_types.particle_name_underscores.items(): - try: - setattr(self, f"n_{part_name}", self.num_part[part_number]) - except IndexError: - # Backwards compatibility; mass/number table can change size. - setattr(self, f"n_{part_name}", 0) - - # Need to unpack the gas gamma for cosmology - try: - self.gas_gamma = self.hydro_scheme["Adiabatic index"] - except (KeyError, TypeError): - print("Could not find gas gamma, assuming 5./3.") - self.gas_gamma = 5.0 / 3.0 - - try: - self.a = self.scale_factor - except AttributeError: - # These must always be present for the initialisation of cosmology properties - self.a = 1.0 - self.scale_factor = 1.0 - - return - - def load_particle_types(self): - """ - Loads the particle types and metadata into objects: - - metadata._properties - - This contains six arrays, - - metadata._properties.field_names - metadata._properties.field_paths - metadata._properties.field_units - metadata._properties.field_cosmologies - metadata._properties.field_descriptions - metadata._properties.field_compressions - - As well as some more information about the particle type. - """ - - for particle_type, particle_name in zip( - self.present_particle_types, self.present_particle_names - ): - setattr( - self, - f"{particle_name}_properties", - SWIFTParticleTypeMetadata( - particle_type=particle_type, - particle_name=particle_name, - metadata=self, - scale_factor=self.scale_factor, - ), - ) - - return - - def extract_cosmology(self): - """ - Creates an astropy.cosmology object from the internal cosmology system. - - This will be saved as ``self.cosmology``. - """ - - if self.cosmology_raw is not None: - cosmo = self.cosmology_raw - else: - cosmo = {"Cosmological run": 0} - - if cosmo.get("Cosmological run", 0): - self.cosmology = swift_cosmology_to_astropy(cosmo, units=self.units) - else: - self.cosmology = None - - return - - @property - def present_particle_types(self): - """ - The particle types that are present in the file. - """ - - return np.where(np.array(getattr(self, "has_type", self.num_part)) != 0)[0] - - @property - def present_particle_names(self): - """ - The particle _names_ that are present in the simulation. - """ - - return [ - metadata.particle_types.particle_name_underscores[x] - for x in self.present_particle_types - ] - - @property - def code_info(self) -> str: - """ - Gets a nicely printed set of code information with: - - Name (Git Branch) - Git Revision - Git Date - """ - - def get_string(x): - return self.code[x].decode("utf-8") - - output = ( - f"{get_string('Code')} ({get_string('Git Branch')})\n" - f"{get_string('Git Revision')}\n" - f"{get_string('Git Date')}" - ) - - return output - - @property - def compiler_info(self) -> str: - """ - Gets information about the compiler and formats it as: - - Compiler Name (Compiler Version) - MPI library - """ - - def get_string(x): - return self.code[x].decode("utf-8") - - output = ( - f"{get_string('Compiler Name')} ({get_string('Compiler Version')})\n" - f"{get_string('MPI library')}" - ) - - return output - - @property - def library_info(self) -> str: - """ - Gets information about the libraries used and formats it as: - - FFTW vFFTW library version - GSL vGSL library version - HDF5 vHDF5 library version - """ - - def get_string(x): - return self.code[f"{x} library version"].decode("utf-8") - - output = ( - f"FFTW v{get_string('FFTW')}\n" - f"GSL v{get_string('GSL')}\n" - f"HDF5 v{get_string('HDF5')}" - ) - - return output - - @property - def hydro_info(self) -> str: - r""" - Gets information about the hydro scheme and formats it as: - - Scheme - Kernel function in DimensionD - $\eta$ = Kernel eta (Kernel target N_ngb $N_{ngb}$) - $C_{\rm CFL}$ = CFL parameter - """ - - def get_float(x): - return "{:4.2f}".format(self.hydro_scheme[x][0]) - - def get_int(x): - return int(self.hydro_scheme[x][0]) - - def get_string(x): - return self.hydro_scheme[x].decode("utf-8") - - output = ( - f"{get_string('Scheme')}\n" - f"{get_string('Kernel function')} in {get_int('Dimension')}D\n" - rf"$\eta$ = {get_float('Kernel eta')} " - rf"({get_float('Kernel target N_ngb')} $N_{{ngb}}$)" - "\n" - rf"$C_{{\rm CFL}}$ = {get_float('CFL parameter')}" - ) - - return output - - @property - def viscosity_info(self) -> str: - r""" - Gets information about the viscosity scheme and formats it as: - - Viscosity Model - $\alpha_{V, 0}$ = Alpha viscosity, $\ell_V$ = Viscosity decay length [internal units], $\beta_V$ = Beta viscosity - Alpha viscosity (min) < $\alpha_V$ < Alpha viscosity (max) - """ - - def get_float(x): - return "{:4.2f}".format(self.hydro_scheme[x][0]) - - def get_string(x): - return self.hydro_scheme[x].decode("utf-8") - - output = ( - f"{get_string('Viscosity Model')}\n" - rf"$\alpha_{{V, 0}}$ = {get_float('Alpha viscosity')}, " - rf"$\ell_V$ = {get_float('Viscosity decay length [internal units]')}, " - rf"$\beta_V$ = {get_float('Beta viscosity')}" - "\n" - rf"{get_float('Alpha viscosity (min)')} < $\alpha_V$ < {get_float('Alpha viscosity (max)')}" - ) - - return output - - @property - def diffusion_info(self) -> str: - """ - Gets information about the diffusion scheme and formats it as: - - $\alpha_{D, 0}$ = Diffusion alpha, $\beta_D$ = Diffusion beta - Diffusion alpha (min) < $\alpha_D$ < Diffusion alpha (max) - """ - - def get_float(x): - return "{:4.2f}".format(self.hydro_scheme[x][0]) - - output = ( - rf"$\alpha_{{D, 0}}$ = {get_float('Diffusion alpha')}, " - rf"$\beta_D$ = {get_float('Diffusion beta')}" - "\n" - rf"${get_float('Diffusion alpha (min)')} < " - rf"\alpha_D < {get_float('Diffusion alpha (max)')}$" - ) - - return output - - @property - def partial_snapshot(self) -> bool: - """ - Whether or not this snapshot is partial (e.g. a "x.0.hdf5" file), or - a file describing an entire snapshot. - """ - - # Partial snapshots have num_files_per_snapshot set to 1. Virtual snapshots - # collating multiple sub-snapshots together have num_files_per_snapshot = 1. - - return self.num_files_per_snapshot > 1 - - -class SWIFTParticleTypeMetadata(object): - """ - Object that contains the metadata for one particle type. - - This, for instance, could be part type 0, or 'gas'. This will load in - the names of all particle datasets, their units, possible named fields, - and their cosmology, and present them for use in the actual i/o routines. - - Methods - ------- - load_metadata(self): - Loads the required metadata. - load_field_names(self): - Loads in the field names. - load_field_units(self): - Loads in the units from each dataset. - load_field_descriptions(self): - Loads in descriptions of the fields for each dataset. - load_field_compressions(self): - Loads in compressions of the fields for each dataset. - load_cosmology(self): - Loads in the field cosmologies. - load_named_columns(self): - Loads the named column data for relevant fields. - """ - - def __init__( - self, - particle_type: int, - particle_name: str, - metadata: SWIFTMetadata, - scale_factor: float, - ): - """ - Constructor for SWIFTParticleTypeMetadata class - - Parameters - ---------- - partycle_type : int - the integer particle type - particle_name : str - the corresponding particle name - metadata : SWIFTMetadata - the snapshot metadata - scale_factor : float - the snapshot scale factor - """ - self.particle_type = particle_type - self.particle_name = particle_name - self.metadata = metadata - self.units = metadata.units - self.scale_factor = scale_factor - - self.filename = metadata.filename - - self.load_metadata() - - return - - def __str__(self): - return f"Metadata class for PartType{self.particle_type} ({self.particle_name})" - - def __repr__(self): - return self.__str__() - - def load_metadata(self): - """ - Loads the required metadata. - - This includes loading the field names, units and descriptions, as well as the - cosmology metadata and any custom named columns - """ - - self.load_field_names() - self.load_field_units() - self.load_field_descriptions() - self.load_field_compressions() - self.load_cosmology() - self.load_named_columns() - - def load_field_names(self): - """ - Loads in only the field names. - """ - - # regular expression for camel case to snake case - # https://stackoverflow.com/a/1176023 - def convert(name): - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() - - self.field_paths = [ - f"PartType{self.particle_type}/{item}" - for item in self.metadata.handle[f"PartType{self.particle_type}"].keys() - ] - - self.field_names = [ - convert(item) - for item in self.metadata.handle[f"PartType{self.particle_type}"].keys() - ] - - return - - def load_field_units(self): - """ - Loads in the units from each dataset. - """ - - unit_dict = { - "I": self.units.current, - "L": self.units.length, - "M": self.units.mass, - "T": self.units.temperature, - "t": self.units.time, - } - - def get_units(unit_attribute): - units = 1.0 - - for exponent, unit in unit_dict.items(): - # We store the 'unit exponent' in the SWIFT metadata. This corresponds - # to the power we need to raise each unit to, to return the correct units - try: - # Need to check if the exponent is 0 manually because of float precision - unit_exponent = unit_attribute[f"U_{exponent} exponent"][0] - if unit_exponent != 0.0: - units *= unit ** unit_exponent - except KeyError: - # Can't load that data! - # We should probably warn the user here... - pass - - # Deal with case where we _really_ have a dimensionless quantity. Comparing with - # 1.0 doesn't work, beacause in these cases unyt reverts to a floating point - # comparison. - try: - units.units - except AttributeError: - units = None - - return units - - self.field_units = [ - get_units(self.metadata.handle[x].attrs) for x in self.field_paths - ] - - return - - def load_field_descriptions(self): - """ - Loads in the text descriptions of the fields for each dataset. - """ - - def get_desc(dataset): - try: - description = dataset.attrs["Description"].decode("utf-8") - except KeyError: - # Can't load description! - description = "No description available" - - return description - - self.field_descriptions = [ - get_desc(self.metadata.handle[x]) for x in self.field_paths - ] - - return - - def load_field_compressions(self): - """ - Loads in the string describing the compression filters of the fields for each dataset. - """ - - def get_comp(dataset): - try: - comp = dataset.attrs["Lossy compression filter"].decode("utf-8") - except KeyError: - # Can't load compression string! - comp = "No compression info available" - - return comp - - self.field_compressions = [ - get_comp(self.metadata.handle[x]) for x in self.field_paths - ] - - return - - def load_cosmology(self): - """ - Loads in the field cosmologies. - """ - - current_scale_factor = self.scale_factor - - def get_cosmo(dataset): - try: - cosmo_exponent = dataset.attrs["a-scale exponent"][0] - except: - # Can't load, 'graceful' fallback. - cosmo_exponent = 0.0 - - a_factor_this_dataset = a ** cosmo_exponent - - return cosmo_factor(a_factor_this_dataset, current_scale_factor) - - self.field_cosmologies = [ - get_cosmo(self.metadata.handle[x]) for x in self.field_paths - ] - - return - - def load_named_columns(self): - """ - Loads the named column data for relevant fields. - """ - - named_columns = {} - - for field in self.field_paths: - property_name = field.split("/")[-1] - - if property_name in self.metadata.named_columns.keys(): - field_names = self.metadata.named_columns[property_name] - - # Now need to make a decision on capitalisation. If we have a set of - # words with only one capital in them, then it's likely that they are - # element names or something similar, so they should be lower case. - # If on average we have many more capitals, then they are likely to be - # ionized fractions (e.g. HeII) and so we want to leave them with their - # original capitalisation. - - num_capitals = lambda x: sum(1 for c in x if c.isupper()) - mean_num_capitals = sum(map(num_capitals, field_names)) / len( - field_names - ) - - if mean_num_capitals < 1.01: - # Decapitalise them as they are likely individual element names - formatted_field_names = [x.lower() for x in field_names] - else: - formatted_field_names = field_names - - named_columns[field] = formatted_field_names - else: - named_columns[field] = None - - self.named_columns = named_columns - - return - - def generate_getter( filename, name: str, @@ -1065,6 +44,8 @@ def generate_getter( cosmo_factor: cosmo_factor, description: str, compression: str, + physical: bool, + valid_transform: bool, columns: Union[None, slice] = None, ): """ @@ -1110,6 +91,14 @@ def generate_getter( String describing the lossy compression filters that were applied to the data (read from the HDF5 file). + physical: bool + Bool that describes whether the data in the file is stored in comoving + or physical units. + + valid_transform: bool + Bool that describes whether converting this field from physical to comoving + units is a valid operation. + columns: np.lib.index_tricks.IndexEpression, optional Index expression corresponding to which columns to read from the numpy array. If not provided, we read all columns and return an n-dimensional array. @@ -1121,7 +110,7 @@ def generate_getter( getter: callable A callable object that gets the value of the array that has been saved to ``_name``. This function takes only ``self`` from the - :obj:``__SWIFTParticleDataset`` class. + :obj:``__SWIFTGroupDataset`` class. Notes @@ -1181,6 +170,8 @@ def getter(self): cosmo_factor=cosmo_factor, name=description, compression=compression, + comoving=not physical, + valid_transform=valid_transform, ), ) else: @@ -1197,6 +188,8 @@ def getter(self): cosmo_factor=cosmo_factor, name=description, compression=compression, + comoving=not physical, + valid_transform=valid_transform, ), ) except KeyError: @@ -1257,7 +250,7 @@ def deleter(self): return deleter -class __SWIFTParticleDataset(object): +class __SWIFTGroupDataset(object): """ Creates empty property fields @@ -1270,26 +263,26 @@ class __SWIFTParticleDataset(object): creates empty properties to be accessed through setter and getter functions """ - def __init__(self, particle_metadata: SWIFTParticleTypeMetadata): + def __init__(self, group_metadata: SWIFTGroupMetadata): """ - Constructor for SWIFTParticleDataset class + Constructor for SWIFTGroupDatasets class This function primarily calls the generate_empty_properties function to ensure that defaults are set correctly. Parameters ---------- - particle_metadata : SWIFTParticleTypeMetadata + group_metadata : SWIFTGroupMetadata the metadata used to generate empty properties """ - self.filename = particle_metadata.filename - self.units = particle_metadata.units + self.filename = group_metadata.filename + self.units = group_metadata.units - self.particle_type = particle_metadata.particle_type - self.particle_name = particle_metadata.particle_name + self.group = group_metadata.group + self.group_name = group_metadata.group_name - self.particle_metadata = particle_metadata - self.metadata = particle_metadata.metadata + self.group_metadata = group_metadata + self.metadata = group_metadata.metadata self.generate_empty_properties() @@ -1305,7 +298,7 @@ def generate_empty_properties(self): """ for field_name, field_path in zip( - self.particle_metadata.field_names, self.particle_metadata.field_paths + self.group_metadata.field_names, self.group_metadata.field_paths ): if field_path in self.metadata.handle: setattr(self, f"_{field_name}", None) @@ -1323,7 +316,7 @@ def generate_empty_properties(self): class __SWIFTNamedColumnDataset(object): """ Holder class for individual named datasets. Very similar to - __SWIFTParticleDataset but much simpler. + __SWIFTGroupsDatasets but much simpler. """ def __init__(self, field_path: str, named_columns: List[str], name: str): @@ -1379,9 +372,9 @@ def __eq__(self, other): return self.named_columns == other.named_columns and self.name == other.name -def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): +def generate_datasets(group_metadata: SWIFTGroupMetadata, mask): """ - Generates a SWIFTParticleDataset _class_ that corresponds to the + Generates a SWIFTGroupDatasets _class_ that corresponds to the particle type given. We _must_ do the following _outside_ of the class itself, as one @@ -1391,44 +384,49 @@ def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): Here we loop through all of the possible properties in the metadata file. We then use the builtin property() function and some generators to create setters and getters for those properties. This will allow them - to be accessed from outside by using SWIFTParticleDataset.name, where + to be accessed from outside by using SWIFTGroupDatasets.name, where the name is, for example, coordinates. Parameters ---------- - particle_metadata : SWIFTParticleTypeMetadata - the metadata for the particle type + group_metadata : SWIFTGroupMetadata + the metadata for the group mask : SWIFTMask - the mask object for the dataset + the mask object for the datasets """ - filename = particle_metadata.filename - particle_type = particle_metadata.particle_type - particle_name = particle_metadata.particle_name - particle_nice_name = metadata.particle_types.particle_name_class[particle_type] + filename = group_metadata.filename + group = group_metadata.group + group_name = group_metadata.group_name + + group_nice_name = group_metadata.metadata.get_nice_name(group) # Mask is an object that contains all masks for all possible datasets. if mask is not None: - mask_array = getattr(mask, particle_name) - mask_size = getattr(mask, f"{particle_name}_size") + mask_array = getattr(mask, group_name) + mask_size = getattr(mask, f"{group_name}_size") else: mask_array = None mask_size = -1 # Set up an iterator for us to loop over for all fields - field_paths = particle_metadata.field_paths - field_names = particle_metadata.field_names - field_cosmologies = particle_metadata.field_cosmologies - field_units = particle_metadata.field_units - field_descriptions = particle_metadata.field_descriptions - field_compressions = particle_metadata.field_compressions - field_named_columns = particle_metadata.named_columns + field_paths = group_metadata.field_paths + field_names = group_metadata.field_names + field_cosmologies = group_metadata.field_cosmologies + field_units = group_metadata.field_units + field_physicals = group_metadata.field_physicals + field_valid_transforms = group_metadata.field_valid_transforms + field_descriptions = group_metadata.field_descriptions + field_compressions = group_metadata.field_compressions + field_named_columns = group_metadata.named_columns dataset_iterator = zip( field_paths, field_names, field_cosmologies, field_units, + field_physicals, + field_valid_transforms, field_descriptions, field_compressions, ) @@ -1437,7 +435,7 @@ def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): # for different particle types. We initially fill a dict with the properties that # we want, and then create a single instance of our class. - this_dataset_bases = (__SWIFTParticleDataset, object) + this_dataset_bases = (__SWIFTGroupDataset, object) this_dataset_dict = {} for ( @@ -1445,6 +443,8 @@ def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): field_name, field_cosmology, field_unit, + field_physical, + field_valid_transform, field_description, field_compression, ) in dataset_iterator: @@ -1462,6 +462,8 @@ def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): cosmo_factor=field_cosmology, description=field_description, compression=field_compression, + physical=field_physical, + valid_transform=field_valid_transform, ), generate_setter(field_name), generate_deleter(field_name), @@ -1488,6 +490,8 @@ def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): cosmo_factor=field_cosmology, description=f"{field_description} [Column {index}, {column}]", compression=field_compression, + physical=field_physical, + valid_transform=field_valid_transform, columns=np.s_[index], ), generate_setter(column), @@ -1495,7 +499,7 @@ def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): ) ThisNamedColumnDataset = type( - f"{particle_nice_name}{field_path.split('/')[-1]}Columns", + f"{group_nice_name}{field_path.split('/')[-1]}Columns", this_named_column_dataset_bases, this_named_column_dataset_dict, ) @@ -1507,9 +511,9 @@ def generate_dataset(particle_metadata: SWIFTParticleTypeMetadata, mask): this_dataset_dict[field_name] = field_property ThisDataset = type( - f"{particle_nice_name}Dataset", this_dataset_bases, this_dataset_dict + f"{group_nice_name}Dataset", this_dataset_bases, this_dataset_dict ) - empty_dataset = ThisDataset(particle_metadata) + empty_dataset = ThisDataset(group_metadata) return empty_dataset @@ -1520,7 +524,7 @@ class SWIFTDataset(object): + SWIFTUnits, + SWIFTMetadata, - + SWIFTParticleDataset + + SWIFTGroupDatasets This object, in essence, completely represents a SWIFT snapshot. You can access the different particles as follows: @@ -1564,7 +568,7 @@ def __init__(self, filename, mask=None): self.get_units() self.get_metadata() - self.create_particle_datasets() + self.create_datasets() return @@ -1599,14 +603,14 @@ def get_metadata(self): this function again if you mess things up. """ - self.metadata = SWIFTMetadata(self.filename, self.units) + self.metadata = metadata_discriminator(self.filename, self.units) return - def create_particle_datasets(self): + def create_datasets(self): """ - Creates particle datasets for whatever particle types and names - are specified in metadata.particle_types. + Creates datasets for whichever groups + are specified in metadata.present_group_names. These can then be accessed using their underscore names, e.g. gas. """ @@ -1614,12 +618,12 @@ def create_particle_datasets(self): if not hasattr(self, "metadata"): self.get_metadata() - for particle_name in self.metadata.present_particle_names: + for group_name in self.metadata.present_group_names: setattr( self, - particle_name, - generate_dataset( - getattr(self.metadata, f"{particle_name}_properties"), self.mask + group_name, + generate_datasets( + getattr(self.metadata, f"{group_name}_properties"), self.mask ), ) diff --git a/swiftsimio/writer.py b/swiftsimio/snapshot_writer.py similarity index 95% rename from swiftsimio/writer.py rename to swiftsimio/snapshot_writer.py index 99b8899f..2d9a4206 100644 --- a/swiftsimio/writer.py +++ b/swiftsimio/snapshot_writer.py @@ -1,7 +1,7 @@ """ Contains functions and objects for creating SWIFT datasets. -Essentially all you want to do is use SWIFTWriterDataset and fill the attributes +Essentially all you want to do is use SWIFTSnapshotWriter and fill the attributes that are required for each particle type. More information is available in the README. """ @@ -17,6 +17,23 @@ from swiftsimio.metadata.cosmology.cosmology_fields import a_exponents +def _ptype_str_to_int(ptype_str): + """ + Convert a string like `"PartType0"` to an integer (in this example, `0`). + + Parameters + ---------- + ptype_str : str + The particle type string. + + Returns + ------- + out : int + The corresponding integer. + """ + return int(ptype_str.strip("PartType")) if "PartType" in ptype_str else ptype_str + + class __SWIFTWriterParticleDataset(object): """ A particle dataset for _writing_ with. This is explicitly different @@ -59,7 +76,6 @@ def __init__(self, unit_system: Union[unyt.UnitSystem, str], particle_type: int) self.unit_system = unit_system self.particle_type = particle_type - self.particle_handle = f"PartType{self.particle_type}" self.particle_name = metadata.particle_types.particle_name_underscores[ self.particle_type ] @@ -191,7 +207,7 @@ def write_particle_group(self, file_handle: h5py.File, compress: bool): flag to indicate whether to turn on gzip compression """ - particle_group = file_handle.create_group(self.particle_handle) + particle_group = file_handle.create_group(self.particle_type) if compress: compression = "gzip" @@ -223,7 +239,7 @@ def write_particle_group_metadata( for name, output_handle in getattr( metadata.required_fields, self.particle_name ).items(): - obj = file_handle[f"/PartType{self.particle_type}/{output_handle}"] + obj = file_handle[f"/{self.particle_type}/{output_handle}"] for attr_name, attr_value in dset_attributes[output_handle].items(): obj.attrs.create(attr_name, attr_value) @@ -477,7 +493,7 @@ def generate_dataset( return empty_dataset -class SWIFTWriterDataset(object): +class SWIFTSnapshotWriter(object): """ The SWIFT writer dataset. This is used to store all particle arrays and do some extra processing before writing a HDF5 file containing: @@ -500,7 +516,7 @@ def __init__( scale_factor: np.float32 = 1.0, ): """ - Creates SWIFTWriterDataset object + Creates SWIFTSnapshotWriter object Parameters ---------- @@ -601,14 +617,24 @@ def _write_metadata(self, handle: h5py.File, names_to_write: List): names_to_write : list list of metadata fields to write """ - part_types = max(metadata.particle.particle_name_underscores.keys()) + 1 + part_types = ( + max( + [ + _ptype_str_to_int(k) + for k in metadata.particle.particle_name_underscores.keys() + ] + ) + + 1 + ) number_of_particles = [0] * part_types mass_table = [0.0] * part_types for number, name in metadata.particle_types.particle_name_underscores.items(): if name in names_to_write: - number_of_particles[number] = getattr(self, name).n_part - mass_table[number] = getattr(self, name).masses[0] + number_of_particles[_ptype_str_to_int(number)] = getattr( + self, name + ).n_part + mass_table[_ptype_str_to_int(number)] = getattr(self, name).masses[0] attrs = { "BoxSize": self.box_size, diff --git a/swiftsimio/subset_writer.py b/swiftsimio/subset_writer.py index f73e7363..235cbbe2 100644 --- a/swiftsimio/subset_writer.py +++ b/swiftsimio/subset_writer.py @@ -28,14 +28,13 @@ def get_swift_name(name: str) -> str: str SWIFT particle type corresponding to `name` (e.g. PartType0) """ - part_type_nums = [ + part_type_names = [ k for k, v in metadata.particle_types.particle_name_underscores.items() ] part_types = [ v for k, v in metadata.particle_types.particle_name_underscores.items() ] - part_type_num = part_type_nums[part_types.index(name)] - return f"PartType{part_type_num}" + return part_type_names[part_types.index(name)] def get_dataset_mask( @@ -66,7 +65,7 @@ def get_dataset_mask( suffix = "" if suffix is None else suffix if "PartType" in dataset_name: - part_type = [int(x) for x in filter(str.isdigit, dataset_name)][0] + part_type = dataset_name.lstrip("/").split("/")[0] mask_name = metadata.particle_types.particle_name_underscores[part_type] return getattr(mask, f"{mask_name}{suffix}", None) else: diff --git a/swiftsimio/visualisation/power_spectrum.py b/swiftsimio/visualisation/power_spectrum.py index 8f24e978..cdd56f17 100644 --- a/swiftsimio/visualisation/power_spectrum.py +++ b/swiftsimio/visualisation/power_spectrum.py @@ -10,7 +10,7 @@ from swiftsimio.optional_packages import tqdm from swiftsimio.accelerated import jit, NUM_THREADS, prange from swiftsimio import cosmo_array -from swiftsimio.reader import __SWIFTParticleDataset +from swiftsimio.reader import __SWIFTGroupDataset from typing import Optional, Dict, Tuple @@ -169,7 +169,7 @@ def deposit_parallel( def render_to_deposit( - data: __SWIFTParticleDataset, + data: __SWIFTGroupDataset, resolution: int, project: str = "masses", folding: int = 0, diff --git a/swiftsimio/visualisation/projection.py b/swiftsimio/visualisation/projection.py index 1a92730f..fa0738b3 100644 --- a/swiftsimio/visualisation/projection.py +++ b/swiftsimio/visualisation/projection.py @@ -22,7 +22,7 @@ from unyt import unyt_array, unyt_quantity, exceptions from swiftsimio import SWIFTDataset, cosmo_array -from swiftsimio.reader import __SWIFTParticleDataset +from swiftsimio.reader import __SWIFTGroupDataset from swiftsimio.accelerated import jit, NUM_THREADS, prange from swiftsimio.visualisation.projection_backends import backends, backends_parallel @@ -42,7 +42,7 @@ def project_pixel_grid( - data: __SWIFTParticleDataset, + data: __SWIFTGroupDataset, boxsize: unyt_array, resolution: int, project: Union[str, None] = "masses", @@ -65,7 +65,7 @@ def project_pixel_grid( Parameters ---------- - data: __SWIFTParticleDataset + data: __SWIFTGroupDataset The SWIFT dataset that you wish to visualise (get this from ``load``) boxsize: unyt_array diff --git a/swiftsimio/visualisation/ray_trace.py b/swiftsimio/visualisation/ray_trace.py index 21c7229f..61960109 100644 --- a/swiftsimio/visualisation/ray_trace.py +++ b/swiftsimio/visualisation/ray_trace.py @@ -13,7 +13,7 @@ import math from swiftsimio.objects import cosmo_array -from swiftsimio.reader import __SWIFTParticleDataset, SWIFTDataset +from swiftsimio.reader import __SWIFTGroupDataset, SWIFTDataset from swiftsimio.visualisation.projection_backends.kernels import ( kernel_gamma, kernel_double_precision as kernel, @@ -216,7 +216,7 @@ def core_panels_parallel( def panel_pixel_grid( - data: __SWIFTParticleDataset, + data: __SWIFTGroupDataset, boxsize: unyt.unyt_array, resolution: int, panels: int, diff --git a/tests/subset_write_test.py b/tests/subset_write_test.py index 1e7bd126..fcd06a85 100644 --- a/tests/subset_write_test.py +++ b/tests/subset_write_test.py @@ -34,7 +34,7 @@ def compare_data_contents(A, B): A_type = getattr(A, part_type) B_type = getattr(B, part_type) particle_dataset_field_names = set( - A_type.particle_metadata.field_names + B_type.particle_metadata.field_names + A_type.group_metadata.field_names + B_type.group_metadata.field_names ) for attr in particle_dataset_field_names: diff --git a/tests/test_data.py b/tests/test_data.py index f40fe30a..1cc32d3e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -104,8 +104,8 @@ def test_units(filename): # Now need to extract the particle paths in the original hdf5 file # for comparison... - paths = numpy_array(field.particle_metadata.field_paths) - names = numpy_array(field.particle_metadata.field_names) + paths = numpy_array(field.group_metadata.field_paths) + names = numpy_array(field.group_metadata.field_names) for property in properties: # Read the 0th element, and compare in CGS units. diff --git a/tests/test_extraparts.py b/tests/test_extraparts.py index 89c432fa..796d4cc3 100644 --- a/tests/test_extraparts.py +++ b/tests/test_extraparts.py @@ -87,9 +87,9 @@ def test_write(): ) # Specify a new type in the metadata - currently done by editing the dictionaries directly. # TODO: Remove this terrible way of setting up different particle types. - swp.particle_name_underscores[6] = "extratype" - swp.particle_name_class[6] = "Extratype" - swp.particle_name_text[6] = "Extratype" + swp.particle_name_underscores["PartType7"] = "extratype" + swp.particle_name_class["PartType7"] = "Extratype" + swp.particle_name_text["PartType7"] = "Extratype" swmw.extratype = {"smoothing_length": "SmoothingLength", **swmw.shared} @@ -110,9 +110,9 @@ def test_write(): x.write("extra_test.hdf5") # Clean up these global variables we screwed around with... - swp.particle_name_underscores.pop(6) - swp.particle_name_class.pop(6) - swp.particle_name_text.pop(6) + swp.particle_name_underscores.pop("PartType7") + swp.particle_name_class.pop("PartType7") + swp.particle_name_text.pop("PartType7") def test_read(): @@ -120,9 +120,9 @@ def test_read(): Tests whether swiftsimio can handle a new particle type. Has a few asserts to check the data is read in correctly. """ - swp.particle_name_underscores[6] = "extratype" - swp.particle_name_class[6] = "Extratype" - swp.particle_name_text[6] = "Extratype" + swp.particle_name_underscores["PartType7"] = "extratype" + swp.particle_name_class["PartType7"] = "Extratype" + swp.particle_name_text["PartType7"] = "Extratype" swmw.extratype = {"smoothing_length": "SmoothingLength", **swmw.shared} @@ -136,6 +136,6 @@ def test_read(): os.remove("extra_test.hdf5") # Clean up these global variables we screwed around with... - swp.particle_name_underscores.pop(6) - swp.particle_name_class.pop(6) - swp.particle_name_text.pop(6) + swp.particle_name_underscores.pop("PartType7") + swp.particle_name_class.pop("PartType7") + swp.particle_name_text.pop("PartType7") diff --git a/tests/test_soap.py b/tests/test_soap.py new file mode 100644 index 00000000..0a6a4b0d --- /dev/null +++ b/tests/test_soap.py @@ -0,0 +1,14 @@ +""" +Tests that we can open SOAP files +""" + +from tests.helper import requires + +from swiftsimio import load + + +@requires("soap_example.hdf5") +def test_soap_can_load(filename): + data = load(filename) + + return diff --git a/tests/test_visualisation.py b/tests/test_visualisation.py index 79b59f3a..99a91790 100644 --- a/tests/test_visualisation.py +++ b/tests/test_visualisation.py @@ -249,6 +249,7 @@ def test_volume_parallel(): masses, hsml, resolution, + 1, 1.0, 1.0, 1.0, @@ -329,7 +330,7 @@ def test_render_outside_region(): slice_scatter_parallel(x, y, z, m, h, 0.2, resolution, 1.0, 1.0, 1.0) - volume_render.scatter_parallel(x, y, z, m, h, resolution, 1.0, 1.0, 1.0) + volume_render.scatter_parallel(x, y, z, m, h, resolution, 1, 1.0, 1.0, 1.0) @requires("cosmological_volume.hdf5") @@ -345,10 +346,10 @@ def test_comoving_versus_physical(filename): # conversion in this case img = func(data, resolution=256, project=None) assert img.comoving - assert img.cosmo_factor.expr == a ** aexp + assert (img.cosmo_factor.expr - a ** (aexp)).simplify() == 0 img = func(data, resolution=256, project="densities") assert img.comoving - assert img.cosmo_factor.expr == a ** (aexp - 3.0) + assert (img.cosmo_factor.expr - a ** (aexp - 3.0)).simplify() == 0 # try to mix comoving coordinates with a physical variable data.gas.densities.convert_to_physical() with pytest.raises(AttributeError, match="not compatible with comoving"): @@ -363,11 +364,11 @@ def test_comoving_versus_physical(filename): img = func(data, resolution=256, project="masses") # check that we get a physical result assert not img.comoving - assert img.cosmo_factor.expr == a ** aexp + assert (img.cosmo_factor.expr - a ** aexp).simplify() == 0 # densities are still compatible with physical img = func(data, resolution=256, project="densities") assert not img.comoving - assert img.cosmo_factor.expr == a ** (aexp - 3.0) + assert (img.cosmo_factor.expr - a ** (aexp - 3.0)).simplify() == 0 # now try again with comoving densities data.gas.densities.convert_to_comoving() with pytest.raises(AttributeError, match="not compatible with physical"):