Skip to content

Commit

Permalink
builder where & update support Raw SQL (#147)
Browse files Browse the repository at this point in the history
* where & update support Raw SQL

* fix

* fix

* builder where & update support Raw SQL

* builder where & update support Raw SQL

* add unittests
  • Loading branch information
tuweizhong authored Sep 24, 2023
1 parent da75d5e commit 19c2de2
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 39 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ cond, values, err := builder.BuildSelect(table, where, selectFields)
//values = []interface{}{11, 45, "234", "tx2", 5, "beijing", "shanghai", 35}

rows, err := db.Query(cond, values...)


// support builder.Raw in where & update
where := map[string]interface{}{"gmt_create <": builder.Raw("gmt_modified")}
cond, values, err := builder.BuildSelect(table, where, selectFields)
// SELECT * FROM x WHERE gmt_create < gmt_modified

update = map[string]interface{}{
"code": builder.Raw("VALUES(code)"), // mysql 8.x builder.Raw("new.code")
"name": builder.Raw("VALUES(name)"), // mysql 8.x builder.Raw("new.name")
}
cond, values, err := builder.BuildInsertOnDuplicate(table, data, update)
// INSERT INTO country (id, code, name) VALUES (?,?,?),(?,?,?),(?,?,?)
// ON DUPLICATE KEY UPDATE code=VALUES(code),name=VALUES(name)
```

In the `where` param, `in` operator is automatically added by value type(reflect.Slice).
Expand Down
13 changes: 12 additions & 1 deletion builder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func main() {
"country": "China",
"role": "driver",
"age >": 45,
"gmt_create <": builder.Raw("gmt_modified"),
"_or": []map[string]interface{}{
{
"x1": 11,
Expand All @@ -45,7 +46,7 @@ func main() {
}
cond,vals,err := qb.BuildSelect("tableName", where, []string{"name", "count(price) as total", "age"})

//cond: SELECT name,count(price) as total,age FROM tableName WHERE (((x1=? AND x2>=?) OR (x3=? AND x4!=?)) AND country=? AND role=? AND age>?) GROUP BY name HAVING (total>? AND total<=?) ORDER BY age DESC
//cond: SELECT name,count(price) as total,age FROM tableName WHERE (((x1=? AND x2>=?) OR (x3=? AND x4!=?)) AND country=? AND gmt_create < gmt_modified AND role=? AND age>?) GROUP BY name HAVING (total>? AND total<=?) ORDER BY age DESC
//vals: []interface{}{11, 45, "234", "tx2", "China", "driver", 45, 1000, 50000}

if nil != err {
Expand Down Expand Up @@ -287,6 +288,16 @@ update := map[string]interface{}{
}
cond, vals, err := qb.BuildInsertOnDuplicate(table, data, update)
db.Exec(cond, vals...)


// update support builder.Raw to update when duplicate with value in insert data
update = map[string]interface{}{
"code": builder.Raw("VALUES(code)"), // mysql 8.x builder.Raw("new.code")
"name": builder.Raw("VALUES(name)"), // mysql 8.x builder.Raw("new.name")
}
cond, values, err := builder.BuildInsertOnDuplicate(table, data, update)
// INSERT INTO country (id, code, name) VALUES (?,?,?),(?,?,?),(?,?,?)
// ON DUPLICATE KEY UPDATE code=VALUES(code),name=VALUES(name)
```

#### `NamedQuery`
Expand Down
2 changes: 2 additions & 0 deletions builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ var (
}
)

type Raw string

type whereMapSet struct {
set map[string]map[string]interface{}
}
Expand Down
2 changes: 1 addition & 1 deletion builder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ func Test_BuildSelect(t *testing.T) {
}
}

