Skip to content

Commit

Permalink
Merge pull request #245 from pynbody/pt-loop-resume
Browse files Browse the repository at this point in the history
Enable resuming certain tasks after e.g. tangos add terminates unexpectedly
  • Loading branch information
apontzen authored Dec 14, 2023
2 parents 4e79538 + 82e2b58 commit 6b4cef2
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 34 deletions.
4 changes: 2 additions & 2 deletions tangos/parallel_tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def launch(function, args=None, backend_kwargs=None):

return result

def distributed(file_list, proc=None, of=None):
def distributed(file_list, proc=None, of=None, allow_resume=False, resumption_id=None):
"""Distribute a list of tasks between all nodes"""

if type(file_list) == set:
Expand All @@ -102,7 +102,7 @@ def distributed(file_list, proc=None, of=None):
return file_list[i:j + 1]
else:
from . import jobs
return jobs.parallel_iterate(file_list)
return jobs.parallel_iterate(file_list, allow_resume, resumption_id)


def _exec_function_or_server(function, connection_info, args):
Expand Down
206 changes: 179 additions & 27 deletions tangos/parallel_tasks/jobs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,165 @@
import base64
import getpass
import hashlib
import os
import pickle
import pipes
import sys
import traceback
import zlib

from .. import log
from . import message

j = -1
num_jobs = None
current_job = None

class InconsistentJobList(RuntimeError):
pass

class InconsistentContext(RuntimeError):
pass

class IterationState:
_stored_iteration_states = {}
def __init__(self, context, jobs_complete, /, backend_size=None):
from . import backend
self._context = context
self._jobs_complete = jobs_complete
self._rank_running_job = {i: None for i in range(1,backend_size or backend.size())}

def __len__(self):
return len(self._jobs_complete)

def to_string(self):
return base64.a85encode(
zlib.compress(
pickle.dumps(self._jobs_complete)
)
).decode('ascii')

@classmethod
def from_string(cls, string, context=None, backend_size=None):
jobs_complete = pickle.loads(
zlib.decompress(
base64.a85decode(string.encode('ascii'))
)
)
return cls(context, jobs_complete, backend_size=backend_size)

@classmethod
def from_context(cls, num_jobs, argv=None, stack_hash=None, allow_resume=None, backend_size=None):
context = (argv, stack_hash, num_jobs)
if allow_resume:
cmap = cls._get_stored_completion_map_from_context(context)
if cmap is not None:
r = cls.from_string(cmap, context)
log.logger.info(
f"Resuming from previous run. {r.count_complete()} of {len(r)} jobs are already complete.")
log.logger.info(
f"To prevent tangos from doing this, you can delete the file {cls._resume_state_filename()}")
return r

return cls(context, [False]*num_jobs, backend_size=backend_size)


@classmethod
def _resume_state_filename(cls):
return f"tangos_resume_state_{getpass.getuser()}.pickle"

@classmethod
def _get_stored_completion_maps(cls):
maps = {}
try:
with open(cls._resume_state_filename(), "rb") as f:
maps = pickle.load(f)
except FileNotFoundError:
pass
return maps

@classmethod
def _store_completion_maps(cls, maps):
with open(cls._resume_state_filename(), "wb") as f:
pickle.dump(maps, f)
@classmethod
def _get_stored_completion_map_from_context(cls, context):
maps = cls._get_stored_completion_maps()
return maps.get(context, None)

@classmethod
def clear_resume_state(cls):
try:
os.unlink(cls._resume_state_filename())
except FileNotFoundError:
pass

def _store_completion_map(self):
# In principle, there could be a race condition if more than one tangos process
# is ongoing on the same filesystem. However, this is very unlikely to happen
# since updating the completion map is very quick and happens quite rarely,
# so we don't worry about it.
maps = self._get_stored_completion_maps()
maps[self._context] = self.to_string()
self._store_completion_maps(maps)

