Skip to content

Commit

Permalink
fix: null OAuth email behavior (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
diegosperes authored Apr 17, 2024
1 parent a7bdddb commit b608f2c
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 10 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ Based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## HEAD

## 1.20.1

### Added

* Fixing code that wansn't expeting oauth email as null value.

## 1.20.0

### Added
Expand Down
4 changes: 2 additions & 2 deletions app/data/mock/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (s *accountStore) AddOauthAccount(accountID int, provider, providerID, emai

now := time.Now()
oauthAccount := &models.OauthAccount{
Email: email,
Email: &email,
AccountID: accountID,
Provider: provider,
ProviderID: providerID,
Expand All @@ -130,7 +130,7 @@ func (s *accountStore) UpdateOauthAccount(accountID int, provider, email string)

for i, oauthAccount := range oauthAccounts {
if oauthAccount.Provider == provider {
s.oauthAccountsByID[accountID][i].Email = email
s.oauthAccountsByID[accountID][i].Email = &email
return true, nil
}
}
Expand Down
23 changes: 23 additions & 0 deletions app/data/mysql/account_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,27 @@ func TestAccountStore(t *testing.T) {
db.MustExec("TRUNCATE oauth_accounts")
tester(t, store)
}

t.Run("handle oauth email with null value", func(t *testing.T) {
account, err := store.Create("migrated-user", []byte("old"))
require.NoError(t, err)

err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token")
require.NoError(t, err)

result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = ?", account.ID)
require.NoError(t, err)

rowsAffected, err := result.RowsAffected()
require.NoError(t, err)

require.Equal(t, int64(1), rowsAffected)

oAccounts, err := store.GetOauthAccounts(account.ID)
require.NoError(t, err)

require.Len(t, oAccounts, 1)
require.True(t, oAccounts[0].Email == nil)
require.Equal(t, oAccounts[0].GetEmail(), "")
})
}
23 changes: 23 additions & 0 deletions app/data/postgres/account_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,27 @@ func TestAccountStore(t *testing.T) {
db.MustExec("TRUNCATE oauth_accounts")
tester(t, store)
}

t.Run("handle oauth email with null value", func(t *testing.T) {
account, err := store.Create("migrated-user", []byte("old"))
require.NoError(t, err)

err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token")
require.NoError(t, err)

result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = $1", account.ID)
require.NoError(t, err)

rowsAffected, err := result.RowsAffected()
require.NoError(t, err)

require.Equal(t, int64(1), rowsAffected)

oAccounts, err := store.GetOauthAccounts(account.ID)
require.NoError(t, err)

require.Len(t, oAccounts, 1)
require.True(t, oAccounts[0].Email == nil)
require.Equal(t, oAccounts[0].GetEmail(), "")
})
}
12 changes: 10 additions & 2 deletions app/models/oauth_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@ type OauthAccount struct {
AccountID int `db:"account_id"`
Provider string
ProviderID string `db:"provider_id"`
Email string `db:"email"`
Email *string `db:"email"`
AccessToken string `db:"access_token"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

func (a OauthAccount) GetEmail() string {
if a.Email != nil {
return *a.Email
}

return ""
}

func (o OauthAccount) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Provider string `json:"provider"`
Expand All @@ -24,6 +32,6 @@ func (o OauthAccount) MarshalJSON() ([]byte, error) {
}{
Provider: o.Provider,
ProviderID: o.ProviderID,
Email: o.Email,
Email: o.GetEmail(),
})
}
4 changes: 2 additions & 2 deletions app/services/account_getter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ func TestAccountGetter(t *testing.T) {
require.Equal(t, 2, len(oAccounts))
require.Equal(t, "test", oAccounts[0].Provider)
require.Equal(t, "ID1", oAccounts[0].ProviderID)
require.Equal(t, "email1", oAccounts[0].Email)
require.Equal(t, "email1", oAccounts[0].GetEmail())

require.Equal(t, "trial", oAccounts[1].Provider)
require.Equal(t, "ID2", oAccounts[1].ProviderID)
require.Equal(t, "email2", oAccounts[1].Email)
require.Equal(t, "email2", oAccounts[1].GetEmail())
})
}
2 changes: 1 addition & 1 deletion app/services/identity_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func updateUserInfo(accountStore data.AccountStore, accountID int, providerName
continue
}

if oAccount.Email != providerUser.Email {
if oAccount.GetEmail() != providerUser.Email {
_, err = accountStore.UpdateOauthAccount(accountID, oAccount.Provider, providerUser.Email)
if err != nil {
return errors.Wrap(err, "UpdateOauthAccount")
Expand Down
4 changes: 2 additions & 2 deletions app/services/identity_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestIdentityReconciler(t *testing.T) {
oAccounts, err := store.GetOauthAccounts(account.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(oAccounts))
assert.Equal(t, email, oAccounts[0].Email)
assert.Equal(t, email, oAccounts[0].GetEmail())
})

t.Run("update oauth email when is outdated", func(t *testing.T) {
Expand All @@ -123,6 +123,6 @@ func TestIdentityReconciler(t *testing.T) {
oAccounts, err := store.GetOauthAccounts(account.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(oAccounts))
assert.Equal(t, email, oAccounts[0].Email)
assert.Equal(t, email, oAccounts[0].GetEmail())
})
}
2 changes: 1 addition & 1 deletion server/handlers/get_account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func assertGetAccountResponse(t *testing.T, res *http.Response, acc *models.Acco
oAccounts = append(oAccounts, map[string]interface{}{
"provider": oAcc.Provider,
"provider_account_id": oAcc.ProviderID,
"email": oAcc.Email,
"email": oAcc.GetEmail(),
})
}

Expand Down

0 comments on commit b608f2c

Please sign in to comment.