-
Notifications
You must be signed in to change notification settings - Fork 33
/
test_attach.py
57 lines (39 loc) · 1.87 KB
/
test_attach.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging
import os
from datetime import timedelta
import pytest
from sagemaker import Session
# noinspection PyProtectedMember
from sagemaker.estimator import _TrainingJob
from sagemaker.pytorch import PyTorch
from sagemaker_ssh_helper.detached_sagemaker import DetachedEstimator
from sagemaker_ssh_helper.wrapper import SSHEstimatorWrapper
# noinspection DuplicatedCode
def test_attach_estimator():
estimator = PyTorch(entry_point=os.path.basename('source_dir/training/train.py'),
source_dir='source_dir/training/',
dependencies=[SSHEstimatorWrapper.dependency_dir()],
base_job_name='ssh-training',
framework_version='1.9.1',
py_version='py38',
instance_count=1,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO)
_ = SSHEstimatorWrapper.create(estimator, connection_wait_time_seconds=600)
estimator.fit(wait=False)
job: _TrainingJob = estimator.latest_training_job
ssh_wrapper = SSHEstimatorWrapper.attach(job.name)
ssh_wrapper.print_ssh_info()
ssh_wrapper.start_ssm_connection_and_continue(11022)
ssh_wrapper.wait_training_job()
assert estimator.model_data.find("model.tar.gz") != -1
def test_cannot_fit_detached_estimator():
estimator = DetachedEstimator.attach('training-job-name', Session())
with pytest.raises(ValueError):
_ = SSHEstimatorWrapper.create(estimator)
def test_can_fetch_job_name_from_detached_estimator():
ssh_wrapper = SSHEstimatorWrapper.attach('training-job-name', Session())
job_name = ssh_wrapper.training_job_name()
assert job_name == 'training-job-name'