Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ COPY . .

RUN make release-binary

FROM alpine:3.15.0 AS gomplate
FROM alpine:3.15.1 AS gomplate

ARG TARGETOS
ARG TARGETARCH
Expand All @@ -33,7 +33,7 @@ RUN wget -O /usr/local/bin/gomplate \
&& chmod +x /usr/local/bin/gomplate


FROM alpine:3.15.0
FROM alpine:3.15.1

# Dex connectors, such as GitHub and Google logins require root certificates.
# Proper installations should manage those certificates, but it's a bad user
Expand Down
39 changes: 35 additions & 4 deletions connector/openshift/openshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import (
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
)

const (
wellKnownURLPath = "/.well-known/oauth-authorization-server"
usersURLPath = "/apis/user.openshift.io/v1/users/~"
)

// Config holds configuration options for OpenShift login
type Config struct {
Issuer string `json:"issuer"`
Expand All @@ -32,7 +37,10 @@ type Config struct {
RootCA string `json:"rootCA"`
}

var _ connector.CallbackConnector = (*openshiftConnector)(nil)
var (
_ connector.CallbackConnector = (*openshiftConnector)(nil)
_ connector.RefreshConnector = (*openshiftConnector)(nil)
)

type openshiftConnector struct {
apiURL string
Expand Down Expand Up @@ -61,7 +69,7 @@ type user struct {
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
ctx, cancel := context.WithCancel(context.Background())

wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + "/.well-known/oauth-authorization-server"
wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath
req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil)

openshiftConnector := openshiftConnector{
Expand Down Expand Up @@ -154,8 +162,23 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}

client := c.oauth2Config.Client(ctx, token)
return c.identity(ctx, s, token)
}

func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) {
var token oauth2.Token
err := json.Unmarshal(oldID.ConnectorData, &token)
if err != nil {
return connector.Identity{}, fmt.Errorf("parsing token: %w", err)
}
if c.httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
}
return c.identity(ctx, s, &token)
}

func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) {
client := c.oauth2Config.Client(ctx, token)
user, err := c.user(ctx, client)
if err != nil {
return identity, fmt.Errorf("openshift: get user: %v", err)
Expand All @@ -177,12 +200,20 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
Groups: user.Groups,
}

if s.OfflineAccess {
connData, err := json.Marshal(token)
if err != nil {
return identity, fmt.Errorf("marshal connector data: %v", err)
}
identity.ConnectorData = connData
}

return identity, nil
}

// user function returns the OpenShift user associated with the authenticated user
func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) {
url := c.apiURL + "/apis/user.openshift.io/v1/users/~"
url := c.apiURL + usersURLPath

req, err := http.NewRequest("GET", url, nil)
if err != nil {
Expand Down
79 changes: 79 additions & 0 deletions connector/openshift/openshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"reflect"
"testing"
"time"

"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -184,6 +185,78 @@ func TestCallbackIdentity(t *testing.T) {
expectEquals(t, identity.Groups[0], "users")
}

func TestRefreshIdentity(t *testing.T) {
s := newTestServer(map[string]interface{}{
usersURLPath: user{
ObjectMeta: k8sapi.ObjectMeta{
Name: "jdoe",
UID: "12345",
},
FullName: "John Doe",
Groups: []string{"users"},
},
})
defer s.Close()

h, err := newHTTPClient(true, "")
expectNil(t, err)

oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
},
}}

data, err := json.Marshal(oauth2.Token{AccessToken: "fFAGRNJru1FTz70BzhT3Zg"})
expectNil(t, err)

oldID := connector.Identity{ConnectorData: data}

identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)

expectNil(t, err)
expectEquals(t, identity.UserID, "12345")
expectEquals(t, identity.Username, "jdoe")
expectEquals(t, identity.PreferredUsername, "jdoe")
expectEquals(t, identity.Email, "jdoe")
expectEquals(t, len(identity.Groups), 1)
expectEquals(t, identity.Groups[0], "users")
}

func TestRefreshIdentityFailure(t *testing.T) {
s := newTestServer(map[string]interface{}{
usersURLPath: user{
ObjectMeta: k8sapi.ObjectMeta{
Name: "jdoe",
UID: "12345",
},
FullName: "John Doe",
Groups: []string{"users"},
},
})
defer s.Close()

h, err := newHTTPClient(true, "")
expectNil(t, err)

oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
},
}}

data, err := json.Marshal(oauth2.Token{AccessToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", Expiry: time.Now().Add(-time.Hour)})
expectNil(t, err)

oldID := connector.Identity{ConnectorData: data}

identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)
expectNotNil(t, err)
expectEquals(t, connector.Identity{}, identity)
}

func newTestServer(responses map[string]interface{}) *httptest.Server {
var s *httptest.Server
s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -216,3 +289,9 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) {
t.Errorf("Expected %+v to equal %+v", a, b)
}
}

func expectNotNil(t *testing.T, a interface{}) {
if a == nil {
t.Errorf("Expected %+v to not equal nil", a)
}
}