Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InterfaceCustomUnmarshaler and test case to override interface type #450

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 77 additions & 32 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,44 @@ import (

// Decoder reads and decodes YAML values from an input stream.
type Decoder struct {
reader io.Reader
referenceReaders []io.Reader
anchorNodeMap map[string]ast.Node
anchorValueMap map[string]reflect.Value
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
toCommentMap CommentMap
opts []DecodeOption
referenceFiles []string
referenceDirs []string
isRecursiveDir bool
isResolvedReference bool
validator StructValidator
disallowUnknownField bool
disallowDuplicateKey bool
useOrderedMap bool
useJSONUnmarshaler bool
parsedFile *ast.File
streamIndex int
reader io.Reader
referenceReaders []io.Reader
anchorNodeMap map[string]ast.Node
anchorValueMap map[string]reflect.Value
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
customInterfaceUnmarshalerMap map[reflect.Type]func(interface{}, func(interface{}) error) error
toCommentMap CommentMap
opts []DecodeOption
referenceFiles []string
referenceDirs []string
isRecursiveDir bool
isResolvedReference bool
validator StructValidator
disallowUnknownField bool
disallowDuplicateKey bool
useOrderedMap bool
useJSONUnmarshaler bool
parsedFile *ast.File
streamIndex int
}

// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
return &Decoder{
reader: r,
anchorNodeMap: map[string]ast.Node{},
anchorValueMap: map[string]reflect.Value{},
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
opts: opts,
referenceReaders: []io.Reader{},
referenceFiles: []string{},
referenceDirs: []string{},
isRecursiveDir: false,
isResolvedReference: false,
disallowUnknownField: false,
disallowDuplicateKey: false,
useOrderedMap: false,
reader: r,
anchorNodeMap: map[string]ast.Node{},
anchorValueMap: map[string]reflect.Value{},
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
customInterfaceUnmarshalerMap: map[reflect.Type]func(interface{}, func(interface{}) error) error{},
opts: opts,
referenceReaders: []io.Reader{},
referenceFiles: []string{},
referenceDirs: []string{},
isRecursiveDir: false,
isResolvedReference: false,
disallowUnknownField: false,
disallowDuplicateKey: false,
useOrderedMap: false,
}
}

