Skip to content

Commit

Permalink
Merge pull request #254 from pynbody/kdtree-update
Browse files Browse the repository at this point in the history
Updates for latest pynbody beta
  • Loading branch information
apontzen authored May 6, 2024
2 parents 5e971b7 + 39385b1 commit 9fab28a
Show file tree
Hide file tree
Showing 15 changed files with 118 additions and 45 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/integration-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
CXX: g++-10
steps:
- name: Install Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -33,7 +33,7 @@ jobs:
sudo apt-get update -qq
sudo apt install gcc-10 g++-10
- uses: actions/checkout@v2
- uses: actions/checkout@v4

- name: Update python pip/setuptools/wheel
run: |
Expand All @@ -54,15 +54,15 @@ jobs:
working-directory: test_tutorial_build
run: export INTEGRATION_TESTING=1; bash build.sh

- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v4
with:
name: Tangos database
path: test_tutorial_build/data.db

- name: Verify database
working-directory: test_tutorial_build
run: |
wget https://zenodo.org/record/10825178/files/reference_database.db?download=1 -O reference_database.db -nv
wget https://zenodo.org/record/11122073/files/reference_database.db?download=1 -O reference_database.db -nv
tangos diff data.db reference_database.db --property-tolerance dm_density_profile 1e-2 0 --property-tolerance gas_map 1e-2 0 --property-tolerance gas_map_sideon 1e-2 0 --property-tolerance gas_map_faceon 1e-2 0
# --property-tolerance dm_density_profile here is because if a single particle crosses between bins
# (which seems to happen due to differing library versions), the profile can change by this much
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
'matplotlib >= 3.0.0', # for web interface
'tqdm >= 4.59.0',
'tblib >= 3.0.0',
'packaging >= 22.0'
]

