Skip to content

Commit

Permalink
Merge pull request #259 from pynbody/hbtplus
Browse files Browse the repository at this point in the history
HBT+ support
  • Loading branch information
apontzen authored Jun 21, 2024
2 parents 3d888c4 + 678f7d8 commit 6ef63d8
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
- name: Verify database
working-directory: test_tutorial_build
run: |
wget https://zenodo.org/record/11122073/files/reference_database.db?download=1 -O reference_database.db -nv
wget https://zenodo.org/record/12192344/files/reference_database.db?download=1 -O reference_database.db -nv
tangos diff data.db reference_database.db --property-tolerance dm_density_profile 1e-2 0 --property-tolerance gas_map 1e-2 0 --property-tolerance gas_map_sideon 1e-2 0 --property-tolerance gas_map_faceon 1e-2 0
# --property-tolerance dm_density_profile here is because if a single particle crosses between bins
# (which seems to happen due to differing library versions), the profile can change by this much
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def get_version(rel_path):
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Astronomy",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down
23 changes: 20 additions & 3 deletions tangos/input_handlers/finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ class PatternBasedFileDiscovery:
# will not be used for finding the snapshots, but if files matching these patterns are present
# the handler is more likely to be selected automatically. See e.g. GadgetRockstarInputHandler.

enable_autoselect = True # set to False if user must select manually via --input-handler when adding

@classmethod
def best_matching_handler(cls, basename):
handler_names = []
handler_timestep_lengths = []
handler_scores = []
base = os.path.join(config.base, basename)

# Add all subclasses, sub-subclasses, sub-subclasses, ...
Expand All @@ -57,8 +59,23 @@ def best_matching_handler(cls, basename):
timesteps_detected = find(basename=base + "/", patterns=possible_handler.patterns)
other_files_detected = find(basename=base+"/", patterns=possible_handler.auxiliary_file_patterns)
handler_names.append(possible_handler)
handler_timestep_lengths.append(len(timesteps_detected)+len(other_files_detected))
best_handler = handler_names[np.argmax(handler_timestep_lengths)]
handler_scores.append(len(timesteps_detected)+len(other_files_detected))

handler_scores = np.array(handler_scores)

max_length_mask = (handler_scores == np.max(handler_scores))
if np.sum(max_length_mask) > 1:
logger.warning("Multiple handlers have the same score. Adding specialisation to the decision process.")
# work out how many subclasses below HandlerBase each handler is:
handler_subclass_depths = np.array([len(handler.__mro__) for handler in handler_names])
handler_scores += handler_subclass_depths

max_length_mask = (handler_scores == np.max(handler_scores))
if np.sum(max_length_mask) > 1:
logger.warning("Multiple handlers still have the same score, after adding specialisation score. Choosing the first one.")


best_handler = handler_names[np.argmax(handler_scores)]
logger.debug("Detected best handler (of %d) is %s",len(all_possible_handlers), best_handler)

return best_handler
Expand Down
92 changes: 85 additions & 7 deletions tangos/input_handlers/pynbody.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import glob
import os
import os.path
import pathlib
import time
import weakref
from collections import defaultdict
Expand All @@ -12,10 +15,15 @@

pynbody = None # deferred import; occurs when a PynbodyInputHandler is constructed

from typing import TYPE_CHECKING

from .. import config
from ..log import logger
from . import HandlerBase, finding

if TYPE_CHECKING:
import pynbody

_loaded_halocats = {}

class DummyTimeStep:
Expand Down Expand Up @@ -77,7 +85,7 @@ def get_timestep_properties(self, ts_extension):
'available': True}
return results

def load_timestep_without_caching(self, ts_extension, mode=None):
def load_timestep_without_caching(self, ts_extension, mode=None) -> pynbody.snapshot.simsnap.SimSnap:
if mode=='partial' or mode is None:
f = pynbody.load(self._extension_to_filename(ts_extension))
f.physical_units()
Expand All @@ -92,7 +100,7 @@ def load_timestep_without_caching(self, ts_extension, mode=None):
def _build_kdtree(self, timestep, mode):
timestep.build_tree()

def load_region(self, ts_extension, region_specification, mode=None, expected_number_of_queries=None):
def load_region(self, ts_extension, region_specification, mode=None, expected_number_of_queries=None) -> pynbody.snapshot.simsnap.SimSnap:
timestep = self.load_timestep(ts_extension, mode)

