From c18e13be9de318fa307c21db040c38ac38f9d3f9 Mon Sep 17 00:00:00 2001 From: Mehmet YILMAZ Date: Sat, 1 Jun 2024 15:48:21 +0300 Subject: [PATCH] Add Managed Identity option to azure storage updates --- pkg/storages/azure/folder.go | 44 ++++++++++++++++++++++++++----- pkg/storages/azure/folder_test.go | 30 ++++++++++++++++++--- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/pkg/storages/azure/folder.go b/pkg/storages/azure/folder.go index b8599ce3e..78867e9f6 100644 --- a/pkg/storages/azure/folder.go +++ b/pkg/storages/azure/folder.go @@ -28,6 +28,7 @@ const ( BufferSizeSetting = "AZURE_BUFFER_SIZE" MaxBuffersSetting = "AZURE_MAX_BUFFERS" TryTimeoutSetting = "AZURE_TRY_TIMEOUT" + ClientIDSetting = "AZURE_CLIENT_ID" minBufferSize = 1024 defaultBufferSize = 8 * 1024 * 1024 minBuffers = 1 @@ -40,8 +41,9 @@ const ( type AzureAuthType string const ( - AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth" - AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth" + AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth" + AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth" + AzureManagedIdentityAuth AzureAuthType = "AzureManagedIdentityAuth" ) var SettingList = []string{ @@ -52,6 +54,7 @@ var SettingList = []string{ EndpointSuffix, BufferSizeSetting, MaxBuffersSetting, + ClientIDSetting, } func NewFolderError(err error, format string, args ...interface{}) storage.Error { @@ -78,6 +81,31 @@ func NewFolder( } } +func getContainerClientWithManagedIndetity( + accountName string, + storageEndpointSuffix string, + containerName string, + timeout time.Duration, + clientID string) (*azblob.ContainerClient, error) { + cred, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ + ID: azidentity.ClientID(clientID), + }) + if err != nil { + return nil, err + } + + containerURLString := fmt.Sprintf("https://%s.blob.%s/%s", accountName, storageEndpointSuffix, containerName) + _, err = url.Parse(containerURLString) + if err != nil { + return nil, NewFolderError(err, "Unable to parse service URL") + } + + containerClient, err := azblob.NewContainerClient(containerURLString, cred, &azblob.ClientOptions{ + Retry: policy.RetryOptions{TryTimeout: timeout}, + }) + return containerClient, err +} + func getContainerClientWithSASToken( accountName string, storageEndpointSuffix string, @@ -136,9 +164,9 @@ func getContainerClient( return containerClient, err } -func configureAuthType(settings map[string]string) (AzureAuthType, string, string) { +func configureAuthType(settings map[string]string) (AzureAuthType, string, string, string) { var ok bool - var accountToken, accessKey string + var accountToken, accessKey, clientID string var authType AzureAuthType if accessKey, ok = settings[AccessKeySetting]; ok { @@ -149,9 +177,11 @@ func configureAuthType(settings map[string]string) (AzureAuthType, string, strin if !strings.HasPrefix(accountToken, "?") { accountToken = "?" + accountToken } + } else if clientID, ok = settings[ClientIDSetting]; ok { + authType = AzureManagedIdentityAuth } - return authType, accountToken, accessKey + return authType, accountToken, accessKey, clientID } func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, error) { @@ -161,7 +191,7 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, return nil, NewCredentialError(AccountSetting) } - authType, accountToken, accountKey := configureAuthType(settings) + authType, accountToken, accountKey, clientID := configureAuthType(settings) var credential *azblob.SharedKeyCredential var err error @@ -199,6 +229,8 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, var containerClient *azblob.ContainerClient if authType == AzureSASTokenAuth { containerClient, err = getContainerClientWithSASToken(accountName, storageEndpointSuffix, containerName, timeout, accountToken) + } else if authType == AzureManagedIdentityAuth { + containerClient, err = getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID) } else if authType == AzureAccessKeyAuth { containerClient, err = getContainerClientWithAccessKey(accountName, storageEndpointSuffix, containerName, timeout, credential) } else { diff --git a/pkg/storages/azure/folder_test.go b/pkg/storages/azure/folder_test.go index da4b1f637..cda81a0b0 100644 --- a/pkg/storages/azure/folder_test.go +++ b/pkg/storages/azure/folder_test.go @@ -2,6 +2,7 @@ package azure import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/wal-g/wal-g/pkg/storages/storage" @@ -22,24 +23,47 @@ var ConfigureAuthType = configureAuthType func TestConfigureAccessKeyAuthType(t *testing.T) { settings := map[string]string{AccessKeySetting: "foo"} - authType, accountToken, accessKey := ConfigureAuthType(settings) + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) assert.Equal(t, authType, AzureAccessKeyAuth) assert.Empty(t, accountToken) assert.Equal(t, accessKey, "foo") + assert.Empty(t, clientID) } func TestConfigureSASTokenAuth(t *testing.T) { settings := map[string]string{SasTokenSetting: "foo"} - authType, accountToken, accessKey := ConfigureAuthType(settings) + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) assert.Equal(t, authType, AzureSASTokenAuth) assert.Equal(t, accountToken, "?foo") assert.Empty(t, accessKey) + assert.Empty(t, clientID) } func TestConfigureDefaultAuth(t *testing.T) { settings := make(map[string]string) - authType, accountToken, accessKey := ConfigureAuthType(settings) + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) assert.Empty(t, authType) assert.Empty(t, accountToken) assert.Empty(t, accessKey) + assert.Empty(t, clientID) +} + +func TestConfigureManagedIdentityAuth(t *testing.T) { + settings := map[string]string{ClientIDSetting: "foo"} + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) + assert.Equal(t, authType, AzureManagedIdentityAuth) + assert.Empty(t, accountToken) + assert.Empty(t, accessKey) + assert.Equal(t, clientID, "foo") +} +func TestGetContainerClientWithManagedIdentity(t *testing.T) { + accountName := "test-account" + storageEndpointSuffix := "test-endpoint" + containerName := "test-container" + timeout := time.Minute + clientID := "test-client-id" + + containerClient, err := getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID) + assert.NoError(t, err) + assert.NotNil(t, containerClient) }