-
Notifications
You must be signed in to change notification settings - Fork 8
/
scan.go
248 lines (195 loc) · 5.66 KB
/
scan.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
package crud
import (
"fmt"
"time"
"reflect"
"database/sql"
)
/*
Scan uses tag metadata and column names to extract values from a sql.Rows into one or more objects.
Scan inspects all of the passed arguments and creates a mapping from SQL column
name to the fields of the passed objects by inspecting the struct tag metadata.
It then constructs an appropriate call to rows.Scan, passing in pointers as the
mapping dictates.
Any string passed in the arguments list is considered a "prefix" for the SQL
names of each field in the preceding object.
If two objects have fields that map to the same column name, only the first is
assigned properly. If two columns have the same SQL name, the same interface is
passed for both fields (and which gets bound is undefined). If there is a SQL
column which does not map to a Go field (or vice versa), it is ignored silently.
*/
func Scan(rows *sql.Rows, args ...interface{}) error {
prefix := ""
writeBackMap := make(map[string]interface{})
intRemap := make(map[reflect.Value]*sql.NullInt64)
floatRemap := make(map[reflect.Value]*sql.NullFloat64)
boolRemap := make(map[reflect.Value]*sql.NullBool)
stringRemap := make(map[reflect.Value]*sql.NullString)
unixTimeRemap := make(map[reflect.Value]*sql.NullInt64)
for _, arg := range args {
val := indirectV(reflect.ValueOf(arg))
ty := val.Type()
if ty.Kind() == reflect.String {
prefix = arg.(string)
continue
}
fieldMap, er := sqlToGoFields(ty)
if er != nil {
return er
}
for sqlName, meta := range fieldMap {
goName := meta.GoName
sqlName = prefix + sqlName
field := val.FieldByName(goName)
fieldType := field.Type()
if meta.Unix {
nullInt := new(sql.NullInt64)
writeBackMap[sqlName] = nullInt
unixTimeRemap[field] = nullInt
continue
} else if fieldType.Kind() == reflect.Ptr {
fieldElemKind := fieldType.Elem().Kind()
switch fieldElemKind {
case reflect.Int8:
fallthrough
case reflect.Int16:
fallthrough
case reflect.Int32:
fallthrough
case reflect.Int64:
nullInt := new(sql.NullInt64)
writeBackMap[sqlName] = nullInt
intRemap[field] = nullInt
continue
case reflect.Float32:
fallthrough
case reflect.Float64:
nullFloat := new(sql.NullFloat64)
writeBackMap[sqlName] = nullFloat
floatRemap[field] = nullFloat
continue
case reflect.Bool:
nullBool := new(sql.NullBool)
writeBackMap[sqlName] = nullBool
boolRemap[field] = nullBool
continue
case reflect.String:
nullString := new(sql.NullString)
writeBackMap[sqlName] = nullString
stringRemap[field] = nullString
continue
}
}
writeBackMap[sqlName] = val.FieldByName(goName).Addr().Interface()
}
prefix = ""
}
cols, er := rows.Columns()
if er != nil {
return er
}
writeBack := make([]interface{}, len(cols))
for i, col := range cols {
if target, ok := writeBackMap[col] ; ok {
writeBack[i] = target
} else {
writeBack[i] = new(interface{})
}
}
if er := rows.Scan(writeBack...) ; er != nil {
fmt.Printf("Error encountered, columns: %#v\n", cols)
return er
}
for field, nullInt := range intRemap {
if nullInt.Valid {
switch field.Type().Elem().Kind() {
case reflect.Int8:
tmp := int8(nullInt.Int64)
field.Set(reflect.ValueOf(&tmp))
case reflect.Int16:
tmp := int16(nullInt.Int64)
field.Set(reflect.ValueOf(&tmp))
case reflect.Int32:
tmp := int32(nullInt.Int64)
field.Set(reflect.ValueOf(&tmp))
case reflect.Int64:
field.Set(reflect.ValueOf(&nullInt.Int64))
}
}
}
for field, nullFloat := range floatRemap {
if nullFloat.Valid {
switch field.Type().Elem().Kind() {
case reflect.Float32:
tmp := float32(nullFloat.Float64)
field.Set(reflect.ValueOf(&tmp))
case reflect.Float64:
field.Set(reflect.ValueOf(&nullFloat.Float64))
}
}
}
for field, nullBool := range boolRemap {
if nullBool.Valid {
field.Set(reflect.ValueOf(&nullBool.Bool))
}
}
for field, nullString := range stringRemap {
if nullString.Valid {
field.Set(reflect.ValueOf(&nullString.String))
}
}
for field, nullInt := range unixTimeRemap {
if nullInt.Valid {
t := time.Unix(nullInt.Int64, 0)
if field.Kind() == reflect.Ptr && field.Type().Elem() == reflect.TypeOf(time.Time{}) {
if field.IsNil() {
newVal := &time.Time{}
field.Set(reflect.ValueOf(newVal))
}
field = field.Elem()
}
if field.Type() != reflect.TypeOf(time.Time{}) {
return fmt.Errorf("Cannot map a unix time to a non-time field (%T)", field.Interface())
}
field.Set(reflect.ValueOf(t))
}
}
return nil
}
/*
ScanAll accepts a pointer to a slice of a type and fills it with repeated calls to Scan.
ScanAll only works if you're trying to extract a single object from each row
of the query results. Additionally, it closes the passed sql.Rows object. ScanAll
effectively replaces this code
// old code
defer rows.Close()
objs := []Object{}
for rows.Next() {
var obj Object
Scan(rows, &obj)
objs = append(objs, obj)
}
With simply
// new code
objs := []Object{}
ScanAll(rows, &objs)
*/
func ScanAll(rows *sql.Rows, slicePtr interface{}) error {
defer rows.Close()
sliceVal := reflect.ValueOf(slicePtr).Elem()
if sliceVal.Kind() != reflect.Slice {
return fmt.Errorf("Argument to crud.ScanAll is not a slice")
}
elemType := sliceVal.Type().Elem()
if elemType.Kind() != reflect.Struct {
return fmt.Errorf("Argument to crud.ScanAll must be a slice of structs")
}
for rows.Next() {
newVal := reflect.New(elemType)
if er := Scan(rows, newVal.Interface()) ; er != nil {
return er
}
sliceVal.Set(reflect.Append(sliceVal, newVal.Elem()))
}
return nil
}