diff --git a/setup.py b/setup.py index 2ef91b85..e3ae92bc 100644 --- a/setup.py +++ b/setup.py @@ -22,14 +22,15 @@ 'scipy >= 0.14.0', 'more_itertools >= 8.0.0', 'matplotlib >= 3.0.0', # for web interface - 'tqdm >= 4.59.0' + 'tqdm >= 4.59.0', + 'tblib >= 3.0.0', ] tests_require = [ 'pytest >= 5.0.0', 'webtest >= 2.0', 'pyquery >= 1.3.0', - 'pynbody >= 1.3.2', + 'pynbody >= 1.5.0', 'yt>=3.4.0', 'PyMySQL>=1.0.2', ] diff --git a/tangos/__init__.py b/tangos/__init__.py index 7c36d019..e6b862f2 100644 --- a/tangos/__init__.py +++ b/tangos/__init__.py @@ -20,4 +20,4 @@ from .core import * from .query import * -__version__ = '1.8.1' +__version__ = '1.9.0' diff --git a/tangos/config.py b/tangos/config.py index de8eed3a..ad79dc85 100644 --- a/tangos/config.py +++ b/tangos/config.py @@ -88,9 +88,14 @@ diff_default_rtol = 1e-3 -# Database import: how many rows to copy at a time, and when to issue a commit +# Database import: how many rows to copy at a time DB_IMPORT_CHUNK_SIZE = 10 -DB_IMPORT_COMMIT_AFTER_CHUNKS = 500 + +# Property writer: longest to wait before trying to commit properties (even if in middle of timestep) +PROPERTY_WRITER_MAXIMUM_TIME_BETWEEN_COMMITS = 600 # seconds + +# Property writer: don't bother committing even if a timestep is finished if this time hasn't elapsed: +PROPERTY_WRITER_MINIMUM_TIME_BETWEEN_COMMITS = 300 # seconds try: from .config_local import * diff --git a/tangos/input_handlers/pynbody.py b/tangos/input_handlers/pynbody.py index 98092765..6f20782b 100644 --- a/tangos/input_handlers/pynbody.py +++ b/tangos/input_handlers/pynbody.py @@ -4,10 +4,8 @@ import time import weakref from collections import defaultdict -from itertools import chain import numpy as np -from more_itertools import always_iterable from ..util import proxy_object @@ -36,8 +34,8 @@ class PynbodyInputHandler(finding.PatternBasedFileDiscovery, HandlerBase): def __new__(cls, *args, **kwargs): import pynbody as pynbody_local - if pynbody_local.__version__<"1.2.2": - raise ImportError("Using tangos with pynbody requires pynbody 1.2.2 or later") + if pynbody_local.__version__<"1.5.0": + raise ImportError("Using tangos with pynbody requires pynbody 1.5.0 or later") global pynbody pynbody = pynbody_local @@ -82,19 +80,22 @@ def load_timestep_without_caching(self, ts_extension, mode=None): f = pynbody.load(self._extension_to_filename(ts_extension)) f.physical_units() return f - elif mode=='server' or mode=='server-partial': + elif mode in ('server', 'server-partial', 'server-shared-mem'): from ..parallel_tasks import pynbody_server as ps - return ps.RemoteSnapshotConnection(self,ts_extension) + return ps.RemoteSnapshotConnection(self, ts_extension, + shared_mem = (mode == 'server-shared-mem')) else: raise NotImplementedError("Load mode %r is not implemented"%mode) def load_region(self, ts_extension, region_specification, mode=None): - if mode is None: + if mode is None or mode=='server': timestep = self.load_timestep(ts_extension, mode) return timestep[region_specification] - elif mode=='server': + elif mode=='server-shared-mem': + from ..parallel_tasks import pynbody_server as ps timestep = self.load_timestep(ts_extension, mode) - return timestep.get_view(region_specification) + simsnap = timestep.shared_mem_view + return simsnap[region_specification].get_copy_on_access_simsnap() elif mode=='server-partial': timestep = self.load_timestep(ts_extension, mode) view = timestep.get_view(region_specification) @@ -114,19 +115,29 @@ def load_object(self, ts_extension, finder_id, finder_offset, object_typetag='ha h_file = h.load_copy(finder_offset) h_file.physical_units() return h_file - elif mode=='server': + elif mode=='server' : timestep = self.load_timestep(ts_extension, mode) from ..parallel_tasks import pynbody_server as ps - return timestep.get_view(ps.ObjectSpecification(finder_id, finder_offset, object_typetag)) + return timestep.get_view( + ps.snapshot_queue.ObjectSpecification(finder_id, finder_offset, object_typetag)) elif mode=='server-partial': timestep = self.load_timestep(ts_extension, mode) from ..parallel_tasks import pynbody_server as ps - view = timestep.get_view(ps.ObjectSpecification(finder_id, finder_offset, object_typetag)) + view = timestep.get_view( + ps.snapshot_queue.ObjectSpecification(finder_id, finder_offset, object_typetag)) load_index = view['remote-index-list'] logger.info("Partial load %r, taking %d particles", ts_extension, len(load_index)) f = pynbody.load(self._extension_to_filename(ts_extension), take=load_index) f.physical_units() return f + elif mode=='server-shared-mem': + timestep = self.load_timestep(ts_extension, mode) + from ..parallel_tasks import pynbody_server as ps + view = timestep.get_view( + ps.snapshot_queue.ObjectSpecification(finder_id, finder_offset, object_typetag)) + view_index = view['remote-index-list'] + return timestep.shared_mem_view[view_index].get_copy_on_access_simsnap() + elif mode is None: h = self._construct_halo_cat(ts_extension, object_typetag) return h[finder_offset] diff --git a/tangos/parallel_tasks/__init__.py b/tangos/parallel_tasks/__init__.py index 16b14c87..4e2672b6 100644 --- a/tangos/parallel_tasks/__init__.py +++ b/tangos/parallel_tasks/__init__.py @@ -139,12 +139,12 @@ def _server_thread(): obj.process() - log.logger.info("Terminating manager") + log.logger.info("Terminating manager process") def _shutdown_parallelism(): global backend - log.logger.info("Shutting down parallel_tasks") + log.logger.info("Terminating worker process") backend.barrier() backend.finalize() backend = None diff --git a/tangos/parallel_tasks/backends/multiprocessing.py b/tangos/parallel_tasks/backends/multiprocessing.py index c01a60e3..cc4c2798 100644 --- a/tangos/parallel_tasks/backends/multiprocessing.py +++ b/tangos/parallel_tasks/backends/multiprocessing.py @@ -3,6 +3,9 @@ import signal import sys import threading +from typing import Optional + +import tblib.pickling_support _slave = False _rank = None @@ -86,6 +89,8 @@ def finalize(): _pipe.send("finalize") def launch_wrapper(target_fn, rank_in, size_in, pipe_in, args_in): + tblib.pickling_support.install() + global _slave, _rank, _size, _pipe, _recv_lock _rank = rank_in _size = size_in @@ -103,10 +108,12 @@ def launch_wrapper(target_fn, rank_in, size_in, pipe_in, args_in): print("Error on a sub-process:", file=sys.stderr) traceback.print_exception(exc_type, exc_value, exc_traceback, file=sys.stderr) - _pipe.send(("error", e)) + _pipe.send(("error", exc_value, exc_traceback)) _pipe.close() +class RemoteException(Exception): + pass def launch_functions(functions, args): global _slave @@ -125,7 +132,7 @@ def launch_functions(functions, args): proc_i.start() running = [True for rank in range(num_procs)] - error = False + error: Optional[Exception] = None while any(running): for i, pipe_i in enumerate(parent_connections): @@ -136,6 +143,7 @@ def launch_functions(functions, args): running[i]=False elif isinstance(message[0], str) and message[0]=='error': error = message[1] + traceback = message[2] running = [False] break else: @@ -153,8 +161,8 @@ def launch_functions(functions, args): os.kill(proc_i.pid, signal.SIGTERM) proc_i.join() - if error: - raise error + if error is not None: + raise error.with_traceback(traceback) diff --git a/tangos/parallel_tasks/jobs.py b/tangos/parallel_tasks/jobs.py index effb5558..0cacfdaa 100644 --- a/tangos/parallel_tasks/jobs.py +++ b/tangos/parallel_tasks/jobs.py @@ -32,11 +32,11 @@ def process(self): global j, num_jobs, current_job source = self.source if current_job is not None and num_jobs>0: - log.logger.info("Send job %d of %d to node %d", current_job, num_jobs, source) + log.logger.debug("Send job %d of %d to node %d", current_job, num_jobs, source) else: num_jobs = None current_job = None # in case num_jobs=0, still want to send 'end of loop' signal to client - log.logger.info("Finished jobs; notify node %d", source) + log.logger.debug("Finished jobs; notify node %d", source) MessageDeliverJob(current_job).send(source) diff --git a/tangos/parallel_tasks/pynbody_server.py b/tangos/parallel_tasks/pynbody_server.py deleted file mode 100644 index ecc42232..00000000 --- a/tangos/parallel_tasks/pynbody_server.py +++ /dev/null @@ -1,381 +0,0 @@ -import gc -import pickle -import time - -import numpy as np -import pynbody - -from ..util.check_deleted import check_deleted -from . import log, parallelism_is_active, remote_import -from .message import ExceptionMessage, Message - - -class ConfirmLoadPynbodySnapshot(Message): - pass - -class ObjectSpecification: - def __init__(self, object_number, object_index, object_typetag='halo'): - self.object_number = object_number - self.object_index = object_index - self.object_typetag = object_typetag - - def __repr__(self): - return "ObjectSpecification(%d, %d, %r)"%(self.object_number, self.object_index, self.object_typetag) - - def __eq__(self, other): - if not isinstance(other, ObjectSpecification): - return False - return self.object_number==other.object_number and self.object_typetag==other.object_typetag - - def __hash__(self): - return hash((self.object_number, self.object_index, self.object_typetag)) - -class PynbodySnapshotQueue: - def __init__(self): - self.timestep_queue = [] - self.handler_queue = [] - self.load_requester_queue = [] - self.current_timestep = None - self.current_snapshot = None - self.current_subsnap_cache = {} - self.current_handler = None - self.in_use_by = [] - - def add(self, handler, filename, requester): - log.logger.debug("Pynbody server: client %d requests access to %r", requester, filename) - if filename==self.current_timestep: - self._notify_available(requester) - self.in_use_by.append(requester) - elif filename in self.timestep_queue: - queue_position = self.timestep_queue.index(filename) - self.load_requester_queue[queue_position].append(requester) - assert self.handler_queue[queue_position] == handler - else: - self.timestep_queue.append(filename) - self.handler_queue.append(handler) - self.load_requester_queue.append([requester]) - self._load_next_if_free() - - def free(self, requester): - self.in_use_by.remove(requester) - log.logger.debug("Pynbody server: client %d is now finished with %r", requester, self.current_timestep) - self._free_if_unused() - self._load_next_if_free() - - def get_subsnap(self, filter_or_object_spec, fam): - if (filter_or_object_spec, fam) in self.current_subsnap_cache: - log.logger.debug("Pynbody server: cache hit for %r (fam %r)",filter_or_object_spec, fam) - return self.current_subsnap_cache[(filter_or_object_spec, fam)] - else: - log.logger.debug("Pynbody server: cache miss for %r (fam %r)",filter_or_object_spec, fam) - subsnap = self.get_subsnap_uncached(filter_or_object_spec, fam) - self.current_subsnap_cache[(filter_or_object_spec, fam)] = subsnap - return subsnap - - def get_subsnap_uncached(self, filter_or_object_spec, fam): - - snap = self.current_snapshot - - if isinstance(filter_or_object_spec, pynbody.filt.Filter): - snap = snap[filter_or_object_spec] - elif isinstance(filter_or_object_spec, ObjectSpecification): - snap = self.current_handler.load_object(self.current_timestep, filter_or_object_spec.object_number, - filter_or_object_spec.object_index, - filter_or_object_spec.object_typetag) - else: - raise TypeError("filter_or_object_spec must be either a pynbody filter or an ObjectRequestInformation object") - - if fam is not None: - snap = snap[fam] - - return snap - - - - def _free_if_unused(self): - if len(self.in_use_by)==0: - log.logger.debug("Pynbody server: all clients are finished with the current snapshot; freeing.") - with check_deleted(self.current_snapshot): - self.current_snapshot = None - self.current_timestep = None - self.current_subsnap_cache = {} - self.current_handler = None - - def _notify_available(self, node): - log.logger.debug("Pynbody server: notify %d that snapshot is now available", node) - ConfirmLoadPynbodySnapshot(type(self.current_snapshot)).send(node) - - def _notify_unavailable(self, node): - log.logger.debug("Pynbody server: notify %d that snapshot is unavailable", node) - ConfirmLoadPynbodySnapshot(None).send(node) - - def _load_next_if_free(self): - if len(self.timestep_queue)==0: - return - - if self.current_handler is None: - # TODO: Error handling - self.current_timestep = self.timestep_queue.pop(0) - self.current_handler = self.handler_queue.pop(0) - - try: - self.current_snapshot = self.current_handler.load_timestep(self.current_timestep) - self.current_snapshot.physical_units() - log.logger.info("Pynbody server: loaded %r", self.current_timestep) - success = True - except OSError: - success = False - - notify = self.load_requester_queue.pop(0) - - if success: - self.in_use_by = notify - for n in notify: - self._notify_available(n) - else: - self.current_timestep = None - self.current_handler = None - self.current_snapshot = None - - for n in notify: - self._notify_unavailable(n) - self._load_next_if_free() - - else: - log.logger.info("The currently loaded snapshot is still required and so other clients will have to wait") - log.logger.info("(Currently %d snapshots are in the queue to be loaded later)", len(self.timestep_queue)) - - - -_server_queue = PynbodySnapshotQueue() - -class RequestLoadPynbodySnapshot(Message): - def process(self): - _server_queue.add(self.contents[0], self.contents[1], self.source) - -class ReleasePynbodySnapshot(Message): - def process(self): - _server_queue.free(self.source) - -class ReturnPynbodyArray(Message): - @classmethod - def deserialize(cls, source, message): - from . import backend - contents = backend.receive_numpy_array(source=source) - - if message!="": - contents = contents.view(pynbody.array.SimArray) - contents.units = pickle.loads(message) - - obj = ReturnPynbodyArray(contents) - obj.source = source - - return obj - - def serialize(self): - assert isinstance(self.contents, np.ndarray) - if hasattr(self.contents, 'units'): - serialized_info = pickle.dumps(self.contents.units) - else: - serialized_info = "" - - return serialized_info - - def send(self, destination): - # send envelope - super().send(destination) - - # send contents - from . import backend - backend.send_numpy_array(self.contents.view(np.ndarray), destination) - -class RequestPynbodyArray(Message): - def __init__(self, filter_or_object_spec, array, fam=None): - self.filter_or_object_spec = filter_or_object_spec - self.array = array - self.fam = fam - - @classmethod - def deserialize(cls, source, message): - obj = RequestPynbodyArray(*message) - obj.source = source - return obj - - def serialize(self): - return (self.filter_or_object_spec, self.array, self.fam) - - def process(self): - start_time = time.time() - try: - log.logger.debug("Receive request for array %r from %d",self.array,self.source) - subsnap = _server_queue.get_subsnap(self.filter_or_object_spec, self.fam) - - with subsnap.immediate_mode, subsnap.lazy_derive_off: - if subsnap._array_name_implies_ND_slice(self.array): - raise KeyError("Not transferring a single slice %r of a ND array"%self.array) - if self.array=='remote-index-list': - subarray = subsnap.get_index_list(subsnap.ancestor) - else: - subarray = subsnap[self.array] - assert isinstance(subarray, pynbody.array.SimArray) - array_result = ReturnPynbodyArray(subarray) - - except Exception as e: - array_result = ExceptionMessage(e) - - array_result.send(self.source) - del array_result - gc.collect() - log.logger.debug("Array sent after %.2fs"%(time.time()-start_time)) - - - -class ReturnPynbodySubsnapInfo(Message): - def __init__(self, families, sizes, properties, loadable_keys, fam_loadable_keys): - super().__init__() - self.families = families - self.sizes = sizes - self.properties = properties - self.loadable_keys = loadable_keys - self.fam_loadable_keys = fam_loadable_keys - - def serialize(self): - return self.families, self.sizes, self.properties, self.loadable_keys, self.fam_loadable_keys - - @classmethod - def deserialize(cls, source, message): - obj = cls(*message) - obj.source = source - return obj - - - -class RequestPynbodySubsnapInfo(Message): - def __init__(self, filename, filter_): - super().__init__() - self.filename = filename - self.filter_or_object_spec = filter_ - - @classmethod - def deserialize(cls, source, message): - obj = cls(*message) - obj.source = source - return obj - - def serialize(self): - return (self.filename, self.filter_or_object_spec) - - def process(self): - start_time = time.time() - assert(_server_queue.current_timestep == self.filename) - log.logger.debug("Received request for subsnap info, spec %r", self.filter_or_object_spec) - obj = _server_queue.get_subsnap(self.filter_or_object_spec, None) - families = obj.families() - fam_lengths = [len(obj[fam]) for fam in families] - fam_lkeys = [obj.loadable_keys(fam) for fam in families] - lkeys = obj.loadable_keys() - ReturnPynbodySubsnapInfo(families, fam_lengths, obj.properties, lkeys, fam_lkeys).send(self.source) - log.logger.debug("Subsnap info sent after %.2f",(time.time()-start_time)) - -class RemoteSubSnap(pynbody.snapshot.SimSnap): - def __init__(self, connection, filter_or_object_spec): - super().__init__() - - self.connection = connection - self._filename = connection.identity - self._server_id = connection._server_id - - RequestPynbodySubsnapInfo(connection.filename, filter_or_object_spec).send(self._server_id) - info = ReturnPynbodySubsnapInfo.receive(self._server_id) - - index = 0 - for fam, size in zip(info.families, info.sizes): - self._family_slice[fam] = slice(index, index+size) - index+=size - self._num_particles = index - - self.properties.update(info.properties) - self._loadable_keys = info.loadable_keys - self._fam_loadable_keys = {fam: lk for fam, lk in zip(info.families, info.fam_loadable_keys)} - self._filter_or_object_spec = filter_or_object_spec - - def _find_deriving_function(self, name): - cl = self.connection.underlying_pynbody_class - if cl in self._derived_quantity_registry \ - and name in self._derived_quantity_registry[cl]: - return self._derived_quantity_registry[cl][name] - else: - return super()._find_deriving_function(name) - - - def _load_array(self, array_name, fam=None): - RequestPynbodyArray(self._filter_or_object_spec, array_name, fam).send(self._server_id) - try: - start_time=time.time() - log.logger.debug("Send array request") - data = ReturnPynbodyArray.receive(self._server_id).contents - log.logger.debug("Array received; waited %.2fs",time.time()-start_time) - except KeyError: - raise OSError("No such array %r available from the remote"%array_name) - with self.auto_propagate_off: - if fam is None: - self[array_name] = data - else: - self[fam][array_name] = data - - -_connection_active = False - -class RemoteSnapshotConnection: - def __init__(self, input_handler, ts_extension, server_id=0): - global _connection_active - - from ..input_handlers import pynbody - assert isinstance(input_handler, pynbody.PynbodyInputHandler) - - if _connection_active: - raise RuntimeError("Each client can only have one remote snapshot connection at any time") - - - - super().__init__() - - self._server_id = server_id - self._input_handler = input_handler - self.filename = ts_extension - self.identity = "%d: %s"%(self._server_id, ts_extension) - self.connected = False - - # ensure server knows what our messages are about - remote_import.ImportRequestMessage(__name__).send(self._server_id) - - log.logger.debug("Pynbody client: attempt to connect to remote snapshot %r", ts_extension) - RequestLoadPynbodySnapshot((input_handler, ts_extension)).send(self._server_id) - self.underlying_pynbody_class = ConfirmLoadPynbodySnapshot.receive(self._server_id).contents - if self.underlying_pynbody_class is None: - raise OSError("Could not load remote snapshot %r"%ts_extension) - - _connection_active = True - self.connected = True - log.logger.info("Pynbody client: connected to remote snapshot %r", ts_extension) - - def get_view(self, filter_or_object_spec): - """Return a RemoteSubSnap that contains either the pynbody filtered region, or the specified object from a catalogue - - filter_or_object_spec is either an instance of pynbody.filt.Filter, or a tuple containing - (typetag, number), which are respectively the object type tag and object number to be loaded - """ - return RemoteSubSnap(self, filter_or_object_spec) - - def disconnect(self): - global _connection_active - - if not self.connected: - return - - ReleasePynbodySnapshot(self.filename).send(self._server_id) - _connection_active = False - self.connected = False - - def __del__(self): - self.disconnect() diff --git a/tangos/parallel_tasks/pynbody_server/__init__.py b/tangos/parallel_tasks/pynbody_server/__init__.py new file mode 100644 index 00000000..41f05b79 --- /dev/null +++ b/tangos/parallel_tasks/pynbody_server/__init__.py @@ -0,0 +1,281 @@ +import gc +import pickle +import time + +import numpy as np +import pynbody +import pynbody.snapshot.copy_on_access + +import tangos.parallel_tasks.pynbody_server.snapshot_queue + +from .. import log, remote_import +from ..message import ExceptionMessage, Message +from . import snapshot_queue, transfer_array +from .snapshot_queue import (ConfirmLoadPynbodySnapshot, + ReleasePynbodySnapshot, + RequestLoadPynbodySnapshot, _server_queue) + + +class ReturnPynbodyArray(Message): + + def __init__(self, contents, shared_mem = False): + self.shared_mem = shared_mem + super().__init__(contents) + + @classmethod + def deserialize(cls, source, message): + units, shared_mem = pickle.loads(message) + + contents = transfer_array.receive_array(source, use_shared_memory=shared_mem) + + if units is not None: + if not isinstance(contents, pynbody.array.SimArray): + contents = contents.view(pynbody.array.SimArray) + contents.units = units + + obj = ReturnPynbodyArray(contents, shared_mem=shared_mem) + obj.source = source + + return obj + + def serialize(self): + assert isinstance(self.contents, np.ndarray) + if hasattr(self.contents, 'units'): + units = self.contents.units + else: + units = None + + serialized_info = pickle.dumps((units, self.shared_mem)) + return serialized_info + + def send(self, destination): + # send envelope + super().send(destination) + + # send contents + transfer_array.send_array(self.contents, destination, use_shared_memory=self.shared_mem) + +class RequestPynbodyArray(Message): + def __init__(self, filter_or_object_spec, array, fam=None): + self.filter_or_object_spec = filter_or_object_spec + self.array = array + self.fam = fam + + @classmethod + def deserialize(cls, source, message): + obj = RequestPynbodyArray(*message) + obj.source = source + return obj + + def serialize(self): + return (self.filter_or_object_spec, self.array, self.fam) + + def process(self): + start_time = time.time() + try: + log.logger.debug("Receive request for array %r from %d",self.array,self.source) + subsnap = _server_queue.get_subsnap(self.filter_or_object_spec, self.fam) + transfer_via_shared_mem = _server_queue.current_shared_mem_flag + + with subsnap.immediate_mode, subsnap.lazy_derive_off: + if subsnap._array_name_implies_ND_slice(self.array): + raise KeyError("Not transferring a single slice %r of a ND array"%self.array) + if self.array=='remote-index-list': + subarray = subsnap.get_index_list(subsnap.ancestor).view(pynbody.array.SimArray) + # this won't be actually in shared memory – it's a regular numpy array + transfer_via_shared_mem = False + else: + subarray = subsnap[self.array] + assert isinstance(subarray, pynbody.array.SimArray) + array_result = ReturnPynbodyArray(subarray, transfer_via_shared_mem) + + except Exception as e: + array_result = ExceptionMessage(e) + + array_result.send(self.source) + del array_result + gc.collect() + log.logger.debug("Array sent after %.2fs"%(time.time()-start_time)) + + + +class ReturnPynbodySubsnapInfo(Message): + def __init__(self, families, sizes, properties, loadable_keys, fam_loadable_keys): + super().__init__() + self.families = families + self.sizes = sizes + self.properties = properties + self.loadable_keys = loadable_keys + self.fam_loadable_keys = fam_loadable_keys + + def serialize(self): + return self.families, self.sizes, self.properties, self.loadable_keys, self.fam_loadable_keys + + @classmethod + def deserialize(cls, source, message): + obj = cls(*message) + obj.source = source + return obj + + + +class RequestPynbodySubsnapInfo(Message): + def __init__(self, filename, filter_): + super().__init__() + self.filename = filename + self.filter_or_object_spec = filter_ + + @classmethod + def deserialize(cls, source, message): + obj = cls(*message) + obj.source = source + return obj + + def serialize(self): + return (self.filename, self.filter_or_object_spec) + + def process(self): + start_time = time.time() + assert(_server_queue.current_timestep == self.filename) + if self.filter_or_object_spec is not None: + log.logger.debug("Received request for subsnap info, spec %r", self.filter_or_object_spec) + else: + log.logger.debug("Received request for snapshot info") + obj = _server_queue.get_subsnap(self.filter_or_object_spec, None) + families = obj.families() + fam_lengths = [len(obj[fam]) for fam in families] + fam_lkeys = [obj.loadable_keys(fam) for fam in families] + lkeys = obj.loadable_keys() + ReturnPynbodySubsnapInfo(families, fam_lengths, obj.properties, lkeys, fam_lkeys).send(self.source) + log.logger.debug("Info sent after %.2f",(time.time()-start_time)) + + +class RemoteSnap(pynbody.snapshot.copy_on_access.UnderlyingClassMixin, pynbody.snapshot.SimSnap): + def __init__(self, connection, filter_or_object_spec): + """Create a remote snapshot object + + filter_or_object_spec can be: + - a pynbody filter + - a tuple (typetag, number) specifying an object to be loaded + - None to load the whole snapshot (only sensible in shared memory mode) + """ + super().__init__(connection.underlying_pynbody_class) + self.connection = connection + self._filename = connection.identity + self._server_id = connection._server_id + + RequestPynbodySubsnapInfo(connection.filename, filter_or_object_spec).send(self._server_id) + info = ReturnPynbodySubsnapInfo.receive(self._server_id) + + index = 0 + for fam, size in zip(info.families, info.sizes): + self._family_slice[fam] = slice(index, index+size) + index+=size + self._num_particles = index + + self.properties.update(info.properties) + self._loadable_keys = info.loadable_keys + self._fam_loadable_keys = {fam: lk for fam, lk in zip(info.families, info.fam_loadable_keys)} + self._filter_or_object_spec = filter_or_object_spec + + + + + def _load_array(self, array_name, fam=None): + RequestPynbodyArray(self._filter_or_object_spec, array_name, fam).send(self._server_id) + try: + start_time=time.time() + log.logger.debug("Send array request") + data = ReturnPynbodyArray.receive(self._server_id).contents + log.logger.debug("Array received; waited %.2fs",time.time()-start_time) + except KeyError: + raise OSError("No such array %r available from the remote"%array_name) + with self.auto_propagate_off: + if len(data.shape)==1: + ndim = 1 + elif len(data.shape)==2: + ndim = data.shape[-1] + else: + assert False, "Don't know how to handle this data shape" + + if fam is None: + self._create_array(array_name, ndim=ndim, source_array=data) + else: + self._create_family_array(array_name, fam, ndim=ndim, source_array=data) + + def _promote_family_array(self, name, *args, **kwargs): + # special logic: the normal promotion procedure would copy everything for this array out of shared memory + # which we don't want if the server can provide us with a shared memory view of the whole array + + if self.connection.shared_mem and not self.delay_promotion: + if name in self._loadable_keys: + for fam in self.families(): + try: + del self[fam][name] + except KeyError: + pass + self._load_array(name) + return + + super()._promote_family_array(name, *args, **kwargs) + + + + + + +class RemoteSnapshotConnection: + def __init__(self, input_handler, ts_extension, server_id=0, shared_mem=False): + + from ...input_handlers import pynbody + assert isinstance(input_handler, pynbody.PynbodyInputHandler) + + if tangos.parallel_tasks.pynbody_server.snapshot_queue._connection_active: + raise RuntimeError("Each client can only have one remote snapshot connection at any time") + + + + super().__init__() + + self._server_id = server_id + self._input_handler = input_handler + self.filename = ts_extension + self.identity = "%d: %s"%(self._server_id, ts_extension) + self.connected = False + self.shared_mem = shared_mem + + # ensure server knows what our messages are about + remote_import.ImportRequestMessage(__name__).send(self._server_id) + + log.logger.debug("Pynbody client: attempt to connect to remote snapshot %r", ts_extension) + RequestLoadPynbodySnapshot((input_handler, ts_extension, self.shared_mem)).send(self._server_id) + self.underlying_pynbody_class = ConfirmLoadPynbodySnapshot.receive(self._server_id).contents + if self.underlying_pynbody_class is None: + raise OSError("Could not load remote snapshot %r"%ts_extension) + + tangos.parallel_tasks.pynbody_server.snapshot_queue._connection_active = True + self.connected = True + log.logger.debug("Pynbody client: connected to remote snapshot %r", ts_extension) + + if self.shared_mem: + self.shared_mem_view = self.get_view(None) + + def get_view(self, filter_or_object_spec): + """Return a RemoteSubSnap that contains either the pynbody filtered region, or the specified object from a catalogue + + filter_or_object_spec is either an instance of pynbody.filt.Filter, or a tuple containing + (typetag, number), which are respectively the object type tag and object number to be loaded + """ + return RemoteSnap(self, filter_or_object_spec) + + def disconnect(self): + + if not self.connected: + return + + ReleasePynbodySnapshot(self.filename).send(self._server_id) + tangos.parallel_tasks.pynbody_server.snapshot_queue._connection_active = False + self.connected = False + + def __del__(self): + self.disconnect() diff --git a/tangos/parallel_tasks/pynbody_server/snapshot_queue.py b/tangos/parallel_tasks/pynbody_server/snapshot_queue.py new file mode 100644 index 00000000..19d3cef2 --- /dev/null +++ b/tangos/parallel_tasks/pynbody_server/snapshot_queue.py @@ -0,0 +1,172 @@ +import pynbody + +from tangos import log +from tangos.parallel_tasks.message import Message +from tangos.util.check_deleted import check_deleted + + +class ConfirmLoadPynbodySnapshot(Message): + pass + + +class PynbodySnapshotQueue: + def __init__(self): + self.timestep_queue = [] + self.handler_queue = [] + self.shared_mem_queue = [] + self.load_requester_queue = [] + self.current_timestep = None + self.current_snapshot = None + self.current_subsnap_cache = {} + self.current_handler = None + self.in_use_by = [] + + def add(self, requester, handler, filename, shared_mem=False): + log.logger.debug("Pynbody server: client %d requests access to %r", requester, filename) + if shared_mem: + log.logger.debug(" (shared memory mode)") + if filename==self.current_timestep: + self._notify_available(requester) + self.in_use_by.append(requester) + elif filename in self.timestep_queue: + queue_position = self.timestep_queue.index(filename) + self.load_requester_queue[queue_position].append(requester) + assert self.handler_queue[queue_position] == handler + assert self.shared_mem_queue[queue_position] == shared_mem + else: + self.timestep_queue.append(filename) + self.handler_queue.append(handler) + self.load_requester_queue.append([requester]) + self.shared_mem_queue.append(shared_mem) + self._load_next_if_free() + + def free(self, requester): + self.in_use_by.remove(requester) + log.logger.debug("Pynbody server: client %d is now finished with %r", requester, self.current_timestep) + self._free_if_unused() + self._load_next_if_free() + + def get_subsnap(self, filter_or_object_spec, fam): + if filter_or_object_spec is None: + if fam is None: + return self.current_snapshot + else: + return self.current_snapshot[fam] + elif (filter_or_object_spec, fam) in self.current_subsnap_cache: + log.logger.debug("Pynbody server: cache hit for %r (fam %r)",filter_or_object_spec, fam) + return self.current_subsnap_cache[(filter_or_object_spec, fam)] + else: + log.logger.debug("Pynbody server: cache miss for %r (fam %r)",filter_or_object_spec, fam) + subsnap = self.get_subsnap_uncached(filter_or_object_spec, fam) + self.current_subsnap_cache[(filter_or_object_spec, fam)] = subsnap + return subsnap + + def get_subsnap_uncached(self, filter_or_object_spec, fam): + + snap = self.current_snapshot + + if isinstance(filter_or_object_spec, pynbody.filt.Filter): + snap = snap[filter_or_object_spec] + elif isinstance(filter_or_object_spec, ObjectSpecification): + snap = self.current_handler.load_object(self.current_timestep, filter_or_object_spec.object_number, + filter_or_object_spec.object_index, + filter_or_object_spec.object_typetag) + else: + raise TypeError("filter_or_object_spec must be either a pynbody filter or an ObjectRequestInformation object") + + if fam is not None: + snap = snap[fam] + + return snap + + + + def _free_if_unused(self): + if len(self.in_use_by)==0: + log.logger.debug("Pynbody server: all clients are finished with the current snapshot; freeing.") + with check_deleted(self.current_snapshot): + self.current_snapshot = None + self.current_timestep = None + self.current_subsnap_cache = {} + self.current_handler = None + + def _notify_available(self, node): + log.logger.debug("Pynbody server: notify %d that snapshot is now available", node) + ConfirmLoadPynbodySnapshot(type(self.current_snapshot)).send(node) + + def _notify_unavailable(self, node): + log.logger.debug("Pynbody server: notify %d that snapshot is unavailable", node) + ConfirmLoadPynbodySnapshot(None).send(node) + + def _load_next_if_free(self): + if len(self.timestep_queue)==0: + return + + if self.current_handler is None: + # TODO: Error handling + self.current_timestep = self.timestep_queue.pop(0) + self.current_handler = self.handler_queue.pop(0) + self.current_shared_mem_flag = self.shared_mem_queue.pop(0) + notify = self.load_requester_queue.pop(0) + + try: + self.current_snapshot = self.current_handler.load_timestep(self.current_timestep) + log.logger.info("Pynbody server: loaded %r", self.current_timestep) + if self.current_shared_mem_flag: + log.logger.info(" (shared memory mode)") + self.current_snapshot._shared_arrays = True + self.current_snapshot.physical_units() + success = True + except OSError: + success = False + + if success: + self.in_use_by = notify + for n in notify: + self._notify_available(n) + else: + self.current_timestep = None + self.current_handler = None + self.current_snapshot = None + + for n in notify: + self._notify_unavailable(n) + self._load_next_if_free() + + else: + log.logger.info("The currently loaded snapshot is still required and so other clients will have to wait") + log.logger.info("(Currently %d snapshots are in the queue to be loaded later)", len(self.timestep_queue)) + + +_server_queue = PynbodySnapshotQueue() + + +class RequestLoadPynbodySnapshot(Message): + def process(self): + _server_queue.add(self.source, *self.contents) + + +class ReleasePynbodySnapshot(Message): + def process(self): + _server_queue.free(self.source) + + +_connection_active = False + + +class ObjectSpecification: + def __init__(self, object_number, object_index, object_typetag='halo'): + self.object_number = object_number + self.object_index = object_index + self.object_typetag = object_typetag + + def __repr__(self): + return "ObjectSpecification(%d, %d, %r)"%(self.object_number, self.object_index, self.object_typetag) + + def __eq__(self, other): + if not isinstance(other, ObjectSpecification): + return False + return self.object_number==other.object_number and self.object_typetag==other.object_typetag + + def __hash__(self): + return hash((self.object_number, self.object_index, self.object_typetag)) diff --git a/tangos/parallel_tasks/pynbody_server/transfer_array.py b/tangos/parallel_tasks/pynbody_server/transfer_array.py new file mode 100644 index 00000000..54f547a3 --- /dev/null +++ b/tangos/parallel_tasks/pynbody_server/transfer_array.py @@ -0,0 +1,42 @@ +import numpy as np +import pynbody + +from ..message import Message + + +def send_array(array: pynbody.array.SimArray, destination: int, use_shared_memory: bool = False): + if use_shared_memory: + if not hasattr(array, "_shared_fname"): + if isinstance(array, np.ndarray) and hasattr(array, "base") and hasattr(array.base, "_shared_fname"): + array._shared_fname = array.base._shared_fname # the strides/offset will point into the same memory + else: + raise ValueError("Array %r has no shared memory information" % array) + _send_array_shared_memory(array, destination) + else: + _send_array_copy(array, destination) + +def receive_array(source: int, use_shared_memory: bool = False): + if use_shared_memory: + return _receive_array_shared_memory(source) + else: + return _receive_array_copy(source) + +def _send_array_copy(array: np.ndarray, destination: int): + from .. import backend + backend.send_numpy_array(array, destination) + +def _receive_array_copy(source): + from .. import backend + return backend.receive_numpy_array(source) + + +class SharedMemoryArrayInfo(Message): + pass + +def _send_array_shared_memory(array: pynbody.array.SimArray, destination: int): + info = pynbody.array._shared_array_deconstruct(array, transfer_ownership=False) + SharedMemoryArrayInfo(info).send(destination) + +def _receive_array_shared_memory(source): + info = SharedMemoryArrayInfo.receive(source) + return pynbody.array._shared_array_reconstruct(info.contents) diff --git a/tangos/tools/add_simulation.py b/tangos/tools/add_simulation.py index 86af7eb7..6bb94ed3 100644 --- a/tangos/tools/add_simulation.py +++ b/tangos/tools/add_simulation.py @@ -72,9 +72,10 @@ def simulation_exists(self): return num_matches>0 def timestep_exists_for_extension(self, ts_extension): - ex = core.get_default_session().query(TimeStep).filter_by( - simulation=self._get_simulation(), - extension=ts_extension).first() + with pt.lock.SharedLock("db_write_lock"): + ex = core.get_default_session().query(TimeStep).filter_by( + simulation=self._get_simulation(), + extension=ts_extension).first() return ex is not None def add_timestep(self, ts_extension): @@ -161,4 +162,5 @@ def add_timestep_properties(self, ts): def _get_simulation(self): - return self.session.query(Simulation).filter_by(basename=self.basename).first() + with pt.lock.SharedLock("db_write_lock"): + return self.session.query(Simulation).filter_by(basename=self.basename).first() diff --git a/tangos/tools/db_importer.py b/tangos/tools/db_importer.py index 0dedaf11..de513f75 100644 --- a/tangos/tools/db_importer.py +++ b/tangos/tools/db_importer.py @@ -9,7 +9,7 @@ from sqlalchemy.schema import Column from tangos import Base, Creator, DictionaryItem, core -from tangos.config import DB_IMPORT_CHUNK_SIZE, DB_IMPORT_COMMIT_AFTER_CHUNKS +from tangos.config import DB_IMPORT_CHUNK_SIZE from tangos.core import (HaloLink, HaloProperty, Simulation, SimulationObjectBase, SimulationProperty, TimeStep) @@ -173,7 +173,6 @@ def _copy_table(from_connection, target_connection, orm_class, offsets, destinat try: target_connection.execute(insert(destination_table).values(all_rows)) - except sqlalchemy.exc.OperationalError as e: if retries>=1: raise # if this line is hit, it may reflect a data limit in the server, e.g. max_allowed_packet in MySQL @@ -181,19 +180,14 @@ def _copy_table(from_connection, target_connection, orm_class, offsets, destinat # server log, but in MySQL it does not seem to be. Reducing CHUNK_SIZE may help, or increasing # the limit on the server. - num_committed = num_done - (num_done % (DB_IMPORT_CHUNK_SIZE * DB_IMPORT_COMMIT_AFTER_CHUNKS)) - pbar.update(num_committed-num_done) # negative correction - print(f"Note: lost connection to database after {num_done} rows. Resetting to {num_committed}.") - # reset to point of last commit - num_done = num_committed + print(f"Note: lost connection to database after {num_done} rows. Trying again.") target_connection.rollback() # create a new connection from the target connection's engine target_connection = target_connection.engine.connect() - source_result = from_connection.execute(select(table).offset(num_committed)) + source_result = from_connection.execute(select(table).offset(num_done)) retries+=1 continue - num_done += len(all_rows) pbar.update(len(all_rows)) diff --git a/tangos/tools/property_writer.py b/tangos/tools/property_writer.py index eea3ca50..24133e0d 100644 --- a/tangos/tools/property_writer.py +++ b/tangos/tools/property_writer.py @@ -10,7 +10,7 @@ import sqlalchemy.exc import sqlalchemy.orm -from .. import core, live_calculation, parallel_tasks, properties +from .. import config, core, live_calculation, parallel_tasks, properties from ..cached_writer import insert_list from ..log import logger from ..util import proxy_object, terminalcontroller, timing_monitor @@ -28,8 +28,8 @@ class PropertyWriter(GenericTangosTool): def __init__(self): self.redirect = terminalcontroller.redirect - self._writer_timeout = 60 - self._writer_minimum = 60 # don't commit at end of halo if < 1 minute past + self._writer_timeout = config.PROPERTY_WRITER_MAXIMUM_TIME_BETWEEN_COMMITS + self._writer_minimum = config.PROPERTY_WRITER_MINIMUM_TIME_BETWEEN_COMMITS self._current_timestep_id = None self._loaded_timestep = None self._loaded_halo_id = None @@ -58,13 +58,14 @@ def add_parser_arguments(self, parser): help='Process timesteps in random order') parser.add_argument('--with-prerequisites', action='store_true', help='Automatically calculate any missing prerequisites for the properties') - parser.add_argument('--load-mode', action='store', choices=['all', 'partial', 'server', 'server-partial'], + parser.add_argument('--load-mode', action='store', choices=['all', 'partial', 'server', 'server-partial', 'server-shared-mem'], required=False, default=None, help="Select a load-mode: " \ - " --load-mode partial: each node attempts to load only the data it needs; " \ - " --load-mode server: a server process manages the data;" - " --load-mode server-partial: a server process figures out the indices to load, which are then passed to the partial loader" \ - " --load-mode all: each node loads all the data (default, and often fine for zoom simulations).") + " --load-mode partial: each processor attempts to load only the data it needs; " \ + " --load-mode server: a server process manages the data;" + " --load-mode server-partial: a server process figures out the indices to load, which are then passed to the partial loader" \ + " --load-mode all: each processor loads all the data (default, and often fine for zoom simulations)." \ + " --load-mode server-shared-mem: a server process manages the data, passing to other processes via shared memory") parser.add_argument('--type', action='store', type=str, dest='htype', help="Secify the object type to run on by tag name (or integer). Can be halo, group, or BH.") parser.add_argument('--hmin', action='store', type=int, default=0, @@ -179,6 +180,13 @@ def _compile_inclusion_criterion(self): else: self._include = None + def _log_one_process(self, *args): + if parallel_tasks.backend is None or parallel_tasks.backend.rank()==1: + logger.info(*args) + + def _summarise_timing_one_process(self): + if parallel_tasks.backend is None or parallel_tasks.backend.rank() == 1: + self.timing_monitor.summarise_timing(logger) def _build_halo_list(self, db_timestep): query = core.halo.SimulationObjectBase.timestep == db_timestep @@ -198,7 +206,7 @@ def _build_halo_list(self, db_timestep): if self._include: needed_properties.append(self._include) - logger.info('Gathering existing properties for all halos in timestep %r',db_timestep) + logger.debug('Gathering existing properties for all halos in timestep %r',db_timestep) halo_query = live_calculation.MultiCalculation(*needed_properties).supplement_halo_query(halo_query) halos = halo_query.all() @@ -212,8 +220,8 @@ def _build_halo_list(self, db_timestep): # perform filtering: halos = [halo_i for halo_i, include_i in zip(halos, inclusion) if include_i] - logger.info("User-specified inclusion criterion excluded %d of %d halos", - len(inclusion)-len(halos),len(inclusion)) + self._log_one_process("User-specified inclusion criterion excluded %d of %d halos", + len(inclusion)-len(halos),len(inclusion)) return halos @@ -268,7 +276,7 @@ def _commit_results_if_needed(self, end_of_timestep=False, end_of_simulation=Fal logger.info(f"...{num_properties} properties were committed") self._pending_properties = [] self._start_time = time.time() - self.timing_monitor.summarise_timing(logger) + self._summarise_timing_one_process() def _queue_results_for_later_commit(self, db_halo, names, results, existing_properties_data): for n, r in zip(names, results): @@ -334,7 +342,10 @@ def _set_current_timestep(self, db_timestep): def _set_current_halo(self, db_halo): - self._set_current_timestep(db_halo.timestep) + with parallel_tasks.lock.SharedLock("insert_list"): + # don't want this to happen in parallel with a database write -- seems to lazily fetch + # rows in the background + self._set_current_timestep(db_halo.timestep) if self._loaded_halo_id==db_halo.id: return @@ -476,17 +487,17 @@ def run_timestep_calculation(self, db_timestep): self._existing_properties_all_halos = self._build_existing_properties_all_halos(db_halos) - logger.info("Successfully gathered existing properties; calculating halo properties now...") + self._log_one_process("Successfully gathered existing properties; calculating halo properties now...") - logger.info(" %d halos to consider; %d calculation routines for each of them, resulting in %d properties per halo", + self._log_one_process(" %d halos to consider; %d calculation routines for each of them, resulting in %d properties per halo", len(db_halos), len(self._property_calculator_instances), sum([1 if isinstance(x.names, str) else len(x.names) for x in self._property_calculator_instances]) ) - logger.info(" The property modules are:") + self._log_one_process(" The property modules are:") for x in self._property_calculator_instances: x_type = type(x) - logger.info(f" {x_type.__module__}.{x_type.__qualname__}") + self._log_one_process(f" {x_type.__module__}.{x_type.__qualname__}") for db_halo, existing_properties in \ self._get_parallel_halo_iterator(list(zip(db_halos, self._existing_properties_all_halos))): @@ -497,9 +508,8 @@ def run_timestep_calculation(self, db_timestep): self._unload_timestep() self.tracker.report_to_log(logger) - sys.stderr.flush() - self._commit_results_if_needed(True) + self._commit_results_if_needed(end_of_timestep=True) def _add_prerequisites_to_calculator_instances(self, db_timestep): will_calculate = [] @@ -514,8 +524,8 @@ def _add_prerequisites_to_calculator_instances(self, db_timestep): for r in requirements: if r not in will_calculate: new_instance = properties.instantiate_class(db_timestep.simulation, r) - logger.info("Missing prerequisites - added class %r",type(new_instance)) - logger.info(" providing properties %r",new_instance.names) + self._log_one_process("Missing prerequisites - added class %r",type(new_instance)) + self._log_one_process(" providing properties %r",new_instance.names) self._property_calculator_instances = [new_instance]+self._property_calculator_instances self._add_prerequisites_to_calculator_instances(db_timestep) # everything has changed; start afresh break diff --git a/test_tutorial_build/build.sh b/test_tutorial_build/build.sh index cdce01e7..1cffbe80 100755 --- a/test_tutorial_build/build.sh +++ b/test_tutorial_build/build.sh @@ -20,7 +20,7 @@ detect_mpi() { echo "Detected mpirun -- will use where appropriate" else export MPIBACKEND="--backend=multiprocessing-3" # 1 process for server, 1 for each worker - export MPILOADMODE="--load-mode=server" + export MPILOADMODE="--load-mode=server-shared-mem" echo "No mpirun found; adopting multiprocessing with 2 workers" fi } diff --git a/tests/test_parallel_tasks.py b/tests/test_parallel_tasks.py index cac09195..3a8e23f0 100644 --- a/tests/test_parallel_tasks.py +++ b/tests/test_parallel_tasks.py @@ -1,6 +1,7 @@ -import sys import time +import pytest + import tangos import tangos.testing.simulation_generator from tangos import parallel_tasks as pt @@ -196,3 +197,22 @@ def test_shared_locks_in_queue(): lock_held-=1 else: assert False, "Unexpected line in log: "+line + +class ErrorOnServer(pt.message.Message): + def process(self): + raise RuntimeError("Error on server") + +def test_error_on_server(): + pt.use("multiprocessing-2") + with pytest.raises(RuntimeError) as e: + pt.launch(lambda: ErrorOnServer().send(0)) + assert "Error on server" in str(e.value) + +def test_error_on_client(): + pt.use("multiprocessing-2") + def _error_on_client(): + raise RuntimeError("Error on client") + + with pytest.raises(RuntimeError) as e: + pt.launch(_error_on_client) + assert "Error on client" in str(e.value) diff --git a/tests/test_pynbody_server.py b/tests/test_pynbody_server.py index c061f069..77fd55f8 100644 --- a/tests/test_pynbody_server.py +++ b/tests/test_pynbody_server.py @@ -9,6 +9,7 @@ import tangos.input_handlers.pynbody import tangos.parallel_tasks as pt import tangos.parallel_tasks.pynbody_server as ps +import tangos.parallel_tasks.pynbody_server.snapshot_queue class _TestHandler(tangos.input_handlers.pynbody.ChangaInputHandler): @@ -32,8 +33,8 @@ def teardown_module(): def _get_array(): test_filter = pynbody.filt.Sphere('5000 kpc') for fname in pt.distributed(["tiny.000640", "tiny.000832"]): - ps.RequestLoadPynbodySnapshot((handler, fname)).send(0) - ps.ConfirmLoadPynbodySnapshot.receive(0) + ps.snapshot_queue.RequestLoadPynbodySnapshot((handler, fname)).send(0) + ps.snapshot_queue.ConfirmLoadPynbodySnapshot.receive(0) ps.RequestPynbodyArray(test_filter, "pos").send(0) @@ -42,7 +43,7 @@ def _get_array(): remote_result = ps.ReturnPynbodyArray.receive(0).contents assert (f_local[test_filter]['pos']==remote_result).all() - ps.ReleasePynbodySnapshot().send(0) + ps.snapshot_queue.ReleasePynbodySnapshot().send(0) def test_get_array(): @@ -50,6 +51,52 @@ def test_get_array(): pt.launch(_get_array) +def _get_shared_array(): + if pt.backend.rank()==1: + shared_array = pynbody.array._array_factory((10,), int, True, True) + shared_array[:] = np.arange(0,10) + pt.pynbody_server.transfer_array.send_array(shared_array, 2, True) + assert shared_array[2]==2 + pt.barrier() + # change the value, to be checked in the other process + shared_array[2] = 100 + pt.barrier() + elif pt.backend.rank()==2: + shared_array = pt.pynbody_server.transfer_array.receive_array(2, True) + assert shared_array[2]==2 + pt.barrier() + # now the other process should be changing the value + pt.barrier() + assert shared_array[2]==100 + +def test_get_shared_array(): + pt.use("multiprocessing-3") + pt.launch(_get_shared_array) + +def _get_shared_array_slice(): + if pt.backend.rank()==1: + shared_array = pynbody.array._array_factory((10,), int, True, True) + shared_array[:] = np.arange(0,10) + pt.pynbody_server.transfer_array.send_array(shared_array[1:7:2], 2, True) + assert shared_array[3] == 3 + pt.barrier() + # change the value, to be checked in the other process + shared_array[3] = 100 + pt.barrier() + elif pt.backend.rank()==2: + shared_array = pt.pynbody_server.transfer_array.receive_array(2, True) + assert len(shared_array)==3 + assert shared_array[1] == 3 + pt.barrier() + # now the other process should be changing the value + pt.barrier() + assert shared_array[1]==100 + +def test_get_shared_array_slice(): + """Like test_get_shared_array, but with a slice""" + pt.use("multiprocessing-3") + pt.launch(_get_shared_array_slice) + def _test_simsnap_properties(): test_filter = pynbody.filt.Sphere('5000 kpc') conn = ps.RemoteSnapshotConnection(handler, "tiny.000640") @@ -96,7 +143,7 @@ def test_nonexistent_array(): def _test_halo_array(): conn = ps.RemoteSnapshotConnection(handler, "tiny.000640") - f = conn.get_view(ps.ObjectSpecification(1, 1)) + f = conn.get_view(ps.snapshot_queue.ObjectSpecification(1, 1)) f_local = pynbody.load(tangos.config.base+"test_simulations/test_tipsy/tiny.000640").halos()[1] assert len(f)==len(f_local) assert (f['x'] == f_local['x']).all() @@ -109,7 +156,7 @@ def test_halo_array(): def _test_remote_file_index(): conn = ps.RemoteSnapshotConnection(handler, "tiny.000640") - f = conn.get_view(ps.ObjectSpecification(1, 1)) + f = conn.get_view(ps.snapshot_queue.ObjectSpecification(1, 1)) f_local = pynbody.load(tangos.config.base+"test_simulations/test_tipsy/tiny.000640").halos()[1] local_index_list = f_local.get_index_list(f_local.ancestor) index_list = f['remote-index-list'] @@ -125,7 +172,7 @@ def _debug_print_arrays(*arrays): def _test_lazy_evaluation_is_local(): conn = ps.RemoteSnapshotConnection(handler, "tiny.000640") - f = conn.get_view(ps.ObjectSpecification(1, 1)) + f = conn.get_view(ps.snapshot_queue.ObjectSpecification(1, 1)) f_local = pynbody.load(tangos.config.base+"test_simulations/test_tipsy/tiny.000640").halos()[1] f_local.physical_units() @@ -154,7 +201,7 @@ def tipsy_specific_derived_array(sim): def _test_underlying_class(): conn = ps.RemoteSnapshotConnection(handler, "tiny.000640") - f = conn.get_view(ps.ObjectSpecification(1, 1)) + f = conn.get_view(ps.snapshot_queue.ObjectSpecification(1, 1)) f_local = pynbody.load(tangos.config.base + "test_simulations/test_tipsy/tiny.000640").halos()[1] f_local.physical_units() npt.assert_almost_equal(f['tipsy_specific_derived_array'],f_local['tipsy_specific_derived_array'], decimal=4) @@ -197,8 +244,8 @@ def metals(sim): """Derived array that will only be invoked for dm, since metals is present on disk for gas/stars""" return pynbody.array.SimArray(np.ones(len(sim))) -def _test_mixed_derived_loaded_arrays(): - f_remote = handler.load_object('tiny.000640', 1, 1, mode='server') +def _test_mixed_derived_loaded_arrays(mode='server'): + f_remote = handler.load_object('tiny.000640', 1, 1, mode=mode) f_local = handler.load_object('tiny.000640', 1, 1, mode=None) assert (f_remote.dm['metals'] == f_local.dm['metals']).all() assert (f_remote.st['metals'] == f_local.st['metals']).all() @@ -210,3 +257,96 @@ def test_mixed_derived_loaded_arrays(): specifically a "derived array is not writable" error on the server. This test ensures that the correct behaviour""" pt.use("multiprocessing-2") pt.launch(_test_mixed_derived_loaded_arrays) + pt.launch(lambda: _test_mixed_derived_loaded_arrays(mode='server-shared-mem')) + + +def _test_shmem_simulation(load_sphere=False): + sphere_filter = pynbody.filt.Sphere("3 Mpc") + def loader_function(**kwargs): + if load_sphere: + return handler.load_region("tiny.000640", sphere_filter, **kwargs) + else: + return handler.load_object("tiny.000640", 1, 1, **kwargs) + if pt.backend.rank()==1: + f_remote = loader_function(mode='server-shared-mem') + f_local = loader_function(mode=None) + # note we are using the velocity rather than the position because the position is already accessed + # in the case of the sphere region test. We intentionally load information family-by-family (see + # below). Using a 3d array slice (vx, rather than vel) tests that we don't accidentally just retireve + # 1d slices - the whole 3d array should be retrieved, even though the code only asks for the x component. + assert (f_remote.dm['vx'] == f_local.dm['vx']).all() + assert (f_remote.st['vx'] == f_local.st['vx']).all() + + f_remote.dm['vx'][:] = 1337.0 # this should be a copy, not the actual shared memory array + + assert 'vel' in f_remote.dm.keys() # should have got the whole 3d array + pt.barrier() + # other rank will test that 1337.0 is *not* in the array. The reason this must be + # true is so that we don't get race conditions when two processes are processing overlapping + # regions + pt.barrier() + f = handler.load_timestep("tiny.000640", mode='server-shared-mem').shared_mem_view + # now we get the *actual* shared memory view, so updates here really should reflect into the + # other process. This isn't particularly a desirable behaviour, but it just serves to verify + # everything really is backing onto shared memory + f.dm['vx'][:] = 1234.0 + pt.barrier() + + # We now want to test what happens when we load the rest of the position array. What we don't + # want to happen is a 'local promotion' -- this would imply a copy into local memory, defeating + # the point of having shared memory mode. Instead, we want to recognise that actually we always had + # pointers into a simulation-level shared memory array, and just keep looking at that. + + assert 'vel' not in f.keys() # currently a family array + assert f.dm['vel'].ancestor._shared_fname is not None + + # prompt a promotion: + f.gas['vx'] + + assert 'vel' in f.keys() + assert f['vel']._shared_fname is not None + + # Note: in principle, there could arise a situation where the server 'promotes' the array and so unlinks + # the shared memory file, but one or more clients still has a reference to it. However, the OS should + # not actually delete the file until all references are closed, so this should not cause a crash - it just + # means there could be excess memory usage. For now, let's not worry about it. + + + elif pt.backend.rank()==2: + pt.barrier() # let the other rank try to corrupt things + f_remote = loader_function(mode='server-shared-mem') + assert np.all(f_remote.dm['vx'] != 1337.0) + pt.barrier() + # other process is updating the shared memory array + pt.barrier() + f = handler.load_timestep("tiny.000640", mode='server-shared-mem').shared_mem_view + assert np.all(f.dm['vx']==1234.0) + + +def test_shmem_simulation_with_halo(): + """This test ensures that a simulation can be loaded correctly in shared memory, and halos accessed""" + pt.use("multiprocessing-3") + pt.launch(_test_shmem_simulation) + +def test_shmem_simulation_with_filter(): + """This test ensures that a simulation can be loaded correctly in shared memory, and filter regions accessed""" + pt.use("multiprocessing-3") + pt.launch(lambda: _test_shmem_simulation(load_sphere=True)) + + +def _test_implict_array_promotion_shared_mem(): + + f_remote = handler.load_timestep("tiny.000640", mode='server-shared-mem').shared_mem_view + + f_remote.dm['pos'] + f_remote.gas['pos'] + + # Don't explicitly load the f_remote.star['pos']. It should implicitly get promoted: + f_remote['pos'] + + f_local = handler.load_timestep("tiny.000640", mode=None) + assert (f_remote['pos'] == f_local['pos']).all() + +def test_implicit_array_promotion_shared_mem(): + pt.use("multiprocessing-2") + pt.launch(_test_implict_array_promotion_shared_mem) diff --git a/tests/test_simulation_outputs.py b/tests/test_simulation_outputs.py index c92f7087..30c5878e 100644 --- a/tests/test_simulation_outputs.py +++ b/tests/test_simulation_outputs.py @@ -83,7 +83,7 @@ def test_load_timestep(): def test_load_halo(): add_test_simulation_to_db() pynbody_h = db.get_halo("test_tipsy/tiny.000640/1").load() - assert isinstance(pynbody_h, pynbody.snapshot.SubSnap) + assert isinstance(pynbody_h, pynbody.snapshot.IndexedSubSnap) assert len(pynbody_h)==200 assert_is_subview_of_full_file(pynbody_h)