Skip to content

Commit

Permalink
Merge pull request #74 from oschwald/greg/custom-deserializer
Browse files Browse the repository at this point in the history
Support custom deserializer
  • Loading branch information
horgh authored Aug 20, 2020
2 parents a1069d8 + 1f1e288 commit 52f6238
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 8 deletions.
135 changes: 135 additions & 0 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ func (d *decoder) decode(offset uint, result reflect.Value, depth int) (uint, er
return d.decodeFromType(typeNum, size, newOffset, result, depth+1)
}

func (d *decoder) decodeToDeserializer(offset uint, dser deserializer, depth int) (uint, error) {
if depth > maximumDataStructureDepth {
return 0, newInvalidDatabaseError("exceeded maximum data structure depth; database is likely corrupt")
}
typeNum, size, newOffset, err := d.decodeCtrlData(offset)
if err != nil {
return 0, err
}

skip, err := dser.ShouldSkip(uintptr(offset))
if err != nil {
return 0, err
}
if skip {
return d.nextValueOffset(offset, 1)
}

return d.decodeFromTypeToDeserializer(typeNum, size, newOffset, dser, depth+1)
}

func (d *decoder) decodeCtrlData(offset uint) (dataType, uint, uint, error) {
newOffset := offset + 1
if offset >= uint(len(d.buffer)) {
Expand Down Expand Up @@ -157,6 +177,68 @@ func (d *decoder) decodeFromType(
}
}

func (d *decoder) decodeFromTypeToDeserializer(
dtype dataType,
size uint,
offset uint,
dser deserializer,
depth int,
) (uint, error) {
// For these types, size has a special meaning
switch dtype {
case _Bool:
v, offset := d.decodeBool(size, offset)
return offset, dser.Bool(v)
case _Map:
return d.decodeMapToDeserializer(size, offset, dser, depth)
case _Pointer:
pointer, newOffset, err := d.decodePointer(size, offset)
if err != nil {
return 0, err
}
_, err = d.decodeToDeserializer(pointer, dser, depth)
return newOffset, err
case _Slice:
return d.decodeSliceToDeserializer(size, offset, dser, depth)
}

// For the remaining types, size is the byte size
if offset+size > uint(len(d.buffer)) {
return 0, newOffsetError()
}
switch dtype {
case _Bytes:
v, offset := d.decodeBytes(size, offset)
return offset, dser.Bytes(v)
case _Float32:
v, offset := d.decodeFloat32(size, offset)
return offset, dser.Float32(v)
case _Float64:
v, offset := d.decodeFloat64(size, offset)
return offset, dser.Float64(v)
case _Int32:
v, offset := d.decodeInt(size, offset)
return offset, dser.Int32(int32(v))
case _String:
v, offset := d.decodeString(size, offset)
return offset, dser.String(v)
case _Uint16:
v, offset := d.decodeUint(size, offset)
return offset, dser.Uint16(uint16(v))
case _Uint32:
v, offset := d.decodeUint(size, offset)
return offset, dser.Uint32(uint32(v))
case _Uint64:
v, offset := d.decodeUint(size, offset)
return offset, dser.Uint64(v)
case _Uint128:
v, offset := d.decodeUint128(size, offset)
return offset, dser.Uint128(v)
default:
return 0, newInvalidDatabaseError("unknown type: %d", dtype)
}
}

func (d *decoder) unmarshalBool(size, offset uint, result reflect.Value) (uint, error) {
if size > 1 {
return 0, newInvalidDatabaseError("the MaxMind DB file's data section contains bad data (bool size of %v)", size)
Expand Down Expand Up @@ -199,6 +281,7 @@ func (d *decoder) indirect(result reflect.Value) reflect.Value {
if result.IsNil() {
result.Set(reflect.New(result.Type().Elem()))
}

result = result.Elem()
}
return result
Expand Down Expand Up @@ -486,6 +569,35 @@ func (d *decoder) decodeMap(
return offset, nil
}

func (d *decoder) decodeMapToDeserializer(
size uint,
offset uint,
dser deserializer,
depth int,
) (uint, error) {
err := dser.StartMap(size)
if err != nil {
return 0, err
}
for i := uint(0); i < size; i++ {
// TODO - implement key/value skipping?
offset, err = d.decodeToDeserializer(offset, dser, depth)
if err != nil {
return 0, err
}

offset, err = d.decodeToDeserializer(offset, dser, depth)
if err != nil {
return 0, err
}
}
err = dser.End()
if err != nil {
return 0, err
}
return offset, nil
}

func (d *decoder) decodePointer(
size uint,
offset uint,
Expand Down Expand Up @@ -538,6 +650,29 @@ func (d *decoder) decodeSlice(
return offset, nil
}

func (d *decoder) decodeSliceToDeserializer(
size uint,
offset uint,
dser deserializer,
depth int,
) (uint, error) {
err := dser.StartSlice(size)
if err != nil {
return 0, err
}
for i := uint(0); i < size; i++ {
offset, err = d.decodeToDeserializer(offset, dser, depth)
if err != nil {
return 0, err
}
}
err = dser.End()
if err != nil {
return 0, err
}
return offset, nil
}

func (d *decoder) decodeString(size, offset uint) (string, uint) {
newOffset := offset + size
return string(d.buffer[offset:newOffset]), newOffset
Expand Down
31 changes: 31 additions & 0 deletions deserializer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package maxminddb

import "math/big"

// deserializer is an interface for a type that deserializes an MaxMind DB
// data record to some other type. This exists as an alternative to the
// standard reflection API.
//
// This is fundamentally different than the Unmarshaler interface that
// several packages provide. A Deserializer will generally create the
// final struct or value rather than unmarshaling to itself.
//
// This interface and the associated unmarshaling code is EXPERIMENTAL!
// It is not currently covered by any Semantic Versioning guarantees.
// Use at your own risk.
type deserializer interface {
ShouldSkip(offset uintptr) (bool, error)
StartSlice(size uint) error
StartMap(size uint) error
End() error
String(string) error
Float64(float64) error
Bytes([]byte) error
Uint16(uint16) error
Uint32(uint32) error
Int32(int32) error
Uint64(uint64) error
Uint128(*big.Int) error
Bool(bool) error
Float32(float32) error
}
119 changes: 119 additions & 0 deletions deserializer.go deserializer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package maxminddb

import (
"math/big"
"net"
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodingToDeserializer(t *testing.T) {
reader, err := Open(testFile("MaxMind-DB-test-decoder.mmdb"))
require.NoError(t, err, "unexpected error while opening database: %v", err)

dser := testDeserializer{}
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser)
require.NoError(t, err, "unexpected error while doing lookup: %v", err)

checkDecodingToInterface(t, dser.rv)
}

type stackValue struct {
value interface{}
curNum int
}

type testDeserializer struct {
stack []*stackValue
rv interface{}
key *string
}

func (d *testDeserializer) ShouldSkip(offset uintptr) (bool, error) {
return false, nil
}

func (d *testDeserializer) StartSlice(size uint) error {
return d.add(make([]interface{}, size))
}

func (d *testDeserializer) StartMap(size uint) error {
return d.add(map[string]interface{}{})
}

func (d *testDeserializer) End() error {
d.stack = d.stack[:len(d.stack)-1]
return nil
}

func (d *testDeserializer) String(v string) error {
return d.add(v)
}

func (d *testDeserializer) Float64(v float64) error {
return d.add(v)
}

func (d *testDeserializer) Bytes(v []byte) error {
return d.add(v)
}

func (d *testDeserializer) Uint16(v uint16) error {
return d.add(uint64(v))
}

func (d *testDeserializer) Uint32(v uint32) error {
return d.add(uint64(v))
}

func (d *testDeserializer) Int32(v int32) error {
return d.add(int(v))
}

func (d *testDeserializer) Uint64(v uint64) error {
return d.add(v)
}

func (d *testDeserializer) Uint128(v *big.Int) error {
return d.add(v)
}

func (d *testDeserializer) Bool(v bool) error {
return d.add(v)
}

func (d *testDeserializer) Float32(v float32) error {
return d.add(v)
}

func (d *testDeserializer) add(v interface{}) error {
if len(d.stack) == 0 {
d.rv = v
} else {
top := d.stack[len(d.stack)-1]
switch parent := top.value.(type) {
case map[string]interface{}:
if d.key == nil {
key := v.(string)
d.key = &key
} else {
parent[*d.key] = v
d.key = nil
}

case []interface{}:
parent[top.curNum] = v
top.curNum++
default:
}
}

switch v := v.(type) {
case map[string]interface{}, []interface{}:
d.stack = append(d.stack, &stackValue{value: v})
default:
}

return nil
}
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.0 h1:DMOzIV76tmoDNE9pX6RSN0aDtCYeCg5VueieJaAo1uw=
github.com/stretchr/testify v1.5.0/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76 h1:Dho5nD6R3PcW2SH1or8vS0dszDaXRxIw55lBX7XiE5g=
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
5 changes: 5 additions & 0 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ func (r *Reader) decode(offset uintptr, result interface{}) error {
return errors.New("result param must be a pointer")
}

if dser, ok := result.(deserializer); ok {
_, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0)
return err
}

_, err := r.decoder.decode(uint(offset), rv, 0)
return err
}
Expand Down
4 changes: 4 additions & 0 deletions reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ func TestDecodingToInterface(t *testing.T) {
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &recordInterface)
require.NoError(t, err, "unexpected error while doing lookup: %v", err)

checkDecodingToInterface(t, recordInterface)
}

func checkDecodingToInterface(t *testing.T, recordInterface interface{}) {
record := recordInterface.(map[string]interface{})
assert.Equal(t, []interface{}{uint64(1), uint64(2), uint64(3)}, record["array"])
assert.Equal(t, true, record["boolean"])
Expand Down

0 comments on commit 52f6238

Please sign in to comment.