def mark_complete(self, job):
if job is None:
return
self._jobs_complete[job] = True
self._store_completion_map()

def next_job(self, for_rank):
if for_rank in self._rank_running_job:
self.mark_complete(self._rank_running_job[for_rank])
del self._rank_running_job[for_rank]

for i in range(len(self._jobs_complete)):
if not self._jobs_complete[i] and i not in self._rank_running_job.values():
self._rank_running_job[for_rank] = i
return i
return None

def finished(self):
# not enough for all jobs to be complete, must also have notified all ranks (this matters
# if some ranks never did any work at all)
return all(self._jobs_complete) and len(self._rank_running_job)==0

def count_complete(self):
return sum(self._jobs_complete)

def __eq__(self, other):
return self._jobs_complete == other._jobs_complete



current_num_jobs = None
current_iteration_state = None
current_stack_hash = None
current_is_resumable = False

class MessageStartIteration(message.Message):
def process(self):
global num_jobs, current_job
if num_jobs is None:
num_jobs = self.contents
current_job = 0
else:
if num_jobs != self.contents:
raise RuntimeError("Number of jobs (%d) expected by rank %d is inconsistent with %d" % (
self.contents, self.source, num_jobs))
global current_iteration_state, current_stack_hash, current_is_resumable
req_jobs, req_hash, allow_resume = self.contents

if current_iteration_state is None:
current_num_jobs = req_jobs
current_stack_hash = req_hash
current_is_resumable = allow_resume
# convert sys.argv to a string which is quoted/escaped as needed
argv_string = " ".join([pipes.quote(arg) for arg in sys.argv])

current_iteration_state = IterationState.from_context(req_jobs, argv=" ".join(sys.argv),
stack_hash=req_hash,
allow_resume=allow_resume)


else:
if len(current_iteration_state) != req_jobs:
raise InconsistentJobList(f"Number of jobs ({req_jobs}) expected by rank {self.source} "
f"is inconsistent with {len(current_iteration_state)}")
if current_stack_hash != req_hash:
raise InconsistentContext(f"Inconsistent stack from rank {self.source} when entering parallel loop")
if current_is_resumable != allow_resume:
raise InconsistentContext(f"Inconsistent allow_resume flag from rank {self.source} when entering parallel loop")

class MessageDeliverJob(message.Message):
pass
Expand All @@ -29,30 +173,38 @@ def process(self):
MessageDistributeJobList(self.contents).send(rank)
class MessageRequestJob(message.Message):
def process(self):
global j, num_jobs, current_job
global current_iteration_state
source = self.source
if current_job is not None and num_jobs>0:
log.logger.debug("Send job %d of %d to node %d", current_job, num_jobs, source)
assert current_iteration_state is not None # should not be requesting jobs if we are not in a loop

job = current_iteration_state.next_job(source)

if job is not None:
log.logger.debug("Send job %d of %d to node %d", job, len(current_iteration_state), source)
else:
num_jobs = None
current_job = None # in case num_jobs=0, still want to send 'end of loop' signal to client
log.logger.debug("Finished jobs; notify node %d", source)

MessageDeliverJob(current_job).send(source)
MessageDeliverJob(job).send(source)

if current_job is not None:
current_job += 1
if current_job == num_jobs:
num_jobs = None
current_job = None
if current_iteration_state.finished():
current_iteration_state = None

def parallel_iterate(task_list, allow_resume=False, resumption_id=None):
"""Sets up an iterator returning items of task_list.
def parallel_iterate(task_list):
"""Sets up an iterator returning items of task_list. """
If allow_resume is True, then the iterator will resume from the last point it reached
provided argv and the stack trace are unchanged. If resumption_id is not None, then
the stack trace is ignored and only resumption_id needs to match.
"""
from . import backend, barrier

if resumption_id is None:
stack_string = "\n".join(traceback.format_stack())
# we need a hash of stack_string that is stable across runs.
resumption_id = hashlib.sha256(stack_string.encode('utf-8')).hexdigest()

