diff --git a/handlers.go b/handlers.go index 7f207ee..0ad1735 100644 --- a/handlers.go +++ b/handlers.go @@ -52,7 +52,8 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st return } - raw, err := json.Marshal(resource.response(resourceType)) + location := resourceLocation(resourceType, id, s.baseURL) + raw, err := json.Marshal(resource.response(resourceType, location)) if err != nil { s.errorHandler(w, &errors.ScimErrorInternal) s.log.Error( @@ -97,7 +98,8 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id return } - raw, err := json.Marshal(resource.response(resourceType)) + location := resourceLocation(resourceType, id, s.baseURL) + raw, err := json.Marshal(resource.response(resourceType, location)) if err != nil { s.errorHandler(w, &errors.ScimErrorInternal) s.log.Error( @@ -141,7 +143,8 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso return } - raw, err := json.Marshal(resource.response(resourceType)) + location := resourceLocation(resourceType, resource.ID, s.baseURL) + raw, err := json.Marshal(resource.response(resourceType, location)) if err != nil { s.errorHandler(w, &errors.ScimErrorInternal) s.log.Error( @@ -152,6 +155,7 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso return } + w.Header().Set("Location", location) if resource.Meta.Version != "" { w.Header().Set("Etag", resource.Meta.Version) } @@ -185,7 +189,8 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st return } - raw, err := json.Marshal(resource.response(resourceType)) + location := resourceLocation(resourceType, id, s.baseURL) + raw, err := json.Marshal(resource.response(resourceType, location)) if err != nil { s.errorHandler(w, &errors.ScimErrorInternal) s.log.Error( @@ -305,7 +310,7 @@ func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, reso lr := listResponse{ TotalResults: page.TotalResults, - Resources: page.resources(resourceType), + Resources: page.resources(resourceType, s.baseURL), StartIndex: params.StartIndex, ItemsPerPage: params.Count, } diff --git a/handlers_test.go b/handlers_test.go index 112cd96..f040a1c 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -279,6 +279,22 @@ func TestServerResourceGetHandlerNotFound(t *testing.T) { assertEqualSCIMErrors(t, expectedError, scimErr) } +func TestServerResourceGetHandlerWithBaseURL(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/Users/0001", nil) + rr := httptest.NewRecorder() + newTestServerWithBaseURL(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var resource map[string]interface{} + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource)) + + meta, ok := resource["meta"].(map[string]interface{}) + assertTypeOk(t, ok, "object") + + assertEqual(t, "https://example.com/v2/Users/0001", meta["location"]) +} + func TestServerResourcePatchHandlerFailOnBadType(t *testing.T) { req := httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], @@ -529,6 +545,31 @@ func TestServerResourcePatchHandlerValidRemoveOp(t *testing.T) { assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } +func TestServerResourcePatchHandlerWithBaseURL(t *testing.T) { + req := httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ + "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], + "Operations":[ + { + "op":"replace", + "path":"active", + "value":false + } + ] + }`)) + rr := httptest.NewRecorder() + newTestServerWithBaseURL(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var resource map[string]interface{} + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource)) + + meta, ok := resource["meta"].(map[string]interface{}) + assertTypeOk(t, ok, "object") + + assertEqual(t, "https://example.com/v2/Users/0001", meta["location"]) +} + func TestServerResourcePostHandlerMissingSchemas(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/Users", strings.NewReader(`{"userName": "test1"}`)) rr := httptest.NewRecorder() @@ -599,10 +640,33 @@ func TestServerResourcePostHandlerValid(t *testing.T) { assertEqual(t, fmt.Sprintf("v%s", resource["id"]), meta["version"]) // ETag and version needs to be the same. assertEqual(t, rr.Header().Get("Etag"), meta["version"]) + // Location header must match meta.location (RFC 7644 Section 3.3). + assertEqual(t, meta["location"], rr.Header().Get("Location")) }) } } +func TestServerResourcePostHandlerWithBaseURL(t *testing.T) { + body := `{"userName": "test1", "schemas":["urn:ietf:params:scim:schemas:core:2.0:User"]}` + req := httptest.NewRequest(http.MethodPost, "/Users", strings.NewReader(body)) + rr := httptest.NewRecorder() + newTestServerWithBaseURL(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusCreated, rr.Code) + + var resource map[string]interface{} + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource)) + + meta, ok := resource["meta"].(map[string]interface{}) + assertTypeOk(t, ok, "object") + + location, ok := meta["location"].(string) + assertTypeOk(t, ok, "string") + + assertEqual(t, fmt.Sprintf("https://example.com/v2/Users/%s", resource["id"]), location) + assertEqual(t, location, rr.Header().Get("Location")) +} + func TestServerResourcePostHandlerWithExtension(t *testing.T) { body := `{ "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User", "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"], @@ -701,6 +765,23 @@ func TestServerResourcePutHandlerValid(t *testing.T) { } } +func TestServerResourcePutHandlerWithBaseURL(t *testing.T) { + body := `{"userName": "test1", "schemas":["urn:ietf:params:scim:schemas:core:2.0:User"]}` + req := httptest.NewRequest(http.MethodPut, "/Users/0001", strings.NewReader(body)) + rr := httptest.NewRecorder() + newTestServerWithBaseURL(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var resource map[string]interface{} + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource)) + + meta, ok := resource["meta"].(map[string]interface{}) + assertTypeOk(t, ok, "object") + + assertEqual(t, "https://example.com/v2/Users/0001", meta["location"]) +} + func TestServerResourceTypeHandlerValid(t *testing.T) { tests := []struct { name string @@ -856,6 +937,29 @@ func TestServerResourcesGetHandlerPagination(t *testing.T) { assertEqual(t, 20, response.TotalResults) } +func TestServerResourcesGetHandlerWithBaseURL(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/Users?count=2&startIndex=1", nil) + rr := httptest.NewRecorder() + newTestServerWithBaseURL(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response struct { + Resources []map[string]interface{} + } + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + + for _, resource := range response.Resources { + meta, ok := resource["meta"].(map[string]interface{}) + assertTypeOk(t, ok, "object") + + location, ok := meta["location"].(string) + assertTypeOk(t, ok, "string") + + assertTrue(t, strings.HasPrefix(location, "https://example.com/v2/Users/")) + } +} + func TestServerSchemaEndpointValid(t *testing.T) { tests := []struct { name string @@ -1135,6 +1239,29 @@ func newTestServer(t *testing.T) Server { return s } +func newTestServerWithBaseURL(t *testing.T) Server { + userSchema := getUserSchema() + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Schema: userSchema, + Handler: newTestResourceHandler(), + }, + }, + }, + WithBaseURL("https://example.com/v2"), + ) + if err != nil { + t.Fatal(err) + } + return s +} + // statusRecordingResponseWriter wraps an http.ResponseWriter and records // whether WriteHeader was called explicitly, simulating logging middleware. type statusRecordingResponseWriter struct { diff --git a/list_response.go b/list_response.go index 56448f1..99828ee 100644 --- a/list_response.go +++ b/list_response.go @@ -12,7 +12,7 @@ type Page struct { Resources []Resource } -func (p Page) resources(resourceType ResourceType) []interface{} { +func (p Page) resources(resourceType ResourceType, baseURL string) []interface{} { // If the page.Resources is nil, then it will also be represented as a `null` in the response. // Otherwise is it is an empty slice then it will result in an empty array `[]`. if len(p.Resources) == 0 { @@ -24,9 +24,10 @@ func (p Page) resources(resourceType ResourceType) []interface{} { var resources []interface{} for _, v := range p.Resources { + location := resourceLocation(resourceType, v.ID, baseURL) resources = append( resources, - v.response(resourceType), + v.response(resourceType, location), ) } return resources diff --git a/resource_handler.go b/resource_handler.go index 1d0cd15..b5aa131 100644 --- a/resource_handler.go +++ b/resource_handler.go @@ -1,9 +1,7 @@ package scim import ( - "fmt" "net/http" - "net/url" "time" "github.com/elimity-com/scim/filter" @@ -47,7 +45,7 @@ type Resource struct { Meta Meta } -func (r Resource) response(resourceType ResourceType) ResourceAttributes { +func (r Resource) response(resourceType ResourceType, location string) ResourceAttributes { response := r.Attributes if response == nil { response = ResourceAttributes{} @@ -66,7 +64,7 @@ func (r Resource) response(resourceType ResourceType) ResourceAttributes { m := meta{ ResourceType: resourceType.Name, - Location: fmt.Sprintf("%s/%s", resourceType.Endpoint[1:], url.PathEscape(r.ID)), + Location: location, } if r.Meta.Created != nil { diff --git a/server.go b/server.go index a13f5b0..66f8d88 100644 --- a/server.go +++ b/server.go @@ -52,12 +52,26 @@ func parseIdentifier(path, endpoint string) (string, error) { return url.PathUnescape(strings.TrimPrefix(path, endpoint+"/")) } +func resourceLocation(resourceType ResourceType, id string, baseURL string) string { + relativePath := resourceType.Endpoint[1:] + "/" + url.PathEscape(id) + if baseURL == "" { + return relativePath + } + u, err := url.Parse(baseURL) + if err != nil { + return relativePath + } + u.Path = u.Path + "/" + relativePath + return u.String() +} + // Server represents a SCIM server which implements the HTTP-based SCIM protocol // that makes managing identities in multi-domain scenarios easier to support via a standardized service. type Server struct { config ServiceProviderConfig resourceTypes []ResourceType log Logger + baseURL string } func NewServer(args *ServerArgs, opts ...ServerOption) (Server, error) { @@ -250,6 +264,15 @@ type ServerArgs struct { type ServerOption func(*Server) +// WithBaseURL configures the server to use absolute URIs for resource +// locations. The base URL is prepended to all meta.location values and +// Location headers. For example, "https://example.com/v2". +func WithBaseURL(baseURL string) ServerOption { + return func(s *Server) { + s.baseURL = strings.TrimRight(baseURL, "/") + } +} + // WithLogger sets the logger for the server. func WithLogger(logger Logger) ServerOption { return func(s *Server) {