Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rockstar+RAMSES functionality #238

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 179 additions & 2 deletions tangos/input_handlers/yt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import glob
import os
from typing import List, Tuple

import numpy as np

Expand Down Expand Up @@ -46,8 +47,90 @@ def load_object(self, ts_extension, finder_id, finder_offset, object_typetag='ha
def load_tracked_region(self, ts_extension, track_data, mode=None):
raise NotImplementedError("Tracked regions not implemented for yt")

def match_objects(self, ts1, ts2, halo_min, halo_max, dm_only=False, threshold=0.005, object_typetag='halo'):
raise NotImplementedError("Matching halos still needs to be implemented for yt")
def match_objects(
self,
ts1: str,
ts2: str,
halo_min: int,
halo_max: int,
dm_only: bool=False,
threshold:float =0.005,
object_typetag: str="halo",
output_handler_for_ts2=None,
fuzzy_match_kwa={},
) -> Tuple[List[int], List[List[Tuple[int, int]]]]:
if output_handler_for_ts2 is None:
raise NotImplementedError(
"Alternative output_handler_for_ts2 is not implemented for yt."
)
if fuzzy_match_kwa:
raise NotImplementedError(
"Fuzzy matching is not implemented for yt."
)

if halo_min is None:
halo_min = 0
if halo_max is None:
halo_max = np.inf

h1, _ = self._load_halo_cat(ts1, object_typetag)
if output_handler_for_ts2 is None:
h2, _ = self._load_halo_cat(ts2, object_typetag)
else:
h2, _ = output_handler_for_ts2._load_halo_cat(ts2, object_typetag)

# Compute the sets of particle ids in each halo
members2 = np.concatenate([
h2.halo("halos", i).member_ids
for i in h2.r["particle_identifier"].astype(int)
if halo_min <= i <= halo_max
])

members2halo2 = np.concatenate([
np.repeat(itangos, len(h2.halo("halos", irockstar).member_ids))
for itangos, irockstar in enumerate(h2.r["particle_identifier"].astype(int))
if halo_min <= itangos <= halo_max
])

# Compute size of intersection of all sets in h1 with those in h2
cat = []
for ihalo1_tangos, ihalo1_rockstar in enumerate(h1.r["particle_identifier"].astype(int)):
if not (halo_min <= ihalo1_tangos <= halo_max):
continue

ids1 = h1.halo("halos", ihalo1_rockstar).member_ids
#mask = np.in1d(ids1, members2)
mask = np.in1d(members2, ids1)
if mask.sum() == 0:
cat.append([])
continue

# Get the halo ids of the particles in the other snapshot
idhalo2 = members2halo2[mask]

# Count the number of particles in each halo
idhalo2, counts = np.unique(idhalo2, return_counts=True)
weights = counts / len(ids1)

# Sort the links by decreasing number of particles
_order = np.argsort(weights)[::-1]
idhalo2 = idhalo2[_order]
weights = weights[_order]

# Keep only the links with a significant number of particles
mask = weights > threshold
if mask.sum() == 0:
cat.append(
[]
)
continue

idhalo2 = idhalo2[mask]
weights = weights[mask]

cat.append(list(zip(idhalo2, weights)))

return cat

def enumerate_objects(self, ts_extension, object_typetag="halo", min_halo_particles=config.min_halo_particles):
if object_typetag!="halo":
Expand Down Expand Up @@ -85,6 +168,100 @@ def _load_halo_cat_without_caching(self, ts_extension, snapshot_file):
def get_properties(self):
return {}

def available_object_property_names_for_timestep(self, ts_extension, object_typetag):
h, _ = self._load_halo_cat(ts_extension, object_typetag)
return [fn for ft, fn in h.field_list if ft == "halos"]


def iterate_object_properties_for_timestep(self, ts_extension, object_typetag, property_names):
try:
yield from super().iterate_object_properties_for_timestep(ts_extension, object_typetag, property_names)
return
except OSError:
pass
h, ad = self._load_halo_cat(ts_extension, object_typetag)

props_with_ftype = [
("halos", name) for name in property_names
]

ad.get_data(props_with_ftype)

Nhalo = len(ad["halos", "particle_identifier"])
yield from zip(range(Nhalo), range(Nhalo), *(
ad[_] for _ in props_with_ftype
))


class YtRamsesRockstarInputHandler(YtInputHandler):
patterns = ["output_0????"]
auxiliary_file_patterns = ["halos_*.bin"]

def load_timestep_without_caching(self, ts_extension, mode=None):
if mode is not None:
raise ValueError("Custom load modes are not supported with yt")
return yt.load(self._extension_to_filename(ts_extension))

def _load_halo_cat_without_caching(self, ts_extension, snapshot_file):
# Check whether datasets.txt exists (i.e., if rockstar was run with yt)
if os.path.exists(self._extension_to_filename("datasets.txt")):
fnum = read_datasets(self._extension_to_filename(""),ts_extension)
else: # otherwise, assume a one-to-one correspondence
overdir = self._extension_to_filename("")
snapfiles = glob.glob(overdir+ts_extension[:2]+len(ts_extension[2:].split('/')[0])*'?')
rockfiles = glob.glob(overdir+"out_*.list")
sortind = np.array([int(rname.split('.')[0].split('_')[-1]) for rname in rockfiles])
sortord = np.argsort(sortind)
snapfiles.sort()
rockfiles = np.array(rockfiles)[sortord]
Comment on lines +210 to +216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know where this piece of code comes from? It's really hard to make sense out of it and I'm wondering if we could simplify it somehow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the YtRamsesRockstarInputHandler was first created by taking the YtEnzoRockstarInputHandler and adapting it to ramses datasets. This function is taken from there.

timestep_ind = np.argwhere(np.array([s.split('/')[-1] for s in snapfiles])==ts_extension.split('/')[0])[0]
fnum = int(rockfiles[timestep_ind][0].split('.')[0].split('_')[-1])
cat = yt.load(self._extension_to_filename(f"halos_{fnum}.0.bin"))
cat_data = cat.all_data()
# Check whether rockstar was run with Behroozi's distribution or Wise's
if np.any(cat_data["halos","particle_identifier"]<0):
cat = yt.load(self._extension_to_filename(f"halos_{fnum}.0.bin"))
cat.parameters['format_revision'] = 2 #
cat_data = cat.all_data()
return cat, cat_data

def enumerate_objects(self, ts_extension, object_typetag="halo", min_halo_particles=config.min_halo_particles):
if object_typetag!="halo":
return
if self._can_enumerate_objects_from_statfile(ts_extension, object_typetag):
yield from self._enumerate_objects_from_statfile(ts_extension, object_typetag)
AnatoleStorck marked this conversation as resolved.
Show resolved Hide resolved
else:
logger.warn("No halo statistics file found for timestep %r", ts_extension)
logger.warn(" => enumerating %ss directly using yt", object_typetag)

_catalogue, catalogue_data = self._load_halo_cat(ts_extension, object_typetag)
num_objects = len(catalogue_data["halos", "virial_radius"])

# Make sure this isn't garbage collected
_f = self.load_timestep(ts_extension)

for i in range(num_objects):
obj = self.load_object(
ts_extension,
int(catalogue_data["halos","particle_identifier"][i]),
i,
object_typetag
)
NDM = len(obj["DM", "particle_ones"])
NGas = 0 # cells
NStar = len(obj["star", "particle_ones"])
if NDM + NGas + NStar> min_halo_particles:
yield i, int(catalogue_data["halos","particle_identifier"][i]), NDM, NStar, NGas

def load_object(self, ts_extension, finder_id, finder_offset, object_typetag='halo', mode=None):
f = self.load_timestep(ts_extension, mode)
cat, cat_dat = self._load_halo_cat(ts_extension, object_typetag)
index = np.argwhere(cat_dat["halos", "particle_identifier"] == finder_id)[0, 0]
center = cat_dat["halos","particle_position"][index]
center += f.domain_left_edge - cat.domain_left_edge
radius = cat_dat["halos", "virial_radius"][index]
return f.sphere(center, radius)


class YtChangaAHFInputHandler(YtInputHandler):
patterns = ["*.00???", "*.00????"]
Expand Down
5 changes: 2 additions & 3 deletions tangos/util/read_datasets_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ def read_datasets(basedir,filename):
if os.path.exists(os.path.join(basedir, "datasets.txt")):
with open(os.path.join(basedir, "datasets.txt")) as f:
for l in f:
if l.split()[0].endswith(filename):
if filename in l.split()[0]:
return int(l.split()[1])
else:
raise AssertionError("Unable to open datasets.txt")
raise AssertionError("Unable to open datasets.txt")
Loading