timestep._tangos_cached_regions = getattr(timestep, '_tangos_cached_regions', {})
Expand Down Expand Up @@ -130,7 +138,7 @@ def _load_region_uncached(self, timestep, ts_extension, region_specification, mo
else:
raise NotImplementedError("Load mode %r is not implemented"%mode)

def load_object(self, ts_extension, finder_id, finder_offset, object_typetag='halo', mode=None):
def load_object(self, ts_extension, finder_id, finder_offset, object_typetag='halo', mode=None) -> pynbody.snapshot.simsnap.SimSnap:
if mode=='partial':
h = self.get_catalogue(ts_extension, object_typetag)
h_file = h.load_copy(finder_id)
Expand Down Expand Up @@ -166,7 +174,7 @@ def load_object(self, ts_extension, finder_id, finder_offset, object_typetag='ha
else:
raise NotImplementedError("Load mode %r is not implemented"%mode)

def load_tracked_region(self, ts_extension, track_data, mode=None):
def load_tracked_region(self, ts_extension, track_data, mode=None) -> pynbody.snapshot.simsnap.SimSnap:
f = self.load_timestep(ts_extension, mode)
indices = self._get_indices_for_snapshot(f, track_data)
if mode=='partial':
Expand Down Expand Up @@ -206,7 +214,7 @@ def _get_indices_for_snapshot(self, f, track_data):



def get_catalogue(self, ts_extension, object_typetag):
def get_catalogue(self, ts_extension, object_typetag) -> pynbody.halo.HaloCatalogue:
if object_typetag!= 'halo':
raise ValueError("Unknown object type %r" % object_typetag)
f = self.load_timestep(ts_extension)
Expand Down Expand Up @@ -274,9 +282,10 @@ def enumerate_objects(self, ts_extension, object_typetag="halo", min_halo_partic
logger.warning("No %s statistics file found for timestep %r", object_typetag, ts_extension)

snapshot_keep_alive = self.load_timestep(ts_extension)

try:
h = self.get_catalogue(ts_extension, object_typetag)
except:
except Exception as e:
logger.warning("Unable to read %ss using pynbody; assuming step has none", object_typetag)
return

Expand Down Expand Up @@ -367,7 +376,7 @@ class GadgetSubfindInputHandler(PynbodyInputHandler):

def _is_able_to_load(self, filepath):
try:
f = eval(self.snap_class_name)(filepath)
f = eval(self.snap_class_name)(pathlib.Path(filepath))
h = f.halos()
if isinstance(h, eval(self.catalogue_class_name)):
return True
Expand Down Expand Up @@ -473,6 +482,71 @@ def _transform_extension(self, extension_name):
else:
return extension_name

class Gadget4HBTPlusInputHandler(Gadget4HDFSubfindInputHandler):
auxiliary_file_patterns = ["SubSnap_???.hdf5", "SubSnap_???.0.hdf5"]
catalogue_class_name = "pynbody.halo.hbtplus.HBTPlusCatalogueWithGroups"
_sub_parent_names = [] # although HBTplus stores this as 'HostHaloId', pynbody already translates it to 'parent'
_property_prefix_for_type = {'group': 'Group'}

@classmethod
def _construct_pynbody_halos(cls, sim, *args, **kwargs):
if kwargs.pop('subs', False):
h = pynbody.halo.hbtplus.HBTPlusCatalogue(sim)
h.load_all()
return h
else:
return super()._construct_pynbody_halos(sim, *args, **kwargs)

def _construct_group_cat(self, ts_extension):
sim = self.load_timestep(ts_extension)
groups = super()._construct_pynbody_halos(sim, subs=False)
# can't call super()._construct_group_cat because that verifies the type of the catalogue, which is wrong
# until we do the modification below

hbt_halos = self._construct_pynbody_halos(sim, subs=True)
return hbt_halos.with_groups_from(groups)

def _is_able_to_load(self, filepath):
try:
f = eval(self.snap_class_name)(pathlib.Path(filepath))
h = pynbody.halo.hbtplus.HBTPlusCatalogue(f)
return True
except (OSError, RuntimeError):
return False

def match_objects(self, ts1, ts2, halo_min, halo_max,
dm_only=False, threshold=0.005, object_typetag='halo',
output_handler_for_ts2=None,
fuzzy_match_kwa={}):

if object_typetag=='halo' and output_handler_for_ts2 is self:
# specialised case
f1 = self.load_timestep(ts1)
h1 = self.get_catalogue(ts1, 'halo')
f2 = self.load_timestep(ts2)
h2 = self.get_catalogue(ts2, 'halo')

id1_to_number1 = h1.number_mapper.index_to_number
id2_to_number2 = h2.number_mapper.index_to_number

props1 = h1.get_properties_all_halos()
props2 = h2.get_properties_all_halos()

id1_to_trackid = props1['TrackId']
id2_to_trackid = props2['TrackId']

trackid_to_id2 = {trackid: id2 for id2,trackid in enumerate(id2_to_trackid)}
number1_to_number2 = {id1_to_number1(id1): [(id2_to_number2(trackid_to_id2[trackid]), 1.0)]
if trackid in trackid_to_id2 else []
for id1, trackid in enumerate(id1_to_trackid)}

return number1_to_number2


else:
return super().match_objects(ts1, ts2, halo_min, halo_max, dm_only, threshold, object_typetag,
output_handler_for_ts2, fuzzy_match_kwa)



class GadgetRockstarInputHandler(PynbodyInputHandler):
Expand Down Expand Up @@ -751,11 +825,15 @@ class ChangaIgnoreIDLInputHandler(ChangaInputHandler):
pynbody_halo_class_name = "AHFCatalogue"
halo_stat_file_class_name = "AHFStatFile"

enable_autoselect = False

class ChangaUseIDLInputHandler(ChangaInputHandler):
pynbody_halo_class_name = "AmigaGrpCatalogue"
halo_stat_file_class_name = "AmigaIDLStatFile"
auxiliary_file_patterns = ["*.amiga.grp"]

enable_autoselect = False

from . import caterpillar, eagle, ramsesHOP

RamsesHOPInputHandler = ramsesHOP.RamsesHOPInputHandler
2 changes: 1 addition & 1 deletion tangos/input_handlers/ramsesHOP.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def enumerate_objects(self, ts_extension, object_typetag="halo", min_halo_partic
try:
h = self.get_catalogue(ts_extension, object_typetag)
h._index_parent = False
except:
except Exception as e:
logger.warning("Unable to read %ss using pynbody; assuming step has none", object_typetag)
return

Expand Down
5 changes: 3 additions & 2 deletions tangos/util/subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

def find_subclasses(cls: type) -> List[type]:
subclasses = cls.__subclasses__()
all_subclasses = []
all_subclasses = [cls]
while subclasses:
subclass = subclasses.pop()
subclasses.extend(subclass.__subclasses__())
all_subclasses.append(subclass)
if subclass.enable_autoselect:
all_subclasses.append(subclass)

return all_subclasses
13 changes: 12 additions & 1 deletion test_tutorial_build/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ get_tutorial_data() {
wget -nv -O - https://zenodo.org/record/5155467/files/$1.tar.gz?download=1 | tar -xzv
else
echo "Downloading mini tutorial data for $1"
wget -nv -O - https://zenodo.org/records/10825178/files/$1.tar.gz?download=1 | tar -xzv
wget -nv -O - https://zenodo.org/records/12189455/files/$1.tar.gz?download=1 | tar -xzv
fi
fi
}
Expand All @@ -33,6 +33,16 @@ build_gadget4() {
$MPI tangos $MPIBACKEND write dm_density_profile --with-prerequisites --include-only="NDM()>5000" --type=halo --for tutorial_gadget4
}

build_gadget4_hbtplus() {
get_tutorial_data tutorial_gadget4
get_tutorial_data tutorial_gadget4_hbtplus
tangos add tutorial_gadget4_hbtplus
tangos link --for tutorial_gadget4_hbtplus
tangos import-properties --for tutorial_gadget4_hbtplus
tangos import-properties --for tutorial_gadget4_hbtplus --type group
$MPI tangos $MPIBACKEND write dm_density_profile --with-prerequisites --include-only="NDM()>5000" --type=halo --for tutorial_gadget4_hbtplus
}

build_gadget_subfind() {
get_tutorial_data tutorial_gadget
tangos add tutorial_gadget --min-particles 100
Expand Down Expand Up @@ -128,6 +138,7 @@ clearup_files tutorial_gadget_rockstar
build_ramses
clearup_files tutorial_ramses
build_gadget4
build_gadget4_hbtplus
clearup_files tutorial_gadget4
build_enzo_yt
clearup_files enzo.tinycosmo
28 changes: 28 additions & 0 deletions tests/test_simulation_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,31 @@ def test_halo_class_priority():
assert (h.get_index_list(h.ancestor) == np.arange(1000)).all()
h = db.get_halo("test_tipsy/tiny.000640/2").load()
assert (h.get_index_list(h.ancestor) == np.arange(1000, 2000)).all()

def test_input_handler_priority():
handler = pynbody_outputs.ChangaInputHandler.best_matching_handler("test_tipsy")
assert handler is DummyPynbodyHandler


# test that if we specialise further, we get the more specialised handler
class DummyPynbodyHandler2(DummyPynbodyHandler):
pass

handler = pynbody_outputs.ChangaInputHandler.best_matching_handler("test_tipsy")
assert handler is DummyPynbodyHandler2


# specialise still further, but disable autoselect so that we don't get this handler returned
class DummyPynbodyHandler3(DummyPynbodyHandler2):
enable_autoselect = False
pass

handler = pynbody_outputs.ChangaInputHandler.best_matching_handler("test_tipsy")
assert handler is DummyPynbodyHandler2

handler = DummyPynbodyHandler2.best_matching_handler("test_tipsy")
assert handler is DummyPynbodyHandler2

# if we select handler 3 manually, we should get it
handler = DummyPynbodyHandler3.best_matching_handler("test_tipsy")
assert handler is DummyPynbodyHandler3

0 comments on commit 6ef63d8

Please sign in to comment.