diff --git a/optional/string.go b/optional/string.go index bc010c9..539d0e2 100644 --- a/optional/string.go +++ b/optional/string.go @@ -1,5 +1,7 @@ package optional +import "encoding/json" + // String represents an optional string value. type String struct { value string @@ -19,6 +21,19 @@ func (s String) Present() bool { return s.present } +// UnmarshalJSON implements the json.Unmarshaler interface. +func (s *String) UnmarshalJSON(data []byte) error { + var v *string + if err := json.Unmarshal(data, &v); err != nil { + return err + } + if v != nil && *v != "" { + s.value = *v + s.present = true + } + return nil +} + // Value returns the value of the optional string. func (s String) Value() string { return s.value diff --git a/schema/characteristics.go b/schema/characteristics.go index 09091d7..88b94be 100644 --- a/schema/characteristics.go +++ b/schema/characteristics.go @@ -7,15 +7,20 @@ import ( ) func checkAttributeName(name string) { - // starts w/ a A-Za-z followed by a A-Za-z0-9, a dollar sign, a hyphen or an underscore - match, err := regexp.MatchString(`^[A-Za-z][\w$-]*$`, name) - if err != nil { + if err := validateAttributeName(name); err != nil { panic(err) } +} +func validateAttributeName(name string) error { + match, err := regexp.MatchString(`^[A-Za-z][\w$-]*$`, name) + if err != nil { + return err + } if !match { - panic(fmt.Sprintf("invalid attribute name %q", name)) + return fmt.Errorf("invalid attribute name %q", name) } + return nil } // AttributeDataType is a single keyword indicating the derived data type from JSON. diff --git a/schema/core.go b/schema/core.go index ef6a00b..61d5291 100644 --- a/schema/core.go +++ b/schema/core.go @@ -45,30 +45,9 @@ type CoreAttribute struct { func ComplexCoreAttribute(params ComplexParams) CoreAttribute { checkAttributeName(params.Name) - names := map[string]int{} - var sa []CoreAttribute - - for i, a := range params.SubAttributes { - name := strings.ToLower(a.name) - if j, ok := names[name]; ok { - panic(fmt.Errorf("duplicate name %q for sub-attributes %d and %d", name, i, j)) - } - - names[name] = i - - sa = append(sa, CoreAttribute{ - canonicalValues: a.canonicalValues, - caseExact: a.caseExact, - description: a.description, - multiValued: a.multiValued, - mutability: a.mutability, - name: a.name, - referenceTypes: a.referenceTypes, - required: a.required, - returned: a.returned, - typ: a.typ, - uniqueness: a.uniqueness, - }) + sa, err := buildSubAttributes(params.SubAttributes) + if err != nil { + panic(err) } return CoreAttribute{ @@ -103,6 +82,36 @@ func SimpleCoreAttribute(params SimpleParams) CoreAttribute { } } +func buildSubAttributes(subAttributes []SimpleParams) ([]CoreAttribute, error) { + names := map[string]int{} + var sa []CoreAttribute + + for i, a := range subAttributes { + name := strings.ToLower(a.name) + if j, ok := names[name]; ok { + return nil, fmt.Errorf("duplicate name %q for sub-attributes %d and %d", name, i, j) + } + + names[name] = i + + sa = append(sa, CoreAttribute{ + canonicalValues: a.canonicalValues, + caseExact: a.caseExact, + description: a.description, + multiValued: a.multiValued, + mutability: a.mutability, + name: a.name, + referenceTypes: a.referenceTypes, + required: a.required, + returned: a.returned, + typ: a.typ, + uniqueness: a.uniqueness, + }) + } + + return sa, nil +} + // AttributeType returns the attribute type. func (a CoreAttribute) AttributeType() string { return a.typ.String() diff --git a/schema/schema.go b/schema/schema.go index b07c05e..91952bc 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -2,6 +2,7 @@ package schema import ( "encoding/json" + "fmt" "strings" "github.com/elimity-com/scim/errors" @@ -28,9 +29,32 @@ func isReadOnly(attr CoreAttribute) bool { return attr.mutability == attributeMutabilityReadOnly } +// validateUnmarshalAttributeName validates an attribute name for JSON +// unmarshaling. It applies the same rules as the constructors but also +// accepts "$ref", which RFC 7643 Section 2.4 defines as a standard +// sub-attribute despite it violating the ABNF grammar. +func validateUnmarshalAttributeName(name string) error { + if name == "$ref" { + return nil + } + return validateAttributeName(name) +} + // Attributes represent a list of Core Attributes. type Attributes []CoreAttribute +func unmarshalAttributes(rawAttrs []json.RawMessage) (Attributes, error) { + attrs := make(Attributes, 0, len(rawAttrs)) + for _, raw := range rawAttrs { + a, err := unmarshalCoreAttribute(raw) + if err != nil { + return nil, err + } + attrs = append(attrs, a) + } + return attrs, nil +} + // ContainsAttribute checks whether the list of Core Attributes contains an attribute with the given name. func (as Attributes) ContainsAttribute(name string) (CoreAttribute, bool) { for _, a := range as { @@ -41,6 +65,91 @@ func (as Attributes) ContainsAttribute(name string) (CoreAttribute, bool) { return CoreAttribute{}, false } +func unmarshalCoreAttribute(data json.RawMessage) (CoreAttribute, error) { + var raw struct { + Name string `json:"name"` + Type string `json:"type"` + Description optional.String `json:"description"` + MultiValued bool `json:"multiValued"` + Required bool `json:"required"` + CaseExact bool `json:"caseExact"` + Mutability string `json:"mutability"` + Returned string `json:"returned"` + Uniqueness string `json:"uniqueness"` + CanonicalValues []string `json:"canonicalValues"` + ReferenceTypes []string `json:"referenceTypes"` + SubAttributes []json.RawMessage `json:"subAttributes"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return CoreAttribute{}, err + } + + if err := validateUnmarshalAttributeName(raw.Name); err != nil { + return CoreAttribute{}, err + } + + typ, err := parseAttributeType(raw.Type) + if err != nil { + return CoreAttribute{}, err + } + + mut, err := parseAttributeMutability(raw.Mutability) + if err != nil { + return CoreAttribute{}, err + } + + ret, err := parseAttributeReturned(raw.Returned) + if err != nil { + return CoreAttribute{}, err + } + + uniq, err := parseAttributeUniqueness(raw.Uniqueness) + if err != nil { + return CoreAttribute{}, err + } + + if typ == attributeDataTypeComplex { + subParams, err := unmarshalSimpleParams(raw.SubAttributes) + if err != nil { + return CoreAttribute{}, err + } + subAttrs, err := buildSubAttributes(subParams) + if err != nil { + return CoreAttribute{}, err + } + return CoreAttribute{ + description: raw.Description, + multiValued: raw.MultiValued, + mutability: mut, + name: raw.Name, + required: raw.Required, + returned: ret, + subAttributes: subAttrs, + typ: attributeDataTypeComplex, + uniqueness: uniq, + }, nil + } + + var refTypes []AttributeReferenceType + for _, r := range raw.ReferenceTypes { + refTypes = append(refTypes, AttributeReferenceType(r)) + } + + return CoreAttribute{ + canonicalValues: raw.CanonicalValues, + caseExact: raw.CaseExact, + description: raw.Description, + multiValued: raw.MultiValued, + mutability: mut, + name: raw.Name, + referenceTypes: refTypes, + required: raw.Required, + returned: ret, + typ: typ, + uniqueness: uniq, + }, nil +} + // Schema is a collection of attribute definitions that describe the contents of an entire or partial resource. type Schema struct { Attributes Attributes @@ -65,6 +174,30 @@ func (s Schema) ToMap() map[string]interface{} { } } +// UnmarshalJSON parses a JSON-encoded schema into the Schema struct. +func (s *Schema) UnmarshalJSON(data []byte) error { + var raw struct { + ID string `json:"id"` + Name optional.String `json:"name"` + Description optional.String `json:"description"` + Attributes []json.RawMessage `json:"attributes"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + attrs, err := unmarshalAttributes(raw.Attributes) + if err != nil { + return err + } + + s.ID = raw.ID + s.Name = raw.Name + s.Description = raw.Description + s.Attributes = attrs + return nil +} + // Validate validates given resource based on the schema, including the // "schemas" attribute. Does NOT validate mutability. // NOTE: only used in POST and PUT requests where attributes MAY be (re)defined. @@ -203,3 +336,143 @@ func (s Schema) validateSchemaID(resource map[string]interface{}) *errors.ScimEr return nil } + +func unmarshalSimpleParam(data json.RawMessage) (SimpleParams, error) { + var raw struct { + Name string `json:"name"` + Type string `json:"type"` + Description optional.String `json:"description"` + MultiValued bool `json:"multiValued"` + Required bool `json:"required"` + CaseExact bool `json:"caseExact"` + Mutability string `json:"mutability"` + Returned string `json:"returned"` + Uniqueness string `json:"uniqueness"` + CanonicalValues []string `json:"canonicalValues"` + ReferenceTypes []string `json:"referenceTypes"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return SimpleParams{}, err + } + + if err := validateUnmarshalAttributeName(raw.Name); err != nil { + return SimpleParams{}, err + } + + typ, err := parseAttributeType(raw.Type) + if err != nil { + return SimpleParams{}, err + } + + mut, err := parseAttributeMutability(raw.Mutability) + if err != nil { + return SimpleParams{}, err + } + + ret, err := parseAttributeReturned(raw.Returned) + if err != nil { + return SimpleParams{}, err + } + + uniq, err := parseAttributeUniqueness(raw.Uniqueness) + if err != nil { + return SimpleParams{}, err + } + + var refTypes []AttributeReferenceType + for _, r := range raw.ReferenceTypes { + refTypes = append(refTypes, AttributeReferenceType(r)) + } + + return SimpleParams{ + canonicalValues: raw.CanonicalValues, + caseExact: raw.CaseExact, + description: raw.Description, + multiValued: raw.MultiValued, + mutability: mut, + name: raw.Name, + referenceTypes: refTypes, + required: raw.Required, + returned: ret, + typ: typ, + uniqueness: uniq, + }, nil +} + +func unmarshalSimpleParams(rawAttrs []json.RawMessage) ([]SimpleParams, error) { + params := make([]SimpleParams, 0, len(rawAttrs)) + for _, raw := range rawAttrs { + p, err := unmarshalSimpleParam(raw) + if err != nil { + return nil, err + } + params = append(params, p) + } + return params, nil +} + +func parseAttributeMutability(s string) (attributeMutability, error) { + switch s { + case "readWrite", "": + return attributeMutabilityReadWrite, nil + case "immutable": + return attributeMutabilityImmutable, nil + case "readOnly": + return attributeMutabilityReadOnly, nil + case "writeOnly": + return attributeMutabilityWriteOnly, nil + default: + return 0, fmt.Errorf("unknown mutability: %q", s) + } +} + +func parseAttributeReturned(s string) (attributeReturned, error) { + switch s { + case "default", "": + return attributeReturnedDefault, nil + case "always": + return attributeReturnedAlways, nil + case "never": + return attributeReturnedNever, nil + case "request": + return attributeReturnedRequest, nil + default: + return 0, fmt.Errorf("unknown returned: %q", s) + } +} + +func parseAttributeType(s string) (attributeType, error) { + switch s { + case "string": + return attributeDataTypeString, nil + case "boolean": + return attributeDataTypeBoolean, nil + case "decimal": + return attributeDataTypeDecimal, nil + case "integer": + return attributeDataTypeInteger, nil + case "dateTime": + return attributeDataTypeDateTime, nil + case "reference": + return attributeDataTypeReference, nil + case "complex": + return attributeDataTypeComplex, nil + case "binary": + return attributeDataTypeBinary, nil + default: + return 0, fmt.Errorf("unknown attribute type: %q", s) + } +} + +func parseAttributeUniqueness(s string) (attributeUniqueness, error) { + switch s { + case "none", "": + return attributeUniquenessNone, nil + case "server": + return attributeUniquenessServer, nil + case "global": + return attributeUniquenessGlobal, nil + default: + return 0, fmt.Errorf("unknown uniqueness: %q", s) + } +} diff --git a/schema/schema_test.go b/schema/schema_test.go index 686cd03..28c2816 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -103,6 +103,292 @@ func TestJSONMarshalling(t *testing.T) { } } +func TestJSONUnmarshalling(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + originalJSON, err := testSchema.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + var got Schema + if err := json.Unmarshal(originalJSON, &got); err != nil { + t.Fatal(err) + } + + gotJSON, err := got.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + normalizedOriginal, err := normalizeJSON(originalJSON) + if err != nil { + t.Fatal(err) + } + normalizedGot, err := normalizeJSON(gotJSON) + if err != nil { + t.Fatal(err) + } + + if normalizedOriginal != normalizedGot { + t.Errorf("round trip mismatch.\nWant: %s\nGot: %s", normalizedOriginal, normalizedGot) + } + }) + + t.Run("user schema round trip", func(t *testing.T) { + originalJSON, err := CoreUserSchema().MarshalJSON() + if err != nil { + t.Fatal(err) + } + + var got Schema + if err := json.Unmarshal(originalJSON, &got); err != nil { + t.Fatal(err) + } + + gotJSON, err := got.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + normalizedOriginal, err := normalizeJSON(originalJSON) + if err != nil { + t.Fatal(err) + } + normalizedGot, err := normalizeJSON(gotJSON) + if err != nil { + t.Fatal(err) + } + + if normalizedOriginal != normalizedGot { + t.Errorf("round trip mismatch.\nWant: %s\nGot: %s", normalizedOriginal, normalizedGot) + } + }) + + t.Run("group schema round trip", func(t *testing.T) { + originalJSON, err := CoreGroupSchema().MarshalJSON() + if err != nil { + t.Fatal(err) + } + + var got Schema + if err := json.Unmarshal(originalJSON, &got); err != nil { + t.Fatal(err) + } + + gotJSON, err := got.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + normalizedOriginal, err := normalizeJSON(originalJSON) + if err != nil { + t.Fatal(err) + } + normalizedGot, err := normalizeJSON(gotJSON) + if err != nil { + t.Fatal(err) + } + + if normalizedOriginal != normalizedGot { + t.Errorf("round trip mismatch.\nWant: %s\nGot: %s", normalizedOriginal, normalizedGot) + } + }) + + t.Run("from file", func(t *testing.T) { + data, err := os.ReadFile("./testdata/schema_test.json") + if err != nil { + t.Fatal(err) + } + + var got Schema + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + + if got.ID != "test-schema-id" { + t.Errorf("ID: want %q, got %q", "test-schema-id", got.ID) + } + if got.Name.Value() != "test" { + t.Errorf("Name: want %q, got %q", "test", got.Name.Value()) + } + if len(got.Attributes) != 11 { + t.Errorf("Attributes: want 11, got %d", len(got.Attributes)) + } + }) + + t.Run("enterprise user extension round trip", func(t *testing.T) { + originalJSON, err := ExtensionEnterpriseUser().MarshalJSON() + if err != nil { + t.Fatal(err) + } + + var got Schema + if err := json.Unmarshal(originalJSON, &got); err != nil { + t.Fatal(err) + } + + gotJSON, err := got.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + normalizedOriginal, err := normalizeJSON(originalJSON) + if err != nil { + t.Fatal(err) + } + normalizedGot, err := normalizeJSON(gotJSON) + if err != nil { + t.Fatal(err) + } + + if normalizedOriginal != normalizedGot { + t.Errorf("round trip mismatch.\nWant: %s\nGot: %s", normalizedOriginal, normalizedGot) + } + }) + + t.Run("custom schema round trip", func(t *testing.T) { + custom := Schema{ + ID: "urn:example:custom:1.0:Device", + Name: optional.NewString("Device"), + Description: optional.NewString("A custom device resource"), + Attributes: []CoreAttribute{ + SimpleCoreAttribute(SimpleStringParams(StringParams{ + Name: "serialNumber", + Required: true, + Uniqueness: AttributeUniquenessServer(), + CaseExact: true, + })), + SimpleCoreAttribute(SimpleBooleanParams(BooleanParams{ + Name: "active", + Mutability: AttributeMutabilityReadWrite(), + })), + SimpleCoreAttribute(SimpleNumberParams(NumberParams{ + Name: "firmwareVersion", + Type: AttributeTypeDecimal(), + })), + SimpleCoreAttribute(SimpleDateTimeParams(DateTimeParams{ + Name: "lastSeen", + Mutability: AttributeMutabilityReadOnly(), + Returned: AttributeReturnedAlways(), + })), + SimpleCoreAttribute(SimpleReferenceParams(ReferenceParams{ + Name: "owner", + ReferenceTypes: []AttributeReferenceType{AttributeReferenceTypeExternal, AttributeReferenceTypeURI}, + })), + SimpleCoreAttribute(SimpleStringParams(StringParams{ + Name: "status", + CanonicalValues: []string{"online", "offline", "maintenance"}, + })), + ComplexCoreAttribute(ComplexParams{ + Name: "location", + MultiValued: false, + SubAttributes: []SimpleParams{ + SimpleStringParams(StringParams{Name: "building"}), + SimpleNumberParams(NumberParams{ + Name: "floor", + Type: AttributeTypeInteger(), + }), + }, + }), + }, + } + + originalJSON, err := custom.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + var got Schema + if err := json.Unmarshal(originalJSON, &got); err != nil { + t.Fatal(err) + } + + if got.ID != custom.ID { + t.Errorf("ID: want %q, got %q", custom.ID, got.ID) + } + if got.Name.Value() != custom.Name.Value() { + t.Errorf("Name: want %q, got %q", custom.Name.Value(), got.Name.Value()) + } + if got.Description.Value() != custom.Description.Value() { + t.Errorf("Description: want %q, got %q", custom.Description.Value(), got.Description.Value()) + } + + gotJSON, err := got.MarshalJSON() + if err != nil { + t.Fatal(err) + } + + normalizedOriginal, err := normalizeJSON(originalJSON) + if err != nil { + t.Fatal(err) + } + normalizedGot, err := normalizeJSON(gotJSON) + if err != nil { + t.Fatal(err) + } + + if normalizedOriginal != normalizedGot { + t.Errorf("round trip mismatch.\nWant: %s\nGot: %s", normalizedOriginal, normalizedGot) + } + }) + + t.Run("unknown type", func(t *testing.T) { + data := []byte(`{"id":"x","attributes":[{"name":"a","type":"unknown"}]}`) + var got Schema + if err := json.Unmarshal(data, &got); err == nil { + t.Error("expected error for unknown attribute type") + } + }) + + t.Run("unknown mutability", func(t *testing.T) { + data := []byte(`{"id":"x","attributes":[{"name":"a","type":"string","mutability":"unknown"}]}`) + var got Schema + if err := json.Unmarshal(data, &got); err == nil { + t.Error("expected error for unknown mutability") + } + }) + + t.Run("unknown returned", func(t *testing.T) { + data := []byte(`{"id":"x","attributes":[{"name":"a","type":"string","returned":"unknown"}]}`) + var got Schema + if err := json.Unmarshal(data, &got); err == nil { + t.Error("expected error for unknown returned") + } + }) + + t.Run("unknown uniqueness", func(t *testing.T) { + data := []byte(`{"id":"x","attributes":[{"name":"a","type":"string","uniqueness":"unknown"}]}`) + var got Schema + if err := json.Unmarshal(data, &got); err == nil { + t.Error("expected error for unknown uniqueness") + } + }) + + t.Run("invalid attribute name", func(t *testing.T) { + data := []byte(`{"id":"x","attributes":[{"name":"_invalid","type":"string"}]}`) + var got Schema + if err := json.Unmarshal(data, &got); err == nil { + t.Error("expected error for invalid attribute name") + } + }) + + t.Run("invalid sub-attribute name", func(t *testing.T) { + data := []byte(`{"id":"x","attributes":[{"name":"a","type":"complex","subAttributes":[{"name":"1bad","type":"string"}]}]}`) + var got Schema + if err := json.Unmarshal(data, &got); err == nil { + t.Error("expected error for invalid sub-attribute name") + } + }) + + t.Run("duplicate sub-attribute names", func(t *testing.T) { + data := []byte(`{"id":"x","attributes":[{"name":"a","type":"complex","subAttributes":[{"name":"b","type":"string"},{"name":"b","type":"string"}]}]}`) + var got Schema + if err := json.Unmarshal(data, &got); err == nil { + t.Error("expected error for duplicate sub-attribute names") + } + }) +} + func TestResourceInvalid(t *testing.T) { var resource interface{} if _, scimErr := testSchema.Validate(resource); scimErr == nil {