Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move POST /mlflow/model-versions/create endpoint. #93

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"getRegisteredModel",
// "searchRegisteredModels",
"getLatestVersions",
// "createModelVersion",
"createModelVersion",
// "updateModelVersion",
// "transitionModelVersionStage",
// "deleteModelVersion",
Expand Down
3 changes: 3 additions & 0 deletions magefiles/generate/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,7 @@ var validations = map[string]string{
"Dataset_Schema": "max:1048575",
"InputTag_Key": "required,max=255",
"InputTag_Value": "required,max=500",
"CreateModelVersion_Name": "required",
"ModelVersionTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique",
"ModelVersionTag_Value": "required,truncate=5000",
}
22 changes: 22 additions & 0 deletions mlflow_go/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mlflow.entities.model_registry import ModelVersion, RegisteredModel
from mlflow.protos.model_registry_pb2 import (
CreateModelVersion,
DeleteRegisteredModel,
GetLatestVersions,
GetRegisteredModel,
Expand Down Expand Up @@ -33,6 +34,27 @@ def __del__(self):
if hasattr(self, "service"):
get_lib().DestroyModelRegistryService(self.service.id)

def create_model_version(
self,
name,
source,
run_id=None,
tags=None,
run_link=None,
description=None,
local_model_path=None,
):
request = CreateModelVersion(
name=name,
source=source,
run_id=run_id,
tags=tags,
run_link=run_link,
description=description,
local_model_path=local_model_path,
)
return self.service.call_endpoint(get_lib().ModelRegistryServiceCreateModelVersion, request)

def get_latest_versions(self, name, stages=None):
request = GetLatestVersions(
name=name,
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions pkg/entities/model_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package entities

type ModelTag struct {
Key string
Value string
}
8 changes: 8 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,39 @@ import (
"context"

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/entities"
"github.com/mlflow/mlflow-go/pkg/protos"
)

func (m *ModelRegistryService) CreateModelVersion(
ctx context.Context, input *protos.CreateModelVersion,
) (*protos.CreateModelVersion_Response, *contract.Error) {
tags := make([]entities.ModelTag, 0, len(input.Tags))
for _, tag := range input.Tags {
tags = append(tags, entities.ModelTag{
Key: tag.GetKey(),
Value: tag.GetValue(),
})
}

modelVersion, err := m.store.CreateModelVersion(
ctx,
input.GetName(),
input.GetSource(),
input.GetRunId(),
tags,
input.GetRunLink(),
input.GetDescription(),
)
if err != nil {
return nil, err
}

return &protos.CreateModelVersion_Response{
ModelVersion: modelVersion.ToProto(),
}, nil
}

func (m *ModelRegistryService) GetLatestVersions(
ctx context.Context, input *protos.GetLatestVersions,
) (*protos.GetLatestVersions_Response, *contract.Error) {
Expand Down
102 changes: 102 additions & 0 deletions pkg/model_registry/store/sql/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package sql

import (
"fmt"
"net/url"
"strconv"
"strings"

"github.com/mlflow/mlflow-go/pkg/entities"
)

const (
ModelsURISuffixLatest = "latest"
)

//nolint
var ErrImproperModelURI = func(uri string) error {
return fmt.Errorf(`
Not a proper models:/ URI: %s. "Models URIs must be of the form 'models:/model_name/suffix' or
'models:/model_name@alias' where suffix is a model version, stage, or the string latest
and where alias is a registered model alias. Only one of suffix or alias can be defined at a time."`,
uri,
)
}

type ParsedModelURI struct {
Name string
Stage string
Alias string
Version string
}

func GetModelNextVersion(registeredModel *entities.RegisteredModel) int32 {
if len(registeredModel.Versions) == 0 {
return 1
}

maxVersion := int32(0)
for _, version := range registeredModel.Versions {
if version.Version > maxVersion {
maxVersion = version.Version
}
}

return maxVersion + 1
}

//nolint
func ParseModelURI(uri string) (*ParsedModelURI, error) {
parsedURI, err := url.Parse(uri)
if err != nil {
return nil, err
}

if parsedURI.Scheme != "models" {
return nil, ErrImproperModelURI(uri)
}

if !strings.HasSuffix(parsedURI.Path, "/") || len(parsedURI.Path) <= 1 {
return nil, ErrImproperModelURI(uri)
}

parts := strings.Split(strings.TrimLeft(parsedURI.Path, "/"), "/")
if len(parts) > 2 || strings.Trim(parts[0], " ") == "" {
return nil, ErrImproperModelURI(uri)
}

if len(parts) == 2 {
name, suffix := parts[0], parts[1]
if strings.Trim(suffix, " ") == "" {
return nil, ErrImproperModelURI(uri)
}
// The suffix is a specific version, e.g. "models:/AdsModel1/123"
if _, err := strconv.Atoi(suffix); err == nil {
return &ParsedModelURI{
Name: name,
Version: suffix,
}, nil
}
// The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest"
if (strings.ToLower(suffix)) == ModelsURISuffixLatest {
return &ParsedModelURI{
Name: name,
}, nil
}
// The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return &ParsedModelURI{
Name: name,
Stage: suffix,
}, nil
}

aliasParts := strings.SplitN(parts[0], "@", 1)
if len(aliasParts) != 2 || strings.Trim(aliasParts[1], " ") == "" {
return nil, ErrImproperModelURI(uri)
}

return &ParsedModelURI{
Name: aliasParts[0],
Alias: aliasParts[1],
}, nil
}
137 changes: 137 additions & 0 deletions pkg/model_registry/store/sql/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"net/url"
"strings"
"time"

Expand All @@ -16,6 +17,8 @@ import (
"github.com/mlflow/mlflow-go/pkg/protos"
)

const batchSize = 100

// Validate whether there is a registered model with the given name.
func assertModelExists(db *gorm.DB, name string) *contract.Error {
if err := db.Select("name").Where("name = ?", name).First(&models.RegisteredModel{}).Error; err != nil {
Expand All @@ -36,6 +39,140 @@ func assertModelExists(db *gorm.DB, name string) *contract.Error {
return nil
}

func (m *ModelRegistrySQLStore) GetModelVersion(
ctx context.Context, name, version string,
) (*entities.ModelVersion, *contract.Error) {
var modelVersion models.ModelVersion
if err := m.db.WithContext(
ctx,
).Where(
"name = ?", name,
).Where(
"version = ?", version,
).Where(
"current_stage != ?", models.StageDeletedInternal,
).First(
&modelVersion,
).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, contract.NewError(
protos.ErrorCode_RESOURCE_DOES_NOT_EXIST,
fmt.Sprintf("registered model with name=%q not found", name),
)
}

return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR,
fmt.Sprintf("failed to query registered model with name=%q", name),
err,
)
}

return modelVersion.ToEntity(), nil
}

//nolint:funlen,cyclop
func (m *ModelRegistrySQLStore) CreateModelVersion(
ctx context.Context,
name, source, runID string,
tags []entities.ModelTag,
runLink, description string,
) (*entities.ModelVersion, *contract.Error) {
storageLocation := source

parsedSource, err := url.Parse(source)
if err != nil {
return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR,
fmt.Sprintf("failed to parse source=%q", source),
err,
)
}

if parsedSource.Scheme == "models" {
parsedModelURI, err := ParseModelURI(source)
if err != nil {
return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR,
fmt.Sprintf("Unable to fetch model from model URI source artifact location '%s'.", source),
err,
)
}

modelVersion, contractErr := m.GetModelVersion(ctx, parsedModelURI.Name, parsedModelURI.Version)
if contractErr != nil {
return nil, contractErr
}

if modelVersion.StorageLocation != "" {
storageLocation = modelVersion.StorageLocation
} else if modelVersion.Source != "" {
storageLocation = modelVersion.Source
}
}