assert backend is not None, "Parallelism is not initialised"
MessageStartIteration(len(task_list)).send(0)
MessageStartIteration((len(task_list), resumption_id, allow_resume)).send(0)
barrier()

while True:
Expand All @@ -65,7 +217,7 @@ def parallel_iterate(task_list):
else:
yield task_list[job]

def generate_task_list_and_parallel_iterate(task_list_function):
def generate_task_list_and_parallel_iterate(task_list_function, allow_resume=False):
"""Call task_list_function on only one rank, and then parallel iterate with all ranks"""
from . import backend

Expand All @@ -79,4 +231,4 @@ def generate_task_list_and_parallel_iterate(task_list_function):
log.logger.debug("awaiting rank 1 generating task list")
task_list = MessageDistributeJobList.receive(0).contents
log.logger.debug("task_list = %r",task_list)
return parallel_iterate(task_list)
return parallel_iterate(task_list, allow_resume=allow_resume)
2 changes: 1 addition & 1 deletion tangos/tools/crosslink.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run_calculation_loop(self):
logger.error("No timesteps found to link")
return

pair_list = parallel_tasks.distributed(pair_list)
pair_list = parallel_tasks.distributed(pair_list, allow_resume=True)

object_type = core.halo.SimulationObjectBase.object_typecode_from_tag(self.args.type_)

Expand Down
4 changes: 3 additions & 1 deletion tangos/tools/property_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def add_parser_arguments(self, parser):
help='Process timesteps in random order')
parser.add_argument('--with-prerequisites', action='store_true',
help='Automatically calculate any missing prerequisites for the properties')
parser.add_argument('--no-resume', action='store_true',
help="Prevent resumption from a previous calculation, even if tangos thinks it's possible")
parser.add_argument('--load-mode', action='store', choices=['all', 'partial', 'server', 'server-partial', 'server-shared-mem'],
required=False, default=None,
help="Select a load-mode: " \
Expand Down Expand Up @@ -138,7 +140,7 @@ def _get_parallel_timestep_iterator(self):
ma_files = self.files
else:
# In all other cases, different timesteps are distributed to different nodes
ma_files = parallel_tasks.distributed(self.files)
ma_files = parallel_tasks.distributed(self.files, allow_resume=not self.options.no_resume)
return ma_files

def _get_parallel_halo_iterator(self, items):
Expand Down
21 changes: 18 additions & 3 deletions tests/test_db_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,12 @@ def calculate(self, data, entry):
return 1,


def run_writer_with_args(*args, parallel=False):
def run_writer_with_args(*args, parallel=False, allow_resume=False):
writer = property_writer.PropertyWriter()
writer.parse_command_line(args)
if allow_resume:
writer.parse_command_line(args)
else:
writer.parse_command_line((*args, "--no-resume"))

def _runner():
stored_log = log.LogCapturer()
Expand All @@ -122,13 +125,25 @@ def test_basic_writing(fresh_database):
run_writer_with_args("dummy_property")
_assert_properties_as_expected()


def test_parallel_writing(fresh_database):
parallel_tasks.use('multiprocessing-2')
run_writer_with_args("dummy_property", parallel=True)

_assert_properties_as_expected()

def test_resuming(fresh_database):
parallel_tasks.use("multiprocessing-2")
log = []
for allow_resume in [False, False, True]:
log.append(run_writer_with_args("dummy_property", parallel=True, allow_resume=allow_resume))

for i in [0,1]:
assert "Resuming from previous run" not in log[i]
assert len(log[i].split("\n"))>2 # should have done lots of stuff,
# even if second time it ultimately wrote nothing into the db

assert len(log[2].split("\n"))==2 # it should resume at the end, and so do nothing other than log a message


def _assert_properties_as_expected():
assert db.get_halo("dummy_sim_1/step.1/1")['dummy_property'] == 1.0
Expand Down
Loading

0 comments on commit 6b4cef2

Please sign in to comment.