-
Notifications
You must be signed in to change notification settings - Fork 33
/
test_inference_registry.py
62 lines (48 loc) · 2.56 KB
/
test_inference_registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import logging
import os
from datetime import timedelta
import pytest
import sagemaker
from sagemaker import Predictor
from sagemaker.deserializers import JSONDeserializer
from sagemaker.pytorch import PyTorch, PyTorchPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.utils import name_from_base
# noinspection DuplicatedCode
@pytest.mark.skipif(os.getenv('PYTEST_IGNORE_SKIPS', "false") == "false",
reason="Not working yet")
def test_clean_inference_from_registry():
estimator = PyTorch(entry_point='train_clean.py',
source_dir='source_dir/training_clean/',
framework_version='1.9.1',
py_version='py38',
instance_count=1,
instance_type='ml.m5.xlarge',
max_run=int(timedelta(minutes=15).total_seconds()),
keep_alive_period_in_seconds=1800,
container_log_level=logging.INFO)
estimator.fit()
model = estimator.create_model(entry_point='inference_clean.py',
source_dir='source_dir/inference_clean/')
package: sagemaker.model.ModelPackage = model.register(content_types=['text/csv'],
response_types=['test/csv'],
model_package_name='inference-registry',
inference_instances=['ml.m5.xlarge'],
transform_instances=['ml.m5.xlarge'])
sagemaker.Session().wait_for_model_package(package.name)
model_package_arn = package.model_package_arn
logging.info(f"Registered package: {model_package_arn}")
# Inference
model_package = sagemaker.model.ModelPackage(role=model.role, model_package_arn=model_package_arn)
endpoint_name = name_from_base('inference-registry')
# TODO: replace with model.deploy to make it work
model_package.deploy(initial_instance_count=1,
instance_type='ml.m5.xlarge',
endpoint_name=endpoint_name)
# Note: when using model_package, deploy() returns None, need construct the predictor manually
predictor: Predictor = PyTorchPredictor(endpoint_name)
predictor.serializer = JSONSerializer()
predictor.deserializer = JSONDeserializer()
predicted_value = predictor.predict(data=[1])
assert predicted_value == [43]
predictor.delete_endpoint(delete_endpoint_config=False)