Skip to content

Commit

Permalink
Introduce parallel_tasks synchronized loops (so that their state can …
Browse files Browse the repository at this point in the history
…be saved)
  • Loading branch information
apontzen committed Dec 15, 2023
1 parent 4b852bc commit 55d4437
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 61 deletions.
15 changes: 14 additions & 1 deletion tangos/parallel_tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,20 @@ def distributed(items, allow_resume=False, resumption_id=None):
return items
else:
from . import jobs
return jobs.parallel_iterate(items, allow_resume, resumption_id)
return jobs.distributed_iterate(items, allow_resume, resumption_id)

def synchronized(items, allow_resume=False, resumption_id=None):
"""Return an iterator that consumes all items on all processors.
Optionally, if allow_resume is True, then the iterator will resume from the last point it reached
provided argv and the stack trace are unchanged. If resumption_id is not None, then
the stack trace is ignored and only resumption_id needs to match."""
if _backend_name=='null':
return items
else:
from . import jobs
return jobs.synchronized_iterate(items, allow_resume, resumption_id)




Expand Down
82 changes: 73 additions & 9 deletions tangos/parallel_tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,33 @@ def __eq__(self, other):
return self._jobs_complete == other._jobs_complete


class SynchronizedIterationState(IterationState):
def _first_incomplete_job_after(self, job):
if job is None:
job = -1 # so that we scan over all possible jobs when the loop first enters
for i in range(job+1, len(self._jobs_complete)):
if not self._jobs_complete[i]:
return i
return None

def _is_still_running_somewhere(self, job):
for v in self._rank_running_job.values():
if v == job:
return True
return False

def next_job(self, for_rank):
previous_job = self._rank_running_job[for_rank]
my_next_job = self._first_incomplete_job_after(previous_job)
if my_next_job is None:
del self._rank_running_job[for_rank]
else:
self._rank_running_job[for_rank] = my_next_job

if not self._is_still_running_somewhere(previous_job):
self.mark_complete(previous_job)

return my_next_job

current_num_jobs = None
current_iteration_state = None
Expand All @@ -138,7 +165,7 @@ def __eq__(self, other):
class MessageStartIteration(message.Message):
def process(self):
global current_iteration_state, current_stack_hash, current_is_resumable
req_jobs, req_hash, allow_resume = self.contents
req_jobs, req_hash, allow_resume, synchronized = self.contents

if current_iteration_state is None:
current_num_jobs = req_jobs
Expand All @@ -147,9 +174,13 @@ def process(self):
# convert sys.argv to a string which is quoted/escaped as needed
argv_string = " ".join([pipes.quote(arg) for arg in sys.argv])

