diff --git a/CHANGELOG.md b/CHANGELOG.md index b5d0aaaef..8f55ff6a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ Based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## HEAD +## 1.20.1 + +### Added + +* Fixing code that wansn't expeting oauth email as null value. + ## 1.20.0 ### Added diff --git a/app/data/mock/account_store.go b/app/data/mock/account_store.go index 10c484bee..5e952d5e8 100644 --- a/app/data/mock/account_store.go +++ b/app/data/mock/account_store.go @@ -107,7 +107,7 @@ func (s *accountStore) AddOauthAccount(accountID int, provider, providerID, emai now := time.Now() oauthAccount := &models.OauthAccount{ - Email: email, + Email: &email, AccountID: accountID, Provider: provider, ProviderID: providerID, @@ -130,7 +130,7 @@ func (s *accountStore) UpdateOauthAccount(accountID int, provider, email string) for i, oauthAccount := range oauthAccounts { if oauthAccount.Provider == provider { - s.oauthAccountsByID[accountID][i].Email = email + s.oauthAccountsByID[accountID][i].Email = &email return true, nil } } diff --git a/app/data/mysql/account_store_test.go b/app/data/mysql/account_store_test.go index e39269a21..44d91cb72 100644 --- a/app/data/mysql/account_store_test.go +++ b/app/data/mysql/account_store_test.go @@ -17,4 +17,27 @@ func TestAccountStore(t *testing.T) { db.MustExec("TRUNCATE oauth_accounts") tester(t, store) } + + t.Run("handle oauth email with null value", func(t *testing.T) { + account, err := store.Create("migrated-user", []byte("old")) + require.NoError(t, err) + + err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token") + require.NoError(t, err) + + result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = ?", account.ID) + require.NoError(t, err) + + rowsAffected, err := result.RowsAffected() + require.NoError(t, err) + + require.Equal(t, int64(1), rowsAffected) + + oAccounts, err := store.GetOauthAccounts(account.ID) + require.NoError(t, err) + + require.Len(t, oAccounts, 1) + require.True(t, oAccounts[0].Email == nil) + require.Equal(t, oAccounts[0].GetEmail(), "") + }) } diff --git a/app/data/postgres/account_store_test.go b/app/data/postgres/account_store_test.go index 6009fd39c..c1b9e5200 100644 --- a/app/data/postgres/account_store_test.go +++ b/app/data/postgres/account_store_test.go @@ -44,4 +44,27 @@ func TestAccountStore(t *testing.T) { db.MustExec("TRUNCATE oauth_accounts") tester(t, store) } + + t.Run("handle oauth email with null value", func(t *testing.T) { + account, err := store.Create("migrated-user", []byte("old")) + require.NoError(t, err) + + err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token") + require.NoError(t, err) + + result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = $1", account.ID) + require.NoError(t, err) + + rowsAffected, err := result.RowsAffected() + require.NoError(t, err) + + require.Equal(t, int64(1), rowsAffected) + + oAccounts, err := store.GetOauthAccounts(account.ID) + require.NoError(t, err) + + require.Len(t, oAccounts, 1) + require.True(t, oAccounts[0].Email == nil) + require.Equal(t, oAccounts[0].GetEmail(), "") + }) } diff --git a/app/models/oauth_account.go b/app/models/oauth_account.go index 1c366c640..14720f728 100644 --- a/app/models/oauth_account.go +++ b/app/models/oauth_account.go @@ -10,12 +10,20 @@ type OauthAccount struct { AccountID int `db:"account_id"` Provider string ProviderID string `db:"provider_id"` - Email string `db:"email"` + Email *string `db:"email"` AccessToken string `db:"access_token"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } +func (a OauthAccount) GetEmail() string { + if a.Email != nil { + return *a.Email + } + + return "" +} + func (o OauthAccount) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Provider string `json:"provider"` @@ -24,6 +32,6 @@ func (o OauthAccount) MarshalJSON() ([]byte, error) { }{ Provider: o.Provider, ProviderID: o.ProviderID, - Email: o.Email, + Email: o.GetEmail(), }) } diff --git a/app/services/account_getter_test.go b/app/services/account_getter_test.go index 5938e0507..8bec14a33 100644 --- a/app/services/account_getter_test.go +++ b/app/services/account_getter_test.go @@ -53,10 +53,10 @@ func TestAccountGetter(t *testing.T) { require.Equal(t, 2, len(oAccounts)) require.Equal(t, "test", oAccounts[0].Provider) require.Equal(t, "ID1", oAccounts[0].ProviderID) - require.Equal(t, "email1", oAccounts[0].Email) + require.Equal(t, "email1", oAccounts[0].GetEmail()) require.Equal(t, "trial", oAccounts[1].Provider) require.Equal(t, "ID2", oAccounts[1].ProviderID) - require.Equal(t, "email2", oAccounts[1].Email) + require.Equal(t, "email2", oAccounts[1].GetEmail()) }) } diff --git a/app/services/identity_reconciler.go b/app/services/identity_reconciler.go index 8aa457a69..2d972d208 100644 --- a/app/services/identity_reconciler.go +++ b/app/services/identity_reconciler.go @@ -97,7 +97,7 @@ func updateUserInfo(accountStore data.AccountStore, accountID int, providerName continue } - if oAccount.Email != providerUser.Email { + if oAccount.GetEmail() != providerUser.Email { _, err = accountStore.UpdateOauthAccount(accountID, oAccount.Provider, providerUser.Email) if err != nil { return errors.Wrap(err, "UpdateOauthAccount") diff --git a/app/services/identity_reconciler_test.go b/app/services/identity_reconciler_test.go index 0b35e9377..661a75042 100644 --- a/app/services/identity_reconciler_test.go +++ b/app/services/identity_reconciler_test.go @@ -102,7 +102,7 @@ func TestIdentityReconciler(t *testing.T) { oAccounts, err := store.GetOauthAccounts(account.ID) assert.NoError(t, err) assert.Equal(t, 1, len(oAccounts)) - assert.Equal(t, email, oAccounts[0].Email) + assert.Equal(t, email, oAccounts[0].GetEmail()) }) t.Run("update oauth email when is outdated", func(t *testing.T) { @@ -123,6 +123,6 @@ func TestIdentityReconciler(t *testing.T) { oAccounts, err := store.GetOauthAccounts(account.ID) assert.NoError(t, err) assert.Equal(t, 1, len(oAccounts)) - assert.Equal(t, email, oAccounts[0].Email) + assert.Equal(t, email, oAccounts[0].GetEmail()) }) } diff --git a/server/handlers/get_account_test.go b/server/handlers/get_account_test.go index 737dafc4b..cf0668eec 100644 --- a/server/handlers/get_account_test.go +++ b/server/handlers/get_account_test.go @@ -80,7 +80,7 @@ func assertGetAccountResponse(t *testing.T, res *http.Response, acc *models.Acco oAccounts = append(oAccounts, map[string]interface{}{ "provider": oAcc.Provider, "provider_account_id": oAcc.ProviderID, - "email": oAcc.Email, + "email": oAcc.GetEmail(), }) }