Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: DRM to DRMAA configuration adapters #1

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
226 changes: 226 additions & 0 deletions config_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""
Configuration adapters for mapping native specifications from DRM to DRMAA API
"""

from __future__ import annotations
from typing import (List, ClassVar, Union, Optional, TYPE_CHECKING)

from dataclasses import dataclass, asdict, fields, InitVar
from abc import ABC, abstractmethod
import re

if TYPE_CHECKING:
from drmaa import JobTemplate

# DRMAA specific fields, anything else should be put into native spec
DRMAA_FIELDS = [
"email", "deadlineTime", "errorPath", "hardRunDurationLimit",
"hardWallclockTimeLimit", "inputPath", "outputPath", "jobCategory",
"jobName", "outputPath", "workingDirectory", "transferFiles",
"remoteCommand", "args", "jobName", "jobCategory", "blockEmail"
]

TIMESTR_VALIDATE = re.compile("^(\\d+:)?[0-9][0-9]:[0-9][0-9]$")


@dataclass
class DRMAACompatible(ABC):
'''
Abstract dataclass for mapping DRM specific configuration to a
DRMAA compatible specification

Properties:
_mapped_fields: List of DRM specific keys to re-map onto
the DRMAA specification if used. Preferably users will
use the DRMAA variant of these specifications rather than
the corresponding native specification
'''

_mapped_fields: ClassVar[List[str]]

def __str__(self):
'''
Display formatted configuration for executor
'''
attrs = asdict(self)
drmaa_fields = "\n".join([
f"{field}:\t{attrs.get(field)}" for field in DRMAA_FIELDS
if attrs.get(field) is not None
])

drm_fields = "\n".join([
f"{field}:\t{attrs.get(field)}" for field in self._native_fields()
if attrs.get(field) is not None
])

return ("DRMAA Config:\n" + drmaa_fields + "\nNative Specification\n" +
drm_fields)

def get_drmaa_config(self, jt: JobTemplate) -> JobTemplate:
'''
Apply settings onto DRMAA JobTemplate
'''

for field in DRMAA_FIELDS:
value = getattr(self, field, None)
if value is not None:
setattr(jt, field, value)

jt.nativeSpecification = self.drm2drmaa()
return jt

@abstractmethod
def drm2drmaa(self) -> str:
'''
Build native specification from DRM-specific fields
'''

def _native_fields(self):
return [
f for f in asdict(self).keys()
if (f not in self._mapped_fields) and (f not in DRMAA_FIELDS)
]

def set_fields(self, **drmaa_kwargs):
for field, value in drmaa_kwargs.items():
if field not in DRMAA_FIELDS:
raise AttributeError(
"Malformed adapter class! Cannot map field"
f" {field} to a DRMAA-compliant field")

setattr(self, field, value)


@dataclass
class DRMAAConfig(DRMAACompatible):
def drm2drmaa(self):
return


@dataclass
class SlurmConfig(DRMAACompatible):
'''
Transform SLURM resource specification into DRMAA-compliant inputs

References:
See https://github.com/natefoo/slurm-drmaa for native specification
details
'''

_mapped_fields: ClassVar[List[str]] = {
"error", "output", "job_name", "time"
}

job_name: InitVar[str]
time: InitVar[str]
error: InitVar[str] = None
output: InitVar[str] = None

account: Optional[str] = None
acctg_freq: Optional[str] = None
comment: Optional[str] = None
constraint: Optional[List] = None
cpus_per_task: Optional[int] = None
contiguous: Optional[bool] = None
dependency: Optional[List] = None
exclusive: Optional[bool] = None
gres: Optional[Union[List[str], str]] = None
no_kill: Optional[bool] = None
licenses: Optional[List[str]] = None
clusters: Optional[Union[List[str], str]] = None
mail_type: Optional[str] = None
mem: Optional[int] = None
mincpus: Optional[int] = None
nodes: Optional[int] = None
ntasks: Optional[int] = None
no_requeue: Optional[bool] = None
ntasks_per_node: Optional[int] = None
partition: Optional[int] = None
qos: Optional[str] = None
requeue: Optional[bool] = None
reservation: Optional[str] = None
share: Optional[bool] = None
tmp: Optional[str] = None
nodelist: Optional[Union[List[str], str]] = None
exclude: Optional[Union[List[str], str]] = None

def __post_init__(self, job_name, time, error, output):
'''
Transform Union[List[str]] --> comma-delimited str
'''

_validate_timestr(time, "time")
super().set_fields(jobName=job_name,
hardWallclockTimeLimit=time,
errorPath=error,
outputPath=output)

self.job_name = job_name
self.time = time
self.error = error
self.output = output

for field in fields(self):
value = getattr(self, field.name)
if field.type == Union[List[str], str] and isinstance(value, list):
setattr(self, field.name, ",".join(value))

def drm2drmaa(self) -> str:
return self._transform_attrs()

def _transform_attrs(self) -> str:
'''
Remap named attributes to "-" form, excludes renaming
DRMAA-compliant fields (set in __post_init__()) then join
attributes into a nativeSpecification string
'''

out = []
for field in self._native_fields():

value = getattr(self, field)
if value is None:
continue

field_fmtd = field.replace("_", "-")
if isinstance(value, bool):
out.append(f"--{field_fmtd}")
else:
out.append(f"--{field_fmtd}={value}")
return " ".join(out)


def _timestr_to_sec(timestr: str) -> int:
'''
Transform a time-string (D-HH:MM:SS) --> seconds
'''

days = 0
if "-" in timestr:
day_str, timestr = timestr.split('-')
days = int(day_str)

seconds = (24 * days) * (60**2)
for exp, unit in enumerate(reversed(timestr.split(":"))):
seconds += int(unit) * (60**exp)

return seconds


def _validate_timestr(timestr: str, field_name: str) -> str:
'''
Validate timestring to make sure it meets
expected format.
'''

if not isinstance(timestr, str):
raise TypeError(f"Expected {field_name} to be of type string "
f"but received {type(timestr)}!")

result = TIMESTR_VALIDATE.match(timestr)
if not result:
raise ValueError(f"Expected {field_name} to be of format "
"X...XX:XX:XX or XX:XX! "
f"but received {timestr}")

return timestr
56 changes: 56 additions & 0 deletions drmaa_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
'''
Patches on DRMAA-python module
'''

from drmaa import JobTemplate, Session
from drmaa.helpers import Attribute, IntConverter


#TODO: Make sure this is actually correct?
# Works for SLURM
CORRECT_TO_STRING = [
"hardWallclockTimeLimit"
]


class PatchedIntConverter():
'''
Helper class to correctly encode Integer values
as little-endian bytes for Python 3

Info:
The standard IntConverter class attempts to convert
integer values to bytes using `bytes(value)` which
results in a zero'd byte-array of length `value`.
'''
@staticmethod
def to_drmaa(value: int) -> bytes:
return value.to_bytes(8, byteorder="little")

@staticmethod
def from_drmaa(value: bytes) -> int:
return int.from_bytes(value, byteorder="little")


class PatchedJobTemplate(JobTemplate):
def __init__(self):
'''
Dynamically patch attributes using IntConverter
'''
super(PatchedJobTemplate, self).__init__()
for attr, value in vars(JobTemplate).items():
if isinstance(value, Attribute):
if attr in CORRECT_TO_STRING:
setattr(value, "converter", None)
elif value.converter is IntConverter:
setattr(value, "converter", PatchedIntConverter)


class PatchedSession(Session):
'''
Override createJobTemplate method to return
Patched version
'''
@staticmethod
def createJobTemplate(self) -> PatchedJobTemplate:
return PatchedJobTemplate()
15 changes: 15 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
apache-airflow==2.1.4
apache-airflow-providers-ftp==2.0.1
apache-airflow-providers-http==2.0.1
apache-airflow-providers-imap==2.0.1
apache-airflow-providers-sqlite==2.0.1
apache-airflow-providers-ssh==2.2.0
coverage==5.3.1
flake8==3.8.4
pytest==6.2.5
pytest-cov==2.11.0
pytest-forked==1.3.0
pytest-mock==3.6.1
pytest-xdist==2.2.0
Sphinx==3.4.3
toml==0.10.2
93 changes: 93 additions & 0 deletions tests/test_config_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Tests for config_adapters.py to ensure that mapping
from DRM-specific configuration to DRMAA spec works
correctly
"""

