diff --git a/example/client/app/app.go b/example/client/app/app.go index b7f28684..e7be491d 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -54,10 +54,12 @@ func main() { return uuid.New().String() } - // register the AuthURLHandler at your preferred path - // the AuthURLHandler creates the auth request and redirects the user to the auth server - // including state handling with secure cookie and the possibility to use PKCE - http.Handle("/login", rp.AuthURLHandler(state, provider)) + // register the AuthURLHandler at your preferred path. + // the AuthURLHandler creates the auth request and redirects the user to the auth server. + // including state handling with secure cookie and the possibility to use PKCE. + // Prompts can optionally be set to inform the server of + // any messages that need to be prompted back to the user. + http.Handle("/login", rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!"))) // for demonstration purposes the returned userinfo response is written as JSON object onto response marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) { diff --git a/pkg/client/rp/integration_test.go b/pkg/client/rp/integration_test.go index 6f5f4894..732a4bfd 100644 --- a/pkg/client/rp/integration_test.go +++ b/pkg/client/rp/integration_test.go @@ -75,7 +75,10 @@ func TestRelyingPartySession(t *testing.T) { state := "state-" + strconv.FormatInt(seed.Int63(), 25) capturedW := httptest.NewRecorder() get := httptest.NewRequest("GET", localURL.String(), nil) - rp.AuthURLHandler(func() string { return state }, provider)(capturedW, get) + rp.AuthURLHandler(func() string { return state }, provider, + rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"), + rp.WithURLParam("custom", "param"), + )(capturedW, get) defer func() { if t.Failed() { @@ -84,6 +87,8 @@ func TestRelyingPartySession(t *testing.T) { }() require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code") require.Less(t, capturedW.Code, 400, "captured response code") + require.Contains(t, capturedW.Body.String(), `prompt=Hello%2C+World%21+Goodbye%2C+World%21`) + require.Contains(t, capturedW.Body.String(), `custom=param`) //nolint:bodyclose resp := capturedW.Result() @@ -140,7 +145,7 @@ func TestRelyingPartySession(t *testing.T) { email = info.GetEmail() http.Redirect(w, r, targetURL, 302) } - rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get) + rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get) defer func() { if t.Failed() { @@ -150,6 +155,7 @@ func TestRelyingPartySession(t *testing.T) { }() require.Less(t, capturedW.Code, 400, "token exchange response code") require.Less(t, capturedW.Code, 400, "token exchange response code") + // TODO: how to check the custom header was sent to the server? //nolint:bodyclose resp = capturedW.Result() @@ -193,6 +199,13 @@ func TestRelyingPartySession(t *testing.T) { _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with replacement") } + + t.Run("WithPrompt", func(t *testing.T) { + opts := rp.WithPrompt("foo", "bar")() + url := provider.OAuthConfig().AuthCodeURL("some", opts...) + + require.Contains(t, url, "prompt=foo+bar") + }) } type deferredHandler struct { diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 86b65daf..37586010 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -255,7 +255,7 @@ func WithVerifierOpts(opts ...VerifierOption) Option { // WithClientKey specifies the path to the key.json to be used for the JWT Profile Client Authentication on the token endpoint // -//deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead +// deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead func WithClientKey(path string) Option { return WithJWTProfile(SignerFromKeyPath(path)) } @@ -304,7 +304,7 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey { // Discover calls the discovery endpoint of the provided issuer and returns the found endpoints // -//deprecated: use client.Discover +// deprecated: use client.Discover func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint req, err := http.NewRequest("GET", wellKnown, nil) @@ -323,7 +323,7 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { } // AuthURL returns the auth request url -//(wrapping the oauth2 `AuthCodeURL`) +// (wrapping the oauth2 `AuthCodeURL`) func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { authOpts := make([]oauth2.AuthCodeOption, 0) for _, opt := range opts { @@ -333,10 +333,15 @@ func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { } // AuthURLHandler extends the `AuthURL` method with a http redirect handler -// including handling setting cookie for secure `state` transfer -func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc { +// including handling setting cookie for secure `state` transfer. +// Custom paramaters can optionally be set to the redirect URL. +func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - opts := make([]AuthURLOpt, 0) + opts := make([]AuthURLOpt, len(urlParam)) + for i, p := range urlParam { + opts[i] = AuthURLOpt(p) + } + state := stateFn() if err := trySetStateCookie(w, state, rp); err != nil { http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) @@ -350,6 +355,7 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc { } opts = append(opts, WithCodeChallenge(codeChallenge)) } + http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound) } } @@ -398,8 +404,9 @@ type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *o // CodeExchangeHandler extends the `CodeExchange` method with a http handler // including cookie handling for secure `state` transfer -// and optional PKCE code verifier checking -func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc { +// and optional PKCE code verifier checking. +// Custom paramaters can optionally be set to the token URL. +func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { state, err := tryReadStateCookie(w, r, rp) if err != nil { @@ -411,7 +418,11 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state) return } - codeOpts := make([]CodeExchangeOpt, 0) + codeOpts := make([]CodeExchangeOpt, len(urlParam)) + for i, p := range urlParam { + codeOpts[i] = CodeExchangeOpt(p) + } + if rp.IsPKCE() { codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode) if err != nil { @@ -517,6 +528,37 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { } } +// withURLParam sets custom url paramaters. +// This is the generalized, unexported, function used by both +// URLParamOpt and AuthURLOpt. +func withURLParam(key, value string) func() []oauth2.AuthCodeOption { + return func() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam(key, value), + } + } +} + +// withPrompt sets the `prompt` params in the auth request +// This is the generalized, unexported, function used by both +// URLParamOpt and AuthURLOpt. +func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption { + return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()) +} + +type URLParamOpt func() []oauth2.AuthCodeOption + +// WithURLParam allows setting custom key-vale pairs +// to an OAuth2 URL. +func WithURLParam(key, value string) URLParamOpt { + return withURLParam(key, value) +} + +// WithPromptURLParam sets the `prompt` parameter in a URL. +func WithPromptURLParam(prompt ...string) URLParamOpt { + return withPrompt(prompt...) +} + type AuthURLOpt func() []oauth2.AuthCodeOption // WithCodeChallenge sets the `code_challenge` params in the auth request @@ -531,11 +573,7 @@ func WithCodeChallenge(codeChallenge string) AuthURLOpt { // WithPrompt sets the `prompt` params in the auth request func WithPrompt(prompt ...string) AuthURLOpt { - return func() []oauth2.AuthCodeOption { - return []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()), - } - } + return withPrompt(prompt...) } type CodeExchangeOpt func() []oauth2.AuthCodeOption