Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st
return
}

raw, err := json.Marshal(resource.response(resourceType))
location := resourceLocation(resourceType, id, s.baseURL)
raw, err := json.Marshal(resource.response(resourceType, location))
if err != nil {
s.errorHandler(w, &errors.ScimErrorInternal)
s.log.Error(
Expand Down Expand Up @@ -97,7 +98,8 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id
return
}

raw, err := json.Marshal(resource.response(resourceType))
location := resourceLocation(resourceType, id, s.baseURL)
raw, err := json.Marshal(resource.response(resourceType, location))
if err != nil {
s.errorHandler(w, &errors.ScimErrorInternal)
s.log.Error(
Expand Down Expand Up @@ -141,7 +143,8 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso
return
}

raw, err := json.Marshal(resource.response(resourceType))
location := resourceLocation(resourceType, resource.ID, s.baseURL)
raw, err := json.Marshal(resource.response(resourceType, location))
if err != nil {
s.errorHandler(w, &errors.ScimErrorInternal)
s.log.Error(
Expand All @@ -152,6 +155,7 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso
return
}

w.Header().Set("Location", location)
if resource.Meta.Version != "" {
w.Header().Set("Etag", resource.Meta.Version)
}
Expand Down Expand Up @@ -185,7 +189,8 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st
return
}

raw, err := json.Marshal(resource.response(resourceType))
location := resourceLocation(resourceType, id, s.baseURL)
raw, err := json.Marshal(resource.response(resourceType, location))
if err != nil {
s.errorHandler(w, &errors.ScimErrorInternal)
s.log.Error(
Expand Down Expand Up @@ -305,7 +310,7 @@ func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, reso

lr := listResponse{
TotalResults: page.TotalResults,
Resources: page.resources(resourceType),
Resources: page.resources(resourceType, s.baseURL),
StartIndex: params.StartIndex,
ItemsPerPage: params.Count,
}
Expand Down
127 changes: 127 additions & 0 deletions handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,22 @@ func TestServerResourceGetHandlerNotFound(t *testing.T) {
assertEqualSCIMErrors(t, expectedError, scimErr)
}

func TestServerResourceGetHandlerWithBaseURL(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/Users/0001", nil)
rr := httptest.NewRecorder()
newTestServerWithBaseURL(t).ServeHTTP(rr, req)

assertEqualStatusCode(t, http.StatusOK, rr.Code)

var resource map[string]interface{}
assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource))

meta, ok := resource["meta"].(map[string]interface{})
assertTypeOk(t, ok, "object")

assertEqual(t, "https://example.com/v2/Users/0001", meta["location"])
}

func TestServerResourcePatchHandlerFailOnBadType(t *testing.T) {
req := httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
Expand Down Expand Up @@ -529,6 +545,31 @@ func TestServerResourcePatchHandlerValidRemoveOp(t *testing.T) {
assertEqualStatusCode(t, http.StatusNoContent, rr.Code)
}

func TestServerResourcePatchHandlerWithBaseURL(t *testing.T) {
req := httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations":[
{
"op":"replace",
"path":"active",
"value":false
}
]
}`))
rr := httptest.NewRecorder()
newTestServerWithBaseURL(t).ServeHTTP(rr, req)

assertEqualStatusCode(t, http.StatusOK, rr.Code)

var resource map[string]interface{}
assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource))

meta, ok := resource["meta"].(map[string]interface{})
assertTypeOk(t, ok, "object")

assertEqual(t, "https://example.com/v2/Users/0001", meta["location"])
}

func TestServerResourcePostHandlerMissingSchemas(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/Users", strings.NewReader(`{"userName": "test1"}`))
rr := httptest.NewRecorder()
Expand Down Expand Up @@ -599,10 +640,33 @@ func TestServerResourcePostHandlerValid(t *testing.T) {
assertEqual(t, fmt.Sprintf("v%s", resource["id"]), meta["version"])
// ETag and version needs to be the same.
assertEqual(t, rr.Header().Get("Etag"), meta["version"])
// Location header must match meta.location (RFC 7644 Section 3.3).
assertEqual(t, meta["location"], rr.Header().Get("Location"))
})
}
}

func TestServerResourcePostHandlerWithBaseURL(t *testing.T) {
body := `{"userName": "test1", "schemas":["urn:ietf:params:scim:schemas:core:2.0:User"]}`
req := httptest.NewRequest(http.MethodPost, "/Users", strings.NewReader(body))
rr := httptest.NewRecorder()
newTestServerWithBaseURL(t).ServeHTTP(rr, req)

assertEqualStatusCode(t, http.StatusCreated, rr.Code)

var resource map[string]interface{}
assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource))

meta, ok := resource["meta"].(map[string]interface{})
assertTypeOk(t, ok, "object")

location, ok := meta["location"].(string)
assertTypeOk(t, ok, "string")

assertEqual(t, fmt.Sprintf("https://example.com/v2/Users/%s", resource["id"]), location)
assertEqual(t, location, rr.Header().Get("Location"))
}

func TestServerResourcePostHandlerWithExtension(t *testing.T) {
body := `{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User", "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"],
Expand Down Expand Up @@ -701,6 +765,23 @@ func TestServerResourcePutHandlerValid(t *testing.T) {
}
}

func TestServerResourcePutHandlerWithBaseURL(t *testing.T) {
body := `{"userName": "test1", "schemas":["urn:ietf:params:scim:schemas:core:2.0:User"]}`
req := httptest.NewRequest(http.MethodPut, "/Users/0001", strings.NewReader(body))
rr := httptest.NewRecorder()
newTestServerWithBaseURL(t).ServeHTTP(rr, req)

assertEqualStatusCode(t, http.StatusOK, rr.Code)

var resource map[string]interface{}
assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource))

meta, ok := resource["meta"].(map[string]interface{})
assertTypeOk(t, ok, "object")

assertEqual(t, "https://example.com/v2/Users/0001", meta["location"])
}