tests_require = [
'pytest >= 5.0.0',
'webtest >= 2.0',
'pyquery >= 1.3.0',
'pynbody >= 2.0.0-beta.5',
'pynbody >= 2.0.0-beta.8',
'yt>=3.4.0',
'PyMySQL>=1.0.2',
]
Expand Down
6 changes: 2 additions & 4 deletions tangos/input_handlers/output_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ def match_objects(self, ts1, ts2, halo_min, halo_max, dm_only=False, threshold=0
if halo_max is None:
halo_max = f1.max_halos
halo_max = min((halo_max,f1.max_halos,f2.max_halos))
return_matches = [tuple()]
return_matches = {}
for i in range(1,halo_max+1):
if i>=halo_min:
return_matches.append(((i, 1.0),(i+1,0.05),))
else:
return_matches.append(tuple())
return_matches[i] = (((i, 1.0),(i+1,0.05),))
return return_matches

def load_timestep_without_caching(self, ts_extension, mode=None):
Expand Down
38 changes: 21 additions & 17 deletions tangos/input_handlers/pynbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict

import numpy as np
from packaging.version import Version

from ..util import proxy_object

Expand Down Expand Up @@ -34,9 +35,9 @@ class PynbodyInputHandler(finding.PatternBasedFileDiscovery, HandlerBase):
def __new__(cls, *args, **kwargs):
import pynbody as pynbody_local

min_version = "2.0.0-beta.5"
min_version = "2.0.0-beta.8"

if pynbody_local.__version__ < min_version:
if Version(pynbody_local.__version__) < Version(min_version):
raise ImportError(f"Using tangos with pynbody requires pynbody {min_version} or later")

global pynbody
Expand All @@ -46,11 +47,10 @@ def __new__(cls, *args, **kwargs):

@classmethod
def _construct_pynbody_halos(cls, sim, *args, **kwargs):
if cls.pynbody_halo_class_name is None:
return sim.halos(*args, **kwargs)
else:
halo_class = getattr(pynbody.halo, cls.pynbody_halo_class_name)
return halo_class(sim, *args, **kwargs)
if cls.pynbody_halo_class_name is not None:
kwargs['priority'] = [cls.pynbody_halo_class_name]

return sim.halos(*args, **kwargs)

def _is_able_to_load(self, ts_extension):
filepath = self._extension_to_filename(ts_extension)
Expand Down Expand Up @@ -247,16 +247,22 @@ def match_objects(self, ts1, ts2, halo_min, halo_max,
if halo_max is None:
halo_max = max(len(h2), len(h1))

return self.create_bridge(f1, f2).fuzzy_match_catalog(
halo_min,
halo_max,
threshold=threshold,
only_family=only_family,
groups_1=h1,
groups_2=h2,
matches = self.create_bridge(f1, f2).fuzzy_match_halos(
h1, h2, threshold=threshold, use_family=only_family,
**fuzzy_match_kwa,
)

del_keys = []
for k in matches:
if k < halo_min or k > halo_max:
del_keys.append(k)

for k in del_keys:
del matches[k]

return matches


@classmethod
def create_bridge(cls, f1, f2):
return f1.bridge(f2)
Expand Down Expand Up @@ -493,7 +499,6 @@ def create_bridge(self, f1, f2):

def match_objects(self, ts1, ts2, halo_min, halo_max, dm_only=True, threshold=0.005,
object_typetag="halo", output_handler_for_ts2=None):
import pynbody
if not dm_only:
logger.warning(
"`match_objects` was called with dm_only=%s, but %s only supports DM-only"
Expand All @@ -509,8 +514,7 @@ def match_objects(self, ts1, ts2, halo_min, halo_max, dm_only=True, threshold=0.
dm_only=dm_only,
threshold=threshold,
object_typetag=object_typetag,
output_handler_for_ts2=output_handler_for_ts2,
fuzzy_match_kwa={"use_family": pynbody.family.dm}
output_handler_for_ts2=output_handler_for_ts2
)


Expand Down
9 changes: 7 additions & 2 deletions tangos/parallel_tasks/backends/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import tblib.pickling_support

from ...log import logger

_slave = False
_rank = None
_size = None
Expand Down Expand Up @@ -188,9 +190,12 @@ def launch_functions(functions, args, capture_log=False):

for proc_i in processes:
if error:
#print "multiprocessing backend: send signal to",proc_i.pid
os.kill(proc_i.pid, signal.SIGTERM)
proc_i.join()
proc_i.join(timeout=1.0)
if proc_i.is_alive():
logger.warn("Process %d did not terminate in a timely way; sending SIGKILL", proc_i.pid)
os.kill(proc_i.pid, signal.SIGKILL)
proc_i.join()

if error is not None:
raise error.with_traceback(traceback)
Expand Down
11 changes: 6 additions & 5 deletions tangos/parallel_tasks/pynbody_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,23 @@ def process_async(self):
log.logger.debug("Tree built after %.2fs", time.time()-start)

class ReturnSharedTree(Message):
def __init__(self, leafsize, boxsize, kdnodes, offsets):
def __init__(self, leafsize, boxsize, kdnodes, offsets, kernel_id):
super().__init__()
self.leafsize = leafsize
self.boxsize = boxsize
self.kdnodes = kdnodes
self.offsets = offsets
self.kernel_id = kernel_id

def serialize(self):
return self.leafsize, self.boxsize
return self.leafsize, self.boxsize, self.kernel_id

@classmethod
def deserialize(cls, source, message):
leafsize, boxsize = message
leafsize, boxsize, kernel_id = message
kdnodes = transfer_array.receive_array(source, use_shared_memory=True)
offsets = transfer_array.receive_array(source, use_shared_memory=True)
obj = cls(leafsize, boxsize, kdnodes, offsets)
obj = cls(leafsize, boxsize, kdnodes, offsets, kernel_id)
obj.source = source
return obj

Expand All @@ -92,7 +93,7 @@ def send(self, destination):
transfer_array.send_array(self.offsets, destination, use_shared_memory=True)

def import_tree_into_local_view(self, sim):
sim.import_tree((self.leafsize, self.boxsize, self.kdnodes, self.offsets))
sim.import_tree((self.leafsize, self.boxsize, self.kdnodes, self.offsets, self.kernel_id))


class GetSharedTree(AsyncProcessedMessage):
Expand Down
2 changes: 2 additions & 0 deletions tangos/parallel_tasks/pynbody_server/transfer_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ def send_array(array: pynbody.array.SimArray, destination: int, use_shared_memor
if not hasattr(array, "_shared_fname"):
if isinstance(array, np.ndarray) and hasattr(array, "base") and hasattr(array.base, "_shared_fname"):
array._shared_fname = array.base._shared_fname # the strides/offset will point into the same memory
array._shared_owner = False # otherwise the memory will be deleted
else:
raise ValueError("Array %r has no shared memory information" % array)
_send_array_shared_memory(array, destination)

else:
_send_array_copy(array, destination)

Expand Down
3 changes: 2 additions & 1 deletion tangos/properties/pynbody/centring.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def _get_centre_and_max_radius(self, particle_data):
# ensure the box is wrapped correctly by centring on one of the particles:
temporary_centre = np.array(particle_data['pos'][0])
with _recenter(particle_data, temporary_centre):
center = pynbody.analysis.halo.shrink_sphere_center(particle_data, shrink_factor=0.8, velocity=False)
center = pynbody.analysis.halo.shrink_sphere_center(particle_data, shrink_factor=0.8,
particles_for_velocity=0) # i.e., don't calc velocity

# mark_timer can be used to track timing of parts of the calculation. The results of these timings
# appears in the tangos_writer logs:
Expand Down
10 changes: 5 additions & 5 deletions tangos/properties/pynbody/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def calculate(self, particle_data, properties):
import pynbody.analysis.angmom as angmom
size = self.plot_extent()
g, s = self._render_gas(particle_data, size), self._render_stars(particle_data, size)
with angmom.sideon(particle_data, return_transform=True,
cen_size=self.get_simulation_property("approx_resolution_kpc", 0.1)*10.):
with angmom.sideon(particle_data):
g_side, s_side = self._render_gas(particle_data, size), self._render_stars(particle_data, size)
with particle_data.rotate_x(90):
g_face, s_face = self._render_gas(particle_data, size), self._render_stars(particle_data, size)
Expand All @@ -33,8 +32,8 @@ def calculate(self, particle_data, properties):

def _render_projected(self, f, size):
import pynbody.plot
im = pynbody.plot.sph.image(f[pynbody.filt.BandPass(
'z', -size / 2, size / 2)], 'rho', size, units="Msol kpc^-2", noplot=True)
im = pynbody.plot.sph.image(f, 'rho', size, units="Msol kpc^-2", noplot=True, restrict_depth=True,
resolution=500)
return im

def _render_gas(self, f, size):
Expand All @@ -47,6 +46,7 @@ def _render_stars(self, f, size):
import pynbody.plot
if len(f.st)>0:
return pynbody.plot.stars.render(f.st[pynbody.filt.HighPass('tform',0) & pynbody.filt.BandPass('z', -size / 2, size / 2)],
width=size, plot=False, ret_im=True, mag_range=(16,22))
width=size, noplot=True, return_image=True, mag_range=(16,22),
resolution=500)
else:
return None
1 change: 1 addition & 0 deletions tangos/scripts/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def diff(options):
differ = db_diff.TangosDbDiff(options.uri1, options.uri2, ignore_keys=options.ignore_value_of)
if options.property_tolerance is not None:
for k, rtol, atol in options.property_tolerance:
if k == '.': k = None
differ.set_tolerance(k, float(rtol), float(atol))

if options.simulation:
Expand Down
19 changes: 19 additions & 0 deletions tangos/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,25 @@ def init_blank_db_for_testing(**init_kwargs):

return db_is_blank

@contextlib.contextmanager
def blank_db_for_testing(**kwargs):
"""Context manager to create a blank database, then on exit restores the previous database.
For arguments, see init_blank_db_for_testing.
"""

old_engine = core.get_default_engine()
old_session = core.get_default_session()
old_session_class = core.Session
core._internal_session = None
core._engine = None
init_blank_db_for_testing(**kwargs)
yield
core.close_db()
core._engine = old_engine
core.Session = old_session_class
core.set_default_session(old_session)

def using_parallel_tasks(fn_or_num_processes, num_processes = 2):
"""Decorator for tests, using parallel_tasks multiprocessing backend to launch
Expand Down
2 changes: 1 addition & 1 deletion tangos/tools/crosslink.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def need_crosslink_ts(self, ts1, ts2, object_typecode=0):
def create_db_objects_from_catalog(self, cat, finder_id_to_halos_1, finder_id_to_halos_2, same_d_id):
items = []
missing_db_object = 0
for i, possibilities in enumerate(cat):
for i, possibilities in cat.items():
h1 = finder_id_to_halos_1.get(i, None)
for cat_i, weight in possibilities:
h2 = finder_id_to_halos_2.get(cat_i, None)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pynbody_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_get_array():
@using_parallel_tasks(3)
def test_get_shared_array():
if pt.backend.rank()==1:
shared_array = pynbody.array._array_factory((10,), int, True, True)
shared_array = pynbody.array.array_factory((10,), int, True, True)
shared_array[:] = np.arange(0,10)
pt.pynbody_server.transfer_array.send_array(shared_array, 2, True)
assert shared_array[2]==2
Expand All @@ -74,7 +74,7 @@ def test_get_shared_array():
def test_get_shared_array_slice():
"""Like test_get_shared_array, but with a slice"""
if pt.backend.rank()==1:
shared_array = pynbody.array._array_factory((10,), int, True, True)
shared_array = pynbody.array.array_factory((10,), int, True, True)
shared_array[:] = np.arange(0,10)
pt.pynbody_server.transfer_array.send_array(shared_array[1:7:2], 2, True)
assert shared_array[3] == 3
Expand Down
36 changes: 34 additions & 2 deletions tests/test_simulation_outputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gc
import os

import numpy as np
import numpy.testing as npt
import pynbody

Expand Down Expand Up @@ -41,14 +42,14 @@ def test_handler_properties_quicker_flag():
output_manager.quicker = True
prop = output_manager.get_properties()
npt.assert_allclose(prop['approx_resolution_kpc'], 33.590757, rtol=1e-5)
npt.assert_allclose(prop['approx_resolution_Msol'], 2.412033e+10, rtol=1e-5)
npt.assert_allclose(prop['approx_resolution_Msol'], 2.412033e+10, rtol=1e-4)

def test_enumerate():
assert set(output_manager.enumerate_timestep_extensions())=={"tiny.000640","tiny.000832"}

def test_timestep_properties():
props = output_manager.get_timestep_properties("tiny.000640")
npt.assert_allclose(props['time_gyr'],2.17328504831)
npt.assert_allclose(props['time_gyr'],2.173236752357068)
npt.assert_allclose(props['redshift'], 2.96382819878)

def test_enumerate_objects():
Expand Down Expand Up @@ -181,3 +182,34 @@ def test_load_region_uses_cache():

assert id(region1a) == id(region1b)
assert id(region1a) != id(region2)


class DummyHaloClass(pynbody.halo.number_array.HaloNumberCatalogue):
def __init__(self, sim):
sim['grp'] = np.empty(len(sim), dtype=int)
sim['grp'].fill(-1)
sim['grp'][:1000] = 0
sim['grp'][1000:2000] = 1
super().__init__(sim, 'grp', ignore=-1)

@classmethod
def _can_load(cls, sim, arr_name='grp'):
return True


class DummyPynbodyHandler(pynbody_outputs.ChangaInputHandler):
pynbody_halo_class_name = "DummyHaloClass"

def _can_enumerate_objects_from_statfile(self, ts_extension, object_typetag):
return False # test requires enumerating halos via pynbody, to verify right halo class is used

def test_halo_class_priority():
with testing.blank_db_for_testing(testing_db_name="test_halo_class_priority", erase_if_exists=True):
handler = DummyPynbodyHandler("test_tipsy")

with log.LogCapturer():
add.SimulationAdderUpdater(handler).scan_simulation_and_add_all_descendants()
h = db.get_halo("test_tipsy/tiny.000640/1").load()
assert (h.get_index_list(h.ancestor) == np.arange(1000)).all()
h = db.get_halo("test_tipsy/tiny.000640/2").load()
assert (h.get_index_list(h.ancestor) == np.arange(1000, 2000)).all()
Loading

0 comments on commit 9fab28a

Please sign in to comment.