current_iteration_state = IterationState.from_context(req_jobs, argv=" ".join(sys.argv),
IteratorClass = SynchronizedIterationState if synchronized else IterationState

current_iteration_state = IteratorClass.from_context(req_jobs, argv=" ".join(sys.argv),
stack_hash=req_hash,
allow_resume=allow_resume)
# note: if running a 'synchronized' loop, only rank 1 will be requesting jobs, and the other ranks
# will be synchronized just via a simple barrier. This allows for a much simpler implementation.


else:
Expand Down Expand Up @@ -189,7 +220,7 @@ def process(self):
if current_iteration_state.finished():
current_iteration_state = None

def parallel_iterate(task_list, allow_resume=False, resumption_id=None):
def distributed_iterate(task_list, allow_resume=False, resumption_id=None):
"""Sets up an iterator returning items of task_list.
If allow_resume is True, then the iterator will resume from the last point it reached
Expand All @@ -198,13 +229,10 @@ def parallel_iterate(task_list, allow_resume=False, resumption_id=None):
"""
from . import backend, barrier

if resumption_id is None:
stack_string = "\n".join(traceback.format_stack())
# we need a hash of stack_string that is stable across runs.
resumption_id = hashlib.sha256(stack_string.encode('utf-8')).hexdigest()
resumption_id = resumption_id or _autogenerate_resume_id()

assert backend is not None, "Parallelism is not initialised"
MessageStartIteration((len(task_list), resumption_id, allow_resume)).send(0)
MessageStartIteration((len(task_list), resumption_id, allow_resume, False)).send(0)
barrier()

while True:
Expand All @@ -217,6 +245,42 @@ def parallel_iterate(task_list, allow_resume=False, resumption_id=None):
else:
yield task_list[job]


def _autogenerate_resume_id():
stack_string = "\n".join(traceback.format_stack())
# we need a hash of stack_string that is stable across runs.
resumption_id = hashlib.sha256(stack_string.encode('utf-8')).hexdigest()
return resumption_id


def synchronized_iterate(task_list, allow_resume=False, resumption_id=None):
"""Like distributed_iterate, but all processes see all tasks.
The main advantage is the ability to resume if allow_resume is True"""
from . import backend, barrier

resumption_id = resumption_id or _autogenerate_resume_id()

assert backend is not None, "Parallelism is not initialised"

MessageStartIteration((len(task_list), resumption_id, allow_resume, True)).send(0)
barrier()

while True:
MessageRequestJob().send(0)
job = MessageDeliverJob.receive(0).contents

if job is None:
barrier()
return

yield task_list[job]






def generate_task_list_and_parallel_iterate(task_list_function, allow_resume=False):
"""Call task_list_function on only one rank, and then parallel iterate with all ranks"""
from . import backend
Expand All @@ -231,4 +295,4 @@ def generate_task_list_and_parallel_iterate(task_list_function, allow_resume=Fal
log.logger.debug("awaiting rank 1 generating task list")
task_list = MessageDistributeJobList.receive(0).contents
log.logger.debug("task_list = %r",task_list)
return parallel_iterate(task_list, allow_resume=allow_resume)
return distributed_iterate(task_list, allow_resume=allow_resume)
9 changes: 7 additions & 2 deletions tangos/parallel_tasks/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ def initialise_log():
with open(FILENAME, "w") as f:
f.write("")

def get_log():
def get_log(remove_process_ids=False):
if remove_process_ids:
processor = lambda s: s.strip()[4:]
else:
processor = lambda s: s.strip()

with open(FILENAME) as f:
return f.readlines()
return [processor(s) for s in f.readlines()]

def log(message):
ServerLogMessage(message).send(0)
Expand Down
165 changes: 116 additions & 49 deletions tests/test_parallel_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,82 +106,149 @@ def test_inconsistent_loops_rejected():
pt.launch(_test_loops_different_backtrace)


def _test_resume_loop(attempt):
for i in pt.distributed(list(range(10)), allow_resume=True, resumption_id=101):
with pt.ExclusiveLock("write_to_file", 0.01):
with open("test_resume_loop.log", "a") as f:
f.write(f"Start job {i}\n")
def _test_synchronized_loop():
for i in pt.synchronized(list(range(10))):
pt_testing.log(f"Doing task {i}")
pass

def test_synchronized_loop():
pt.use('multiprocessing-3')
pt_testing.initialise_log()
pt.launch(_test_synchronized_loop)
log = pt_testing.get_log()
assert len(log) == 20
for i in range(10):
for r in (1,2):
assert log.count(f"[{r}] Doing task {i}") == 1


def _test_resume_loop(attempt, mode='distributed'):
if mode=='distributed':
iterator = pt.distributed(list(range(10)), allow_resume=True, resumption_id=1)
# must provide a resumption_id because when we resume the stack trace is different
elif mode=='synchronized':
iterator = pt.synchronized(list(range(10)), allow_resume=True, resumption_id=2)
else:
raise ValueError("Unknown test mode")

for i in iterator:
pt_testing.log(f"Start job {i}")
pt.barrier() # make sure start is logged before kicking up a fuss

if i==5 and attempt==0:
raise ValueError("Suspend processing")

pt.barrier()
time.sleep(0.02)
logger.info(f"Finished job {i}")
with pt.ExclusiveLock("write_to_file", 0.01):
with open("test_resume_loop.log", "a") as f:
f.write(f"Finish job {i}\n")
pt_testing.log(f"Finish job {i}")

pt.barrier()
def test_resume_loop():


@pytest.mark.parametrize("mode", ("distributed", "synchronized"))
def test_resume_loop(mode):
pt.use("multiprocessing-3")

pt.jobs.IterationState.clear_resume_state()

try:
os.unlink("test_resume_loop.log")
except FileNotFoundError:
pass
pt_testing.initialise_log()

with pytest.raises(ValueError):
pt.launch(_test_resume_loop, args=(0,))
pt.launch(_test_resume_loop, args=(0, mode))

with open("test_resume_loop.log") as f:
lines = [s.strip() for s in f.readlines()]
lines = pt_testing.get_log(remove_process_ids=True)

for i in range(6):
assert lines.count(f"Start job {i}") == 1
if i<=3:
assert lines.count(f"Finish job {i}") == 1
else:
assert lines.count(f"Finish job {i}") == 0
expected_when_distributed = [
(f"Start job {i}", 1) for i in range(6)
] + [
(f"Finish job {i}", 1) for i in range(4)
] + [
(f"Finish job {i}", 0) for i in range(4,6)
]

pt.launch(_test_resume_loop, args=(1,))
expected_when_synchronized = [
(f"Start job {i}", 2) for i in range(6)
] + [
(f"Finish job {i}", 2) for i in range(5)
] + [
(f"Finish job 5", 0)
]

with open("test_resume_loop.log") as f:
lines = [s.strip() for s in f.readlines()]
expected = expected_when_distributed if mode=='distributed' else expected_when_synchronized

for i in range(4):
assert lines.count(f"Start job {i}") == 1
assert lines.count(f"Finish job {i}") == 1
for line, count in expected:
assert lines.count(line) == count

for i in range(4,6):
assert lines.count(f"Start job {i}") == 2 # started twice
assert lines.count(f"Finish job {i}") == 1 # finished once
pt_testing.initialise_log()

pt.launch(_test_resume_loop, args=(1, mode))

lines = pt_testing.get_log(remove_process_ids=True)

expected_when_distributed = [
(f"Start job {i}", 1) for i in range(4, 10)
] + [
(f"Finish job {i}", 1) for i in range(4, 10)
] + [
(f"Start job {i}", 0) for i in range(4)
]

expected_when_synchronized = [
(f"Start job {i}", 2) for i in range(5, 10)
] + [
(f"Finish job {i}", 2) for i in range(5, 10)
] + [
(f"Start job {i}", 0) for i in range(5)
]

expected = expected_when_distributed if mode == 'distributed' else expected_when_synchronized

for i in range(7,10):
assert lines.count(f"Start job {i}") == 1
assert lines.count(f"Finish job {i}") == 1
for line, count in expected:
assert lines.count(line) == count


def _test_empty_loop():
for _ in pt.distributed([]):
assert False

def _test_empty_loop(mode):
if mode=='distributed':
for _ in pt.distributed([]):
assert False
elif mode=='synchronized':
for _ in pt.synchronized([]):
assert False
else:
raise ValueError("Unknown test mode")

def test_empty_loop():
@pytest.mark.parametrize("mode", ("distributed", "synchronized"))
def test_empty_loop(mode):
pt.use("multiprocessing-3")
pt.launch(_test_empty_loop)
pt.launch(_test_empty_loop, args=(mode,))

def _test_empty_then_non_empty_loop():
for _ in pt.distributed([]):
pass
def _test_empty_then_non_empty_loop(mode):
if mode=='distributed':
iterator = pt.distributed
elif mode=='synchronized':
iterator = pt.synchronized
else:
raise ValueError("Unknown test mode")

for _ in pt.distributed([1,2,3]):
pass
for _ in iterator([]):
pt_testing.log(f"Should not appear")

for i in iterator([1,2,3]):
pt_testing.log(f"Doing task {i}")

def test_empty_then_non_empty_loop():
@pytest.mark.parametrize("mode", ("distributed", "synchronized"))
def test_empty_then_non_empty_loop(mode):
pt.use("multiprocessing-3")
pt.launch(_test_empty_then_non_empty_loop)
pt_testing.initialise_log()
pt.launch(_test_empty_then_non_empty_loop, args=(mode,))
log = pt_testing.get_log(remove_process_ids=True)
if mode=='distributed':
assert len(log)==3
assert "Doing task 3" in log
assert "Should not appear" not in log
elif mode=='synchronized':
assert len(log)==6
assert log.count("Doing task 3")==2
assert "Should not appear" not in log


def _test_synchronize_db_creator():
Expand Down

0 comments on commit 55d4437

Please sign in to comment.