func TestServerResourceTypeHandlerValid(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -856,6 +937,29 @@ func TestServerResourcesGetHandlerPagination(t *testing.T) {
assertEqual(t, 20, response.TotalResults)
}

func TestServerResourcesGetHandlerWithBaseURL(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/Users?count=2&startIndex=1", nil)
rr := httptest.NewRecorder()
newTestServerWithBaseURL(t).ServeHTTP(rr, req)

assertEqualStatusCode(t, http.StatusOK, rr.Code)

var response struct {
Resources []map[string]interface{}
}
assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response))

for _, resource := range response.Resources {
meta, ok := resource["meta"].(map[string]interface{})
assertTypeOk(t, ok, "object")

location, ok := meta["location"].(string)
assertTypeOk(t, ok, "string")

assertTrue(t, strings.HasPrefix(location, "https://example.com/v2/Users/"))
}
}

func TestServerSchemaEndpointValid(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -1135,6 +1239,29 @@ func newTestServer(t *testing.T) Server {
return s
}

func newTestServerWithBaseURL(t *testing.T) Server {
userSchema := getUserSchema()
s, err := NewServer(
&ServerArgs{
ServiceProviderConfig: &ServiceProviderConfig{},
ResourceTypes: []ResourceType{
{
ID: optional.NewString("User"),
Name: "User",
Endpoint: "/Users",
Schema: userSchema,
Handler: newTestResourceHandler(),
},
},
},
WithBaseURL("https://example.com/v2"),
)
if err != nil {
t.Fatal(err)
}
return s
}

// statusRecordingResponseWriter wraps an http.ResponseWriter and records
// whether WriteHeader was called explicitly, simulating logging middleware.
type statusRecordingResponseWriter struct {
Expand Down
5 changes: 3 additions & 2 deletions list_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type Page struct {
Resources []Resource
}

func (p Page) resources(resourceType ResourceType) []interface{} {
func (p Page) resources(resourceType ResourceType, baseURL string) []interface{} {
// If the page.Resources is nil, then it will also be represented as a `null` in the response.
// Otherwise is it is an empty slice then it will result in an empty array `[]`.
if len(p.Resources) == 0 {
Expand All @@ -24,9 +24,10 @@ func (p Page) resources(resourceType ResourceType) []interface{} {

var resources []interface{}
for _, v := range p.Resources {
location := resourceLocation(resourceType, v.ID, baseURL)
resources = append(
resources,
v.response(resourceType),
v.response(resourceType, location),
)
}
return resources
Expand Down
6 changes: 2 additions & 4 deletions resource_handler.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package scim

import (
"fmt"
"net/http"
"net/url"
"time"

"github.com/elimity-com/scim/filter"
Expand Down Expand Up @@ -47,7 +45,7 @@ type Resource struct {
Meta Meta
}

func (r Resource) response(resourceType ResourceType) ResourceAttributes {
func (r Resource) response(resourceType ResourceType, location string) ResourceAttributes {
response := r.Attributes
if response == nil {
response = ResourceAttributes{}
Expand All @@ -66,7 +64,7 @@ func (r Resource) response(resourceType ResourceType) ResourceAttributes {

m := meta{
ResourceType: resourceType.Name,
Location: fmt.Sprintf("%s/%s", resourceType.Endpoint[1:], url.PathEscape(r.ID)),
Location: location,
}

if r.Meta.Created != nil {
Expand Down
23 changes: 23 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,26 @@ func parseIdentifier(path, endpoint string) (string, error) {
return url.PathUnescape(strings.TrimPrefix(path, endpoint+"/"))
}

func resourceLocation(resourceType ResourceType, id string, baseURL string) string {
relativePath := resourceType.Endpoint[1:] + "/" + url.PathEscape(id)
if baseURL == "" {
return relativePath
}
u, err := url.Parse(baseURL)
if err != nil {
return relativePath
}
u.Path = u.Path + "/" + relativePath
return u.String()
}

// Server represents a SCIM server which implements the HTTP-based SCIM protocol
// that makes managing identities in multi-domain scenarios easier to support via a standardized service.
type Server struct {
config ServiceProviderConfig
resourceTypes []ResourceType
log Logger
baseURL string
}

func NewServer(args *ServerArgs, opts ...ServerOption) (Server, error) {
Expand Down Expand Up @@ -250,6 +264,15 @@ type ServerArgs struct {

type ServerOption func(*Server)

// WithBaseURL configures the server to use absolute URIs for resource
// locations. The base URL is prepended to all meta.location values and
// Location headers. For example, "https://example.com/v2".
func WithBaseURL(baseURL string) ServerOption {
return func(s *Server) {
s.baseURL = strings.TrimRight(baseURL, "/")
}
}

// WithLogger sets the logger for the server.
func WithLogger(logger Logger) ServerOption {
return func(s *Server) {
Expand Down
Loading