Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensuring the integrity of full snapshot before uploading it to the object store. #779

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions cmd/compact.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func NewCompactCommand(ctx context.Context) *cobra.Command {
compactOptions := &brtypes.CompactOptions{
RestoreOptions: options,
CompactorConfig: opts.compactorConfig,
TempDir: opts.snapstoreConfig.TempDir,
}

snapshot, err := cp.Compact(ctx, compactOptions)
Expand Down
2 changes: 1 addition & 1 deletion pkg/compactor/compactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (cp *Compactor) Compact(ctx context.Context, opts *brtypes.CompactOptions)
isFinal := compactorRestoreOptions.BaseSnapshot.IsFinal

cc := &compressor.CompressionConfig{Enabled: isCompressed, CompressionPolicy: compressionPolicy}
snapshot, err := etcdutil.TakeAndSaveFullSnapshot(snapshotReqCtx, clientMaintenance, cp.store, etcdRevision, cc, suffix, isFinal, cp.logger)
snapshot, err := etcdutil.TakeAndSaveFullSnapshot(snapshotReqCtx, clientMaintenance, cp.store, opts.TempDir, etcdRevision, cc, suffix, isFinal, cp.logger)
if err != nil {
return nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/compactor/compactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ var _ = Describe("Running Compactor", func() {
var compactorConfig *brtypes.CompactorConfig
var compactOptions *brtypes.CompactOptions
var compactedSnapshot *brtypes.Snapshot
var snapstoreConfig *brtypes.SnapstoreConfig
var tempRestoreDir string
var tempDataDir string

BeforeEach(func() {
dir = fmt.Sprintf("%s/etcd/snapshotter.bkp", testSuiteDir)
store, err = snapstore.GetSnapstore(&brtypes.SnapstoreConfig{Container: dir, Provider: "Local"})
snapstoreConfig = &brtypes.SnapstoreConfig{
Container: dir,
Provider: "Local",
}

store, err = snapstore.GetSnapstore(snapstoreConfig)
Expect(err).ShouldNot(HaveOccurred())
fmt.Println("The store where compaction will save snapshot is: ", store)

Expand Down Expand Up @@ -104,6 +110,7 @@ var _ = Describe("Running Compactor", func() {
compactOptions = &brtypes.CompactOptions{
RestoreOptions: restoreOpts,
CompactorConfig: compactorConfig,
TempDir: snapstoreConfig.TempDir,
}
})

Expand Down
121 changes: 116 additions & 5 deletions pkg/etcdutil/etcdutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
package etcdutil

import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"fmt"
"io"
"os"
"path/filepath"
"time"

"github.com/gardener/etcd-backup-restore/pkg/compressor"
Expand All @@ -23,6 +28,10 @@ import (
"go.etcd.io/etcd/pkg/transport"
)

const (
hashBufferSize = 4 * 1024 * 1024 // 4 MB
)

// NewFactory returns a Factory that constructs new clients using the supplied ETCD client configuration.
func NewFactory(cfg brtypes.EtcdConnectionConfig, opts ...client.Option) client.Factory {
options := &client.Options{}
Expand Down Expand Up @@ -232,7 +241,7 @@ func GetEtcdEndPointsSorted(ctx context.Context, clientMaintenance client.Mainte
endPoint = etcdEndpoints[0]
} else {
return nil, nil, &errors.EtcdError{
Message: fmt.Sprintf("etcd endpoints are not passed correctly"),
Message: "etcd endpoints are not passed correctly",
}
}

Expand All @@ -254,7 +263,7 @@ func GetEtcdEndPointsSorted(ctx context.Context, clientMaintenance client.Mainte
}

// TakeAndSaveFullSnapshot takes full snapshot and save it to store
func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceCloser, store brtypes.SnapStore, lastRevision int64, cc *compressor.CompressionConfig, suffix string, isFinal bool, logger *logrus.Entry) (*brtypes.Snapshot, error) {
func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceCloser, store brtypes.SnapStore, tempDir string, lastRevision int64, cc *compressor.CompressionConfig, suffix string, isFinal bool, logger *logrus.Entry) (*brtypes.Snapshot, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceCloser, store brtypes.SnapStore, tempDir string, lastRevision int64, cc *compressor.CompressionConfig, suffix string, isFinal bool, logger *logrus.Entry) (*brtypes.Snapshot, error) {
// TakeAndSaveFullSnapshot takes a full snapshot of the etcd database, verifies its integrity,
// optionally compresses it, and saves it to the specified snapshot store.
func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceCloser, store brtypes.SnapStore, tempDir string, lastRevision int64, compressionConfig *compressor.CompressionConfig, suffix string, isFinal bool, logger *logrus.Entry) (*brtypes.Snapshot, error) {
snapshotStartTime := time.Now()
logger.Infof("Starting full snapshot process. Last revision: %d, TempDir: %s", lastRevision, tempDir)
snapshotReader, err := createEtcdSnapshot(ctx, client, logger)
if err != nil {
logger.Errorf("Failed to create etcd snapshot: %v", err)
return nil, err
}
defer snapshotReader.Close()
snapshotTempDBPath := filepath.Join(tempDir, "db")
logger.Infof("Temporary DB path for snapshot verification: %s", snapshotTempDBPath)
// Verify snapshot integrity
verifiedSnapshotReader, err := verifyFullSnapshotIntegrity(snapshotReader, snapshotTempDBPath, logger)
if err != nil {
logger.Errorf("Verification of full snapshot SHA256 hash failed: %v", err)
return nil, err
}
defer cleanUpTempFile(snapshotTempDBPath, logger)
logger.Info("Full snapshot SHA256 hash successfully verified.")
verifiedSnapshotReader, err = compressSnapshotDataIfNeeded(verifiedSnapshotReader, compressionConfig, logger)
if err != nil {
return nil, err
}
logger.Infof("Successfully opened snapshot reader on etcd")
// Save the snapshot to the store
fullSnapshot := snapstore.NewSnapshot(brtypes.SnapshotKindFull, 0, lastRevision, suffix, isFinal)
logger.Infof("Saving full snapshot to store. Snapshot details: %+v", fullSnapshot)
if err := saveSnapshotToStore(store, fullSnapshot, verifiedSnapshotReader, snapshotStartTime, logger); err != nil {
logger.Errorf("Failed to save snapshot to store: %v", err)
return nil, err
}
logger.Infof("Total time to save full snapshot: %f seconds.", time.Since(snapshotStartTime).Seconds())
return fullSnapshot, nil
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not like to change the function TakeAndSaveFullSnapshot() as this is not a refactoring PR.
Feel free to open a separate issue if you want to refactor TakeAndSaveFullSnapshot().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are touching this function and adding more 15+ lines, I don't want to grow this function. This won't help readability of code. Thats why I suggest to move the code into smaller functions, each function name clearly conveys the functionality

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, but I would not like to do that in the same PR for following reasons:

  1. That will change the scope of PR from adding feature PR to refactoring PR which will be hard to track after 1year (say) if some bug is found in some part of code.
  2. With this refactoring change, it will create a huge overhead of manual testing, performance testing etc like taking snapshots, compression of snapshots, restoration of compressed and uncompressed snapshots + verification of snapshots (feature is being added) all needs to be fully tested.
  3. It will completely change the scope of the PR, which will also invalidate the LGTM'ed of reviewers those who has already reviewed the PR as it will become a new PR.

That's I don't want to do this in the same PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point 1: Do you see this as a major refactor? It’s essentially just restructuring a few existing lines into two functions for better readability.

Point 2: This point isn’t entirely valid, as we are already planning to test this functionality as part of the release process. This includes verifying the new integrity checks that have been added.

Point 3: Whenever we introduce new changes or code, it’s common to make minor adjustments to the surrounding logic for better structure. In this case, it’s simply wrapping the logic into smaller, reusable functions.

That said, this isn’t a significant code change. Let’s get some thoughts from other reviewers as well.

WDYT, @anveshreddy18 and @renormalize?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you see this as a major refactor? I

yes, I do because of point 2
I have mentioned the my reasoning in this comment #779 (comment), I'm not making this PR as refactoring PR. Please feel free to open a separate issue later if somebody have time they can pick it up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This point isn’t entirely valid, as we are already planning to test this functionality as part of the release process. This includes verifying the new integrity checks that have been added.

point 2 is totally valid, Do currently we have e2e tests which will do what I have mentioned in point 2 or not ? the answer is no, then it will be create a huge overhead.

we are already planning to test this functionality as part of the release process.

Currently in release process nobody do the performance testing ... question is not about future, question is about do we have it now or not ? I don't think so, we have it snf I don't see it coming in near by weeks.

startTime := time.Now()
rc, err := client.Snapshot(ctx)
if err != nil {
Expand All @@ -265,22 +274,40 @@ func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceClose
timeTaken := time.Since(startTime)
logger.Infof("Total time taken by Snapshot API: %f seconds.", timeTaken.Seconds())

var snapshotData io.ReadCloser
snapshotTempDBPath := filepath.Join(tempDir, "db")
seshachalam-yv marked this conversation as resolved.
Show resolved Hide resolved
if snapshotData, err = checkFullSnapshotIntegrity(rc, snapshotTempDBPath, logger); err != nil {
logger.Errorf("verification of full snapshot SHA256 hash has failed: %v", err)
return nil, err
}
logger.Info("full snapshot SHA256 hash has been successfully verified.")

defer func() {
if err := os.Remove(snapshotTempDBPath); err != nil {
logger.Warnf("failed to remove temporary full snapshot file: %v", err)
}
}()

if cc.Enabled {
startTimeCompression := time.Now()
rc, err = compressor.CompressSnapshot(rc, cc.CompressionPolicy)
snapshotData, err = compressor.CompressSnapshot(snapshotData, cc.CompressionPolicy)
renormalize marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, fmt.Errorf("unable to obtain reader for compressed file: %v", err)
}
timeTakenCompression := time.Since(startTimeCompression)
logger.Infof("Total time taken in full snapshot compression: %f seconds.", timeTakenCompression.Seconds())
}
defer rc.Close()
defer func() {
if err := snapshotData.Close(); err != nil {
logger.Warnf("failed to close snapshot data file: %v", err)
}
}()

logger.Infof("Successfully opened snapshot reader on etcd")

// Then save the snapshot to the store.
snapshot := snapstore.NewSnapshot(brtypes.SnapshotKindFull, 0, lastRevision, suffix, isFinal)
if err := store.Save(*snapshot, rc); err != nil {
if err := store.Save(*snapshot, snapshotData); err != nil {
timeTaken := time.Since(startTime)
metrics.SnapshotDurationSeconds.With(prometheus.Labels{metrics.LabelKind: brtypes.SnapshotKindFull, metrics.LabelSucceeded: metrics.ValueSucceededFalse}).Observe(timeTaken.Seconds())
return nil, &errors.SnapstoreError{
Expand All @@ -294,3 +321,87 @@ func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceClose

return snapshot, nil
}

// checkFullSnapshotIntegrity verifies the integrity of the full snapshot by comparing
// the appended SHA256 hash of the full snapshot with the calculated SHA256 hash of the full snapshot data.
func checkFullSnapshotIntegrity(snapshotData io.ReadCloser, snapTempDBFilePath string, logger *logrus.Entry) (io.ReadCloser, error) {
seshachalam-yv marked this conversation as resolved.
Show resolved Hide resolved
logger.Info("checking the full snapshot integrity with the help of SHA256")

// If previous temp db file already exist then remove it.
if err := os.Remove(snapTempDBFilePath); err != nil && !os.IsNotExist(err) {
return nil, err
}

db, err := os.OpenFile(snapTempDBFilePath, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
return nil, err
}
ishan16696 marked this conversation as resolved.
Show resolved Hide resolved

if _, err := io.Copy(db, snapshotData); err != nil {
return nil, err
}

lastOffset, err := db.Seek(0, io.SeekEnd)
if err != nil {
return nil, err
}
// 512 is chosen because it's a minimum disk sector size in most systems.
hasHash := (lastOffset % 512) == sha256.Size
if !hasHash {
return nil, fmt.Errorf("SHA256 hash seems to be missing from snapshot data")
}

totalSnapshotBytes, err := db.Seek(-sha256.Size, io.SeekEnd)
if err != nil {
return nil, err
}

// get snapshot SHA256 hash
sha := make([]byte, sha256.Size)
if _, err := db.Read(sha); err != nil {
return nil, fmt.Errorf("failed to read SHA256 from snapshot data %v", err)
}

buf := make([]byte, hashBufferSize)
hash := sha256.New()

logger.Infof("Total no. of bytes received from snapshot api call with SHA: %d", lastOffset)
logger.Infof("Total no. of bytes received from snapshot api call without SHA: %d", totalSnapshotBytes)

// reset the file pointer back to starting
currentOffset, err := db.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}

for currentOffset+hashBufferSize <= totalSnapshotBytes {
offset, err := db.Read(buf)
if err != nil {
return nil, fmt.Errorf("unable to read snapshot data into buffer to calculate SHA256: %v", err)
}

hash.Write(buf[:offset])
currentOffset += int64(offset)
}

if currentOffset < totalSnapshotBytes {
if _, err := db.Read(buf); err != nil {
return nil, fmt.Errorf("unable to read last chunk of snapshot data into buffer to calculate SHA256: %v", err)
}

hash.Write(buf[:totalSnapshotBytes-currentOffset])
}

dbSha := hash.Sum(nil)
if !bytes.Equal(sha, dbSha) {
return nil, fmt.Errorf("expected SHA256 for full snapshot: %v, got %v", sha, dbSha)
}

// reset the file pointer back to starting
if _, err := db.Seek(0, io.SeekStart); err != nil {
return nil, err
}

// full-snapshot of database has been successfully verified.
return db, nil
}
Comment on lines +324 to +407
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted certain operations, such as snapshot creation (createEtcdSnapshot) and integrity verification (verifyFullSnapshotIntegrity), into separate functions. This improves readability, maintainability, and allows for better unit testing of individual functions.

Suggested change
// checkFullSnapshotIntegrity verifies the integrity of the full snapshot by comparing
// the appended SHA256 hash of the full snapshot with the calculated SHA256 hash of the full snapshot data.
func checkFullSnapshotIntegrity(snapshotData io.ReadCloser, snapTempDBFilePath string, logger *logrus.Entry) (io.ReadCloser, error) {
logger.Info("checking the full snapshot integrity with the help of SHA256")
// If previous temp db file already exist then remove it.
if err := os.Remove(snapTempDBFilePath); err != nil && !os.IsNotExist(err) {
return nil, err
}
db, err := os.OpenFile(snapTempDBFilePath, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
return nil, err
}
if _, err := io.Copy(db, snapshotData); err != nil {
return nil, err
}
lastOffset, err := db.Seek(0, io.SeekEnd)
if err != nil {
return nil, err
}
// 512 is chosen because it's a minimum disk sector size in most systems.
hasHash := (lastOffset % 512) == sha256.Size
if !hasHash {
return nil, fmt.Errorf("SHA256 hash seems to be missing from snapshot data")
}
totalSnapshotBytes, err := db.Seek(-sha256.Size, io.SeekEnd)
if err != nil {
return nil, err
}
// get snapshot SHA256 hash
sha := make([]byte, sha256.Size)
if _, err := db.Read(sha); err != nil {
return nil, fmt.Errorf("failed to read SHA256 from snapshot data %v", err)
}
buf := make([]byte, hashBufferSize)
hash := sha256.New()
logger.Infof("Total no. of bytes received from snapshot api call with SHA: %d", lastOffset)
logger.Infof("Total no. of bytes received from snapshot api call without SHA: %d", totalSnapshotBytes)
// reset the file pointer back to starting
currentOffset, err := db.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
for currentOffset+hashBufferSize <= totalSnapshotBytes {
offset, err := db.Read(buf)
if err != nil {
return nil, fmt.Errorf("unable to read snapshot data into buffer to calculate SHA256: %v", err)
}
hash.Write(buf[:offset])
currentOffset += int64(offset)
}
if currentOffset < totalSnapshotBytes {
if _, err := db.Read(buf); err != nil {
return nil, fmt.Errorf("unable to read last chunk of snapshot data into buffer to calculate SHA256: %v", err)
}
hash.Write(buf[:totalSnapshotBytes-currentOffset])
}
dbSha := hash.Sum(nil)
if !bytes.Equal(sha, dbSha) {
return nil, fmt.Errorf("expected SHA256 for full snapshot: %v, got %v", sha, dbSha)
}
// reset the file pointer back to starting
if _, err := db.Seek(0, io.SeekStart); err != nil {
return nil, err
}
// full-snapshot of database has been successfully verified.
return db, nil
}
// compressSnapshotDataIfNeeded compresses the snapshot data if compression is enabled.
func compressSnapshotDataIfNeeded(snapshotData io.ReadCloser, compressionConfig *compressor.CompressionConfig, logger *logrus.Entry) (io.ReadCloser, error) {
if compressionConfig != nil && compressionConfig.Enabled {
startTime := time.Now()
logger.Infof("Compression enabled. Starting compression of snapshot data.")
compressedData, err := compressor.CompressSnapshot(snapshotData, compressionConfig.CompressionPolicy)
if err != nil {
logger.Errorf("Failed to compress snapshot data: %v", err)
return nil, fmt.Errorf("unable to obtain reader for compressed file: %v", err)
}
logger.Infof("Total time taken in full snapshot compression: %f seconds.", time.Since(startTime).Seconds())
return compressedData, nil
}
return snapshotData, nil
}
// createEtcdSnapshot initiates the etcd snapshot process.
func createEtcdSnapshot(ctx context.Context, client client.MaintenanceCloser, logger *logrus.Entry) (io.ReadCloser, error) {
startTime := time.Now()
snapshotReader, err := client.Snapshot(ctx)
if err != nil {
return nil, &errors.EtcdError{
Message: fmt.Sprintf("Failed to create etcd snapshot: %v", err),
}
}
logger.Infof("Total time taken by Snapshot API: %f seconds.", time.Since(startTime).Seconds())
return snapshotReader, nil
}
// verifyFullSnapshotIntegrity verifies the integrity of the full snapshot.
func verifyFullSnapshotIntegrity(snapshotData io.ReadCloser, snapTempDBFilePath string, logger *logrus.Entry) (io.ReadCloser, error) {
logger.Info("Verifying full snapshot integrity using SHA256")
// Remove previous temporary DB file if it exists
if err := os.Remove(snapTempDBFilePath); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to remove previous temp DB file %s: %v", snapTempDBFilePath, err)
}
db, err := os.OpenFile(snapTempDBFilePath, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
return nil, fmt.Errorf("failed to create temporary DB file at %s: %v", snapTempDBFilePath, err)
}
buf := make([]byte, hashBufferSize)
if _, err := io.CopyBuffer(db, snapshotData, buf); err != nil {
return nil, fmt.Errorf("failed to copy snapshot data to temporary DB file %s: %v", snapTempDBFilePath, err)
}
// Verify SHA256 hash
if err := validateSnapshotSHA256(db, logger); err != nil {
return nil, err
}
// Reset the file pointer back to the beginning
if err := resetFilePointer(db); err != nil {
return nil, err
}
return db, nil
}
// validateSnapshotSHA256 checks the SHA256 hash appended to the snapshot.
func validateSnapshotSHA256(db *os.File, logger *logrus.Entry) error {
lastOffset, err := db.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("failed to seek to end of file %s: %v", db.Name(), err)
}
totalSnapshotBytes, err := db.Seek(-sha256.Size, io.SeekEnd)
if err != nil {
return fmt.Errorf("failed to seek to SHA256 offset in file %s: %v", db.Name(), err)
}
// Get snapshot SHA256 hash
sha := make([]byte, sha256.Size)
if _, err := io.ReadFull(db, sha); err != nil {
return fmt.Errorf("failed to read SHA256 from snapshot data in file %s: %v", db.Name(), err)
}
hash := sha256.New()
logger.Infof("Total bytes received from snapshot API call (including SHA256 hash): %d", lastOffset)
logger.Infof("Total bytes received from snapshot API call (excluding SHA256 hash): %d", totalSnapshotBytes)
// Reset file pointer and calculate hash
if err := resetFilePointer(db); err != nil {
return fmt.Errorf("failed to reset file pointer for file %s: %v", db.Name(), err)
}
limitedReader := io.LimitReader(db, totalSnapshotBytes)
buf := make([]byte, hashBufferSize)
if _, err := io.CopyBuffer(hash, limitedReader, buf); err != nil {
return fmt.Errorf("failed to calculate SHA256 for file %s: %v", db.Name(), err)
}
dbSha := hash.Sum(nil)
if !bytes.Equal(sha, dbSha) {
return fmt.Errorf("expected SHA256 for full snapshot: %x, got %x", sha, dbSha)
}
return nil
}
// saveSnapshotToStore saves the snapshot to the SnapStore.
func saveSnapshotToStore(store brtypes.SnapStore, snapshot *brtypes.Snapshot, snapshotData io.ReadCloser, startTime time.Time, logger *logrus.Entry) error {
if err := store.Save(*snapshot, snapshotData); err != nil {
timeTaken := time.Since(startTime)
metrics.SnapshotDurationSeconds.With(prometheus.Labels{metrics.LabelKind: brtypes.SnapshotKindFull, metrics.LabelSucceeded: metrics.ValueSucceededFalse}).Observe(timeTaken.Seconds())
return &errors.SnapstoreError{
Message: fmt.Errorf("failed to save snapshot: %w", err).Error(),
}
}
timeTaken := time.Since(startTime)
metrics.SnapshotDurationSeconds.With(prometheus.Labels{metrics.LabelKind: brtypes.SnapshotKindFull, metrics.LabelSucceeded: metrics.ValueSucceededTrue}).Observe(timeTaken.Seconds())
return nil
}
// resetFilePointer resets the file pointer to the beginning of the file.
func resetFilePointer(db *os.File) error {
if _, err := db.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to reset file pointer: %w", err)
}
return nil
}
// cleanUpTempFile removes the temporary file used during snapshot verification.
func cleanUpTempFile(filePath string, logger *logrus.Entry) {
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
logger.Warnf("Failed to remove temporary full snapshot file: %v", err)
}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// resetFilePointer resets the file pointer to the beginning of the file.
func resetFilePointer(db *os.File) error {
if _, err := db.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to reset file pointer: %w", err)
}
return nil
}
// cleanUpTempFile removes the temporary file used during snapshot verification.
func cleanUpTempFile(filePath string, logger *logrus.Entry) {
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
logger.Warnf("Failed to remove temporary full snapshot file: %v", err)
}
}

I wouldn't like to create a separate function for one liner.

Copy link
Member Author

@ishan16696 ishan16696 Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// compressSnapshotDataIfNeeded compresses the snapshot data if compression is enabled.
func compressSnapshotDataIfNeeded(snapshotData io.ReadCloser, compressionConfig *compressor.CompressionConfig, logger *logrus.Entry) (io.ReadCloser, error) {
if compressionConfig != nil && compressionConfig.Enabled {
startTime := time.Now()
logger.Infof("Compression enabled. Starting compression of snapshot data.")
compressedData, err := compressor.CompressSnapshot(snapshotData, compressionConfig.CompressionPolicy)
if err != nil {
logger.Errorf("Failed to compress snapshot data: %v", err)
return nil, fmt.Errorf("unable to obtain reader for compressed file: %v", err)
}
logger.Infof("Total time taken in full snapshot compression: %f seconds.", time.Since(startTime).Seconds())
return compressedData, nil
}
return snapshotData, nil
}

for this as I mentioned in #779 (comment). I wouldn't like to refactor this as this is out of scope of this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

buf := make([]byte, hashBufferSize)
if _, err := io.CopyBuffer(db, snapshotData, buf); err != nil {
return nil, fmt.Errorf("failed to copy snapshot data to temporary DB file %s: %v", snapTempDBFilePath, err)
}
// Verify SHA256 hash
if err := validateSnapshotSHA256(db, logger); err != nil {
return nil, err
}
// Reset the file pointer back to the beginning
if err := resetFilePointer(db); err != nil {
return nil, err
}

this is wrong code, this is not how it's working.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I know what you mean by wrong.

Copy link
Member Author

@ishan16696 ishan16696 Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we have to read the data of snapshots in small buffer and keep calculating the hash

	buf := make([]byte, hashBufferSize)
	hash := sha256.New()

	// reset the file pointer back to starting
	currentOffset, err := db.Seek(0, io.SeekStart)
	if err != nil {
		return nil, err
	}

	for currentOffset+hashBufferSize <= totalSnapshotBytes {
		offset, err := db.Read(buf)
		if err != nil {
			return nil, fmt.Errorf("unable to read snapshot data into buffer to calculate SHA256: %v", err)
		}

		hash.Write(buf[:offset])
		currentOffset += int64(offset)
	}

I see you are using io.LimitReader but IMO we want to avoid that as we don't want to load all db data (~8Gi in worst case) into memory, that's why we went with this approach of reading the data in small chunk, so that there won't be any memory spike as you have yourself has verified this: #779 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

io.LimitReader does not load all the data (~8Gi in the worst case) into memory at once. Instead, it ensures that at most totalSnapshotBytes are read from the underlying db, and the actual memory usage is determined by the buffer size provided to io.CopyBuffer (in this case, hashBufferSize). This keeps the memory footprint small and controlled while simplifying the code by removing manual offset tracking and loop management.

If there are still concerns about the behavior or specific cases where io.LimitReader might not meet expectations, I’d be happy to investigate further. Let me know if you’d like me to profile this or make adjustments!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, if you're confident about this io.LimitReader, then let me explore this once and get back to you.

2 changes: 1 addition & 1 deletion pkg/snapshot/snapshotter/snapshotter.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ func (ssr *Snapshotter) takeFullSnapshot(isFinal bool) (*brtypes.Snapshot, error
}
defer clientMaintenance.Close()

s, err := etcdutil.TakeAndSaveFullSnapshot(ctx, clientMaintenance, ssr.store, lastRevision, ssr.compressionConfig, compressionSuffix, isFinal, ssr.logger)
s, err := etcdutil.TakeAndSaveFullSnapshot(ctx, clientMaintenance, ssr.store, ssr.snapstoreConfig.TempDir, lastRevision, ssr.compressionConfig, compressionSuffix, isFinal, ssr.logger)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions pkg/types/compactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
type CompactOptions struct {
*RestoreOptions
*CompactorConfig
TempDir string
}

// CompactorConfig holds all configuration options related to `compact` subcommand.
Expand Down