Expand Down Expand Up @@ -656,7 +658,6 @@ func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(inte
if unmarshaler, exists := d.customUnmarshalerMap[t]; exists {
return unmarshaler, exists
}

globalCustomUnmarshalerMu.Lock()
defer globalCustomUnmarshalerMu.Unlock()
if unmarshaler, exists := globalCustomUnmarshalerMap[t]; exists {
Expand All @@ -665,11 +666,40 @@ func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(inte
return nil, false
}

func (d *Decoder) existsTypeInCustomInterfaceUnmarshalerMap(t reflect.Type) bool {
if _, exists := d.customInterfaceUnmarshalerMap[t]; exists {
return true
}

globalCustomInterfaceUnmarshalerMu.Lock()
defer globalCustomInterfaceUnmarshalerMu.Unlock()
if _, exists := globalCustomInterfaceUnmarshalerMap[t]; exists {
return true
}
return false
}

func (d *Decoder) unmarshalerFromCustomInterfaceUnmarshalerMap(t reflect.Type) (func(interface{}, func(interface{}) error) error, bool) {
if unmarshaler, exists := d.customInterfaceUnmarshalerMap[t]; exists {
return unmarshaler, exists
}

globalCustomInterfaceUnmarshalerMu.Lock()
defer globalCustomInterfaceUnmarshalerMu.Unlock()
if unmarshaler, exists := globalCustomInterfaceUnmarshalerMap[t]; exists {
return unmarshaler, exists
}
return nil, false
}

func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool {
ptrValue := dst.Addr()
if d.existsTypeInCustomUnmarshalerMap(ptrValue.Type()) {
return true
}
if d.existsTypeInCustomInterfaceUnmarshalerMap(ptrValue.Type()) {
return true
}
iface := ptrValue.Interface()
switch iface.(type) {
case BytesUnmarshalerContext:
Expand Down Expand Up @@ -704,6 +734,21 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
}
return nil
}
if unmarshaler, exists := d.unmarshalerFromCustomInterfaceUnmarshalerMap(ptrValue.Type()); exists {
if err := unmarshaler(ptrValue.Interface(), func(v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Type().Kind() != reflect.Ptr {
return errors.ErrDecodeRequiredPointerType
}
if err := d.decodeValue(ctx, rv.Elem(), src); err != nil {
return errors.Wrapf(err, "failed to decode value")
}
return nil
}); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
iface := ptrValue.Interface()

if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok {
Expand Down
107 changes: 107 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,113 @@ func TestDecoder_CustomUnmarshaler(t *testing.T) {
t.Fatalf("failed to switch to custom unmarshaler. got: %q", v.Foo)
}
})

t.Run("override interface type", func(t *testing.T) {
type I interface{}
type T struct {
Foo string `yaml:"foo"`
}
var i I
src := []byte(`foo: "bar"`)
if err := yaml.UnmarshalWithOptions(src, &i, yaml.CustomUnmarshaler[I](func(dst *I, b []byte) error {
var v T
if err := yaml.Unmarshal(b, &v); err != nil {
t.Fatal(err)
}
if v.Foo != "bar" {
t.Fatalf("failed to use unmarshal function. got %q", v.Foo)
}
*dst = &v
return nil
})); err != nil {
t.Fatal(err)
}
if v, ok := i.(*T); ok {
if v.Foo != "bar" {
t.Fatalf("failed to decode with custom interface unmarshaler. got: %q", v.Foo)
}
} else {
t.Fatalf("failed to switch to custom interface unmarshaler.")
}
})
}

func TestDecoder_CustomInterfaceUnmarshaler(t *testing.T) {
t.Run("override struct type", func(t *testing.T) {
type T struct {
Foo string `yaml:"foo"`
}
src := []byte(`foo: "bar"`)
var v T
if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomInterfaceUnmarshaler[T](func(dst *T, f func(interface{}) error) error {
var m map[string]string
if err := f(&m); err != nil {
t.Fatal(err)
}
if m["foo"] != "bar" {
t.Fatalf("failed to use unmarshal function. got %q", m["foo"])
}
dst.Foo = "bazbaz" // assign another value to target
return nil
})); err != nil {
t.Fatal(err)
}
if v.Foo != "bazbaz" {
t.Fatalf("failed to switch to custom interface unmarshaler. got: %v", v.Foo)
}
})
t.Run("override bytes type", func(t *testing.T) {
type T struct {
Foo []byte `yaml:"foo"`
}
src := []byte(`foo: "bar"`)
var v T
if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomInterfaceUnmarshaler[[]byte](func(dst *[]byte, f func(interface{}) error) error {
var str string
if err := f(&str); err != nil {
t.Fatal(err)
}
if str != "bar" {
t.Fatalf("failed to use unmarshal function. got %q", str)
}
*dst = []byte("bazbaz")
return nil
})); err != nil {
t.Fatal(err)
}
if !bytes.Equal(v.Foo, []byte("bazbaz")) {
t.Fatalf("failed to switch to custom interface unmarshaler. got: %q", v.Foo)
}
})

t.Run("override interface type", func(t *testing.T) {
type I interface{}
type T struct {
Foo string `yaml:"foo"`
}
var i I
src := []byte(`foo: "bar"`)
if err := yaml.UnmarshalWithOptions(src, &i, yaml.CustomInterfaceUnmarshaler[I](func(dst *I, f func(interface{}) error) error {
var v T
if err := f(&v); err != nil {
t.Fatal(err)
}
if v.Foo != "bar" {
t.Fatalf("failed to use unmarshal function. got %q", v.Foo)
}
*dst = &v
return nil
})); err != nil {
t.Fatal(err)
}
if v, ok := i.(*T); ok {
if v.Foo != "bar" {
t.Fatalf("failed to decode with custom interface unmarshaler. got: %q", v.Foo)
}
} else {
t.Fatalf("failed to switch to custom interface unmarshaler.")
}
})
}

