diff --git a/tangos/parallel_tasks/__init__.py b/tangos/parallel_tasks/__init__.py index 4217f967..f22ed735 100644 --- a/tangos/parallel_tasks/__init__.py +++ b/tangos/parallel_tasks/__init__.py @@ -84,25 +84,36 @@ def launch(function, args=None, backend_kwargs=None): return result -def distributed(file_list, proc=None, of=None, allow_resume=False, resumption_id=None): - """Distribute a list of tasks between all nodes""" +def distributed(items, allow_resume=False, resumption_id=None): + """Return an iterator that consumes the items, distributed across all processors + (i.e. each item is consumed by only one processor, in a dynamic way). - if type(file_list) == set: - file_list = list(file_list) + Optionally, if allow_resume is True, then the iterator will resume from the last point it reached + provided argv and the stack trace are unchanged. If resumption_id is not None, then + the stack trace is ignored and only resumption_id needs to match.""" + if type(items) == set: + items = list(items) + + if _backend_name=='null': + return items + else: + from . import jobs + return jobs.distributed_iterate(items, allow_resume, resumption_id) + +def synchronized(items, allow_resume=False, resumption_id=None): + """Return an iterator that consumes all items on all processors. + + Optionally, if allow_resume is True, then the iterator will resume from the last point it reached + provided argv and the stack trace are unchanged. If resumption_id is not None, then + the stack trace is ignored and only resumption_id needs to match.""" if _backend_name=='null': - if proc is None: - proc = 1 - of = 1 - i = (len(file_list) * (proc - 1)) // of - j = (len(file_list) * proc) // of - 1 - assert proc <= of and proc > 0 - if proc == of: - j += 1 - return file_list[i:j + 1] + return items else: from . import jobs - return jobs.parallel_iterate(file_list, allow_resume, resumption_id) + return jobs.synchronized_iterate(items, allow_resume, resumption_id) + + def _exec_function_or_server(function, connection_info, args): @@ -129,11 +140,7 @@ def _server_thread(): # available job and move on. Jobs are labelled through the # provided iterator - j = -1 - num_jobs = None - current_job = None alive = [True for i in range(backend.size())] - awaiting_barrier = [False for i in range(backend.size())] while any(alive[1:]): obj = message.Message.receive() diff --git a/tangos/parallel_tasks/accumulative_statistics.py b/tangos/parallel_tasks/accumulative_statistics.py index bebdd1c9..b5c12554 100644 --- a/tangos/parallel_tasks/accumulative_statistics.py +++ b/tangos/parallel_tasks/accumulative_statistics.py @@ -3,47 +3,27 @@ from ..config import PROPERTY_WRITER_PARALLEL_STATISTICS_TIME_BETWEEN_UPDATES from ..log import logger -from .message import Message +from .message import BarrierMessageWithResponse, Message _new_accumulator_requested_for_ranks = [] _new_accumulator = None _existing_accumulators = [] -class CreateNewAccumulatorMessage(Message): +class CreateNewAccumulatorMessage(BarrierMessageWithResponse): - def process(self): - from . import backend - global _new_accumulator, _new_accumulator_requested_for_ranks, _existing_accumulators - assert issubclass(self.contents, StatisticsAccumulatorBase) - if _new_accumulator is None: - _new_accumulator = self.contents() - _new_accumulator_requested_for_ranks = [self.source] - else: - assert self.source not in _new_accumulator_requested_for_ranks - assert isinstance(_new_accumulator, self.contents) - _new_accumulator_requested_for_ranks.append(self.source) - - from . import backend - - if len(_new_accumulator_requested_for_ranks) == backend.size()-1: - self._confirm_new_accumulator() - - def _confirm_new_accumulator(self): - global _new_accumulator, _new_accumulator_requested_for_ranks, _existing_accumulators + def process_global(self): from . import backend, on_exit_parallelism + + new_accumulator = self.contents() accumulator_id = len(_existing_accumulators) - _existing_accumulators.append(_new_accumulator) + _existing_accumulators.append(new_accumulator) - locally_bound_accumulator = _new_accumulator - logger.debug("Created new accumulator of type %s with id %d" % (locally_bound_accumulator.__class__.__name__, accumulator_id)) + locally_bound_accumulator = new_accumulator + logger.debug("Created new accumulator of type %s with id %d" % ( + locally_bound_accumulator.__class__.__name__, accumulator_id)) on_exit_parallelism(lambda: locally_bound_accumulator.report_to_log_if_needed(logger, 0.05)) - _new_accumulator = None - _new_accumulator_requested_for_ranks = [] + self.respond(accumulator_id) - for destination in range(1, backend.size()): - AccumulatorIdMessage(accumulator_id).send(destination) -class AccumulatorIdMessage(Message): - pass class AccumulateStatisticsMessage(Message): def process(self): global _existing_accumulators @@ -63,9 +43,7 @@ def __init__(self, allow_parallel=False): self._parallel = allow_parallel and parallelism_is_active() and backend.rank() != 0 if self._parallel: logger.debug(f"Registering {self.__class__}") - CreateNewAccumulatorMessage(self.__class__).send(0) - logger.debug(f"Awaiting accumulator id for {self.__class__}") - self.id = AccumulatorIdMessage.receive(0).contents + self.id = CreateNewAccumulatorMessage(self.__class__).send_and_get_response(0) logger.debug(f"Received accumulator id={ self.id}") def report_to_server(self): diff --git a/tangos/parallel_tasks/barrier.py b/tangos/parallel_tasks/barrier.py index ead3eeda..0ee2b4fd 100644 --- a/tangos/parallel_tasks/barrier.py +++ b/tangos/parallel_tasks/barrier.py @@ -1,29 +1,13 @@ from . import message -awaiting_barrier = None - - -class MessageBarrierPass(message.Message): - pass - -class MessageBarrier(message.Message): - def process(self): - from . import backend - global awaiting_barrier - if awaiting_barrier is None: - awaiting_barrier = [False for i in range(backend.size())] - - awaiting_barrier[self.source] = True - if all(awaiting_barrier[1:]): - for i in range(1, backend.size()): - MessageBarrierPass().send(i) - awaiting_barrier = [False for i in range(backend.size())] +class SimpleBarrierMessage(message.BarrierMessageWithResponse): + def process_global(self): + self.respond(None) def barrier(): from . import backend, parallelism_is_active if not parallelism_is_active(): return assert backend.rank()!=0, "The server process cannot take part in a barrier" - MessageBarrier().send(0) - MessageBarrierPass.receive(0) + SimpleBarrierMessage().send_and_get_response(0) # awaits response which only comes when all processes reach barrier diff --git a/tangos/parallel_tasks/database.py b/tangos/parallel_tasks/database.py index 4fa4c6a4..66f2f962 100644 --- a/tangos/parallel_tasks/database.py +++ b/tangos/parallel_tasks/database.py @@ -4,13 +4,10 @@ from . import message, remote_import -class MessageRequestCreatorId(message.Message): +class MessageRequestCreatorId(message.MessageWithResponse): def process(self): creator_id = core.creator.get_creator_id() - MessageDeliverCreatorId(creator_id).send(self.source) - -class MessageDeliverCreatorId(message.Message): - pass + self.respond(creator_id) def synchronize_creator_object(session=None): @@ -28,6 +25,5 @@ def synchronize_creator_object(session=None): return remote_import.ImportRequestMessage(__name__).send(0) - MessageRequestCreatorId().send(0) - id = MessageDeliverCreatorId.receive(0).contents + id = MessageRequestCreatorId().send_and_get_response(0) core.creator.set_creator(session.query(core.creator.Creator).filter_by(id=id).first()) diff --git a/tangos/parallel_tasks/jobs.py b/tangos/parallel_tasks/jobs.py index 3e637ec1..fac9042c 100644 --- a/tangos/parallel_tasks/jobs.py +++ b/tangos/parallel_tasks/jobs.py @@ -129,40 +129,68 @@ def __eq__(self, other): return self._jobs_complete == other._jobs_complete +class SynchronizedIterationState(IterationState): + def _first_incomplete_job_after(self, job): + if job is None: + job = -1 # so that we scan over all possible jobs when the loop first enters + for i in range(job+1, len(self._jobs_complete)): + if not self._jobs_complete[i]: + return i + return None -current_num_jobs = None -current_iteration_state = None -current_stack_hash = None -current_is_resumable = False + def _is_still_running_somewhere(self, job): + for v in self._rank_running_job.values(): + if v == job: + return True + return False -class MessageStartIteration(message.Message): - def process(self): - global current_iteration_state, current_stack_hash, current_is_resumable - req_jobs, req_hash, allow_resume = self.contents + def next_job(self, for_rank): + previous_job = self._rank_running_job[for_rank] + my_next_job = self._first_incomplete_job_after(previous_job) + if my_next_job is None: + del self._rank_running_job[for_rank] + else: + self._rank_running_job[for_rank] = my_next_job - if current_iteration_state is None: - current_num_jobs = req_jobs - current_stack_hash = req_hash - current_is_resumable = allow_resume - # convert sys.argv to a string which is quoted/escaped as needed - argv_string = " ".join([pipes.quote(arg) for arg in sys.argv]) + # NB the next line assumes that no process can ever get two steps ahead of another process + # This is not actually enforced here, it's enforced by a barrier inside the synchronized_iterate + # function below. See also test_parallel_tasks.py::test_overtaking_synchronized_loop + if previous_job is not None and (not self._is_still_running_somewhere(previous_job)): + self.mark_complete(previous_job) - current_iteration_state = IterationState.from_context(req_jobs, argv=" ".join(sys.argv), - stack_hash=req_hash, - allow_resume=allow_resume) + return my_next_job + +_next_iteration_state_id = 0 +_iteration_states = {} + + +class MessageStartIteration(message.BarrierMessageWithResponse): + def process_global(self): + global _next_iteration_state_id, _iteration_states + req_jobs, req_hash, allow_resume, synchronized = self.contents + + argv_string = " ".join([pipes.quote(arg) for arg in sys.argv]) + + IteratorClass = SynchronizedIterationState if synchronized else IterationState + + my_id = _next_iteration_state_id + _iteration_states[my_id] = IteratorClass.from_context(req_jobs, argv=argv_string, + stack_hash=req_hash, + allow_resume=allow_resume) + _next_iteration_state_id += 1 + + self.respond(my_id) + + def assert_consistent(self, other): + assert type(self) == type(other) + self_njobs = self.contents[0] + other_njobs = other.contents[0] + if self_njobs != other_njobs: + raise InconsistentJobList("Inconsistent number of jobs between different processes") + if self.contents != other.contents: + raise InconsistentContext("Inconsistency in requested loops between different processes") - else: - if len(current_iteration_state) != req_jobs: - raise InconsistentJobList(f"Number of jobs ({req_jobs}) expected by rank {self.source} " - f"is inconsistent with {len(current_iteration_state)}") - if current_stack_hash != req_hash: - raise InconsistentContext(f"Inconsistent stack from rank {self.source} when entering parallel loop") - if current_is_resumable != allow_resume: - raise InconsistentContext(f"Inconsistent allow_resume flag from rank {self.source} when entering parallel loop") - -class MessageDeliverJob(message.Message): - pass class MessageDistributeJobList(message.Message): def process(self): @@ -171,10 +199,13 @@ def process(self): for rank in range(1, backend.size()): if rank != self.source: MessageDistributeJobList(self.contents).send(rank) -class MessageRequestJob(message.Message): + +class MessageRequestJob(message.MessageWithResponse): def process(self): - global current_iteration_state + iterator_id = self.contents + current_iteration_state = _iteration_states.get(iterator_id, None) source = self.source + assert current_iteration_state is not None # should not be requesting jobs if we are not in a loop job = current_iteration_state.next_job(source) @@ -184,12 +215,12 @@ def process(self): else: log.logger.debug("Finished jobs; notify node %d", source) - MessageDeliverJob(job).send(source) - if current_iteration_state.finished(): - current_iteration_state = None + del _iteration_states[iterator_id] -def parallel_iterate(task_list, allow_resume=False, resumption_id=None): + self.respond(job) + +def distributed_iterate(task_list, allow_resume=False, resumption_id=None): """Sets up an iterator returning items of task_list. If allow_resume is True, then the iterator will resume from the last point it reached @@ -198,25 +229,54 @@ def parallel_iterate(task_list, allow_resume=False, resumption_id=None): """ from . import backend, barrier - if resumption_id is None: - stack_string = "\n".join(traceback.format_stack()) - # we need a hash of stack_string that is stable across runs. - resumption_id = hashlib.sha256(stack_string.encode('utf-8')).hexdigest() + resumption_id = resumption_id or _autogenerate_resume_id() assert backend is not None, "Parallelism is not initialised" - MessageStartIteration((len(task_list), resumption_id, allow_resume)).send(0) + iteration_id = MessageStartIteration((len(task_list), resumption_id, allow_resume, False)).send_and_get_response(0) barrier() while True: - MessageRequestJob().send(0) - job = MessageDeliverJob.receive(0).contents - + job = MessageRequestJob(iteration_id).send_and_get_response(0) if job is None: barrier() return else: yield task_list[job] + +def _autogenerate_resume_id(): + stack_string = "\n".join(traceback.format_stack()) + # we need a hash of stack_string that is stable across runs. + resumption_id = hashlib.sha256(stack_string.encode('utf-8')).hexdigest() + return resumption_id + + +def synchronized_iterate(task_list, allow_resume=False, resumption_id=None): + """Like distributed_iterate, but all processes see all tasks. + + The main advantage is the ability to resume if allow_resume is True""" + from . import backend, barrier + + resumption_id = resumption_id or _autogenerate_resume_id() + + assert backend is not None, "Parallelism is not initialised" + + iteration_id = MessageStartIteration((len(task_list), resumption_id, allow_resume, True)).send_and_get_response(0) + barrier() + + while True: + job = MessageRequestJob(iteration_id).send_and_get_response(0) + barrier() # this is crucial to keep things in sync (see comment in SynchronizedIterationState.next_job) + if job is None: + return + + yield task_list[job] + + + + + + def generate_task_list_and_parallel_iterate(task_list_function, allow_resume=False): """Call task_list_function on only one rank, and then parallel iterate with all ranks""" from . import backend @@ -231,4 +291,4 @@ def generate_task_list_and_parallel_iterate(task_list_function, allow_resume=Fal log.logger.debug("awaiting rank 1 generating task list") task_list = MessageDistributeJobList.receive(0).contents log.logger.debug("task_list = %r",task_list) - return parallel_iterate(task_list, allow_resume=allow_resume) + return distributed_iterate(task_list, allow_resume=allow_resume) diff --git a/tangos/parallel_tasks/message.py b/tangos/parallel_tasks/message.py index eb26ab2e..c46547f4 100644 --- a/tangos/parallel_tasks/message.py +++ b/tangos/parallel_tasks/message.py @@ -76,6 +76,51 @@ def receive(cls, source=None): def process(self): raise NotImplementedError("No process implemented for this message") +class ServerResponseMessage(Message): + pass + +class MessageWithResponse(Message): + """An extension of the message class where the server can return a response to each process""" + def respond(self, response): + return ServerResponseMessage(response).send(self.source) + + def send_and_get_response(self, destination): + self.send(destination) + return self.get_response(destination) + + def get_response(self, receiving_from): + return ServerResponseMessage.receive(receiving_from).contents + +class BarrierMessageWithResponse(MessageWithResponse): + """An extension of the message class where the client blocks until all processes have made the request, and then the server responds""" + _current_barrier_message = None + def process(self): + from . import backend + if BarrierMessageWithResponse._current_barrier_message is None: + BarrierMessageWithResponse._current_barrier_message = self + BarrierMessageWithResponse._current_barrier_message._all_sources = [self.source] + else: + self.assert_consistent(BarrierMessageWithResponse._current_barrier_message) + assert self.source not in BarrierMessageWithResponse._current_barrier_message._all_sources + BarrierMessageWithResponse._current_barrier_message._all_sources.append(self.source) + + if len(BarrierMessageWithResponse._current_barrier_message._all_sources) == backend.size()-1: + BarrierMessageWithResponse._current_barrier_message = None + self.process_global() + + def process_global(self): + raise NotImplementedError("No process implemented for this message") + + def assert_consistent(self, original_message): + assert type(self) == type(original_message) + assert self.contents == original_message.contents + + def respond(self, response): + from . import backend + response = ServerResponseMessage(response) + for i in range(1, backend.size()): + response.send(i) + class ExceptionMessage(Message): _is_exception = True diff --git a/tangos/parallel_tasks/shared_set.py b/tangos/parallel_tasks/shared_set.py index b7a4e3f5..e802547d 100644 --- a/tangos/parallel_tasks/shared_set.py +++ b/tangos/parallel_tasks/shared_set.py @@ -1,22 +1,16 @@ -from .message import Message +from .message import MessageWithResponse _remote_sets = {} -class RemoteSetOperation(Message): +class RemoteSetOperation(MessageWithResponse): def process(self): set_id, operation, value = self.contents global _remote_sets if operation=="add-if-not-exists": result = LocalSet(set_id).add_if_not_exists(value) - self.reply(result) + self.respond(result) else: raise ValueError("Unknown operation %s" % operation) - def reply(self, result): - RemoteSetResult(result).send(self.source) - -class RemoteSetResult(Message): - pass - class SharedSet: def __new__(cls, set_id, allow_parallel=False): if cls is SharedSet: @@ -46,9 +40,8 @@ def __init__(self, set_id, allow_parallel=False): def add_if_not_exists(self, value): """Adds to the set, and returns a boolean indicating whether the value was already present""" - RemoteSetOperation((self.set_id, "add-if-not-exists", value)).send(0) - result = RemoteSetResult.receive(0).contents - return result + return RemoteSetOperation((self.set_id, "add-if-not-exists", value)).send_and_get_response(0) + class LocalSet(SharedSet): def __init__(self, set_id, allow_parallel=False): diff --git a/tangos/parallel_tasks/testing.py b/tangos/parallel_tasks/testing.py index e006170b..a774b14b 100644 --- a/tangos/parallel_tasks/testing.py +++ b/tangos/parallel_tasks/testing.py @@ -6,9 +6,14 @@ def initialise_log(): with open(FILENAME, "w") as f: f.write("") -def get_log(): +def get_log(remove_process_ids=False): + if remove_process_ids: + processor = lambda s: s.strip()[4:] + else: + processor = lambda s: s.strip() + with open(FILENAME) as f: - return f.readlines() + return [processor(s) for s in f.readlines()] def log(message): ServerLogMessage(message).send(0) diff --git a/tangos/tools/property_writer.py b/tangos/tools/property_writer.py index 83bae064..2e426b2a 100644 --- a/tangos/tools/property_writer.py +++ b/tangos/tools/property_writer.py @@ -131,16 +131,18 @@ def files(self): def _get_parallel_timestep_iterator(self): - if self.options.part is not None: - # In the case of a null backend with manual parallelism, pass the specified part specification - ma_files = parallel_tasks.distributed(self.files, proc=self.options.part[0], of=self.options.part[1]) + if parallel_tasks.backend is None: + # Go sequentially + ma_files = self.files elif self.options.load_mode is not None and self.options.load_mode.startswith('server'): # In the case of loading from a centralised server, each node works on the _same_ timestep -- # parallelism is then implemented at the halo level - ma_files = self.files + ma_files = parallel_tasks.synchronized(self.files, allow_resume=not self.options.no_resume, + resumption_id='parallel-timestep-iterator') else: # In all other cases, different timesteps are distributed to different nodes - ma_files = parallel_tasks.distributed(self.files, allow_resume=not self.options.no_resume) + ma_files = parallel_tasks.distributed(self.files, allow_resume=not self.options.no_resume, + resumption_id='parallel-timestep-iterator') return ma_files def _get_parallel_halo_iterator(self, items): @@ -154,7 +156,7 @@ def _get_parallel_halo_iterator(self, items): # before all nodes have generated their local work lists parallel_tasks.barrier() - return parallel_tasks.distributed(items) + return parallel_tasks.distributed(items, allow_resume=False) else: return items @@ -183,8 +185,8 @@ def _compile_inclusion_criterion(self): else: self._include = None - def _log_one_process(self, *args): - if parallel_tasks.backend is None or parallel_tasks.backend.rank()==1: + def _log_once_per_timestep(self, *args): + if parallel_tasks.backend is None or parallel_tasks.backend.rank()==1 or self.options.load_mode is None: logger.info(*args) def _summarise_timing(self): @@ -223,8 +225,8 @@ def _build_halo_list(self, db_timestep): # perform filtering: halos = [halo_i for halo_i, include_i in zip(halos, inclusion) if include_i] - self._log_one_process("User-specified inclusion criterion excluded %d of %d halos", - len(inclusion)-len(halos),len(inclusion)) + self._log_once_per_timestep("User-specified inclusion criterion excluded %d of %d halos", + len(inclusion) - len(halos), len(inclusion)) return halos @@ -477,7 +479,7 @@ def run_halo_calculation(self, db_halo, existing_properties): def run_timestep_calculation(self, db_timestep): - logger.info("Processing %r", db_timestep) + self._log_once_per_timestep("Processing %r", db_timestep) self._property_calculator_instances = properties.instantiate_classes(db_timestep.simulation, self.options.properties, explain=self.options.explain_classes) @@ -491,24 +493,24 @@ def run_timestep_calculation(self, db_timestep): self._existing_properties_all_halos = self._build_existing_properties_all_halos(db_halos) - self._log_one_process("Successfully gathered existing properties; calculating halo properties now...") + self._log_once_per_timestep("Successfully gathered existing properties; calculating halo properties now...") - self._log_one_process(" %d halos to consider; %d calculation routines for each of them, resulting in %d properties per halo", - len(db_halos), len(self._property_calculator_instances), - sum([1 if isinstance(x.names, str) else len(x.names) for x in self._property_calculator_instances]) - ) + self._log_once_per_timestep(" %d halos to consider; %d calculation routines for each of them, resulting in %d properties per halo", + len(db_halos), len(self._property_calculator_instances), + sum([1 if isinstance(x.names, str) else len(x.names) for x in self._property_calculator_instances]) + ) - self._log_one_process(" The property modules are:") + self._log_once_per_timestep(" The property modules are:") for x in self._property_calculator_instances: x_type = type(x) - self._log_one_process(f" {x_type.__module__}.{x_type.__qualname__}") + self._log_once_per_timestep(f" {x_type.__module__}.{x_type.__qualname__}") for db_halo, existing_properties in \ self._get_parallel_halo_iterator(list(zip(db_halos, self._existing_properties_all_halos))): self._existing_properties_this_halo = existing_properties self.run_halo_calculation(db_halo, existing_properties) - logger.info("Done with %r",db_timestep) + self._log_once_per_timestep("Done with %r", db_timestep) self._unload_timestep() self.tracker.report_to_log_or_server(logger) @@ -528,8 +530,8 @@ def _add_prerequisites_to_calculator_instances(self, db_timestep): for r in requirements: if r not in will_calculate: new_instance = properties.instantiate_class(db_timestep.simulation, r) - self._log_one_process("Missing prerequisites - added class %r",type(new_instance)) - self._log_one_process(" providing properties %r",new_instance.names) + self._log_once_per_timestep("Missing prerequisites - added class %r", type(new_instance)) + self._log_once_per_timestep(" providing properties %r", new_instance.names) self._property_calculator_instances = [new_instance]+self._property_calculator_instances self._add_prerequisites_to_calculator_instances(db_timestep) # everything has changed; start afresh break diff --git a/tests/test_db_writer.py b/tests/test_db_writer.py index 42c3c6ec..7bd43071 100644 --- a/tests/test_db_writer.py +++ b/tests/test_db_writer.py @@ -125,9 +125,13 @@ def test_basic_writing(fresh_database): run_writer_with_args("dummy_property") _assert_properties_as_expected() -def test_parallel_writing(fresh_database): +@pytest.mark.parametrize('load_mode', [None, 'server']) +def test_parallel_writing(fresh_database, load_mode): parallel_tasks.use('multiprocessing-2') - run_writer_with_args("dummy_property", parallel=True) + if load_mode is None: + run_writer_with_args("dummy_property", parallel=True) + else: + run_writer_with_args("dummy_property", "--load-mode="+load_mode, parallel=True) _assert_properties_as_expected() diff --git a/tests/test_parallel_tasks.py b/tests/test_parallel_tasks.py index 3e1b7641..e7db6c0e 100644 --- a/tests/test_parallel_tasks.py +++ b/tests/test_parallel_tasks.py @@ -41,6 +41,23 @@ def test_add_property(): assert tangos.get_halo(i)['my_test_property']==i +def _test_barrier(): + if pt.backend.rank()==1: + # only sleep on one process, to check barrier works + time.sleep(0.3) + pt_testing.log("Before barrier") + pt.barrier() + pt_testing.log("After barrier") + +def test_barrier(): + pt.use("multiprocessing-3") + pt_testing.initialise_log() + pt.launch(_test_barrier) + log = pt_testing.get_log(remove_process_ids=True) + assert log == ["Before barrier"]*2+["After barrier"]*2 + + + def _add_two_properties_different_ranges(): for i in pt.distributed(list(range(1,10))): @@ -106,83 +123,164 @@ def test_inconsistent_loops_rejected(): pt.launch(_test_loops_different_backtrace) -def _test_resume_loop(attempt): - for i in pt.distributed(list(range(10)), allow_resume=True, resumption_id=101): - with pt.ExclusiveLock("write_to_file", 0.01): - with open("test_resume_loop.log", "a") as f: - f.write(f"Start job {i}\n") +def _test_synchronized_loop(): + for i in pt.synchronized(list(range(10))): + pt_testing.log(f"Doing task {i}") + pass + +def test_synchronized_loop(): + pt.use('multiprocessing-3') + pt_testing.initialise_log() + pt.launch(_test_synchronized_loop) + log = pt_testing.get_log() + assert len(log) == 20 + for i in range(10): + for r in (1,2): + assert log.count(f"[{r}] Doing task {i}") == 1 + + +def _test_resume_loop(attempt, mode='distributed'): + if mode=='distributed': + iterator = pt.distributed(list(range(10)), allow_resume=True, resumption_id=1) + # must provide a resumption_id because when we resume the stack trace is different + elif mode=='synchronized': + iterator = pt.synchronized(list(range(10)), allow_resume=True, resumption_id=2) + else: + raise ValueError("Unknown test mode") + + for i in iterator: + pt_testing.log(f"Start job {i}") + pt.barrier() # make sure start is logged before kicking up a fuss + if i==5 and attempt==0: raise ValueError("Suspend processing") pt.barrier() - time.sleep(0.02) - logger.info(f"Finished job {i}") - with pt.ExclusiveLock("write_to_file", 0.01): - with open("test_resume_loop.log", "a") as f: - f.write(f"Finish job {i}\n") + pt_testing.log(f"Finish job {i}") - pt.barrier() -def test_resume_loop(): + + +@pytest.mark.parametrize("mode", ("distributed", "synchronized")) +def test_resume_loop(mode): pt.use("multiprocessing-3") pt.jobs.IterationState.clear_resume_state() - try: - os.unlink("test_resume_loop.log") - except FileNotFoundError: - pass + pt_testing.initialise_log() with pytest.raises(ValueError): - pt.launch(_test_resume_loop, args=(0,)) + pt.launch(_test_resume_loop, args=(0, mode)) - with open("test_resume_loop.log") as f: - lines = [s.strip() for s in f.readlines()] + lines = pt_testing.get_log(remove_process_ids=True) - for i in range(6): - assert lines.count(f"Start job {i}") == 1 - if i<=3: - assert lines.count(f"Finish job {i}") == 1 - else: - assert lines.count(f"Finish job {i}") == 0 + expected_when_distributed = [ + (f"Start job {i}", 1) for i in range(6) + ] + [ + (f"Finish job {i}", 1) for i in range(4) + ] + [ + (f"Finish job {i}", 0) for i in range(4,6) + ] + + expected_when_synchronized = [ + (f"Start job {i}", 2) for i in range(6) + ] + [ + (f"Finish job {i}", 2) for i in range(5) + ] + [ + (f"Finish job 5", 0) + ] - pt.launch(_test_resume_loop, args=(1,)) + expected = expected_when_distributed if mode=='distributed' else expected_when_synchronized - with open("test_resume_loop.log") as f: - lines = [s.strip() for s in f.readlines()] + for line, count in expected: + assert lines.count(line) == count - for i in range(4): - assert lines.count(f"Start job {i}") == 1 - assert lines.count(f"Finish job {i}") == 1 + pt_testing.initialise_log() - for i in range(4,6): - assert lines.count(f"Start job {i}") == 2 # started twice - assert lines.count(f"Finish job {i}") == 1 # finished once + pt.launch(_test_resume_loop, args=(1, mode)) - for i in range(7,10): - assert lines.count(f"Start job {i}") == 1 - assert lines.count(f"Finish job {i}") == 1 + lines = pt_testing.get_log(remove_process_ids=True) + expected_when_distributed = [ + (f"Start job {i}", 1) for i in range(4, 10) + ] + [ + (f"Finish job {i}", 1) for i in range(4, 10) + ] + [ + (f"Start job {i}", 0) for i in range(4) + ] -def _test_empty_loop(): - for _ in pt.distributed([]): - assert False + expected_when_synchronized = [ + (f"Start job {i}", 2) for i in range(5, 10) + ] + [ + (f"Finish job {i}", 2) for i in range(5, 10) + ] + [ + (f"Start job {i}", 0) for i in range(5) + ] + expected = expected_when_distributed if mode == 'distributed' else expected_when_synchronized -def test_empty_loop(): - pt.use("multiprocessing-3") - pt.launch(_test_empty_loop) + for line, count in expected: + assert lines.count(line) == count -def _test_empty_then_non_empty_loop(): - for _ in pt.distributed([]): - pass - for _ in pt.distributed([1,2,3]): - pass -def test_empty_then_non_empty_loop(): +def _test_empty_loop(mode): + if mode=='distributed': + for _ in pt.distributed([]): + assert False + elif mode=='synchronized': + for _ in pt.synchronized([]): + assert False + else: + raise ValueError("Unknown test mode") + +@pytest.mark.parametrize("mode", ("distributed", "synchronized")) +def test_empty_loop(mode): pt.use("multiprocessing-3") - pt.launch(_test_empty_then_non_empty_loop) + pt.launch(_test_empty_loop, args=(mode,)) +def _test_empty_then_non_empty_loop(mode): + if mode=='distributed': + iterator = pt.distributed + elif mode=='synchronized': + iterator = pt.synchronized + else: + raise ValueError("Unknown test mode") + + for _ in iterator([]): + pt_testing.log(f"Should not appear") + + for i in iterator([1,2,3]): + pt_testing.log(f"Doing task {i}") + +@pytest.mark.parametrize("mode", ("distributed", "synchronized")) +def test_empty_then_non_empty_loop(mode): + pt.use("multiprocessing-3") + pt_testing.initialise_log() + pt.launch(_test_empty_then_non_empty_loop, args=(mode,)) + log = pt_testing.get_log(remove_process_ids=True) + if mode=='distributed': + assert len(log)==3 + assert "Doing task 3" in log + assert "Should not appear" not in log + elif mode=='synchronized': + assert len(log)==6 + assert log.count("Doing task 3")==2 + assert "Should not appear" not in log + + +def _test_overtaking_synchronized_loop(): + for i in pt.synchronized([0,1,2]): + pt_testing.log(f"Doing task {i}") + if pt.backend.rank()==1: + time.sleep(0.02) + +def test_overtaking_synchronized_loop(): + # test that iterations stay synchronized even if one process tries to overtake the other + pt.use("multiprocessing-3") + pt_testing.initialise_log() + pt.launch(_test_overtaking_synchronized_loop) + log = pt_testing.get_log() + assert len(log)==6 def _test_synchronize_db_creator(): rank = pt.backend.rank() @@ -204,6 +302,23 @@ def test_synchronize_db_creator(): assert creator_1==creator_2 +def _test_nested_loop(): + for i in pt.synchronized(list(range(3))): + for j in pt.distributed(list(range(3))): + pt_testing.log(f"Task {i},{j}") + +def test_nested_loop(): + pt.use("multiprocessing-3") + pt_testing.initialise_log() + pt.launch(_test_nested_loop) + + log = pt_testing.get_log(remove_process_ids=True) + + for i in range(3): + for j in range(3): + assert log.count(f"Task {i},{j}")==1 + + def _test_shared_locks(): if pt.backend.rank()==1: