diff --git a/internal/converter/testdata/oneof.go b/internal/converter/testdata/oneof.go index a1824c40..95f63a49 100644 --- a/internal/converter/testdata/oneof.go +++ b/internal/converter/testdata/oneof.go @@ -5,6 +5,9 @@ const OneOf = `{ "$ref": "#/definitions/OneOf", "definitions": { "OneOf": { + "required": [ + "something" + ], "properties": { "bar": { "$ref": "#/definitions/samples.OneOf.Bar", @@ -20,15 +23,35 @@ const OneOf = `{ }, "additionalProperties": true, "type": "object", - "oneOf": [ - { - "required": [ - "bar" - ] - }, + "allOf": [ { - "required": [ - "baz" + "oneOf": [ + { + "not": { + "anyOf": [ + { + "required": [ + "bar" + ] + }, + { + "required": [ + "baz" + ] + } + ] + } + }, + { + "required": [ + "bar" + ] + }, + { + "required": [ + "baz" + ] + } ] } ], diff --git a/internal/converter/types.go b/internal/converter/types.go index 50490dc9..0b137087 100644 --- a/internal/converter/types.go +++ b/internal/converter/types.go @@ -625,13 +625,21 @@ func (c *Converter) recursiveConvertMessageType(curPkg *ProtoPackage, msgDesc *d c.logger.WithField("field_name", fieldDesc.GetName()).WithField("type", recursedJSONSchemaType.Type).Trace("Converted field") // If this field is part of a OneOf declaration then build that here: - if c.Flags.EnforceOneOf && fieldDesc.OneofIndex != nil { + if c.Flags.EnforceOneOf && fieldDesc.OneofIndex != nil && !fieldDesc.GetProto3Optional() { + for { + if *fieldDesc.OneofIndex < int32(len(jsonSchemaType.AllOf)) { + break + } + var notAnyOf = &jsonschema.Type{Not: &jsonschema.Type{AnyOf: []*jsonschema.Type{}}} + jsonSchemaType.AllOf = append(jsonSchemaType.AllOf, &jsonschema.Type{OneOf: []*jsonschema.Type{notAnyOf}}) + } if c.Flags.UseJSONFieldnamesOnly { - jsonSchemaType.OneOf = append(jsonSchemaType.OneOf, &jsonschema.Type{Required: []string{fieldDesc.GetJsonName()}}) + jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf = append(jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf, &jsonschema.Type{Required: []string{fieldDesc.GetJsonName()}}) + jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf[0].Not.AnyOf = append(jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf[0].Not.AnyOf, &jsonschema.Type{Required: []string{fieldDesc.GetJsonName()}}) } else { - jsonSchemaType.OneOf = append(jsonSchemaType.OneOf, &jsonschema.Type{Required: []string{fieldDesc.GetName()}}) + jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf = append(jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf, &jsonschema.Type{Required: []string{fieldDesc.GetName()}}) + jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf[0].Not.AnyOf = append(jsonSchemaType.AllOf[*fieldDesc.OneofIndex].OneOf[0].Not.AnyOf, &jsonschema.Type{Required: []string{fieldDesc.GetName()}}) } - } // Figure out which field names we want to use: @@ -646,9 +654,13 @@ func (c *Converter) recursiveConvertMessageType(curPkg *ProtoPackage, msgDesc *d } // Enforce all_fields_required: - if messageFlags.AllFieldsRequired && len(jsonSchemaType.OneOf) == 0 && jsonSchemaType.Properties != nil { - for _, property := range jsonSchemaType.Properties.Keys() { - jsonSchemaType.Required = append(jsonSchemaType.Required, property) + if messageFlags.AllFieldsRequired { + if fieldDesc.OneofIndex == nil && !fieldDesc.GetProto3Optional() { + if c.Flags.UseJSONFieldnamesOnly { + jsonSchemaType.Required = append(jsonSchemaType.Required, fieldDesc.GetJsonName()) + } else { + jsonSchemaType.Required = append(jsonSchemaType.Required, fieldDesc.GetName()) + } } } diff --git a/jsonschemas/OneOf.json b/jsonschemas/OneOf.json index 298bdd1a..bc6f3c38 100644 --- a/jsonschemas/OneOf.json +++ b/jsonschemas/OneOf.json @@ -18,15 +18,35 @@ }, "additionalProperties": true, "type": "object", - "oneOf": [ + "allOf": [ { - "required": [ - "bar" - ] - }, - { - "required": [ - "baz" + "oneOf": [ + { + "not": { + "anyOf": [ + { + "required": [ + "bar" + ] + }, + { + "required": [ + "baz" + ] + } + ] + } + }, + { + "required": [ + "bar" + ] + }, + { + "required": [ + "baz" + ] + } ] } ],