Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to loop tracking/resuming #246

Merged
merged 8 commits into from
Dec 15, 2023
43 changes: 25 additions & 18 deletions tangos/parallel_tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,36 @@ def launch(function, args=None, backend_kwargs=None):

return result

def distributed(file_list, proc=None, of=None, allow_resume=False, resumption_id=None):
"""Distribute a list of tasks between all nodes"""
def distributed(items, allow_resume=False, resumption_id=None):
"""Return an iterator that consumes the items, distributed across all processors
(i.e. each item is consumed by only one processor, in a dynamic way).

if type(file_list) == set:
file_list = list(file_list)
Optionally, if allow_resume is True, then the iterator will resume from the last point it reached
provided argv and the stack trace are unchanged. If resumption_id is not None, then
the stack trace is ignored and only resumption_id needs to match."""

if type(items) == set:
items = list(items)

if _backend_name=='null':
return items
else:
from . import jobs
return jobs.distributed_iterate(items, allow_resume, resumption_id)

def synchronized(items, allow_resume=False, resumption_id=None):
"""Return an iterator that consumes all items on all processors.

Optionally, if allow_resume is True, then the iterator will resume from the last point it reached
provided argv and the stack trace are unchanged. If resumption_id is not None, then
the stack trace is ignored and only resumption_id needs to match."""
if _backend_name=='null':
if proc is None:
proc = 1
of = 1
i = (len(file_list) * (proc - 1)) // of
j = (len(file_list) * proc) // of - 1
assert proc <= of and proc > 0
if proc == of:
j += 1
return file_list[i:j + 1]
return items
else:
from . import jobs
return jobs.parallel_iterate(file_list, allow_resume, resumption_id)
return jobs.synchronized_iterate(items, allow_resume, resumption_id)




def _exec_function_or_server(function, connection_info, args):
Expand All @@ -129,11 +140,7 @@ def _server_thread():
# available job and move on. Jobs are labelled through the
# provided iterator

j = -1
num_jobs = None
current_job = None
alive = [True for i in range(backend.size())]
awaiting_barrier = [False for i in range(backend.size())]

while any(alive[1:]):
obj = message.Message.receive()
Expand Down
44 changes: 11 additions & 33 deletions tangos/parallel_tasks/accumulative_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,27 @@

from ..config import PROPERTY_WRITER_PARALLEL_STATISTICS_TIME_BETWEEN_UPDATES
from ..log import logger
from .message import Message
from .message import BarrierMessageWithResponse, Message

_new_accumulator_requested_for_ranks = []
_new_accumulator = None
_existing_accumulators = []
class CreateNewAccumulatorMessage(Message):
class CreateNewAccumulatorMessage(BarrierMessageWithResponse):

def process(self):
from . import backend
global _new_accumulator, _new_accumulator_requested_for_ranks, _existing_accumulators
assert issubclass(self.contents, StatisticsAccumulatorBase)
if _new_accumulator is None:
_new_accumulator = self.contents()
_new_accumulator_requested_for_ranks = [self.source]
else:
assert self.source not in _new_accumulator_requested_for_ranks
assert isinstance(_new_accumulator, self.contents)
_new_accumulator_requested_for_ranks.append(self.source)

from . import backend

if len(_new_accumulator_requested_for_ranks) == backend.size()-1:
self._confirm_new_accumulator()

def _confirm_new_accumulator(self):
global _new_accumulator, _new_accumulator_requested_for_ranks, _existing_accumulators
def process_global(self):
from . import backend, on_exit_parallelism

new_accumulator = self.contents()
accumulator_id = len(_existing_accumulators)
_existing_accumulators.append(_new_accumulator)
_existing_accumulators.append(new_accumulator)

locally_bound_accumulator = _new_accumulator
logger.debug("Created new accumulator of type %s with id %d" % (locally_bound_accumulator.__class__.__name__, accumulator_id))
locally_bound_accumulator = new_accumulator
logger.debug("Created new accumulator of type %s with id %d" % (
locally_bound_accumulator.__class__.__name__, accumulator_id))
on_exit_parallelism(lambda: locally_bound_accumulator.report_to_log_if_needed(logger, 0.05))

_new_accumulator = None
_new_accumulator_requested_for_ranks = []
self.respond(accumulator_id)

for destination in range(1, backend.size()):
AccumulatorIdMessage(accumulator_id).send(destination)
class AccumulatorIdMessage(Message):
pass
class AccumulateStatisticsMessage(Message):
def process(self):
global _existing_accumulators
Expand All @@ -63,9 +43,7 @@ def __init__(self, allow_parallel=False):
self._parallel = allow_parallel and parallelism_is_active() and backend.rank() != 0
if self._parallel:
logger.debug(f"Registering {self.__class__}")
CreateNewAccumulatorMessage(self.__class__).send(0)
logger.debug(f"Awaiting accumulator id for {self.__class__}")
self.id = AccumulatorIdMessage.receive(0).contents
self.id = CreateNewAccumulatorMessage(self.__class__).send_and_get_response(0)
logger.debug(f"Received accumulator id={ self.id}")

