Skip to content

Commit

Permalink
Merge pull request tonistiigi#198 from tonistiigi/tonistiigi/fix-hard…
Browse files Browse the repository at this point in the history
…link-filter

fix hardlink filter regression
  • Loading branch information
jedevc authored Apr 24, 2024
2 parents 497d33b + 16fccd4 commit 91a3fc4
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 1 deletion.
68 changes: 68 additions & 0 deletions hardlinks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package fsutil

import (
"context"
"io"
gofs "io/fs"
"os"
"syscall"

Expand Down Expand Up @@ -46,3 +49,68 @@ func (v *Hardlinks) HandleChange(kind ChangeKind, p string, fi os.FileInfo, err

return nil
}

// WithHardlinkReset returns a FS that fixes hardlinks for FS that has been filtered
// so that original hardlink sources might be missing
func WithHardlinkReset(fs FS) FS {
return &hardlinkFilter{fs: fs}
}

type hardlinkFilter struct {
fs FS
}

var _ FS = &hardlinkFilter{}

func (r *hardlinkFilter) Walk(ctx context.Context, target string, fn gofs.WalkDirFunc) error {
seenFiles := make(map[string]string)
return r.fs.Walk(ctx, target, func(path string, entry gofs.DirEntry, err error) error {
if err != nil {
return err
}

fi, err := entry.Info()
if err != nil {
return err
}

if fi.IsDir() || fi.Mode()&os.ModeSymlink != 0 {
return fn(path, entry, nil)
}

stat, ok := fi.Sys().(*types.Stat)
if !ok {
return errors.WithStack(&os.PathError{Path: path, Err: syscall.EBADMSG, Op: "fileinfo without stat info"})
}

if stat.Linkname != "" {
if v, ok := seenFiles[stat.Linkname]; !ok {
seenFiles[stat.Linkname] = stat.Path
stat.Linkname = ""
entry = &dirEntryWithStat{DirEntry: entry, stat: stat}
} else {
if v != stat.Path {
stat.Linkname = v
entry = &dirEntryWithStat{DirEntry: entry, stat: stat}
}
}
}

seenFiles[path] = stat.Path

return fn(path, entry, nil)
})
}

func (r *hardlinkFilter) Open(p string) (io.ReadCloser, error) {
return r.fs.Open(p)
}

type dirEntryWithStat struct {
gofs.DirEntry
stat *types.Stat
}

func (d *dirEntryWithStat) Info() (gofs.FileInfo, error) {
return &StatInfo{d.stat}, nil
}
66 changes: 66 additions & 0 deletions receive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,72 @@ func TestCopySwitchDirToFile(t *testing.T) {
`, b.String())
}

func TestHardlinkFilter(t *testing.T) {
d, err := tmpDir(changeStream([]string{
"ADD bar file data1",
"ADD foo file >bar",
"ADD foo2 file >bar",
}))
assert.NoError(t, err)
defer os.RemoveAll(d)

assert.NoError(t, err)
defer os.RemoveAll(d)
fs, err := NewFS(d)
assert.NoError(t, err)
fs, err = NewFilterFS(fs, &FilterOpt{})
assert.NoError(t, err)
fs, err = NewFilterFS(fs, &FilterOpt{
IncludePatterns: []string{"foo*"},
Map: func(_ string, s *types.Stat) MapResult {
s.Uid = 0
s.Gid = 0
return MapResultKeep
},
})
assert.NoError(t, err)

dest := t.TempDir()

eg, ctx := errgroup.WithContext(context.Background())
s1, s2 := sockPairProto(ctx)

eg.Go(func() error {
defer s1.(*fakeConnProto).closeSend()
return Send(ctx, s1, fs, nil)
})
eg.Go(func() error {
return Receive(ctx, s2, dest, ReceiveOpt{
Filter: func(p string, s *types.Stat) bool {
if p == "foo2" {
require.Equal(t, "foo", s.Linkname)
}
if runtime.GOOS != "windows" {
// On Windows, Getuid() and Getgid() always return -1
// See: https://pkg.go.dev/os#Getgid
// See: https://pkg.go.dev/os#Geteuid
s.Uid = uint32(os.Getuid())
s.Gid = uint32(os.Getgid())
}
return true
},
})
})
assert.NoError(t, eg.Wait())

dt, err := os.ReadFile(filepath.Join(dest, "foo"))
assert.NoError(t, err)
assert.Equal(t, "data1", string(dt))

st1, err := os.Stat(filepath.Join(dest, "foo"))
assert.NoError(t, err)

st2, err := os.Stat(filepath.Join(dest, "foo2"))
assert.NoError(t, err)

assert.True(t, os.SameFile(st1, st2))
}

func TestCopySimple(t *testing.T) {
d, err := tmpDir(changeStream([]string{
"ADD foo file data1",
Expand Down
2 changes: 1 addition & 1 deletion send.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Stream interface {
func Send(ctx context.Context, conn Stream, fs FS, progressCb func(int, bool)) error {
s := &sender{
conn: &syncStream{Stream: conn},
fs: fs,
fs: WithHardlinkReset(fs),
files: make(map[uint32]string),
progressCb: progressCb,
sendpipeline: make(chan *sendHandle, 128),
Expand Down

0 comments on commit 91a3fc4

Please sign in to comment.