registeredModel, contractErr := m.GetRegisteredModel(ctx, name)
if contractErr != nil {
return nil, contractErr
}

registeredModel.LastUpdatedTime = time.Now().UnixMilli()

version := GetModelNextVersion(registeredModel)
newModelVersion := models.ModelVersion{
Name: name,
Version: version,
CreationTime: time.Now().UnixMilli(),
LastUpdatedTime: time.Now().UnixMilli(),
Description: description,
Source: source,
RunID: runID,
RunLink: runLink,
StorageLocation: storageLocation,
}

uniqueTags := map[string]string{}
for _, tag := range tags {
uniqueTags[tag.Key] = tag.Value
}

if err := m.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
if err = transaction.Updates(&registeredModel).Error; err != nil {
return fmt.Errorf("failed to update registered model: %w", err)
}

if err = transaction.Create(
&newModelVersion,
).Error; err != nil {
return err
}

modelTags := make([]models.ModelVersionTag, 0, len(uniqueTags))
for key, value := range uniqueTags {
modelTags = append(modelTags, models.ModelVersionTag{
Key: key,
Value: value,
Name: registeredModel.Name,
Version: version,
})
}

if err = transaction.CreateInBatches(
modelTags, batchSize,
).Error; err != nil {
return err
}

return nil
}); err != nil {
return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR, "failed to create model version", err,
)
}

return newModelVersion.ToEntity(), nil
}

func (m *ModelRegistrySQLStore) GetLatestVersions(
ctx context.Context, name string, stages []string,
) ([]*protos.ModelVersion, *contract.Error) {
Expand Down
6 changes: 6 additions & 0 deletions pkg/model_registry/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ import (
type ModelRegistryStore interface {
contract.Destroyer
GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error)
CreateModelVersion(
ctx context.Context,
name, source, runID string,
tags []entities.ModelTag,
runLink, description string,
) (*entities.ModelVersion, *contract.Error)
GetRegisteredModel(ctx context.Context, name string) (*entities.RegisteredModel, *contract.Error)
UpdateRegisteredModel(ctx context.Context, name, description string) (*entities.RegisteredModel, *contract.Error)
RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error)
Expand Down
Loading
Loading