func Test_BuildSelectMutliOr(t *testing.T) {
func Test_BuildSelectMultiOr(t *testing.T) {
type inStruct struct {
table string
where map[string]interface{}
Expand Down
40 changes: 22 additions & 18 deletions builder/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,20 @@ func build(m map[string]interface{}, op string) ([]string, []interface{}) {
}
length := len(m)
cond := make([]string, length)
vals := make([]interface{}, length)
vals := make([]interface{}, 0, length)
var i int
for key := range m {
cond[i] = key
i++
}
defaultSortAlgorithm(cond)
for i = 0; i < length; i++ {
vals[i] = m[cond[i]]
v := m[cond[i]]
if raw, ok := v.(Raw); ok {
cond[i] += op + string(raw)
continue
}
vals = append(vals, v)
cond[i] = assembleExpression(cond[i], op)
}
return cond, vals
Expand All @@ -320,17 +325,6 @@ func assembleExpression(field, op string) string {
return quoteField(field) + op + "?"
}

func resolveKV(m map[string]interface{}) (keys []string, vals []interface{}) {
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vals = append(vals, m[k])
}
return
}

func resolveFields(m map[string]interface{}) []string {
var fields []string
for k := range m {
Expand Down Expand Up @@ -409,13 +403,23 @@ func buildInsertOnDuplicate(table string, data []map[string]interface{}, update
return cond, vals, nil
}

func resolveUpdate(update map[string]interface{}) (string, []interface{}) {
keys, vals := resolveKV(update)
var sets string
func resolveUpdate(update map[string]interface{}) (sets string, vals []interface{}) {
keys := make([]string, 0, len(update))
for key := range update {
keys = append(keys, key)
}
defaultSortAlgorithm(keys)
var sb strings.Builder
for _, k := range keys {
sets += fmt.Sprintf("%s=?,", quoteField(k))
v := update[k]
if _, ok := v.(Raw); ok {
sb.WriteString(fmt.Sprintf("%s=%s,", k, v))
continue
}
vals = append(vals, v)
sb.WriteString(fmt.Sprintf("%s=?,", quoteField(k)))
}
sets = strings.TrimRight(sets, ",")
sets = strings.TrimRight(sb.String(), ",")
return sets, vals
}

Expand Down
67 changes: 48 additions & 19 deletions builder/dao_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ func TestEq(t *testing.T) {
[]string{"baz=?", "foo=?", "qq=?"},
[]interface{}{1, "bar", "ttx"},
},
{
map[string]interface{}{
"gmt_create": Raw("gmt_modified"),
"status": 1,
},
[]string{"gmt_create=gmt_modified", "status=?"},
[]interface{}{1},
},
}
ass := assert.New(t)
for _, testCase := range testData {
cond, vals := Eq(testCase.in).Build()
ass.Equal(len(cond), len(vals))
ass.Equal(testCase.outCon, cond)
ass.Equal(testCase.outVal, vals)
}
Expand Down Expand Up @@ -111,33 +118,51 @@ func TestAssembleExpression(t *testing.T) {
}
}

func TestResolveKV(t *testing.T) {
func TestResolveUpdate(t *testing.T) {
var data = []struct {
in map[string]interface{}
outStr []string
outStr string
outVals []interface{}
}{
{
map[string]interface{}{
in: map[string]interface{}{
"foo": "bar",
"bar": 1,
},
[]string{"bar", "foo"},
[]interface{}{1, "bar"},
outStr: "bar=?,foo=?",
outVals: []interface{}{1, "bar"},
},
{
map[string]interface{}{
in: map[string]interface{}{
"qq": "ttt",
"some": 123,
"other": 456,
},
[]string{"other", "qq", "some"},
[]interface{}{456, "ttt", 123},
outStr: "other=?,qq=?,some=?",
outVals: []interface{}{456, "ttt", 123},
},
{
in: map[string]interface{}{ // mysql5.7
"id": 1,
"name": Raw("VALUES(name)"),
"age": Raw("VALUES(age)"),
},
outStr: "age=VALUES(age),id=?,name=VALUES(name)",
outVals: []interface{}{1},
},
{
in: map[string]interface{}{ // mysql8.0
"id": 1,
"name": Raw("new.name"),
"age": Raw("new.age"),
},
outStr: "age=new.age,id=?,name=new.name",
outVals: []interface{}{1},
},
}
ass := assert.New(t)
for _, tc := range data {
keys, vals := resolveKV(tc.in)
keys, vals := resolveUpdate(tc.in)
ass.Equal(tc.outStr, keys)
ass.Equal(tc.outVals, vals)
}
Expand Down Expand Up @@ -281,11 +306,12 @@ func TestBuildInsertOnDuplicate(t *testing.T) {
},
},
update: map[string]interface{}{
"a": Raw("VALUES(a)"),
"b": 7,
"c": 8,
},
outErr: nil,
outStr: "INSERT INTO tb (a,b,c) VALUES (?,?,?),(?,?,?) ON DUPLICATE KEY UPDATE b=?,c=?",
outStr: "INSERT INTO tb (a,b,c) VALUES (?,?,?),(?,?,?) ON DUPLICATE KEY UPDATE a=VALUES(a),b=?,c=?",
outVals: []interface{}{1, 2, 3, 4, 5, 6, 7, 8},
},
}
Expand Down Expand Up @@ -316,11 +342,12 @@ func TestBuildUpdate(t *testing.T) {
}),
},
data: map[string]interface{}{
"name": "deen",
"age": 23,
"name": "deen",
"age": 23,
"count": Raw("count+1"),
},
outErr: nil,
outStr: "UPDATE tb SET age=?,name=? WHERE (foo=? AND qq=?)",
outStr: "UPDATE tb SET age=?,count=count+1,name=? WHERE (foo=? AND qq=?)",
outVals: []interface{}{23, "deen", "bar", 1},
},
}
Expand All @@ -345,12 +372,13 @@ func TestBuildDelete(t *testing.T) {
table: "tb",
where: []Comparable{
Eq(map[string]interface{}{
"foo": 1,
"bar": 2,
"baz": "tt",
"foo": 1,
"bar": 2,
"baz": "tt",
"gmt_create": Raw("gmt_modified"),
}),
},
outStr: "DELETE FROM tb WHERE (bar=? AND baz=? AND foo=?)",
outStr: "DELETE FROM tb WHERE (bar=? AND baz=? AND foo=? AND gmt_create=gmt_modified)",
outVals: []interface{}{2, "tt", 1},
outErr: nil,
},
Expand Down Expand Up @@ -384,6 +412,7 @@ func TestBuildSelect(t *testing.T) {
Eq(map[string]interface{}{
"foo": 1,
"bar": 2,
"tm": Raw("NOW()"),
}),
In(map[string][]interface{}{
"qq": {4, 5, 6},
Expand Down Expand Up @@ -415,7 +444,7 @@ func TestBuildSelect(t *testing.T) {
},
lockMode: "exclusive",
outErr: nil,
outStr: "SELECT foo,bar FROM tb WHERE (bar=? AND foo=? AND qq IN (?,?,?) AND ((aa=? AND bb=?) OR (cc=? AND dd=?))) ORDER BY foo DESC,baz ASC LIMIT ?,? FOR UPDATE",
outStr: "SELECT foo,bar FROM tb WHERE (bar=? AND foo=? AND tm=NOW() AND qq IN (?,?,?) AND ((aa=? AND bb=?) OR (cc=? AND dd=?))) ORDER BY foo DESC,baz ASC LIMIT ?,? FOR UPDATE",
outVals: []interface{}{2, 1, 4, 5, 6, 3, 4, 7, 8, 10, 20},
},
}
Expand Down

0 comments on commit 19c2de2

Please sign in to comment.