def report_to_server(self):
Expand Down
24 changes: 4 additions & 20 deletions tangos/parallel_tasks/barrier.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,13 @@
from . import message

awaiting_barrier = None


class MessageBarrierPass(message.Message):
pass

class MessageBarrier(message.Message):
def process(self):
from . import backend
global awaiting_barrier
if awaiting_barrier is None:
awaiting_barrier = [False for i in range(backend.size())]

awaiting_barrier[self.source] = True
if all(awaiting_barrier[1:]):
for i in range(1, backend.size()):
MessageBarrierPass().send(i)
awaiting_barrier = [False for i in range(backend.size())]

class SimpleBarrierMessage(message.BarrierMessageWithResponse):
def process_global(self):
self.respond(None)

def barrier():
from . import backend, parallelism_is_active
if not parallelism_is_active():
return
assert backend.rank()!=0, "The server process cannot take part in a barrier"
MessageBarrier().send(0)
MessageBarrierPass.receive(0)
SimpleBarrierMessage().send_and_get_response(0) # awaits response which only comes when all processes reach barrier
10 changes: 3 additions & 7 deletions tangos/parallel_tasks/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
from . import message, remote_import


class MessageRequestCreatorId(message.Message):
class MessageRequestCreatorId(message.MessageWithResponse):
def process(self):
creator_id = core.creator.get_creator_id()
MessageDeliverCreatorId(creator_id).send(self.source)

class MessageDeliverCreatorId(message.Message):
pass
self.respond(creator_id)


def synchronize_creator_object(session=None):
Expand All @@ -28,6 +25,5 @@ def synchronize_creator_object(session=None):
return

remote_import.ImportRequestMessage(__name__).send(0)
MessageRequestCreatorId().send(0)
id = MessageDeliverCreatorId.receive(0).contents
id = MessageRequestCreatorId().send_and_get_response(0)
core.creator.set_creator(session.query(core.creator.Creator).filter_by(id=id).first())
146 changes: 103 additions & 43 deletions tangos/parallel_tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,40 +129,68 @@ def __eq__(self, other):
return self._jobs_complete == other._jobs_complete


class SynchronizedIterationState(IterationState):
def _first_incomplete_job_after(self, job):
if job is None:
job = -1 # so that we scan over all possible jobs when the loop first enters
for i in range(job+1, len(self._jobs_complete)):
if not self._jobs_complete[i]:
return i
return None

current_num_jobs = None
current_iteration_state = None
current_stack_hash = None
current_is_resumable = False
def _is_still_running_somewhere(self, job):
for v in self._rank_running_job.values():
if v == job:
return True
return False

class MessageStartIteration(message.Message):
def process(self):
global current_iteration_state, current_stack_hash, current_is_resumable
req_jobs, req_hash, allow_resume = self.contents
def next_job(self, for_rank):
previous_job = self._rank_running_job[for_rank]
my_next_job = self._first_incomplete_job_after(previous_job)
if my_next_job is None:
del self._rank_running_job[for_rank]
else:
self._rank_running_job[for_rank] = my_next_job

if current_iteration_state is None:
current_num_jobs = req_jobs
current_stack_hash = req_hash
current_is_resumable = allow_resume
# convert sys.argv to a string which is quoted/escaped as needed
argv_string = " ".join([pipes.quote(arg) for arg in sys.argv])
# NB the next line assumes that no process can ever get two steps ahead of another process
# This is not actually enforced here, it's enforced by a barrier inside the synchronized_iterate
# function below. See also test_parallel_tasks.py::test_overtaking_synchronized_loop
if previous_job is not None and (not self._is_still_running_somewhere(previous_job)):
self.mark_complete(previous_job)

current_iteration_state = IterationState.from_context(req_jobs, argv=" ".join(sys.argv),
stack_hash=req_hash,
allow_resume=allow_resume)
return my_next_job

_next_iteration_state_id = 0
_iteration_states = {}


class MessageStartIteration(message.BarrierMessageWithResponse):
def process_global(self):
global _next_iteration_state_id, _iteration_states
req_jobs, req_hash, allow_resume, synchronized = self.contents

argv_string = " ".join([pipes.quote(arg) for arg in sys.argv])

IteratorClass = SynchronizedIterationState if synchronized else IterationState

