diff --git a/go.mod b/go.mod index 249d93f296..31e71ead7e 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.60.1 github.com/quic-go/quic-go v0.48.1 + github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd github.com/rs/cors v1.11.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 diff --git a/go.sum b/go.sum index 7641d29072..6efe110098 100644 --- a/go.sum +++ b/go.sum @@ -176,6 +176,8 @@ github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -577,6 +579,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd h1:wW6BtayFoKaaDeIvXRE3SZVPOscSKlYD+X3bB749+zk= +github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd/go.mod h1:ib9zVtNgRKiGuoMyUqqL5aNpk+r+++YlyiVIkclVqPg= github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= diff --git a/sql/database.go b/sql/database.go index 4f4224b710..94a9322563 100644 --- a/sql/database.go +++ b/sql/database.go @@ -11,6 +11,7 @@ import ( "strings" "sync" "sync/atomic" + "testing" "time" sqlite "github.com/go-llsqlite/crawshaw" @@ -236,6 +237,15 @@ func InMemory(opts ...Opt) *sqliteDatabase { return db } +// InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. +func InMemoryTest(tb testing.TB, opts ...Opt) *sqliteDatabase { + // When using empty DB schema, we don't want to check for schema drift due to + // "PRAGMA user_version = 0;" in the initial schema retrieved from the DB. + db := InMemory(append(opts, WithNoCheckSchemaDrift())...) + tb.Cleanup(func() { db.Close() }) + return db +} + // Open database with options. // // Database is opened in WAL mode and pragma synchronous=normal. diff --git a/sql/expr/expr.go b/sql/expr/expr.go new file mode 100644 index 0000000000..adef890a59 --- /dev/null +++ b/sql/expr/expr.go @@ -0,0 +1,212 @@ +// Package expr proviedes a simple SQL expression parser and builder. +// It wraps the rqlite/sql package and provides a more convenient API that contains only +// what's needed for the go-spacemesh codebase. +package expr + +import ( + "strings" + + rsql "github.com/rqlite/sql" +) + +// SQL operations. +const ( + NE = rsql.NE // != + EQ = rsql.EQ // = + LE = rsql.LE // <= + LT = rsql.LT // < + GT = rsql.GT // > + GE = rsql.GE // >= + BITAND = rsql.BITAND // & + BITOR = rsql.BITOR // | + BITNOT = rsql.BITNOT // ! + LSHIFT = rsql.LSHIFT // << + RSHIFT = rsql.RSHIFT // >> + PLUS = rsql.PLUS // + + MINUS = rsql.MINUS // - + STAR = rsql.STAR // * + SLASH = rsql.SLASH // / + REM = rsql.REM // % + CONCAT = rsql.CONCAT // || + DOT = rsql.DOT // . + AND = rsql.AND + OR = rsql.OR + NOT = rsql.NOT +) + +// Expr represents a parsed SQL expression. +type Expr = rsql.Expr + +// Statement represents a parsed SQL statement. +type Statement = rsql.Statement + +// MustParse parses an SQL expression and panics if there's an error. +func MustParse(s string) rsql.Expr { + expr, err := rsql.ParseExprString(s) + if err != nil { + panic("error parsing SQL expression: " + err.Error()) + } + return expr +} + +// MustParseStatement parses an SQL statement and panics if there's an error. +func MustParseStatement(s string) rsql.Statement { + st, err := rsql.NewParser(strings.NewReader(s)).ParseStatement() + if err != nil { + panic("error parsing SQL statement: " + err.Error()) + } + return st +} + +// MaybeAnd joins together several SQL expressions with AND, ignoring any nil exprs. +// If no non-nil expressions are passed, nil is returned. +// If a single non-nil expression is passed, that single expression is returned. +// Otherwise, the expressions are joined together with ANDs: +// a AND b AND c AND d. +func MaybeAnd(exprs ...Expr) Expr { + var r Expr + for _, expr := range exprs { + switch { + case expr == nil: + case r == nil: + r = expr + default: + r = Op(r, AND, expr) + } + } + return r +} + +// Ident constructs SQL identifier expression for the identifier with the specified name. +func Ident(name string) *rsql.Ident { + return &rsql.Ident{Name: name} +} + +// Number constructs a number literal. +func Number(value string) *rsql.NumberLit { + return &rsql.NumberLit{Value: value} +} + +// TableSource constructs a Source clause for SELECT statement that corresponds to +// selecting from a single table with the specified name. +func TableSource(name string) rsql.Source { + return &rsql.QualifiedTableName{Name: Ident(name)} +} + +// Op constructs a binary expression such as x + y or x < y. +func Op(x Expr, op rsql.Token, y Expr) Expr { + return &rsql.BinaryExpr{ + X: x, + Op: op, + Y: y, + } +} + +// Bind constructs the unnamed bind expression (?). +func Bind() Expr { + return &rsql.BindExpr{Name: "?"} +} + +// Between constructs BETWEEN expression: x BETWEEN a AND b. +func Between(x, a, b Expr) Expr { + return Op(x, rsql.BETWEEN, &rsql.Range{X: a, Y: b}) +} + +// Call constructs a call expression with specified arguments such as max(x). +func Call(name string, args ...Expr) Expr { + return &rsql.Call{Name: Ident(name), Args: args} +} + +// CountStar returns a COUNT(*) expression. +func CountStar() Expr { + return &rsql.Call{Name: Ident("count"), Star: rsql.Pos{Offset: 1}} +} + +// Asc constructs an ascending ORDER BY term. +func Asc(expr Expr) *rsql.OrderingTerm { + return &rsql.OrderingTerm{X: expr} +} + +// Desc constructs a descedning ORDER BY term. +func Desc(expr Expr) *rsql.OrderingTerm { + return &rsql.OrderingTerm{X: expr, Desc: rsql.Pos{Offset: 1}} +} + +// SelectBuilder is used to construct a SELECT statement. +type SelectBuilder struct { + st *rsql.SelectStatement +} + +// Select returns a SELECT statement builder. +func Select(columns ...any) SelectBuilder { + sb := SelectBuilder{st: &rsql.SelectStatement{}} + return sb.Columns(columns...) +} + +// SelectBasedOn returns a SELECT statement builder based on the specified SELECT statement. +// The statement must be parseable, otherwise SelectBasedOn panics. +// The builder methods can be used to alter the statement. +func SelectBasedOn(st Statement) SelectBuilder { + st = rsql.CloneStatement(st) + return SelectBuilder{st: st.(*rsql.SelectStatement)} +} + +// Get returns the underlying SELECT statement. +func (sb SelectBuilder) Get() *rsql.SelectStatement { + return sb.st +} + +// String returns the underlying SELECT statement as a string. +func (sb SelectBuilder) String() string { + return sb.st.String() +} + +// Columns sets columns in the SELECT statement. +func (sb SelectBuilder) Columns(columns ...any) SelectBuilder { + sb.st.Columns = make([]*rsql.ResultColumn, len(columns)) + for n, column := range columns { + switch c := column.(type) { + case *rsql.ResultColumn: + sb.st.Columns[n] = c + case Expr: + sb.st.Columns[n] = &rsql.ResultColumn{Expr: c} + default: + panic("unexpected column type") + } + } + return sb +} + +// From adds FROM clause to the SELECT statement. +func (sb SelectBuilder) From(s rsql.Source) SelectBuilder { + sb.st.Source = s + return sb +} + +// From adds WHERE clause to the SELECT statement. +func (sb SelectBuilder) Where(s Expr) SelectBuilder { + sb.st.WhereExpr = s + return sb +} + +// From adds ORDER BY clause to the SELECT statement. +func (sb SelectBuilder) OrderBy(terms ...*rsql.OrderingTerm) SelectBuilder { + sb.st.OrderingTerms = terms + return sb +} + +// From adds LIMIT clause to the SELECT statement. +func (sb SelectBuilder) Limit(limit Expr) SelectBuilder { + sb.st.LimitExpr = limit + return sb +} + +// ColumnExpr returns nth column expression from the SELECT statement. +func ColumnExpr(st Statement, n int) Expr { + return st.(*rsql.SelectStatement).Columns[n].Expr +} + +// WhereExpr returns WHERE expression from the SELECT statement. +func WhereExpr(st Statement) Expr { + return st.(*rsql.SelectStatement).WhereExpr +} diff --git a/sql/expr/expr_test.go b/sql/expr/expr_test.go new file mode 100644 index 0000000000..8a87bc8395 --- /dev/null +++ b/sql/expr/expr_test.go @@ -0,0 +1,137 @@ +package expr + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExpr(t *testing.T) { + for _, tc := range []struct { + Expr Expr + Expected string + }{ + { + Expr: MustParse("a = ? OR x < 10"), + Expected: `"a" = ? OR "x" < 10`, + }, + { + Expr: Number("1"), + Expected: `1`, + }, + { + Expr: CountStar(), + Expected: `count(*)`, + }, + { + Expr: Op(Ident("x"), EQ, Ident("y")), + Expected: `"x" = "y"`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y"))), + Expected: `"x" = "y"`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y")), nil, nil), + Expected: `"x" = "y"`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y")), + Op(Ident("a"), EQ, Bind())), + Expected: `"x" = "y" AND "a" = ?`, + }, + { + Expr: MaybeAnd(Op(Ident("x"), EQ, Ident("y")), + nil, + Op(Ident("a"), EQ, Bind())), + Expected: `"x" = "y" AND "a" = ?`, + }, + { + Expr: MaybeAnd(), + Expected: "", + }, + { + Expr: Between(Ident("x"), Ident("y"), Bind()), + Expected: `"x" BETWEEN "y" AND ?`, + }, + { + Expr: Call("max", Ident("x")), + Expected: `max("x")`, + }, + { + Expr: MustParse("a.id"), + Expected: `"a"."id"`, + }, + } { + if tc.Expected == "" { + require.Nil(t, tc.Expr) + } else { + require.Equal(t, tc.Expected, tc.Expr.String()) + require.Equal(t, tc.Expected, MustParse(tc.Expected).String()) + } + } +} + +func TestStatement(t *testing.T) { + for _, tc := range []struct { + Statement SelectBuilder + Expected string + Columns []string + }{ + { + Statement: Select(Number("1")), + Expected: `SELECT 1`, + Columns: []string{"1"}, + }, + { + Statement: Select(Call("max", Ident("n"))).From(TableSource("mytable")), + Expected: `SELECT max("n") FROM "mytable"`, + Columns: []string{`max("n")`}, + }, + { + Statement: Select(Ident("id"), Ident("n")). + From(TableSource("mytable")). + Where(Op(Ident("n"), GE, Bind())). + OrderBy(Asc(Ident("n"))). + Limit(Bind()), + Expected: `SELECT "id", "n" FROM "mytable" WHERE "n" >= ? ORDER BY "n" LIMIT ?`, + Columns: []string{`"id"`, `"n"`}, + }, + { + Statement: Select(Ident("id")). + From(TableSource("mytable")). + OrderBy(Desc(Ident("id"))). + Limit(Number("10")), + Expected: `SELECT "id" FROM "mytable" ORDER BY "id" DESC LIMIT 10`, + Columns: []string{`"id"`}, + }, + { + Statement: Select(CountStar()).From(TableSource("mytable")), + Expected: `SELECT count(*) FROM "mytable"`, + Columns: []string{`count(*)`}, + }, + { + Statement: SelectBasedOn( + MustParseStatement("select a.id from a left join b on a.x = b.x")). + Where(Op(Ident("id"), EQ, Bind())), + Expected: `SELECT "a"."id" FROM "a" LEFT JOIN "b" ON "a"."x" = "b"."x" WHERE "id" = ?`, + Columns: []string{`"a"."id"`}, + }, + { + Statement: SelectBasedOn( + MustParseStatement("select a.id from a inner join b on a.x = b.x")). + Columns(CountStar()). + Where(Op(Ident("id"), EQ, Bind())), + Expected: `SELECT count(*) FROM "a" INNER JOIN "b" ON "a"."x" = "b"."x" WHERE "id" = ?`, + Columns: []string{`count(*)`}, + }, + } { + require.Equal(t, tc.Expected, tc.Statement.String()) + st := tc.Statement.Get() + require.Equal(t, tc.Expected, st.String()) + for n, col := range tc.Columns { + require.Equal(t, col, ColumnExpr(st, n).String()) + } + require.Equal(t, tc.Expected, MustParseStatement(tc.Expected).String()) + } +} diff --git a/sync2/rangesync/combine_seqs.go b/sync2/rangesync/combine_seqs.go new file mode 100644 index 0000000000..8c924fd8e3 --- /dev/null +++ b/sync2/rangesync/combine_seqs.go @@ -0,0 +1,196 @@ +package rangesync + +import ( + "iter" + "slices" +) + +type generator struct { + nextFn func() (KeyBytes, bool) + stop func() + k KeyBytes + error SeqErrorFunc + done bool +} + +func gen(sr SeqResult) *generator { + g := &generator{error: sr.Error} + g.nextFn, g.stop = iter.Pull(iter.Seq[KeyBytes](sr.Seq)) + return g +} + +func (g *generator) next() (k KeyBytes, ok bool) { + if g.done { + return nil, false + } + if g.k != nil { + k = g.k + g.k = nil + return k, true + } + return g.nextFn() +} + +func (g *generator) peek() (k KeyBytes, ok bool) { + if !g.done && g.k == nil { + g.k, ok = g.nextFn() + g.done = !ok + } + if g.done { + return nil, false + } + return g.k, true +} + +type combinedSeq struct { + gens []*generator + wrapped []*generator +} + +// CombineSeqs combines multiple ordered sequences from SeqResults into one, returning the +// smallest current key among all iterators at each step. +// startingPoint is used to check if an iterator has wrapped around. If an iterator yields +// a value below startingPoint, it is considered to have wrapped around. +func CombineSeqs(startingPoint KeyBytes, srs ...SeqResult) SeqResult { + var err error + return SeqResult{ + Seq: func(yield func(KeyBytes) bool) { + var c combinedSeq + // We clean up even if c.begin() returned an error so that we don't leak + // any pull iterators that are created before c.begin() failed + defer c.end() + // In case if c.begin() succeeds, the error is reset. If yield + // calls SeqResult's Error function, it will get nil until the + // iteration is finished. + if err = c.begin(startingPoint, srs); err != nil { + return + } + err = c.iterate(yield) + }, + Error: func() error { + return err + }, + } +} + +func (c *combinedSeq) begin(startingPoint KeyBytes, srs []SeqResult) error { + for _, sr := range srs { + g := gen(sr) + k, ok := g.peek() + if !ok { + continue + } + if err := g.error(); err != nil { + return err + } + if startingPoint != nil && k.Compare(startingPoint) < 0 { + c.wrapped = append(c.wrapped, g) + } else { + c.gens = append(c.gens, g) + } + } + if len(c.gens) == 0 { + // all iterators wrapped around + c.gens = c.wrapped + c.wrapped = nil + } + return nil +} + +func (c *combinedSeq) end() { + for _, g := range c.gens { + g.stop() + } + for _, g := range c.wrapped { + g.stop() + } +} + +func (c *combinedSeq) aheadGen() (ahead *generator, aheadIdx int, err error) { + // remove any exhausted generators + j := 0 + for i := range c.gens { + _, ok := c.gens[i].peek() + if ok { + c.gens[j] = c.gens[i] + j++ + } else if err = c.gens[i].error(); err != nil { + return nil, 0, err + } + } + c.gens = c.gens[:j] + // if all the generators have wrapped around, move the wrapped generators + if len(c.gens) == 0 { + if len(c.wrapped) == 0 { + return nil, 0, nil + } + c.gens = c.wrapped + c.wrapped = nil + } + ahead = c.gens[0] + aheadIdx = 0 + aK, _ := ahead.peek() + if err := ahead.error(); err != nil { + return nil, 0, err + } + for i := 1; i < len(c.gens); i++ { + curK, ok := c.gens[i].peek() + if !ok { + // If not all of the generators have wrapped around, then we + // already did a successful peek() on this generator above, so it + // should not be exhausted here. + // If all of the generators have wrapped around, then we have + // moved to the wrapped generators, but the generators may only + // end up in wrapped list after a successful peek(), too. + // So if we get here, then combinedSeq code is broken. + panic("BUG: unexpected exhausted generator") + } + if curK != nil { + if curK.Compare(aK) < 0 { + ahead = c.gens[i] + aheadIdx = i + aK = curK + } + } + } + return ahead, aheadIdx, nil +} + +func (c *combinedSeq) iterate(yield func(KeyBytes) bool) error { + for { + g, idx, err := c.aheadGen() + if err != nil { + return err + } + if g == nil { + return nil + } + k, ok := g.next() + if !ok { + if err := g.error(); err != nil { + return err + } + c.gens = slices.Delete(c.gens, idx, idx+1) + continue + } + if !yield(k) { + return nil + } + newK, ok := g.peek() + if !ok { + if err := g.error(); err != nil { + return err + } + // if this iterator is exhausted, it'll be removed by the + // next aheadGen call + continue + } + if k.Compare(newK) >= 0 { + // the iterator has wrapped around, move it to the wrapped + // list which will be used after all the iterators have + // wrapped around + c.wrapped = append(c.wrapped, g) + c.gens = slices.Delete(c.gens, idx, idx+1) + } + } +} diff --git a/sync2/rangesync/combine_seqs_test.go b/sync2/rangesync/combine_seqs_test.go new file mode 100644 index 0000000000..ec1fbc8849 --- /dev/null +++ b/sync2/rangesync/combine_seqs_test.go @@ -0,0 +1,230 @@ +package rangesync + +import ( + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +var seqTestErr = errors.New("test error") + +type fakeSeqItem struct { + k string + err error + stop bool +} + +func mkFakeSeqItem(s string) fakeSeqItem { + switch s { + case "!": + return fakeSeqItem{err: seqTestErr} + case "$": + return fakeSeqItem{stop: true} + default: + return fakeSeqItem{k: s} + } +} + +type fakeSeq []fakeSeqItem + +func mkFakeSeq(s string) fakeSeq { + seq := make(fakeSeq, len(s)) + for n, c := range s { + seq[n] = mkFakeSeqItem(string(c)) + } + return seq +} + +func (seq fakeSeq) items(startIdx int) SeqResult { + if startIdx > len(seq) { + panic("bad startIdx") + } + var err error + return SeqResult{ + Seq: func(yield func(KeyBytes) bool) { + err = nil + if len(seq) == 0 { + return + } + n := startIdx + for { + if n == len(seq) { + n = 0 + } + item := seq[n] + if item.err != nil { + err = item.err + return + } + if item.stop || !yield(KeyBytes(item.k)) { + return + } + n++ + } + }, + Error: func() error { + return err + }, + } +} + +func seqToStr(t *testing.T, sr SeqResult) string { + var sb strings.Builder + var firstK KeyBytes + wrap := 0 + var s string + for k := range sr.Seq { + require.NoError(t, sr.Error()) + if wrap != 0 { + // after wraparound, make sure the sequence is repeated + if k.Compare(firstK) == 0 { + // arrived to the element for the second time + return s + } + require.Equal(t, s[wrap], k[0]) + wrap++ + continue + } + require.NotNil(t, k) + if firstK == nil { + firstK = k + } else if k.Compare(firstK) == 0 { + s = sb.String() // wraparound + wrap = 1 + continue + } + sb.Write(k) + } + if err := sr.Error(); err != nil { + require.Equal(t, seqTestErr, err) + sb.WriteString("!") // error + return sb.String() + } + return sb.String() + "$" // stop +} + +func TestCombineSeqs(t *testing.T) { + for _, tc := range []struct { + // In each seq, $ means the end of sequence (lack of $ means wraparound), + // and ! means an error. + seqs []string + indices []int + result string + startingPoint string + }{ + { + seqs: []string{"abcd"}, + indices: []int{0}, + result: "abcd", + startingPoint: "a", + }, + { + seqs: []string{"abcd"}, + indices: []int{0}, + result: "abcd", + startingPoint: "c", + }, + { + seqs: []string{"abcd"}, + indices: []int{2}, + result: "cdab", + startingPoint: "c", + }, + { + seqs: []string{"abcd$"}, + indices: []int{0}, + result: "abcd$", + startingPoint: "a", + }, + { + seqs: []string{"abcd!"}, + indices: []int{0}, + result: "abcd!", + startingPoint: "a", + }, + { + seqs: []string{"abcd", "efgh"}, + indices: []int{0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, + { + seqs: []string{"aceg", "bdfh"}, + indices: []int{0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, + { + seqs: []string{"abcd$", "efgh$"}, + indices: []int{0, 0}, + result: "abcdefgh$", + startingPoint: "a", + }, + { + seqs: []string{"aceg$", "bdfh$"}, + indices: []int{0, 0}, + result: "abcdefgh$", + startingPoint: "a", + }, + { + seqs: []string{"abcd!", "efgh!"}, + indices: []int{0, 0}, + result: "abcd!", + startingPoint: "a", + }, + { + seqs: []string{"aceg!", "bdfh!"}, + indices: []int{0, 0}, + result: "abcdefg!", + startingPoint: "a", + }, + { + // wraparound: + // "ac"+"bdefgh" + // abcdefgh ==> + // defghabc + // starting point is d. + // Each sequence must either start after the starting point, or + // all of its elements are considered to be below the starting + // point. "ac" is considered to be wrapped around initially + seqs: []string{"ac", "bdefgh"}, + indices: []int{0, 1}, + result: "defghabc", + startingPoint: "d", + }, + { + seqs: []string{"bc", "ae"}, + indices: []int{0, 1}, + result: "eabc", + startingPoint: "d", + }, + { + seqs: []string{"ac", "bfg", "deh"}, + indices: []int{0, 0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, + { + seqs: []string{"abdefgh", "c"}, + indices: []int{0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, + } { + t.Run("", func(t *testing.T) { + var seqs []SeqResult + for n, s := range tc.seqs { + seqs = append(seqs, mkFakeSeq(s).items(tc.indices[n])) + } + startingPoint := KeyBytes(tc.startingPoint) + combined := CombineSeqs(startingPoint, seqs...) + for range 3 { // make sure the sequence is reusable + require.Equal(t, tc.result, seqToStr(t, combined), + "combine %v (indices %v) starting with %s", + tc.seqs, tc.indices, tc.startingPoint) + } + }) + } +} diff --git a/sync2/sqlstore/dbseq.go b/sync2/sqlstore/dbseq.go new file mode 100644 index 0000000000..bdcadddca0 --- /dev/null +++ b/sync2/sqlstore/dbseq.go @@ -0,0 +1,181 @@ +package sqlstore + +import ( + "errors" + "slices" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// dbSeq represents a sequence of IDs from a database table. +type dbSeq struct { + // database + db sql.Executor + // starting point + from rangesync.KeyBytes + // table snapshot to use + sts *SyncedTableSnapshot + // currently used chunk size + chunkSize int + // timestamp used to fetch recent IDs + // (nanoseconds since epoch, 0 if not in the "recent" mode) + ts int64 + // maximum value for chunkSize + maxChunkSize int + // current chunk of items + chunk []rangesync.KeyBytes + // position within the current chunk + pos int + // lentgh of each key in bytes + keyLen int + // true if there is only a single chunk in the sequence. + // It is set after loading the initial chunk and finding that it's the only one. + singleChunk bool +} + +// idsFromTable iterates over the id field values in a database table. +func idsFromTable( + db sql.Executor, + sts *SyncedTableSnapshot, + from rangesync.KeyBytes, + ts int64, + chunkSize int, + maxChunkSize int, +) rangesync.SeqResult { + if from == nil { + panic("BUG: makeDBIterator: nil from") + } + if maxChunkSize <= 0 { + panic("BUG: makeDBIterator: chunkSize must be > 0") + } + if chunkSize <= 0 { + chunkSize = 1 + } else if chunkSize > maxChunkSize { + chunkSize = maxChunkSize + } + var err error + return rangesync.SeqResult{ + Seq: func(yield func(k rangesync.KeyBytes) bool) { + s := &dbSeq{ + db: db, + from: from.Clone(), + sts: sts, + chunkSize: chunkSize, + ts: ts, + maxChunkSize: maxChunkSize, + keyLen: len(from), + chunk: make([]rangesync.KeyBytes, 1), + singleChunk: false, + } + if err = s.load(); err != nil { + return + } + err = s.iterate(yield) + }, + Error: func() error { + return err + }, + } +} + +// load makes sure the current chunk is loaded. +func (s *dbSeq) load() error { + s.pos = 0 + if s.singleChunk { + // we have a single-chunk DB sequence, don't need to reload, + // just wrap around + return nil + } + + n := 0 + // make sure the chunk is large enough + if cap(s.chunk) < s.chunkSize { + s.chunk = make([]rangesync.KeyBytes, s.chunkSize) + } else { + // if the chunk size was reduced due to a short chunk before wraparound, we need + // to extend it back + s.chunk = s.chunk[:s.chunkSize] + } + + var ierr, err error + dec := func(stmt *sql.Statement) bool { + if n >= len(s.chunk) { + ierr = errors.New("too many rows") + return false + } + // we reuse existing slices when possible for retrieving new IDs + id := s.chunk[n] + if id == nil { + id = make([]byte, s.keyLen) + s.chunk[n] = id + } + stmt.ColumnBytes(0, id) + n++ + return true + } + if s.ts <= 0 { + err = s.sts.LoadRange(s.db, s.from, s.chunkSize, dec) + } else { + err = s.sts.LoadRecent(s.db, s.from, s.chunkSize, s.ts, dec) + } + + fromZero := s.from.IsZero() + s.chunkSize = min(s.chunkSize*2, s.maxChunkSize) + switch { + case ierr != nil: + return ierr + case err != nil: + return err + case n == 0: + // empty chunk + if fromZero { + // already wrapped around or started from 0, + // the set is empty + s.chunk = nil + return nil + } + // wrap around + s.from.Zero() + return s.load() + case n < len(s.chunk): + // short chunk means there are no more items after it, + // start the next chunk from 0 + s.from.Zero() + s.chunk = s.chunk[:n] + // wrapping around on an incomplete chunk that started + // from 0 means we have just a single chunk + s.singleChunk = fromZero + default: + // use last item incremented by 1 as the start of the next chunk + copy(s.from, s.chunk[n-1]) + // inc may wrap around if it's 0xffff...fff, but it's fine + if s.from.Inc() { + // if we wrapped around and the current chunk started from 0, + // we have just a single chunk + s.singleChunk = fromZero + } + } + return nil +} + +// iterate iterates over the table rows. +func (s *dbSeq) iterate(yield func(k rangesync.KeyBytes) bool) error { + if len(s.chunk) == 0 { + return nil + } + for { + if s.pos >= len(s.chunk) { + panic("BUG: bad dbSeq position") + } + if !yield(slices.Clone(s.chunk[s.pos])) { + return nil + } + s.pos++ + if s.pos >= len(s.chunk) { + if err := s.load(); err != nil { + return err + } + } + } +} diff --git a/sync2/sqlstore/dbseq_test.go b/sync2/sqlstore/dbseq_test.go new file mode 100644 index 0000000000..94a9b830b1 --- /dev/null +++ b/sync2/sqlstore/dbseq_test.go @@ -0,0 +1,259 @@ +package sqlstore_test + +import ( + "encoding/hex" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +func TestDBRangeIterator(t *testing.T) { + for _, tc := range []struct { + items []rangesync.KeyBytes + from rangesync.KeyBytes + fromN int + }{ + { + items: nil, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x00}, + }, + { + items: nil, + from: rangesync.KeyBytes{0x80, 0x00, 0x00, 0x00}, + }, + { + items: nil, + from: rangesync.KeyBytes{0xff, 0xff, 0xff, 0xff}, + }, + { + items: []rangesync.KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: rangesync.KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: rangesync.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: rangesync.KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: rangesync.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0xff, 0xff, 0xff, 0xff}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0xff, 0xff, 0xff, 0xff}, + }, + from: rangesync.KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + {0xff, 0xff, 0xff, 0xff}, + }, + from: rangesync.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x01}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x02}, + fromN: 1, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x03}, + fromN: 1, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x05}, + fromN: 2, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x00, 0x07}, + fromN: 3, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: rangesync.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: rangesync.KeyBytes{0x00, 0x00, 0x03, 0x01}, + fromN: 6, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: rangesync.KeyBytes{0x00, 0x01, 0x00, 0x00}, + fromN: 6, + }, + { + items: []rangesync.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: rangesync.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 11, + }, + } { + t.Run("", func(t *testing.T) { + db := sqlstore.CreateDB(t, 4) + sqlstore.InsertDBItems(t, db, tc.items) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + sts, err := st.Snapshot(db) + require.NoError(t, err) + for startChunkSize := 1; startChunkSize < 12; startChunkSize++ { + for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { + sr := sqlstore.IDSFromTable(db, sts, tc.from, -1, + startChunkSize, maxChunkSize) + // when there are no items, errEmptySet is returned + for range 3 { // make sure the sequence is reusable + var collected []rangesync.KeyBytes + var firstK rangesync.KeyBytes + for k := range sr.Seq { + if firstK == nil { + firstK = k + } else if k.Compare(firstK) == 0 { + break + } + collected = append(collected, k) + require.NoError(t, err) + } + require.NoError(t, sr.Error()) + expected := slices.Concat( + tc.items[tc.fromN:], tc.items[:tc.fromN]) + require.Equal( + t, expected, collected, + "count=%d from=%s maxChunkSize=%d", + len(tc.items), hex.EncodeToString(tc.from), + maxChunkSize) + } + } + } + }) + } +} diff --git a/sync2/sqlstore/export_test.go b/sync2/sqlstore/export_test.go new file mode 100644 index 0000000000..6896fc2d2e --- /dev/null +++ b/sync2/sqlstore/export_test.go @@ -0,0 +1,12 @@ +package sqlstore + +import "github.com/spacemeshos/go-spacemesh/sql/expr" + +var IDSFromTable = idsFromTable + +func (st *SyncedTable) GenSelectAll() expr.Statement { return st.genSelectAll() } +func (st *SyncedTable) GenCount() expr.Statement { return st.genCount() } +func (st *SyncedTable) GenSelectMaxRowID() expr.Statement { return st.genSelectMaxRowID() } +func (st *SyncedTable) GenSelectRange() expr.Statement { return st.genSelectRange() } +func (st *SyncedTable) GenRecentCount() expr.Statement { return st.genRecentCount() } +func (st *SyncedTable) GenSelectRecent() expr.Statement { return st.genSelectRecent() } diff --git a/sync2/sqlstore/interface.go b/sync2/sqlstore/interface.go new file mode 100644 index 0000000000..ecb38e93e1 --- /dev/null +++ b/sync2/sqlstore/interface.go @@ -0,0 +1,21 @@ +package sqlstore + +import ( + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// IDStore represents a store of IDs (keys). +type IDStore interface { + // Clone makes a copy of the store. + // It is expected to be an O(1) operation. + Clone() IDStore + // Release releases the resources associated with the store. + Release() + // RegisterKey registers the key with the store. + RegisterKey(k rangesync.KeyBytes) error + // All returns all keys in the store. + All() rangesync.SeqResult + // From returns all keys in the store starting from the given key. + // sizeHint is a hint for the expected number of keys to be returned. + From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult +} diff --git a/sync2/sqlstore/sqlidstore.go b/sync2/sqlstore/sqlidstore.go new file mode 100644 index 0000000000..e26fa37ad7 --- /dev/null +++ b/sync2/sqlstore/sqlidstore.go @@ -0,0 +1,80 @@ +package sqlstore + +import ( + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// max chunk size to use for dbSeq. +const sqlMaxChunkSize = 1024 + +// SQLIDStore is an implementation of IDStore that is based on a database table snapshot. +type SQLIDStore struct { + db sql.Executor + sts *SyncedTableSnapshot + keyLen int +} + +var _ IDStore = &SQLIDStore{} + +// NewSQLIDStore creates a new SQLIDStore. +func NewSQLIDStore(db sql.Executor, sts *SyncedTableSnapshot, keyLen int) *SQLIDStore { + return &SQLIDStore{ + db: db, + sts: sts, + keyLen: keyLen, + } +} + +// Clone creates a new SQLIDStore that shares the same database connection and table snapshot. +// Implements IDStore. +func (s *SQLIDStore) Clone() IDStore { + return NewSQLIDStore(s.db, s.sts, s.keyLen) +} + +// RegisterKey is a no-op for SQLIDStore, as the underlying table is never immediately +// updated upon receiving new items. +// Implements IDStore. +func (s *SQLIDStore) RegisterKey(k rangesync.KeyBytes) error { + // should be registered by the handler code + return nil +} + +// All returns all IDs in the store. +// Implements IDStore. +func (s *SQLIDStore) All() rangesync.SeqResult { + return s.From(make(rangesync.KeyBytes, s.keyLen), 1) +} + +// From returns IDs in the store starting from the given key. +// Implements IDStore. +func (s *SQLIDStore) From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult { + if len(from) != s.keyLen { + panic("BUG: invalid key length") + } + return idsFromTable(s.db, s.sts, from, -1, sizeHint, sqlMaxChunkSize) +} + +// Since returns IDs in the store starting from the given key and timestamp. +func (s *SQLIDStore) Since(from rangesync.KeyBytes, since int64) (rangesync.SeqResult, int) { + if len(from) != s.keyLen { + panic("BUG: invalid key length") + } + count, err := s.sts.LoadRecentCount(s.db, since) + if err != nil { + return rangesync.ErrorSeqResult(err), 0 + } + if count == 0 { + return rangesync.EmptySeqResult(), 0 + } + return idsFromTable(s.db, s.sts, from, since, 1, sqlMaxChunkSize), count +} + +// Sets the table snapshot to use for the store. +func (s *SQLIDStore) SetSnapshot(sts *SyncedTableSnapshot) { + s.sts = sts +} + +// Release is a no-op for SQLIDStore. +// Implements IDStore. +func (s *SQLIDStore) Release() {} diff --git a/sync2/sqlstore/sqlidstore_test.go b/sync2/sqlstore/sqlidstore_test.go new file mode 100644 index 0000000000..b8e75fb6e2 --- /dev/null +++ b/sync2/sqlstore/sqlidstore_test.go @@ -0,0 +1,82 @@ +package sqlstore_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +func TestSQLIdStore(t *testing.T) { + const keyLen = 12 + db := sql.InMemoryTest(t) + _, err := db.Exec( + fmt.Sprintf("create table foo(id char(%d) not null primary key, received int)", keyLen), + nil, nil) + require.NoError(t, err) + for _, row := range []struct { + id rangesync.KeyBytes + ts int64 + }{ + { + id: rangesync.KeyBytes{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 100, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 200, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 300, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 400, + }, + } { + _, err := db.Exec( + "insert into foo (id, received) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, row.id) + stmt.BindInt64(2, row.ts) + }, nil) + require.NoError(t, err) + } + st := sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + TimestampColumn: "received", + } + sts, err := st.Snapshot(db) + require.NoError(t, err) + + store := sqlstore.NewSQLIDStore(db, sts, keyLen) + actualIDs, err := store.From(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 5).FirstN(5) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, // wrapped around + }, actualIDs) + + actualIDs1, err := store.All().FirstN(5) + require.NoError(t, err) + require.Equal(t, actualIDs, actualIDs1) + + sr, count := store.Since(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 300) + require.Equal(t, 2, count) + actualIDs, err = sr.FirstN(3) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, // wrapped around + }, actualIDs) +} diff --git a/sync2/sqlstore/syncedtable.go b/sync2/sqlstore/syncedtable.go new file mode 100644 index 0000000000..0d6c252a21 --- /dev/null +++ b/sync2/sqlstore/syncedtable.go @@ -0,0 +1,289 @@ +package sqlstore + +import ( + "errors" + "fmt" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/expr" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// Binder is a function that binds filter expression parameters to a SQL statement. +type Binder func(s *sql.Statement) + +// SyncedTable represents a table that can be used with SQLIDStore. +type SyncedTable struct { + // The name of the table. + TableName string + // The name of the ID column. + IDColumn string + // The name of the timestamp column. + TimestampColumn string + // The filter expression. + Filter expr.Expr + // The binder function for the bind parameters appearing in the filter expression. + Binder Binder + queries map[string]string +} + +func (st *SyncedTable) cacheQuery(name string, gen func() expr.Statement) string { + s, ok := st.queries[name] + if ok { + return s + } + if st.queries == nil { + st.queries = make(map[string]string) + } + s = gen().String() + st.queries[name] = s + return s +} + +func (st *SyncedTable) exec( + db sql.Executor, + name string, + gen func() expr.Statement, + enc sql.Encoder, + dec sql.Decoder, +) error { + _, err := db.Exec(st.cacheQuery(name, gen), func(stmt *sql.Statement) { + if st.Binder != nil { + st.Binder(stmt) + } + enc(stmt) + }, dec) + return err +} + +// genSelectMaxRowID generates a SELECT statement that returns the maximum rowid in the +// table. +func (st *SyncedTable) genSelectMaxRowID() expr.Statement { + return expr.Select(expr.Call("max", expr.Ident("rowid"))). + From(expr.TableSource(st.TableName)). + Get() +} + +// rowIDCutoff returns an expression that represents a rowid cutoff, that is, limits the +// rowid to be less than or equal to a bind parameter. +func (st *SyncedTable) rowIDCutoff() expr.Expr { + return expr.Op(expr.Ident("rowid"), expr.LE, expr.Bind()) +} + +// timestampCutoff returns an expression that represents a timestamp cutoff, that is, limits the +// timestamp to be greater than or equal to a bind parameter. +func (st *SyncedTable) timestampCutoff() expr.Expr { + return expr.Op(expr.Ident(st.TimestampColumn), expr.GE, expr.Bind()) +} + +// genSelectAll generates a SELECT statement that returns all the rows in the table +// satisfying the filter expression and the rowid cutoff. +func (st *SyncedTable) genSelectAll() expr.Statement { + return expr.Select(expr.Ident(st.IDColumn)). + From(expr.TableSource(st.TableName)). + Where(expr.MaybeAnd(st.Filter, st.rowIDCutoff())). + Get() +} + +// genCount generates a SELECT statement that returns the number of rows in the table +// satisfying the filter expression and the rowid cutoff. +func (st *SyncedTable) genCount() expr.Statement { + return expr.Select(expr.Call("count", expr.Ident(st.IDColumn))). + From(expr.TableSource(st.TableName)). + Where(expr.MaybeAnd(st.Filter, st.rowIDCutoff())). + Get() +} + +// genSelectAllSinceSnapshot generates a SELECT statement that returns all the rows in the +// table satisfying the filter expression that have rowid between the specified min and +// max parameter values, inclusive. +func (st *SyncedTable) genSelectAllSinceSnapshot() expr.Statement { + return expr.Select(expr.Ident(st.IDColumn)). + From(expr.TableSource(st.TableName)). + Where(expr.MaybeAnd( + st.Filter, + expr.Between(expr.Ident("rowid"), expr.Bind(), expr.Bind()))). + Get() +} + +// genSelectRange generates a SELECT statement that returns the rows in the table +// satisfying the filter expression, the rowid cutoff and the specified ID range. +func (st *SyncedTable) genSelectRange() expr.Statement { + return expr.Select(expr.Ident(st.IDColumn)). + From(expr.TableSource(st.TableName)). + Where(expr.MaybeAnd( + st.Filter, + expr.Op(expr.Ident(st.IDColumn), expr.GE, expr.Bind()), + st.rowIDCutoff())). + OrderBy(expr.Asc(expr.Ident(st.IDColumn))). + Limit(expr.Bind()). + Get() +} + +// genRecentCount generates a SELECT statement that returns the number of rows in the table +// added starting with the specified timestamp, taking into account the filter expression +// and the rowid cutoff. +func (st *SyncedTable) genRecentCount() expr.Statement { + return expr.Select(expr.Call("count", expr.Ident(st.IDColumn))). + From(expr.TableSource(st.TableName)). + Where(expr.MaybeAnd(st.Filter, st.rowIDCutoff(), st.timestampCutoff())). + Get() +} + +// genRecentCount generates a SELECT statement that returns the rows in the table added +// starting with the specified timestamp, taking into account the filter expression and +// the rowid cutoff. +func (st *SyncedTable) genSelectRecent() expr.Statement { + return expr.Select(expr.Ident(st.IDColumn)). + From(expr.TableSource(st.TableName)). + Where(expr.MaybeAnd( + st.Filter, + expr.Op(expr.Ident(st.IDColumn), expr.GE, expr.Bind()), + st.rowIDCutoff(), st.timestampCutoff())). + OrderBy(expr.Asc(expr.Ident(st.IDColumn))). + Limit(expr.Bind()). + Get() +} + +// loadMaxRowID returns the max rowid in the table. +func (st *SyncedTable) loadMaxRowID(db sql.Executor) (maxRowID int64, err error) { + nRows, err := db.Exec( + st.cacheQuery("selectMaxRowID", st.genSelectMaxRowID), nil, + func(st *sql.Statement) bool { + maxRowID = st.ColumnInt64(0) + return true + }) + if nRows != 1 { + return 0, fmt.Errorf("expected 1 row, got %d", nRows) + } + return maxRowID, err +} + +// Snaptshot creates a snapshot of the table based on its current max rowid value. +func (st *SyncedTable) Snapshot(db sql.Executor) (*SyncedTableSnapshot, error) { + maxRowID, err := st.loadMaxRowID(db) + if err != nil { + return nil, err + } + return &SyncedTableSnapshot{st, maxRowID}, nil +} + +// SyncedTableSnapshot represents a snapshot of an append-only table. +// The snapshotting is relies on rowids of the table rows never decreasing +// as new rows are added. +// Each snapshot inherits filter expression from the table, so all the rows relevant to +// the snapshot are always filtered using that expression, if it's specified. +type SyncedTableSnapshot struct { + *SyncedTable + maxRowID int64 +} + +// Load loads all the table rows belonging to a snapshot. +func (sts *SyncedTableSnapshot) Load( + db sql.Executor, + dec func(stmt *sql.Statement) bool, +) error { + return sts.exec(db, "selectAll", sts.genSelectAll, func(stmt *sql.Statement) { + stmt.BindInt64(stmt.BindParamCount(), sts.maxRowID) + }, dec) +} + +// LoadCount returns the number of rows in the snapshot. +func (sts *SyncedTableSnapshot) LoadCount( + db sql.Executor, +) (int, error) { + var count int + err := sts.exec( + db, "count", sts.genCount, + func(stmt *sql.Statement) { + stmt.BindInt64(stmt.BindParamCount(), sts.maxRowID) + }, + func(stmt *sql.Statement) bool { + count = stmt.ColumnInt(0) + return true + }) + return count, err +} + +// LoadSinceSnapshot loads rows added since the specified previous snapshot. +func (sts *SyncedTableSnapshot) LoadSinceSnapshot( + db sql.Executor, + prev *SyncedTableSnapshot, + dec func(stmt *sql.Statement) bool, +) error { + return sts.exec( + db, "selectAllSinceSnapshot", sts.genSelectAllSinceSnapshot, + func(stmt *sql.Statement) { + nParams := stmt.BindParamCount() + stmt.BindInt64(nParams-1, prev.maxRowID+1) + stmt.BindInt64(nParams, sts.maxRowID) + }, + dec) +} + +// LoadRange loads ids starting from the specified one. +// limit specifies the maximum number of ids to load. +func (sts *SyncedTableSnapshot) LoadRange( + db sql.Executor, + fromID rangesync.KeyBytes, + limit int, + dec func(stmt *sql.Statement) bool, +) error { + return sts.exec( + db, "selectRange", sts.genSelectRange, + func(stmt *sql.Statement) { + nParams := stmt.BindParamCount() + stmt.BindBytes(nParams-2, fromID) + stmt.BindInt64(nParams-1, sts.maxRowID) + stmt.BindInt64(nParams, int64(limit)) + }, + dec) +} + +var errNoTimestampColumn = errors.New("no timestamp column") + +// LoadRecentCount returns the number of rows added since the specified timestamp. +func (sts *SyncedTableSnapshot) LoadRecentCount( + db sql.Executor, + since int64, +) (int, error) { + if sts.TimestampColumn == "" { + return 0, errNoTimestampColumn + } + var count int + err := sts.exec( + db, "genRecentCount", sts.genRecentCount, + func(stmt *sql.Statement) { + nParams := stmt.BindParamCount() + stmt.BindInt64(nParams-1, sts.maxRowID) + stmt.BindInt64(nParams, since) + }, + func(stmt *sql.Statement) bool { + count = stmt.ColumnInt(0) + return true + }) + return count, err +} + +// LoadRecent loads rows added since the specified timestamp. +func (sts *SyncedTableSnapshot) LoadRecent( + db sql.Executor, + fromID rangesync.KeyBytes, + limit int, + since int64, + dec func(stmt *sql.Statement) bool, +) error { + if sts.TimestampColumn == "" { + return errNoTimestampColumn + } + return sts.exec( + db, "selectRecent", sts.genSelectRecent, + func(stmt *sql.Statement) { + nParams := stmt.BindParamCount() + stmt.BindBytes(nParams-3, fromID) + stmt.BindInt64(nParams-2, sts.maxRowID) + stmt.BindInt64(nParams-1, since) + stmt.BindInt64(nParams, int64(limit)) + }, + dec) +} diff --git a/sync2/sqlstore/syncedtable_test.go b/sync2/sqlstore/syncedtable_test.go new file mode 100644 index 0000000000..28a944ab35 --- /dev/null +++ b/sync2/sqlstore/syncedtable_test.go @@ -0,0 +1,404 @@ +package sqlstore_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/common/util" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/expr" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +func TestSyncedTable_GenSQL(t *testing.T) { + for _, tc := range []struct { + name string + st *sqlstore.SyncedTable + all string + count string + maxRowID string + idRange string + recent string + recentCount string + }{ + { + name: "no filter", + st: &sqlstore.SyncedTable{ + TableName: "atxs", + IDColumn: "id", + TimestampColumn: "received", + }, + all: `SELECT "id" FROM "atxs" WHERE "rowid" <= ?`, + count: `SELECT count("id") FROM "atxs" WHERE "rowid" <= ?`, + maxRowID: `SELECT max("rowid") FROM "atxs"`, + idRange: `SELECT "id" FROM "atxs" WHERE "id" >= ? AND "rowid" <= ? ` + + `ORDER BY "id" LIMIT ?`, + recent: `SELECT "id" FROM "atxs" WHERE "id" >= ? AND "rowid" <= ? ` + + `AND "received" >= ? ORDER BY "id" LIMIT ?`, + recentCount: `SELECT count("id") FROM "atxs" WHERE "rowid" <= ? ` + + `AND "received" >= ?`, + }, + { + name: "filter", + st: &sqlstore.SyncedTable{ + TableName: "atxs", + IDColumn: "id", + Filter: expr.MustParse("epoch = ?"), + TimestampColumn: "received", + }, + all: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "rowid" <= ?`, + count: `SELECT count("id") FROM "atxs" WHERE "epoch" = ? AND "rowid" <= ?`, + maxRowID: `SELECT max("rowid") FROM "atxs"`, + idRange: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + + `AND "rowid" <= ? ORDER BY "id" LIMIT ?`, + recent: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + + `AND "rowid" <= ? AND "received" >= ? ORDER BY "id" LIMIT ?`, + recentCount: `SELECT count("id") FROM "atxs" WHERE "epoch" = ? ` + + `AND "rowid" <= ? AND "received" >= ?`, + }, + } { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.all, tc.st.GenSelectAll().String(), "all") + require.Equal(t, tc.count, tc.st.GenCount().String(), "count") + require.Equal(t, tc.maxRowID, tc.st.GenSelectMaxRowID().String(), "maxRowID") + require.Equal(t, tc.idRange, tc.st.GenSelectRange().String(), "idRange") + require.Equal(t, tc.recent, tc.st.GenSelectRecent().String(), "recent") + require.Equal(t, tc.recentCount, tc.st.GenRecentCount().String(), "recentCount") + }) + } +} + +func TestSyncedTable_LoadIDs(t *testing.T) { + var db sql.Database + type row struct { + id string + epoch int + ts int + } + rows := []row{ + {"0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", 1, 100}, + {"0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", 1, 110}, + {"18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", 2, 120}, + {"1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", 2, 150}, + {"1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", 2, 180}, + {"2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", 3, 190}, + {"24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", 3, 220}, + } + + insertRows := func(rows []row) { + for _, r := range rows { + _, err := db.Exec("insert into atxs (id, epoch, received) values (?, ?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, util.FromHex(r.id)) + stmt.BindInt64(2, int64(r.epoch)) + stmt.BindInt64(3, int64(r.ts)) + }, nil) + require.NoError(t, err) + } + } + + initDB := func() { + db = sql.InMemoryTest(t) + _, err := db.Exec(`create table atxs ( + id char(32) not null primary key, + epoch int, + received int)`, nil, nil) + require.NoError(t, err) + insertRows(rows) + } + + mkDecode := func(ids *[]string) func(stmt *sql.Statement) bool { + return func(stmt *sql.Statement) bool { + id := make(rangesync.KeyBytes, stmt.ColumnLen(0)) + stmt.ColumnBytes(0, id) + *ids = append(*ids, id.String()) + return true + } + } + + loadCount := func(sts *sqlstore.SyncedTableSnapshot) int { + count, err := sts.LoadCount(db) + require.NoError(t, err) + return count + } + + loadIDs := func(sts *sqlstore.SyncedTableSnapshot) []string { + var ids []string + require.NoError(t, sts.Load(db, mkDecode(&ids))) + return ids + } + + loadIDsSince := func(stsNew, stsOld *sqlstore.SyncedTableSnapshot) []string { + var ids []string + require.NoError(t, stsNew.LoadSinceSnapshot(db, stsOld, mkDecode(&ids))) + return ids + } + + loadIDRange := func(sts *sqlstore.SyncedTableSnapshot, from rangesync.KeyBytes, limit int) []string { + var ids []string + require.NoError(t, sts.LoadRange(db, from, limit, mkDecode(&ids))) + return ids + } + + loadRecentCount := func( + sts *sqlstore.SyncedTableSnapshot, + ts int64, + ) int { + count, err := sts.LoadRecentCount(db, ts) + require.NoError(t, err) + return count + } + + loadRecent := func( + sts *sqlstore.SyncedTableSnapshot, + from rangesync.KeyBytes, + limit int, + ts int64, + ) []string { + var ids []string + require.NoError(t, sts.LoadRecent(db, from, limit, ts, mkDecode(&ids))) + return ids + } + + t.Run("no filter", func(t *testing.T) { + initDB() + + st := &sqlstore.SyncedTable{ + TableName: "atxs", + IDColumn: "id", + TimestampColumn: "received", + } + + sts1, err := st.Snapshot(db) + require.NoError(t, err) + + require.Equal(t, 7, loadCount(sts1)) + require.ElementsMatch(t, + []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadIDs(sts1)) + + fromID := util.FromHex("1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55") + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadIDRange(sts1, fromID, 100)) + + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, loadIDRange(sts1, fromID, 2)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, loadRecent(sts1, fromID, 3, 180)) + require.Equal(t, 3, loadRecentCount(sts1, 180)) + + insertRows([]row{ + {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2, 300}, + }) + + // the new row is not included in the first snapshot + require.Equal(t, 7, loadCount(sts1)) + require.ElementsMatch(t, + []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, loadIDs(sts1)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadIDRange(sts1, fromID, 100)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadRecent(sts1, fromID, 3, 180)) + + sts2, err := st.Snapshot(db) + require.NoError(t, err) + + require.Equal(t, 8, loadCount(sts2)) + require.ElementsMatch(t, + []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDs(sts2)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDRange(sts2, fromID, 100)) + require.ElementsMatch(t, + []string{ + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDsSince(sts2, sts1)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadRecent(sts2, fromID, 4, 180)) + require.ElementsMatch(t, + []string{ + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadRecent(sts2, + util.FromHex("2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b"), + 4, 180)) + require.ElementsMatch(t, + []string{ + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadRecent(sts2, + util.FromHex("2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b"), + 2, 180)) + }) + + t.Run("filter", func(t *testing.T) { + initDB() + st := &sqlstore.SyncedTable{ + TableName: "atxs", + IDColumn: "id", + TimestampColumn: "received", + Filter: expr.MustParse("epoch = ?"), + Binder: func(stmt *sql.Statement) { + stmt.BindInt64(1, 2) + }, + } + + sts1, err := st.Snapshot(db) + require.NoError(t, err) + + require.Equal(t, 3, loadCount(sts1)) + require.ElementsMatch(t, + []string{ + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDs(sts1)) + + fromID := util.FromHex("1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55") + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDRange(sts1, fromID, 100)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + }, loadIDRange(sts1, fromID, 1)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadRecent(sts1, fromID, 1, 180)) + + insertRows([]row{ + {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2, 300}, + }) + + // the new row is not included in the first snapshot + require.Equal(t, 3, loadCount(sts1)) + require.ElementsMatch(t, + []string{ + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDs(sts1)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDRange(sts1, fromID, 100)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadRecent(sts1, fromID, 1, 180)) + + sts2, err := st.Snapshot(db) + require.NoError(t, err) + + require.Equal(t, 4, loadCount(sts2)) + require.ElementsMatch(t, + []string{ + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDs(sts2)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDRange(sts2, fromID, 100)) + require.ElementsMatch(t, + []string{ + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDsSince(sts2, sts1)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadRecent(sts2, fromID, 2, 180)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadRecent(sts2, fromID, 1, 180)) + }) +} diff --git a/sync2/sqlstore/testdb.go b/sync2/sqlstore/testdb.go new file mode 100644 index 0000000000..5f5372e635 --- /dev/null +++ b/sync2/sqlstore/testdb.go @@ -0,0 +1,46 @@ +package sqlstore + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// CreateDB creates a test database. It is only used for testing. +func CreateDB(t *testing.T, keyLen int) sql.Database { + db := sql.InMemoryTest(t) + _, err := db.Exec( + fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) + require.NoError(t, err) + return db +} + +// InsertDBItems inserts items into a test database. It is only used for testing. +func InsertDBItems(t *testing.T, db sql.Database, content []rangesync.KeyBytes) { + err := db.WithTx(context.Background(), func(tx sql.Transaction) error { + for _, id := range content { + _, err := tx.Exec( + "insert into foo(id) values(?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id) + }, nil) + if err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) +} + +// PopulateDB creates a test database and inserts items into it. It is only used for testing. +func PopulateDB(t *testing.T, keyLen int, content []rangesync.KeyBytes) sql.Database { + db := CreateDB(t, keyLen) + InsertDBItems(t, db, content) + return db +}