type unmarshalContext struct {
Expand Down
82 changes: 60 additions & 22 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@ const (

// Encoder writes YAML values to an output stream.
type Encoder struct {
writer io.Writer
opts []EncodeOption
indent int
indentSequence bool
singleQuote bool
isFlowStyle bool
isJSONStyle bool
useJSONMarshaler bool
anchorCallback func(*ast.AnchorNode, interface{}) error
anchorPtrToNameMap map[uintptr]string
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
useLiteralStyleIfMultiline bool
commentMap map[*Path][]*Comment
written bool
writer io.Writer
opts []EncodeOption
indent int
indentSequence bool
singleQuote bool
isFlowStyle bool
isJSONStyle bool
useJSONMarshaler bool
anchorCallback func(*ast.AnchorNode, interface{}) error
anchorPtrToNameMap map[uintptr]string
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
customInterfaceMarshalerMap map[reflect.Type]func(interface{}) (interface{}, error)
useLiteralStyleIfMultiline bool
commentMap map[*Path][]*Comment
written bool

line int
column int
Expand All @@ -53,14 +54,15 @@ type Encoder struct {
// The Encoder should be closed after use to flush all data to w.
func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder {
return &Encoder{
writer: w,
opts: opts,
indent: DefaultIndentSpaces,
anchorPtrToNameMap: map[uintptr]string{},
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
line: 1,
column: 1,
offset: 0,
writer: w,
opts: opts,
indent: DefaultIndentSpaces,
anchorPtrToNameMap: map[uintptr]string{},
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
customInterfaceMarshalerMap: map[reflect.Type]func(interface{}) (interface{}, error){},
line: 1,
column: 1,
offset: 0,
}
}

Expand Down Expand Up @@ -301,13 +303,42 @@ func (e *Encoder) marshalerFromCustomMarshalerMap(t reflect.Type) (func(interfac
return nil, false
}

func (e *Encoder) existsTypeInCustomInterfaceMarshalerMap(t reflect.Type) bool {
if _, exists := e.customInterfaceMarshalerMap[t]; exists {
return true
}

globalCustomInterfaceMarshalerMu.Lock()
defer globalCustomInterfaceMarshalerMu.Unlock()
if _, exists := globalCustomInterfaceMarshalerMap[t]; exists {
return true
}
return false
}

func (e *Encoder) marshalerFromCustomInterfaceMarshalerMap(t reflect.Type) (func(interface{}) (interface{}, error), bool) {
if marshaler, exists := e.customInterfaceMarshalerMap[t]; exists {
return marshaler, exists
}

globalCustomInterfaceMarshalerMu.Lock()
defer globalCustomInterfaceMarshalerMu.Unlock()
if marshaler, exists := globalCustomInterfaceMarshalerMap[t]; exists {
return marshaler, exists
}
return nil, false
}

func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool {
if !v.CanInterface() {
return false
}
if e.existsTypeInCustomMarshalerMap(v.Type()) {
return true
}
if e.existsTypeInCustomInterfaceMarshalerMap(v.Type()) {
return true
}
iface := v.Interface()
switch iface.(type) {
case BytesMarshalerContext:
Expand Down Expand Up @@ -344,6 +375,13 @@ func (e *Encoder) encodeByMarshaler(ctx context.Context, v reflect.Value, column
}
return node, nil
}
if marshaler, exists := e.marshalerFromCustomInterfaceMarshalerMap(v.Type()); exists {
marshalV, err := marshaler(iface)
if err != nil {
return nil, errors.Wrapf(err, "failed to MarshalYAML")
}
return e.encodeValue(ctx, reflect.ValueOf(marshalV), column)
}

if marshaler, ok := iface.(BytesMarshalerContext); ok {
doc, err := marshaler.MarshalYAML(ctx)
Expand Down
Loading
Loading