my_id = _next_iteration_state_id
_iteration_states[my_id] = IteratorClass.from_context(req_jobs, argv=argv_string,
stack_hash=req_hash,
allow_resume=allow_resume)
_next_iteration_state_id += 1

self.respond(my_id)

def assert_consistent(self, other):
assert type(self) == type(other)
self_njobs = self.contents[0]
other_njobs = other.contents[0]
if self_njobs != other_njobs:
raise InconsistentJobList("Inconsistent number of jobs between different processes")
if self.contents != other.contents:
raise InconsistentContext("Inconsistency in requested loops between different processes")


else:
if len(current_iteration_state) != req_jobs:
raise InconsistentJobList(f"Number of jobs ({req_jobs}) expected by rank {self.source} "
f"is inconsistent with {len(current_iteration_state)}")
if current_stack_hash != req_hash:
raise InconsistentContext(f"Inconsistent stack from rank {self.source} when entering parallel loop")
if current_is_resumable != allow_resume:
raise InconsistentContext(f"Inconsistent allow_resume flag from rank {self.source} when entering parallel loop")

class MessageDeliverJob(message.Message):
pass

class MessageDistributeJobList(message.Message):
def process(self):
Expand All @@ -171,10 +199,13 @@ def process(self):
for rank in range(1, backend.size()):
if rank != self.source:
MessageDistributeJobList(self.contents).send(rank)
class MessageRequestJob(message.Message):

class MessageRequestJob(message.MessageWithResponse):
def process(self):
global current_iteration_state
iterator_id = self.contents
current_iteration_state = _iteration_states.get(iterator_id, None)
source = self.source

assert current_iteration_state is not None # should not be requesting jobs if we are not in a loop

job = current_iteration_state.next_job(source)
Expand All @@ -184,12 +215,12 @@ def process(self):
else:
log.logger.debug("Finished jobs; notify node %d", source)

MessageDeliverJob(job).send(source)

if current_iteration_state.finished():
current_iteration_state = None
del _iteration_states[iterator_id]

def parallel_iterate(task_list, allow_resume=False, resumption_id=None):
self.respond(job)

def distributed_iterate(task_list, allow_resume=False, resumption_id=None):
"""Sets up an iterator returning items of task_list.

If allow_resume is True, then the iterator will resume from the last point it reached
Expand All @@ -198,25 +229,54 @@ def parallel_iterate(task_list, allow_resume=False, resumption_id=None):
"""
from . import backend, barrier

if resumption_id is None:
stack_string = "\n".join(traceback.format_stack())
# we need a hash of stack_string that is stable across runs.
resumption_id = hashlib.sha256(stack_string.encode('utf-8')).hexdigest()
resumption_id = resumption_id or _autogenerate_resume_id()

assert backend is not None, "Parallelism is not initialised"
MessageStartIteration((len(task_list), resumption_id, allow_resume)).send(0)
iteration_id = MessageStartIteration((len(task_list), resumption_id, allow_resume, False)).send_and_get_response(0)
barrier()

while True:
MessageRequestJob().send(0)
job = MessageDeliverJob.receive(0).contents

job = MessageRequestJob(iteration_id).send_and_get_response(0)
if job is None:
barrier()
return
else:
yield task_list[job]


def _autogenerate_resume_id():
stack_string = "\n".join(traceback.format_stack())
# we need a hash of stack_string that is stable across runs.
resumption_id = hashlib.sha256(stack_string.encode('utf-8')).hexdigest()
return resumption_id


def synchronized_iterate(task_list, allow_resume=False, resumption_id=None):
"""Like distributed_iterate, but all processes see all tasks.

The main advantage is the ability to resume if allow_resume is True"""
from . import backend, barrier

resumption_id = resumption_id or _autogenerate_resume_id()

assert backend is not None, "Parallelism is not initialised"

iteration_id = MessageStartIteration((len(task_list), resumption_id, allow_resume, True)).send_and_get_response(0)
barrier()

while True:
job = MessageRequestJob(iteration_id).send_and_get_response(0)
barrier() # this is crucial to keep things in sync (see comment in SynchronizedIterationState.next_job)
if job is None:
return

yield task_list[job]






def generate_task_list_and_parallel_iterate(task_list_function, allow_resume=False):
"""Call task_list_function on only one rank, and then parallel iterate with all ranks"""
from . import backend
Expand All @@ -231,4 +291,4 @@ def generate_task_list_and_parallel_iterate(task_list_function, allow_resume=Fal
log.logger.debug("awaiting rank 1 generating task list")
task_list = MessageDistributeJobList.receive(0).contents
log.logger.debug("task_list = %r",task_list)
return parallel_iterate(task_list, allow_resume=allow_resume)
return distributed_iterate(task_list, allow_resume=allow_resume)
Loading
Loading