Skip to content

Commit

Permalink
Style changes introduced by black
Browse files Browse the repository at this point in the history
  • Loading branch information
tomfaulhaber committed Nov 10, 2022
1 parent 8ceb970 commit 7a33b93
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
source_suffix = ".rst" # The suffix of source filenames.
master_doc = "index" # The master toctree document.

copyright = u"%s, Amazon" % datetime.now().year
copyright = "%s, Amazon" % datetime.now().year

pygments_style = "default"

Expand Down
2 changes: 1 addition & 1 deletion sagemaker_run_notebook/container_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_project(repo_name, role, zipfile, base_image=default_base):
sts = session.client("sts")
identity = sts.get_caller_identity()
account = identity["Account"]
partition = identity["Arn"].split(':')[1]
partition = identity["Arn"].split(":")[1]
args = {
"name": f"create-sagemaker-container-{repo_name}",
"description": f"Build the container {repo_name} for running notebooks in SageMaker",
Expand Down
9 changes: 5 additions & 4 deletions sagemaker_run_notebook/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def execute_notebook(
):
session = ensure_session()
region = session.region_name
caller_id=session.client("sts").get_caller_identity()
partition = caller_id["Arn"].split(':')[1]
caller_id = session.client("sts").get_caller_identity()
partition = caller_id["Arn"].split(":")[1]
account = caller_id["Account"]
domain = domain_for_region(region)
if not image:
Expand Down Expand Up @@ -151,6 +151,7 @@ def ensure_session(session=None):
session = boto3.session.Session()
return session


def domain_for_region(region):
"""Get the DNS suffix for the given region.
Args:
Expand All @@ -163,8 +164,8 @@ def domain_for_region(region):
if region.startswith("us-isob-"):
return "sc2s.sgov.gov"
if region.startswith("cn-"):
return "amazonaws.com.cn"
return "amazonaws.com"
return "amazonaws.com.cn"
return "amazonaws.com"


def lambda_handler(event, context):
Expand Down
27 changes: 12 additions & 15 deletions sagemaker_run_notebook/run_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def abbreviate_image(image):


abbrev_role_pat = re.compile(
r"arn:aws([^:]*):iam::(?P<account>\d+):role/(?P<name>[^/]+)")
r"arn:aws([^:]*):iam::(?P<account>\d+):role/(?P<name>[^/]+)"
)


def abbreviate_role(role):
Expand Down Expand Up @@ -127,15 +128,14 @@ def execute_notebook(
elif "/" not in role:
identity = session.client("sts").get_caller_identity()
account = identity["Account"]
partition = identity["Arn"].split(':')[1]
partition = identity["Arn"].split(":")[1]
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)

if "/" not in image:
account = session.client("sts").get_caller_identity()["Account"]
region = session.region_name
domain = domain_for_region(region)
image = "{}.dkr.ecr.{}.{}/{}:latest".format(
account, region, domain, image)
image = "{}.dkr.ecr.{}.{}/{}:latest".format(account, region, domain, image)

if notebook == None:
notebook = input_path
Expand All @@ -145,8 +145,7 @@ def execute_notebook(
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

job_name = (
("papermill-" + re.sub(r"[^-a-zA-Z0-9]",
"-", nb_name))[: 62 - len(timestamp)]
("papermill-" + re.sub(r"[^-a-zA-Z0-9]", "-", nb_name))[: 62 - len(timestamp)]
+ "-"
+ timestamp
)
Expand Down Expand Up @@ -636,7 +635,7 @@ def create_lambda(role=None, session=None):
if "/" not in role:
identity = session.client("sts").get_caller_identity()
account = identity["Account"]
partition = identity["Arn"].split(':')[1]
partition = identity["Arn"].split(":")[1]
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)

code_bytes = zip_bytes(code_file)
Expand Down Expand Up @@ -788,9 +787,8 @@ def proc(extras):
if "/" not in image:
account = session.client("sts").get_caller_identity()["Account"]
region = session.region_name
domain = domain_for_region(region)
image = "{}.dkr.ecr.{}.{}/{}:latest".format(
account, region, domain, image)
domain = domain_for_region(region)
image = "{}.dkr.ecr.{}.{}/{}:latest".format(account, region, domain, image)

if not role:
try:
Expand All @@ -801,7 +799,7 @@ def proc(extras):
if "/" not in role:
identity = session.client("sts").get_caller_identity()
account = identity["Account"]
partition = identity["Arn"].split(':')[1]
partition = identity["Arn"].split(":")[1]
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)

if input_path is None:
Expand Down Expand Up @@ -918,8 +916,7 @@ def proc(extras):
account = session.client("sts").get_caller_identity()["Account"]
region = session.region_name
domain = domain_for_region(region)
image = "{}.dkr.ecr.{}.{}/{}:latest".format(
account, region, domain, image)
image = "{}.dkr.ecr.{}.{}/{}:latest".format(account, region, domain, image)

if not role:
try:
Expand All @@ -930,7 +927,7 @@ def proc(extras):
if "/" not in role:
identity = session.client("sts").get_caller_identity()
account = identity["Account"]
partition = identity["Arn"].split(':')[1]
partition = identity["Arn"].split(":")[1]
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)

if input_path is None:
Expand Down Expand Up @@ -963,7 +960,7 @@ def proc(extras):
)
identity = session.client("sts").get_caller_identity()
account = identity["Account"]
partition = identity["Arn"].split(':')[1]
partition = identity["Arn"].split(":")[1]
region = session.region_name
target_arn = "arn:{}:lambda:{}:{}:function:{}".format(
partition, region, account, lambda_function_name
Expand Down
14 changes: 11 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@
# Get our version
version = get_version(str(Path(HERE) / name / "server_extension" / "_version.py"))

lab_path = Path(HERE) / "labextension"
lab_path = Path(HERE) / "labextension"

data_files_spec = [
("share/jupyter/labextensions/%s/static" % name, str(lab_path / name / "labextension" / "static"), "**"),
("share/jupyter/labextensions/%s" % name, str(lab_path / name / "labextension"), "package.json"),
(
"share/jupyter/labextensions/%s/static" % name,
str(lab_path / name / "labextension" / "static"),
"**",
),
(
"share/jupyter/labextensions/%s" % name,
str(lab_path / name / "labextension"),
"package.json",
),
(
"etc/jupyter/jupyter_notebook_config.d",
"sagemaker_run_notebook/server_extension/jupyter-config/jupyter_notebook_config.d",
Expand Down

0 comments on commit 7a33b93

Please sign in to comment.