Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: implement best practices for transaction through helper method #1605

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions app/controlplane/pkg/data/attestationstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,50 +53,39 @@ func (r *AttestationStateRepo) Initialized(ctx context.Context, runID uuid.UUID)

// baseDigest, when provided will be used to check that it matches the digest of the state currently in the DB
// if the digests do not match, the state has been modified and the caller should retry
func (r *AttestationStateRepo) Save(ctx context.Context, runID uuid.UUID, state []byte, baseDigest string) (err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return fmt.Errorf("failed to create transaction: %w", err)
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
func (r *AttestationStateRepo) Save(ctx context.Context, runID uuid.UUID, state []byte, baseDigest string) error {
return WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// compared the provided digest with the digest of the state in the DB
// TODO: make digest check mandatory on updates
if baseDigest != "" {
// Get the run but BLOCK IT for update
run, err := tx.WorkflowRun.Query().ForUpdate().Where(workflowrun.ID(runID)).Only(ctx)
if err != nil && !ent.IsNotFound(err) {
return fmt.Errorf("failed to read attestation state: %w", err)
} else if run == nil || run.AttestationState == nil {
return biz.NewErrNotFound("attestation state")
}

// calculate the digest of the current state
storedDigest, err := digest(run.AttestationState)
if err != nil {
return fmt.Errorf("failed to calculate digest: %w", err)
}

if baseDigest != storedDigest {
return biz.NewErrAttestationStateConflict(storedDigest, baseDigest)
}
}
}()

// compared the provided digest with the digest of the state in the DB
// TODO: make digest check mandatory on updates
if baseDigest != "" {
// Get the run but BLOCK IT for update
run, err := tx.WorkflowRun.Query().ForUpdate().Where(workflowrun.ID(runID)).Only(ctx)
// Update it in the DB if the digest matches
err := tx.WorkflowRun.UpdateOneID(runID).SetAttestationState(state).Exec(ctx)
if err != nil && !ent.IsNotFound(err) {
return fmt.Errorf("failed to read attestation state: %w", err)
} else if run == nil || run.AttestationState == nil {
return biz.NewErrNotFound("attestation state")
}

// calculate the digest of the current state
storedDigest, err := digest(run.AttestationState)
if err != nil {
return fmt.Errorf("failed to calculate digest: %w", err)
}

if baseDigest != storedDigest {
return biz.NewErrAttestationStateConflict(storedDigest, baseDigest)
return fmt.Errorf("failed to store attestation state: %w", err)
} else if err != nil {
return biz.NewErrNotFound("workflow run")
}
}

// Update it in the DB if the digest matches
err = tx.WorkflowRun.UpdateOneID(runID).SetAttestationState(state).Exec(ctx)
if err != nil && !ent.IsNotFound(err) {
return fmt.Errorf("failed to store attestation state: %w", err)
} else if err != nil {
return biz.NewErrNotFound("workflow run")
}

return tx.Commit()
return nil
})
}

func (r *AttestationStateRepo) Read(ctx context.Context, runID uuid.UUID) ([]byte, string, error) {
Expand Down
150 changes: 67 additions & 83 deletions app/controlplane/pkg/data/casbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,106 +83,90 @@ func (r *CASBackendRepo) FindFallbackBackend(ctx context.Context, orgID uuid.UUI

// Create creates a new CAS backend in the given organization
// If it's set as default, it will unset the previous default backend
func (r *CASBackendRepo) Create(ctx context.Context, opts *biz.CASBackendCreateOpts) (b *biz.CASBackend, err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}
func (r *CASBackendRepo) Create(ctx context.Context, opts *biz.CASBackendCreateOpts) (*biz.CASBackend, error) {
var (
backend *ent.CASBackend
err error
)
if err := WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return fmt.Errorf("failed to clear previous default backend: %w", err)
}
}

defer func() {
// Unblock the row if there was an error
// 2 - create the new backend and set it as default if needed
backend, err = tx.CASBackend.Create().
SetName(opts.Name).
SetOrganizationID(opts.OrgID).
SetLocation(opts.Location).
SetDescription(opts.Description).
SetFallback(opts.Fallback).
SetProvider(opts.Provider).
SetDefault(opts.Default).
SetSecretName(opts.SecretName).
SetMaxBlobSizeBytes(opts.MaxBytes).
Save(ctx)
if err != nil {
_ = tx.Rollback()
}
}()

// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to clear previous default backend: %w", err)
}
}
if ent.IsConstraintError(err) {
return biz.NewErrAlreadyExists(err)
}

// 2 - create the new backend and set it as default if needed
backend, err := tx.CASBackend.Create().
SetName(opts.Name).
SetOrganizationID(opts.OrgID).
SetLocation(opts.Location).
SetDescription(opts.Description).
SetFallback(opts.Fallback).
SetProvider(opts.Provider).
SetDefault(opts.Default).
SetSecretName(opts.SecretName).
SetMaxBlobSizeBytes(opts.MaxBytes).
Save(ctx)
if err != nil {
if ent.IsConstraintError(err) {
return nil, biz.NewErrAlreadyExists(err)
return fmt.Errorf("failed to create backend: %w", err)
}

return nil, fmt.Errorf("failed to create backend: %w", err)
}

// 3 - commit the transaction
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
return nil
}); err != nil {
return nil, err
}

// Return the backend from the DB to have consistent marshalled object
return r.FindByID(ctx, backend.ID)
}