import pytest
from drmaa_executor_plugin.drmaa_patches import (PatchedJobTemplate as
JobTemplate)
from drmaa_executor_plugin.config_adapters import SlurmConfig


@pytest.fixture()
def job_template():
jt = JobTemplate()
yield jt
jt.delete()


def test_slurm_config_transforms_to_drmaa(job_template):
'''
Check whether SLURM adapter class correctly
transforms SLURM specs to DRMAA attributes
'''

error = "TEST_VALUE"
output = "TEST_VALUE"
time = "10:00:00"
job_name = "FAKE_JOB"

expected_drmaa_attrs = {
"errorPath": error,
"outputPath": output,
"hardWallclockTimeLimit": "10:00:00",
"jobName": job_name
}

slurm_config = SlurmConfig(error=error,
output=output,
time=time,
job_name=job_name)

jt = slurm_config.get_drmaa_config(job_template)

# Test attributes matches what is expected
for k, v in expected_drmaa_attrs.items():
assert getattr(jt, k) == v


def test_slurm_config_native_spec_transforms_correctly(job_template):
'''
Test whether scheduler-specific configuration is transformed
into nativeSpecification correctly
'''

job_name = "TEST"
time = "01:00"
account = "TEST"
cpus_per_task = 5
slurm_config = SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)

jt = slurm_config.get_drmaa_config(job_template)
for spec in ['account=TEST', 'cpus-per-task=5']:
assert spec in jt.nativeSpecification


def test_invalid_timestr_fails():
job_name = "TEST"
time = "FAILURE"
account = "TEST"
cpus_per_task = 10

with pytest.raises(ValueError):
SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)


def test_timestr_not_string_fails():
job_name = "TEST"
time = 10
account = "TEST"
cpus_per_task = 10

with pytest.raises(TypeError):
SlurmConfig(job_name=job_name,
time=time,
account=account,
cpus_per_task=cpus_per_task)