diff --git a/cmd/compact.go b/cmd/compact.go index cb2d115d2..0147aadf8 100644 --- a/cmd/compact.go +++ b/cmd/compact.go @@ -55,6 +55,7 @@ func NewCompactCommand(ctx context.Context) *cobra.Command { compactOptions := &brtypes.CompactOptions{ RestoreOptions: options, CompactorConfig: opts.compactorConfig, + TempDir: opts.snapstoreConfig.TempDir, } snapshot, err := cp.Compact(ctx, compactOptions) diff --git a/pkg/compactor/compactor.go b/pkg/compactor/compactor.go index be9140f5d..7d38671b7 100644 --- a/pkg/compactor/compactor.go +++ b/pkg/compactor/compactor.go @@ -166,7 +166,7 @@ func (cp *Compactor) Compact(ctx context.Context, opts *brtypes.CompactOptions) isFinal := compactorRestoreOptions.BaseSnapshot.IsFinal cc := &compressor.CompressionConfig{Enabled: isCompressed, CompressionPolicy: compressionPolicy} - snapshot, err := etcdutil.TakeAndSaveFullSnapshot(snapshotReqCtx, clientMaintenance, cp.store, etcdRevision, cc, suffix, isFinal, cp.logger) + snapshot, err := etcdutil.TakeAndSaveFullSnapshot(snapshotReqCtx, clientMaintenance, cp.store, opts.TempDir, etcdRevision, cc, suffix, isFinal, cp.logger) if err != nil { return nil, err } diff --git a/pkg/compactor/compactor_test.go b/pkg/compactor/compactor_test.go index e10e5d23d..b899fdeb5 100644 --- a/pkg/compactor/compactor_test.go +++ b/pkg/compactor/compactor_test.go @@ -60,12 +60,18 @@ var _ = Describe("Running Compactor", func() { var compactorConfig *brtypes.CompactorConfig var compactOptions *brtypes.CompactOptions var compactedSnapshot *brtypes.Snapshot + var snapstoreConfig *brtypes.SnapstoreConfig var tempRestoreDir string var tempDataDir string BeforeEach(func() { dir = fmt.Sprintf("%s/etcd/snapshotter.bkp", testSuiteDir) - store, err = snapstore.GetSnapstore(&brtypes.SnapstoreConfig{Container: dir, Provider: "Local"}) + snapstoreConfig = &brtypes.SnapstoreConfig{ + Container: dir, + Provider: "Local", + } + + store, err = snapstore.GetSnapstore(snapstoreConfig) Expect(err).ShouldNot(HaveOccurred()) fmt.Println("The store where compaction will save snapshot is: ", store) @@ -104,6 +110,7 @@ var _ = Describe("Running Compactor", func() { compactOptions = &brtypes.CompactOptions{ RestoreOptions: restoreOpts, CompactorConfig: compactorConfig, + TempDir: snapstoreConfig.TempDir, } }) diff --git a/pkg/etcdutil/etcdutil.go b/pkg/etcdutil/etcdutil.go index fd5c06093..e962c62a1 100644 --- a/pkg/etcdutil/etcdutil.go +++ b/pkg/etcdutil/etcdutil.go @@ -5,9 +5,14 @@ package etcdutil import ( + "bytes" "context" + "crypto/sha256" "crypto/tls" "fmt" + "io" + "os" + "path/filepath" "time" "github.com/gardener/etcd-backup-restore/pkg/compressor" @@ -23,6 +28,10 @@ import ( "go.etcd.io/etcd/pkg/transport" ) +const ( + hashBufferSize = 4 * 1024 * 1024 // 4 MB +) + // NewFactory returns a Factory that constructs new clients using the supplied ETCD client configuration. func NewFactory(cfg brtypes.EtcdConnectionConfig, opts ...client.Option) client.Factory { options := &client.Options{} @@ -232,7 +241,7 @@ func GetEtcdEndPointsSorted(ctx context.Context, clientMaintenance client.Mainte endPoint = etcdEndpoints[0] } else { return nil, nil, &errors.EtcdError{ - Message: fmt.Sprintf("etcd endpoints are not passed correctly"), + Message: "etcd endpoints are not passed correctly", } } @@ -254,7 +263,7 @@ func GetEtcdEndPointsSorted(ctx context.Context, clientMaintenance client.Mainte } // TakeAndSaveFullSnapshot takes full snapshot and save it to store -func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceCloser, store brtypes.SnapStore, lastRevision int64, cc *compressor.CompressionConfig, suffix string, isFinal bool, logger *logrus.Entry) (*brtypes.Snapshot, error) { +func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceCloser, store brtypes.SnapStore, tempDir string, lastRevision int64, cc *compressor.CompressionConfig, suffix string, isFinal bool, logger *logrus.Entry) (*brtypes.Snapshot, error) { startTime := time.Now() rc, err := client.Snapshot(ctx) if err != nil { @@ -265,22 +274,40 @@ func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceClose timeTaken := time.Since(startTime) logger.Infof("Total time taken by Snapshot API: %f seconds.", timeTaken.Seconds()) + var snapshotData io.ReadCloser + snapshotTempDBPath := filepath.Join(tempDir, "db") + if snapshotData, err = checkFullSnapshotIntegrity(rc, snapshotTempDBPath, logger); err != nil { + logger.Errorf("verification of full snapshot SHA256 hash has failed: %v", err) + return nil, err + } + logger.Info("full snapshot SHA256 hash has been successfully verified.") + + defer func() { + if err := os.Remove(snapshotTempDBPath); err != nil { + logger.Warnf("failed to remove temporary full snapshot file: %v", err) + } + }() + if cc.Enabled { startTimeCompression := time.Now() - rc, err = compressor.CompressSnapshot(rc, cc.CompressionPolicy) + snapshotData, err = compressor.CompressSnapshot(snapshotData, cc.CompressionPolicy) if err != nil { return nil, fmt.Errorf("unable to obtain reader for compressed file: %v", err) } timeTakenCompression := time.Since(startTimeCompression) logger.Infof("Total time taken in full snapshot compression: %f seconds.", timeTakenCompression.Seconds()) } - defer rc.Close() + defer func() { + if err := snapshotData.Close(); err != nil { + logger.Warnf("failed to close snapshot data file: %v", err) + } + }() logger.Infof("Successfully opened snapshot reader on etcd") // Then save the snapshot to the store. snapshot := snapstore.NewSnapshot(brtypes.SnapshotKindFull, 0, lastRevision, suffix, isFinal) - if err := store.Save(*snapshot, rc); err != nil { + if err := store.Save(*snapshot, snapshotData); err != nil { timeTaken := time.Since(startTime) metrics.SnapshotDurationSeconds.With(prometheus.Labels{metrics.LabelKind: brtypes.SnapshotKindFull, metrics.LabelSucceeded: metrics.ValueSucceededFalse}).Observe(timeTaken.Seconds()) return nil, &errors.SnapstoreError{ @@ -294,3 +321,87 @@ func TakeAndSaveFullSnapshot(ctx context.Context, client client.MaintenanceClose return snapshot, nil } + +// checkFullSnapshotIntegrity verifies the integrity of the full snapshot by comparing +// the appended SHA256 hash of the full snapshot with the calculated SHA256 hash of the full snapshot data. +func checkFullSnapshotIntegrity(snapshotData io.ReadCloser, snapTempDBFilePath string, logger *logrus.Entry) (io.ReadCloser, error) { + logger.Info("checking the full snapshot integrity with the help of SHA256") + + // If previous temp db file already exist then remove it. + if err := os.Remove(snapTempDBFilePath); err != nil && !os.IsNotExist(err) { + return nil, err + } + + db, err := os.OpenFile(snapTempDBFilePath, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return nil, err + } + + if _, err := io.Copy(db, snapshotData); err != nil { + return nil, err + } + + lastOffset, err := db.Seek(0, io.SeekEnd) + if err != nil { + return nil, err + } + // 512 is chosen because it's a minimum disk sector size in most systems. + hasHash := (lastOffset % 512) == sha256.Size + if !hasHash { + return nil, fmt.Errorf("SHA256 hash seems to be missing from snapshot data") + } + + totalSnapshotBytes, err := db.Seek(-sha256.Size, io.SeekEnd) + if err != nil { + return nil, err + } + + // get snapshot SHA256 hash + sha := make([]byte, sha256.Size) + if _, err := db.Read(sha); err != nil { + return nil, fmt.Errorf("failed to read SHA256 from snapshot data %v", err) + } + + buf := make([]byte, hashBufferSize) + hash := sha256.New() + + logger.Infof("Total no. of bytes received from snapshot api call with SHA: %d", lastOffset) + logger.Infof("Total no. of bytes received from snapshot api call without SHA: %d", totalSnapshotBytes) + + // reset the file pointer back to starting + currentOffset, err := db.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + + for currentOffset+hashBufferSize <= totalSnapshotBytes { + offset, err := db.Read(buf) + if err != nil { + return nil, fmt.Errorf("unable to read snapshot data into buffer to calculate SHA256: %v", err) + } + + hash.Write(buf[:offset]) + currentOffset += int64(offset) + } + + if currentOffset < totalSnapshotBytes { + if _, err := db.Read(buf); err != nil { + return nil, fmt.Errorf("unable to read last chunk of snapshot data into buffer to calculate SHA256: %v", err) + } + + hash.Write(buf[:totalSnapshotBytes-currentOffset]) + } + + dbSha := hash.Sum(nil) + if !bytes.Equal(sha, dbSha) { + return nil, fmt.Errorf("expected SHA256 for full snapshot: %v, got %v", sha, dbSha) + } + + // reset the file pointer back to starting + if _, err := db.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + // full-snapshot of database has been successfully verified. + return db, nil +} diff --git a/pkg/snapshot/snapshotter/snapshotter.go b/pkg/snapshot/snapshotter/snapshotter.go index bac30f99b..0a3342b77 100644 --- a/pkg/snapshot/snapshotter/snapshotter.go +++ b/pkg/snapshot/snapshotter/snapshotter.go @@ -364,7 +364,7 @@ func (ssr *Snapshotter) takeFullSnapshot(isFinal bool) (*brtypes.Snapshot, error } defer clientMaintenance.Close() - s, err := etcdutil.TakeAndSaveFullSnapshot(ctx, clientMaintenance, ssr.store, lastRevision, ssr.compressionConfig, compressionSuffix, isFinal, ssr.logger) + s, err := etcdutil.TakeAndSaveFullSnapshot(ctx, clientMaintenance, ssr.store, ssr.snapstoreConfig.TempDir, lastRevision, ssr.compressionConfig, compressionSuffix, isFinal, ssr.logger) if err != nil { return nil, err } diff --git a/pkg/types/compactor.go b/pkg/types/compactor.go index 8e86fec75..cda424cd7 100644 --- a/pkg/types/compactor.go +++ b/pkg/types/compactor.go @@ -25,6 +25,7 @@ const ( type CompactOptions struct { *RestoreOptions *CompactorConfig + TempDir string } // CompactorConfig holds all configuration options related to `compact` subcommand.