func (r *CASBackendRepo) Update(ctx context.Context, opts *biz.CASBackendUpdateOpts) (b *biz.CASBackend, err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
func (r *CASBackendRepo) Update(ctx context.Context, opts *biz.CASBackendUpdateOpts) (*biz.CASBackend, error) {
var (
backend *ent.CASBackend
err error
)
if err = WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return fmt.Errorf("failed to clear previous default backend: %w", err)
}
}
}()

// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to clear previous default backend: %w", err)
}
}

// 2 - Chain the list of updates
// TODO: allow setting values as empty, currently it's not possible.
// We do it in other models by providing pointers to string + setNillableX methods
updateChain := tx.CASBackend.UpdateOneID(opts.ID).SetDefault(opts.Default)
if opts.Description != "" {
updateChain = updateChain.SetDescription(opts.Description)
}
// 2 - Chain the list of updates
// TODO: allow setting values as empty, currently it's not possible.
// We do it in other models by providing pointers to string + setNillableX methods
updateChain := tx.CASBackend.UpdateOneID(opts.ID).SetDefault(opts.Default)
if opts.Description != "" {
updateChain = updateChain.SetDescription(opts.Description)
}

// If secretName is provided we set it
if opts.SecretName != "" {
updateChain = updateChain.SetSecretName(opts.SecretName)
}
// If secretName is provided we set it
if opts.SecretName != "" {
updateChain = updateChain.SetSecretName(opts.SecretName)
}

backend, err := updateChain.Save(ctx)
if err != nil {
backend, err = updateChain.Save(ctx)
if err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}

// 3 - commit the transaction
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}

return r.FindByID(ctx, backend.ID)
}

Expand Down
27 changes: 25 additions & 2 deletions app/controlplane/pkg/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ package data

import (
"context"
"database/sql"
"fmt"
"io"
"time"

"database/sql"

"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"

Expand Down Expand Up @@ -160,3 +159,27 @@ func toTimePtr(t time.Time) *time.Time {
func orgScopedQuery(client *ent.Client, orgID uuid.UUID) *ent.OrganizationQuery {
return client.Organization.Query().Where(organization.ID(orgID))
}

// WithTx initiates a transaction and wraps the DB function
func WithTx(ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) error) error {
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() {
if v := recover(); v != nil {
_ = tx.Rollback()
panic(v)
}
}()
if err = fn(tx); err != nil {
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%w: rolling back transaction: %w", err, rerr)
}
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}
31 changes: 10 additions & 21 deletions app/controlplane/pkg/data/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,29 +112,18 @@ func (r *IntegrationRepo) FindByNameInOrg(ctx context.Context, orgID uuid.UUID,
return entIntegrationToBiz(integration), nil
}

func (r *IntegrationRepo) SoftDelete(ctx context.Context, id uuid.UUID) (err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return err
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
func (r *IntegrationRepo) SoftDelete(ctx context.Context, id uuid.UUID) error {
return WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// soft-delete attachments associated with this workflow
if err := tx.IntegrationAttachment.Update().Where(integrationattachment.HasIntegrationWith(integration.ID(id))).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}
}()

// soft-delete attachments associated with this workflow
if err := tx.IntegrationAttachment.Update().Where(integrationattachment.HasIntegrationWith(integration.ID(id))).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}

if err := tx.Integration.UpdateOneID(id).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}

return tx.Commit()
if err := tx.Integration.UpdateOneID(id).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}
return nil
})
}

func entIntegrationToBiz(i *ent.Integration) *biz.Integration {
Expand Down
35 changes: 11 additions & 24 deletions app/controlplane/pkg/data/membership.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,34 +152,21 @@
return nil, err
}

// For the found user, we must, in a transaction.
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return nil, err
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
if err = WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// 1 - Set all the memberships to current=false
if err = tx.Membership.Update().Where(membership.HasUserWith(user.ID(m.Edges.User.ID))).
SetCurrent(false).Exec(ctx); err != nil {
return err
}
}()

// 1 - Set all the memberships to current=false
if err = tx.Membership.Update().Where(membership.HasUserWith(user.ID(m.Edges.User.ID))).
SetCurrent(false).Exec(ctx); err != nil {
return nil, err
}

// 2 - Set the referenced membership to current=true
if err = tx.Membership.UpdateOneID(membershipID).SetCurrent(true).Exec(ctx); err != nil {
return nil, err
}

if err := tx.Commit(); err != nil {
// 2 - Set the referenced membership to current=true
if err = tx.Membership.UpdateOneID(membershipID).SetCurrent(true).Exec(ctx); err != nil {
return err
}
}); err != nil {

Check failure on line 166 in app/controlplane/pkg/data/membership.go

View workflow job for this annotation

GitHub Actions / Test (main-module)

missing return

Check failure on line 166 in app/controlplane/pkg/data/membership.go

View workflow job for this annotation

GitHub Actions / Test (main-module)

missing return

Check failure on line 166 in app/controlplane/pkg/data/membership.go

View workflow job for this annotation

GitHub Actions / Test (main-module)

missing return

Check failure on line 166 in app/controlplane/pkg/data/membership.go

View workflow job for this annotation

GitHub Actions / lint (main-module)

missing return) (typecheck)

Check failure on line 166 in app/controlplane/pkg/data/membership.go

View workflow job for this annotation

GitHub Actions / lint (main-module)

missing return) (typecheck)
return nil, err
}

// Reload returned data
m, err = r.loadMembership(ctx, membershipID)
if err != nil {
Expand Down
Loading
Loading