From 02d7791617df573d817897488fdee64ae474cdbd Mon Sep 17 00:00:00 2001 From: Suleiman Dibirov Date: Tue, 15 Sep 2020 18:47:30 +0300 Subject: [PATCH 1/9] Add oidc pkce support --- README.md | 1 + go.mod | 1 + go.sum | 2 ++ internal/provider/oidc.go | 23 +++++++++++++++++++---- internal/provider/oidc_test.go | 28 ++++++++++++++++++++++++++++ internal/provider/providers.go | 8 ++++---- 6 files changed, 55 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 54cb5798..ad39c608 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,7 @@ OIDC Provider: --providers.oidc.client-id= Client ID [$PROVIDERS_OIDC_CLIENT_ID] --providers.oidc.client-secret= Client Secret [$PROVIDERS_OIDC_CLIENT_SECRET] --providers.oidc.resource= Optional resource indicator [$PROVIDERS_OIDC_RESOURCE] + --providers.oidc.pkce_required= Optional pkce required indicator [$PROVIDERS_OIDC_PKCE_REQUIRED] Generic OAuth2 Provider: --providers.generic-oauth.auth-url= Auth/Login URL [$PROVIDERS_GENERIC_OAUTH_AUTH_URL] diff --git a/go.mod b/go.mod index 2c6eb8a8..fb7d08b3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.13 require ( github.com/containous/traefik/v2 v2.1.2 github.com/coreos/go-oidc v2.1.0+incompatible + github.com/nirasan/go-oauth-pkce-code-verifier v0.0.0-20170819232839-0fbfe93532da github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index fd583712..e4531f57 100644 --- a/go.sum +++ b/go.sum @@ -281,6 +281,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= +github.com/nirasan/go-oauth-pkce-code-verifier v0.0.0-20170819232839-0fbfe93532da h1:qiPWuGGr+1GQE6s9NPSK8iggR/6x/V+0snIoOPYsBgc= +github.com/nirasan/go-oauth-pkce-code-verifier v0.0.0-20170819232839-0fbfe93532da/go.mod h1:DvuJJ/w1Y59rG8UTDxsMk5U+UJXJwuvUgbiJSm9yhX8= github.com/nrdcg/auroradns v1.0.0/go.mod h1:6JPXKzIRzZzMqtTDgueIhTi6rFf1QvYE/HzqidhOhjw= github.com/nrdcg/goinwx v0.6.1/go.mod h1:XPiut7enlbEdntAqalBIqcYcTEVhpv/dKWgDCX2SwKQ= github.com/nrdcg/namesilo v0.2.1/go.mod h1:lwMvfQTyYq+BbjJd30ylEG4GPSS6PII0Tia4rRpRiyw= diff --git a/internal/provider/oidc.go b/internal/provider/oidc.go index 5e17a580..f403be1f 100644 --- a/internal/provider/oidc.go +++ b/internal/provider/oidc.go @@ -3,6 +3,7 @@ package provider import ( "context" "errors" + pkce "github.com/nirasan/go-oauth-pkce-code-verifier" "github.com/coreos/go-oidc" "golang.org/x/oauth2" @@ -13,11 +14,13 @@ type OIDC struct { IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"` ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"` ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"` + PkceRequired bool `long:"pkce-required" env:"PKCE_REQUIRED" description:"Optional pkce required indicator"` OAuthProvider - provider *oidc.Provider - verifier *oidc.IDTokenVerifier + provider *oidc.Provider + verifier *oidc.IDTokenVerifier + pkceVerifier *pkce.CodeVerifier } // Name returns the name of the provider @@ -61,12 +64,24 @@ func (o *OIDC) Setup() error { // GetLoginURL provides the login url for the given redirect uri and state func (o *OIDC) GetLoginURL(redirectURI, state string) string { - return o.OAuthGetLoginURL(redirectURI, state) + var opts []oauth2.AuthCodeOption + if o.PkceRequired { + o.pkceVerifier, _ = pkce.CreateCodeVerifier() + opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", "S256")) + opts = append(opts, oauth2.SetAuthURLParam("code_challenge", o.pkceVerifier.CodeChallengeS256())) + } + return o.OAuthGetLoginURL(redirectURI, state, opts...) } // ExchangeCode exchanges the given redirect uri and code for a token func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) { - token, err := o.OAuthExchangeCode(redirectURI, code) + var opts []oauth2.AuthCodeOption + + if o.PkceRequired { + opts = append(opts, oauth2.SetAuthURLParam("code_verifier", o.pkceVerifier.String())) + } + + token, err := o.OAuthExchangeCode(redirectURI, code, opts...) if err != nil { return "", err } diff --git a/internal/provider/oidc_test.go b/internal/provider/oidc_test.go index d514d37c..d7044c54 100644 --- a/internal/provider/oidc_test.go +++ b/internal/provider/oidc_test.go @@ -60,6 +60,34 @@ func TestOIDCGetLoginURL(t *testing.T) { // Calling the method should not modify the underlying config assert.Equal("", provider.Config.RedirectURL) + // + // Test with PkceRequired config option + // + provider.PkceRequired = true + + // Check url + uri, err = url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state")) + assert.Nil(err) + assert.Equal(serverURL.Scheme, uri.Scheme) + assert.Equal(serverURL.Host, uri.Host) + assert.Equal("/auth", uri.Path) + + // Check query string + qs = uri.Query() + expectedQs = url.Values{ + "client_id": []string{"idtest"}, + "code_challenge": []string{provider.pkceVerifier.CodeChallengeS256()}, + "code_challenge_method": []string{"S256"}, + "redirect_uri": []string{"http://example.com/_oauth"}, + "response_type": []string{"code"}, + "scope": []string{"openid profile email"}, + "state": []string{"state"}, + } + assert.Equal(expectedQs, qs) + + // Calling the method should not modify the underlying config + assert.Equal("", provider.Config.RedirectURL) + // // Test with resource config option // diff --git a/internal/provider/providers.go b/internal/provider/providers.go index ac863df3..3e0d6bd1 100644 --- a/internal/provider/providers.go +++ b/internal/provider/providers.go @@ -49,18 +49,18 @@ func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config { } // OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2 -func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string) string { +func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string, opts ...oauth2.AuthCodeOption) string { config := p.ConfigCopy(redirectURI) if p.Resource != "" { return config.AuthCodeURL(state, oauth2.SetAuthURLParam("resource", p.Resource)) } - return config.AuthCodeURL(state) + return config.AuthCodeURL(state, opts...) } // OAuthExchangeCode provides a base "ExchangeCode" for proiders using OAauth2 -func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string) (*oauth2.Token, error) { +func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { config := p.ConfigCopy(redirectURI) - return config.Exchange(p.ctx, code) + return config.Exchange(p.ctx, code, opts...) } From e9d57fb7f4a621dcc95dfa4026a03585b4750aee Mon Sep 17 00:00:00 2001 From: Suleiman Dibirov Date: Wed, 16 Sep 2020 10:35:24 +0300 Subject: [PATCH 2/9] Remove oidc client secret params if pkce required --- internal/provider/oidc.go | 10 +++++++--- internal/provider/oidc_test.go | 10 +++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/internal/provider/oidc.go b/internal/provider/oidc.go index f403be1f..f1719c80 100644 --- a/internal/provider/oidc.go +++ b/internal/provider/oidc.go @@ -30,9 +30,13 @@ func (o *OIDC) Name() string { // Setup performs validation and setup func (o *OIDC) Setup() error { - // Check parms - if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" { - return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set") + // Check params + if o.IssuerURL == "" || o.ClientID == "" { + return errors.New("providers.oidc.issuer-url, providers.oidc.client-id must be set") + } + + if o.ClientSecret == "" && !o.PkceRequired { + return errors.New("providers.oidc.client-secret must be set if pkce not required") } var err error diff --git a/internal/provider/oidc_test.go b/internal/provider/oidc_test.go index d7044c54..2d7e14ad 100644 --- a/internal/provider/oidc_test.go +++ b/internal/provider/oidc_test.go @@ -29,7 +29,15 @@ func TestOIDCSetup(t *testing.T) { err := p.Setup() if assert.Error(err) { - assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error()) + assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id must be set", err.Error()) + } + + p.IssuerURL = "test" + p.ClientID = "test" + + err = p.Setup() + if assert.Error(err) { + assert.Equal("providers.oidc.client-secret must be set if pkce not required", err.Error()) } } From 37aa42686444cee3aed2b5236734981070196537 Mon Sep 17 00:00:00 2001 From: Suleiman Dibirov Date: Thu, 17 Sep 2020 15:42:19 +0300 Subject: [PATCH 3/9] Add oidc checkParams func --- internal/provider/oidc.go | 31 +++++++++++++++++++++++++------ internal/provider/oidc_test.go | 14 ++++++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/internal/provider/oidc.go b/internal/provider/oidc.go index f1719c80..26f3ff54 100644 --- a/internal/provider/oidc.go +++ b/internal/provider/oidc.go @@ -4,6 +4,7 @@ import ( "context" "errors" pkce "github.com/nirasan/go-oauth-pkce-code-verifier" + "strings" "github.com/coreos/go-oidc" "golang.org/x/oauth2" @@ -31,12 +32,8 @@ func (o *OIDC) Name() string { // Setup performs validation and setup func (o *OIDC) Setup() error { // Check params - if o.IssuerURL == "" || o.ClientID == "" { - return errors.New("providers.oidc.issuer-url, providers.oidc.client-id must be set") - } - - if o.ClientSecret == "" && !o.PkceRequired { - return errors.New("providers.oidc.client-secret must be set if pkce not required") + if err := o.checkParams(); err != nil { + return err } var err error @@ -116,3 +113,25 @@ func (o *OIDC) GetUser(token string) (User, error) { return user, nil } + +func (o *OIDC) checkParams() error { + if o.IssuerURL == "" || o.ClientID == "" || (o.ClientSecret == "" && !o.PkceRequired) { + var emptyFields []string + + if o.IssuerURL == "" { + emptyFields = append(emptyFields, "providers.oidc.issuer-url") + } + + if o.ClientID == "" { + emptyFields = append(emptyFields, "providers.oidc.client-id") + } + + if o.ClientSecret == "" && !o.PkceRequired { + emptyFields = append(emptyFields, "providers.oidc.client-secret") + } + + return errors.New(strings.Join(emptyFields, ", ") + " must be set") + } + + return nil +} diff --git a/internal/provider/oidc_test.go b/internal/provider/oidc_test.go index 2d7e14ad..5c9dceaa 100644 --- a/internal/provider/oidc_test.go +++ b/internal/provider/oidc_test.go @@ -29,15 +29,21 @@ func TestOIDCSetup(t *testing.T) { err := p.Setup() if assert.Error(err) { - assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id must be set", err.Error()) + assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error()) } - p.IssuerURL = "test" - p.ClientID = "test" + p.IssuerURL = "url" err = p.Setup() if assert.Error(err) { - assert.Equal("providers.oidc.client-secret must be set if pkce not required", err.Error()) + assert.Equal("providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error()) + } + + p.ClientID = "id" + + err = p.Setup() + if assert.Error(err) { + assert.Equal("providers.oidc.client-secret must be set", err.Error()) } } From 18c28ba4106939ce9e8fcc9d5c25442246f23651 Mon Sep 17 00:00:00 2001 From: Suleiman Dibirov Date: Thu, 17 Sep 2020 16:08:08 +0300 Subject: [PATCH 4/9] Add pkce verifier.go and remove external dependency --- go.mod | 1 - go.sum | 2 -- internal/pkce/verifier.go | 50 +++++++++++++++++++++++++++++++++++++++ internal/provider/oidc.go | 4 ++-- 4 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 internal/pkce/verifier.go diff --git a/go.mod b/go.mod index fb7d08b3..2c6eb8a8 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.13 require ( github.com/containous/traefik/v2 v2.1.2 github.com/coreos/go-oidc v2.1.0+incompatible - github.com/nirasan/go-oauth-pkce-code-verifier v0.0.0-20170819232839-0fbfe93532da github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index e4531f57..fd583712 100644 --- a/go.sum +++ b/go.sum @@ -281,8 +281,6 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= -github.com/nirasan/go-oauth-pkce-code-verifier v0.0.0-20170819232839-0fbfe93532da h1:qiPWuGGr+1GQE6s9NPSK8iggR/6x/V+0snIoOPYsBgc= -github.com/nirasan/go-oauth-pkce-code-verifier v0.0.0-20170819232839-0fbfe93532da/go.mod h1:DvuJJ/w1Y59rG8UTDxsMk5U+UJXJwuvUgbiJSm9yhX8= github.com/nrdcg/auroradns v1.0.0/go.mod h1:6JPXKzIRzZzMqtTDgueIhTi6rFf1QvYE/HzqidhOhjw= github.com/nrdcg/goinwx v0.6.1/go.mod h1:XPiut7enlbEdntAqalBIqcYcTEVhpv/dKWgDCX2SwKQ= github.com/nrdcg/namesilo v0.2.1/go.mod h1:lwMvfQTyYq+BbjJd30ylEG4GPSS6PII0Tia4rRpRiyw= diff --git a/internal/pkce/verifier.go b/internal/pkce/verifier.go new file mode 100644 index 00000000..97c7efb0 --- /dev/null +++ b/internal/pkce/verifier.go @@ -0,0 +1,50 @@ +package pkce + +import ( + "crypto/sha256" + "encoding/base64" + "math/rand" + "strings" + "time" +) + +type CodeVerifier struct { + Value string +} + +func CreateCodeVerifier() *CodeVerifier { + return &CodeVerifier{ + Value: encode([]byte(randomString(32))), + } +} + +func (v *CodeVerifier) String() string { + return v.Value +} + +func (v *CodeVerifier) CodeChallengeS256() string { + h := sha256.New() + h.Write([]byte(v.Value)) + + return encode(h.Sum(nil)) +} + +func randomString(n int) string { + var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + rand.Seed(time.Now().UnixNano()) + + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + + return string(b) +} + +func encode(msg []byte) string { + encoded := base64.StdEncoding.EncodeToString(msg) + encoded = strings.Replace(encoded, "+", "-", -1) + encoded = strings.Replace(encoded, "/", "_", -1) + encoded = strings.Replace(encoded, "=", "", -1) + return encoded +} diff --git a/internal/provider/oidc.go b/internal/provider/oidc.go index 26f3ff54..480a8aab 100644 --- a/internal/provider/oidc.go +++ b/internal/provider/oidc.go @@ -3,7 +3,7 @@ package provider import ( "context" "errors" - pkce "github.com/nirasan/go-oauth-pkce-code-verifier" + "github.com/thomseddon/traefik-forward-auth/internal/pkce" "strings" "github.com/coreos/go-oidc" @@ -67,7 +67,7 @@ func (o *OIDC) Setup() error { func (o *OIDC) GetLoginURL(redirectURI, state string) string { var opts []oauth2.AuthCodeOption if o.PkceRequired { - o.pkceVerifier, _ = pkce.CreateCodeVerifier() + o.pkceVerifier = pkce.CreateCodeVerifier() opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", "S256")) opts = append(opts, oauth2.SetAuthURLParam("code_challenge", o.pkceVerifier.CodeChallengeS256())) } From 41560feaa75b1a37094d90dbd0628809c63121c4 Mon Sep 17 00:00:00 2001 From: Thom Seddon Date: Wed, 23 Sep 2020 14:48:04 +0100 Subject: [PATCH 5/9] Support concurrent CSRF cookies by using a prefix of nonce (#187) * Support concurrent CSRF cookies by using a prefix of nonce. * Move ValidateState out and make CSRF cookies last 1h * add tests to check csrf cookie nam + minor tweaks Co-authored-by: Michal Witkowski --- internal/auth.go | 38 ++++++++++++++++------- internal/auth_test.go | 67 +++++++++++++++++++++-------------------- internal/server.go | 18 ++++++++--- internal/server_test.go | 2 +- 4 files changed, 76 insertions(+), 49 deletions(-) diff --git a/internal/auth.go b/internal/auth.go index cefd2315..64a2dc10 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -170,23 +170,31 @@ func ClearCookie(r *http.Request) *http.Cookie { } } +func buildCSRFCookieName(nonce string) string { + return config.CSRFCookieName + "_" + nonce[:6] +} + // MakeCSRFCookie makes a csrf cookie (used during login only) +// +// Note, CSRF cookies live shorter than auth cookies, a fixed 1h. +// That's because some CSRF cookies may belong to auth flows that don't complete +// and thus may not get cleared by ClearCookie. func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: buildCSRFCookieName(nonce), Value: nonce, Path: "/", Domain: csrfCookieDomain(r), HttpOnly: true, Secure: !config.InsecureCookie, - Expires: cookieExpiry(), + Expires: time.Now().Local().Add(time.Hour * 1), } } // ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie -func ClearCSRFCookie(r *http.Request) *http.Cookie { +func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: c.Name, Value: "", Path: "/", Domain: csrfCookieDomain(r), @@ -196,18 +204,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie { } } -// ValidateCSRFCookie validates the csrf cookie against state -func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) { - state := r.URL.Query().Get("state") +// FindCSRFCookie extracts the CSRF cookie from the request based on state. +func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) { + // Check for CSRF cookie + return r.Cookie(buildCSRFCookieName(state)) +} +// ValidateCSRFCookie validates the csrf cookie against state +func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) { if len(c.Value) != 32 { return false, "", "", errors.New("Invalid CSRF cookie value") } - if len(state) < 34 { - return false, "", "", errors.New("Invalid CSRF state value") - } - // Check nonce match if c.Value != state[:32] { return false, "", "", errors.New("CSRF cookie does not match state") @@ -229,6 +237,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string { return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r)) } +// ValidateState checks whether the state is of right length. +func ValidateState(state string) error { + if len(state) < 34 { + return errors.New("Invalid CSRF state value") + } + return nil +} + // Nonce generates a random nonce func Nonce() (error, string) { nonce := make([]byte, 16) diff --git a/internal/auth_test.go b/internal/auth_test.go index 14ee1ce6..d013dfc5 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -1,7 +1,6 @@ package tfa import ( - "fmt" "net/http" "net/url" "strings" @@ -217,29 +216,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) { // No cookie domain or auth url c := MakeCSRFCookie(r, "12345678901234567890123456789012") + assert.Equal("_forward_auth_csrf_123456", c.Name) assert.Equal("app.example.com", c.Domain) // With cookie domain but no auth url - config = &Config{ - CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, - } - c = MakeCSRFCookie(r, "12345678901234567890123456789012") + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} + c = MakeCSRFCookie(r, "12222278901234567890123456789012") + assert.Equal("_forward_auth_csrf_122222", c.Name) assert.Equal("app.example.com", c.Domain) // With cookie domain and auth url - config = &Config{ - AuthHost: "auth.example.com", - CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, - } - c = MakeCSRFCookie(r, "12345678901234567890123456789012") + config.AuthHost = "auth.example.com" + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} + c = MakeCSRFCookie(r, "12333378901234567890123456789012") + assert.Equal("_forward_auth_csrf_123333", c.Name) assert.Equal("example.com", c.Domain) } func TestAuthClearCSRFCookie(t *testing.T) { + assert := assert.New(t) config, _ = NewConfig([]string{}) r, _ := http.NewRequest("GET", "http://example.com", nil) - c := ClearCSRFCookie(r) + c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"}) + assert.Equal("someCsrfCookie", c.Name) if c.Value != "" { t.Error("ClearCSRFCookie should create cookie with empty value") } @@ -249,56 +249,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) c := &http.Cookie{} - - newCsrfRequest := func(state string) *http.Request { - u := fmt.Sprintf("http://example.com?state=%s", state) - r, _ := http.NewRequest("GET", u, nil) - return r - } + state := "" // Should require 32 char string - r := newCsrfRequest("") + state = "" c.Value = "" - valid, _, _, err := ValidateCSRFCookie(r, c) + valid, _, _, err := ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } c.Value = "123456789012345678901234567890123" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } - // Should require valid state - r = newCsrfRequest("12345678901234567890123456789012:") - c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) - assert.False(valid) - if assert.Error(err) { - assert.Equal("Invalid CSRF state value", err.Error()) - } - // Should require provider - r = newCsrfRequest("12345678901234567890123456789012:99") + state = "12345678901234567890123456789012:99" c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF state format", err.Error()) } // Should allow valid state - r = newCsrfRequest("12345678901234567890123456789012:p99:url123") + state = "12345678901234567890123456789012:p99:url123" c.Value = "12345678901234567890123456789012" - valid, provider, redirect, err := ValidateCSRFCookie(r, c) + valid, provider, redirect, err := ValidateCSRFCookie(c, state) assert.True(valid, "valid request should return valid") assert.Nil(err, "valid request should not return an error") assert.Equal("p99", provider, "valid request should return correct provider") assert.Equal("url123", redirect, "valid request should return correct redirect") } +func TestValidateState(t *testing.T) { + assert := assert.New(t) + + // Should require valid state + state := "12345678901234567890123456789012:" + err := ValidateState(state) + if assert.Error(err) { + assert.Equal("Invalid CSRF state value", err.Error()) + } + // Should pass this state + state = "12345678901234567890123456789012:p99:url123" + err = ValidateState(state) + assert.Nil(err, "valid request should not return an error") +} + func TestMakeState(t *testing.T) { assert := assert.New(t) diff --git a/internal/server.go b/internal/server.go index 8ac03131..ad7e8301 100644 --- a/internal/server.go +++ b/internal/server.go @@ -121,16 +121,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Logging setup logger := s.logger(r, "AuthCallback", "default", "Handling callback") + // Check state + state := r.URL.Query().Get("state") + if err := ValidateState(state); err != nil { + logger.WithFields(logrus.Fields{ + "error": err, + }).Warn("Error validating state") + http.Error(w, "Not authorized", 401) + return + } + // Check for CSRF cookie - c, err := r.Cookie(config.CSRFCookieName) + c, err := FindCSRFCookie(r, state) if err != nil { logger.Info("Missing csrf cookie") http.Error(w, "Not authorized", 401) return } - // Validate state - valid, providerName, redirect, err := ValidateCSRFCookie(r, c) + // Validate CSRF cookie against state + valid, providerName, redirect, err := ValidateCSRFCookie(c, state) if !valid { logger.WithFields(logrus.Fields{ "error": err, @@ -153,7 +163,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Clear CSRF cookie - http.SetCookie(w, ClearCSRFCookie(r)) + http.SetCookie(w, ClearCSRFCookie(r, c)) // Exchange code for token token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code")) diff --git a/internal/server_test.go b/internal/server_test.go index 2e543400..8ec0f01d 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) { // Check for CSRF cookie var cookie *http.Cookie for _, c := range res.Cookies() { - if c.Name == config.CSRFCookieName { + if strings.HasPrefix(c.Name, config.CSRFCookieName) { cookie = c } } From 04f5499f0bb38c7dffbda8e244cb125e9ac09e36 Mon Sep 17 00:00:00 2001 From: Thom Seddon Date: Wed, 23 Sep 2020 14:50:15 +0100 Subject: [PATCH 6/9] Allow override of domains and whitelist in rules (#169) Co-authored-by: Mathieu Cantin Co-authored-by: Pete Shaw --- README.md | 9 ++- internal/auth.go | 58 ++++++++++++++------ internal/auth_test.go | 125 +++++++++++++++++++++++++++++++++--------- internal/config.go | 16 +++++- internal/server.go | 2 +- 5 files changed, 164 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index 54cb5798..b0c5444b 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,7 @@ All options can be supplied in any of the following ways, in the following prece - `action` - same usage as [`default-action`](#default-action), supported values: - `auth` (default) - `allow` + - `domains` - optional, same usage as [`domain`](#domain) - `provider` - same usage as [`default-provider`](#default-provider), supported values: - `google` - `oidc` @@ -333,6 +334,7 @@ All options can be supplied in any of the following ways, in the following prece - ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)`` - ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)`` - ``Query(`foo=bar`, `bar=baz`)`` + - `whitelist` - optional, same usage as whitelist`](#whitelist) For example: ``` @@ -348,6 +350,11 @@ All options can be supplied in any of the following ways, in the following prece rule.oidc.action = auth rule.oidc.provider = oidc rule.oidc.rule = PathPrefix(`/github`) + + # Allow jane@example.com to `/janes-eyes-only` + rule.two.action = allow + rule.two.rule = Path(`/janes-eyes-only`) + rule.two.whitelist = jane@example.com ``` Note: It is possible to break your redirect flow with rules, please be careful not to create an `allow` rule that matches your redirect_uri unless you know what you're doing. This limitation is being tracked in in #101 and the behaviour will change in future releases. @@ -361,7 +368,7 @@ You can restrict who can login with the following parameters: * `domain` - Use this to limit logins to a specific domain, e.g. test.com only * `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only -Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). +Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rule, the global configuration is ignored. ### Forwarded Headers diff --git a/internal/auth.go b/internal/auth.go index 64a2dc10..0b0a9676 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -59,18 +59,28 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { // ValidateEmail checks if the given email address matches either a whitelisted // email address, as defined by the "whitelist" config parameter. Or is part of // a permitted domain, as defined by the "domains" config parameter -func ValidateEmail(email string) bool { +func ValidateEmail(email, ruleName string) bool { + // Use global config by default + whitelist := config.Whitelist + domains := config.Domains + + if rule, ok := config.Rules[ruleName]; ok { + // Override with rule config if found + if len(rule.Whitelist) > 0 || len(rule.Domains) > 0 { + whitelist = rule.Whitelist + domains = rule.Domains + } + } + // Do we have any validation to perform? - if len(config.Whitelist) == 0 && len(config.Domains) == 0 { + if len(whitelist) == 0 && len(domains) == 0 { return true } // Email whitelist validation - if len(config.Whitelist) > 0 { - for _, whitelist := range config.Whitelist { - if email == whitelist { - return true - } + if len(whitelist) > 0 { + if ValidateWhitelist(email, whitelist) { + return true } // If we're not matching *either*, stop here @@ -80,18 +90,34 @@ func ValidateEmail(email string) bool { } // Domain validation - if len(config.Domains) > 0 { - parts := strings.Split(email, "@") - if len(parts) < 2 { - return false - } - for _, domain := range config.Domains { - if domain == parts[1] { - return true - } + if len(domains) > 0 && ValidateDomains(email, domains) { + return true + } + + return false +} + +// ValidateWhitelist checks if the email is in whitelist +func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool { + for _, whitelist := range whitelist { + if email == whitelist { + return true } } + return false +} +// ValidateDomains checks if the email matches a whitelisted domain +func ValidateDomains(email string, domains CommaSeparatedList) bool { + parts := strings.Split(email, "@") + if len(parts) < 2 { + return false + } + for _, domain := range domains { + if domain == parts[1] { + return true + } + } return false } diff --git a/internal/auth_test.go b/internal/auth_test.go index d013dfc5..5b0bedaf 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -65,32 +65,25 @@ func TestAuthValidateEmail(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) - // Should allow any - v := ValidateEmail("test@test.com") + // Should allow any with no whitelist/domain is specified + v := ValidateEmail("test@test.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - // Should block non matching domain - config.Domains = []string{"test.com"} - v = ValidateEmail("one@two.com") - assert.False(v, "should not allow user from another domain") - // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("one@two.com", "default") + assert.False(v, "should not allow user from another domain") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user from allowed domain") - // Should block non whitelisted email address - config.Domains = []string{} - config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("one@two.com") - assert.False(v, "should not allow user not in whitelist") - // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("one@two.com", "default") + assert.False(v, "should not allow user not in whitelist") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user in whitelist") // Should allow only matching email address when @@ -98,24 +91,106 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("test@test.com") - assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com") - assert.False(v, "should not allow user from valid domain") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@example.com", "default") + assert.False(v, "should not allow user from allowed domain") + v = ValidateEmail("test@test.com", "default") + assert.True(v, "should allow user in whitelist") // Should allow either matching domain or email address when // MatchWhitelistOrDomain is enabled config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = true - v = ValidateEmail("test@test.com") + v = ValidateEmail("one@two.com", "default") + assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@example.com", "default") + assert.True(v, "should allow user from allowed domain") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com") - assert.True(v, "should allow user from valid domain") - v = ValidateEmail("one@two.com") + + // Rule testing + + // Should use global whitelist/domain when not specified on rule + config.Domains = []string{"example.com"} + config.Whitelist = []string{"test@test.com"} + config.Rules = map[string]*Rule{"test": NewRule()} + config.MatchWhitelistOrDomain = true + v = ValidateEmail("one@two.com", "test") assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@example.com", "test") + assert.True(v, "should allow user from allowed global domain") + v = ValidateEmail("test@test.com", "test") + assert.True(v, "should allow user in global whitelist") + + // Should allow matching domain in rule + config.Domains = []string{"testglobal.com"} + config.Whitelist = []string{} + rule := NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user from another domain") + v = ValidateEmail("one@testglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user from allowed domain") + + // Should allow matching whitelist in rule + config.Domains = []string{} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user from another domain") + v = ValidateEmail("test@testglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user from allowed domain") + + // Should allow only matching email address when + // MatchWhitelistOrDomain is disabled + config.Domains = []string{"exampleglobal.com"} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"examplerule.com"} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@testglobal.com", "test") + assert.False(v, "should not allow user in global whitelist") + v = ValidateEmail("test@exampleglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@examplerule.com", "test") + assert.False(v, "should not allow user from allowed domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user in whitelist") + + // Should allow either matching domain or email address when + // MatchWhitelistOrDomain is enabled + config.Domains = []string{"exampleglobal.com"} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"examplerule.com"} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = true + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@testglobal.com", "test") + assert.False(v, "should not allow user in global whitelist") + v = ValidateEmail("test@exampleglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@examplerule.com", "test") + assert.True(v, "should allow user from allowed domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user in whitelist") } func TestRedirectUri(t *testing.T) { diff --git a/internal/config.go b/internal/config.go index e35a732c..40546121 100644 --- a/internal/config.go +++ b/internal/config.go @@ -210,6 +210,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [ rule.Rule = val case "provider": rule.Provider = val + case "whitelist": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Whitelist = list + case "domains": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Domains = list default: return args, fmt.Errorf("invalid route param: %v", option) } @@ -327,9 +335,11 @@ func (c *Config) setupProvider(name string) error { // Rule holds defined rules type Rule struct { - Action string - Rule string - Provider string + Action string + Rule string + Provider string + Whitelist CommaSeparatedList + Domains CommaSeparatedList } // NewRule creates a new rule object diff --git a/internal/server.go b/internal/server.go index ad7e8301..4bce9635 100644 --- a/internal/server.go +++ b/internal/server.go @@ -101,7 +101,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { } // Validate user - valid := ValidateEmail(email) + valid := ValidateEmail(email, rule) if !valid { logger.WithField("email", email).Warn("Invalid email") http.Error(w, "Not authorized", 401) From c19f622fbdac74a959477f8f2b8bdc712754e1c9 Mon Sep 17 00:00:00 2001 From: Thom Seddon Date: Thu, 1 Oct 2020 09:29:36 +0100 Subject: [PATCH 7/9] Create codeql-analysis.yml --- .github/workflows/codeql-analysis.yml | 71 +++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 .github/workflows/codeql-analysis.yml diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..f8786ef5 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,71 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +name: "CodeQL" + +on: + push: + branches: [master] + pull_request: + # The branches below must be a subset of the branches above + branches: [master] + schedule: + - cron: '0 10 * * 2' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + # Override automatic language detection by changing the below list + # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python'] + language: ['go'] + # Learn more... + # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + with: + # We must fetch at least the immediate parents so that if this is + # a pull request then we can checkout the head. + fetch-depth: 2 + + # If this run was triggered by a pull request event, then checkout + # the head of the pull request instead of the merge commit. + - run: git checkout HEAD^2 + if: ${{ github.event_name == 'pull_request' }} + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏ī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 From 1fcc0b62d2dccd2eed3ef18fd6d8ced661d9cf4c Mon Sep 17 00:00:00 2001 From: Suleiman Dibirov Date: Sat, 8 Jun 2024 17:19:43 +0300 Subject: [PATCH 8/9] Updated implementation --- internal/pkce/verifier.go | 40 ++++++++++++++++++---------------- internal/provider/oidc.go | 31 ++++++++++++++++++++++---- internal/provider/providers.go | 2 +- internal/server.go | 9 +++++++- 4 files changed, 57 insertions(+), 25 deletions(-) diff --git a/internal/pkce/verifier.go b/internal/pkce/verifier.go index 97c7efb0..4bf8d98d 100644 --- a/internal/pkce/verifier.go +++ b/internal/pkce/verifier.go @@ -1,21 +1,25 @@ package pkce import ( + "crypto/rand" "crypto/sha256" "encoding/base64" - "math/rand" - "strings" - "time" + "fmt" + "io" ) type CodeVerifier struct { Value string } -func CreateCodeVerifier() *CodeVerifier { - return &CodeVerifier{ - Value: encode([]byte(randomString(32))), +func CreateCodeVerifier() (*CodeVerifier, error) { + secureRandomString, err := generateSecureRandomString(32) + if err != nil { + return nil, err } + return &CodeVerifier{ + Value: secureRandomString, + }, nil } func (v *CodeVerifier) String() string { @@ -25,26 +29,24 @@ func (v *CodeVerifier) String() string { func (v *CodeVerifier) CodeChallengeS256() string { h := sha256.New() h.Write([]byte(v.Value)) + hash := h.Sum(nil) - return encode(h.Sum(nil)) + return encode(hash) } -func randomString(n int) string { - var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - rand.Seed(time.Now().UnixNano()) +func GenerateNonce() (string, error) { + return generateSecureRandomString(32) +} - b := make([]rune, n) - for i := range b { - b[i] = letters[rand.Intn(len(letters))] +func generateSecureRandomString(length int) (string, error) { + bytes := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, bytes); err != nil { + return "", fmt.Errorf("failed to generate secure random string: %w", err) } - - return string(b) + return base64.RawURLEncoding.EncodeToString(bytes), nil } func encode(msg []byte) string { - encoded := base64.StdEncoding.EncodeToString(msg) - encoded = strings.Replace(encoded, "+", "-", -1) - encoded = strings.Replace(encoded, "/", "_", -1) - encoded = strings.Replace(encoded, "=", "", -1) + encoded := base64.RawURLEncoding.EncodeToString(msg) return encoded } diff --git a/internal/provider/oidc.go b/internal/provider/oidc.go index 480a8aab..147c9e34 100644 --- a/internal/provider/oidc.go +++ b/internal/provider/oidc.go @@ -3,10 +3,10 @@ package provider import ( "context" "errors" - "github.com/thomseddon/traefik-forward-auth/internal/pkce" "strings" "github.com/coreos/go-oidc" + "github.com/thomseddon/traefik-forward-auth/internal/pkce" "golang.org/x/oauth2" ) @@ -22,6 +22,7 @@ type OIDC struct { provider *oidc.Provider verifier *oidc.IDTokenVerifier pkceVerifier *pkce.CodeVerifier + nonce string } // Name returns the name of the provider @@ -64,14 +65,27 @@ func (o *OIDC) Setup() error { } // GetLoginURL provides the login url for the given redirect uri and state -func (o *OIDC) GetLoginURL(redirectURI, state string) string { +func (o *OIDC) GetLoginURL(redirectURI, state string) (string, error) { var opts []oauth2.AuthCodeOption + + // Generate and store nonce + var err error + o.nonce, err = pkce.GenerateNonce() + if err != nil { + return "", err + } + + opts = append(opts, oauth2.SetAuthURLParam("nonce", o.nonce)) + if o.PkceRequired { - o.pkceVerifier = pkce.CreateCodeVerifier() + o.pkceVerifier, err = pkce.CreateCodeVerifier() + if err != nil { + return "", err + } opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", "S256")) opts = append(opts, oauth2.SetAuthURLParam("code_challenge", o.pkceVerifier.CodeChallengeS256())) } - return o.OAuthGetLoginURL(redirectURI, state, opts...) + return o.OAuthGetLoginURL(redirectURI, state, opts...), nil } // ExchangeCode exchanges the given redirect uri and code for a token @@ -93,6 +107,15 @@ func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) { return "", errors.New("Missing id_token") } + // Verify nonce + idToken, err := o.verifier.Verify(o.ctx, rawIDToken) + if err != nil { + return "", err + } + if idToken.Nonce != o.nonce { + return "", errors.New("nonce verification failed") + } + return rawIDToken, nil } diff --git a/internal/provider/providers.go b/internal/provider/providers.go index 3e0d6bd1..d139688b 100644 --- a/internal/provider/providers.go +++ b/internal/provider/providers.go @@ -17,7 +17,7 @@ type Providers struct { // Provider is used to authenticate users type Provider interface { Name() string - GetLoginURL(redirectURI, state string) string + GetLoginURL(redirectURI, state string) (string, error) ExchangeCode(redirectURI, code string) (string, error) GetUser(token string) (User, error) Setup() error diff --git a/internal/server.go b/internal/server.go index 4bce9635..d2d38c09 100644 --- a/internal/server.go +++ b/internal/server.go @@ -231,7 +231,14 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht } // Forward them on - loginURL := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce)) + loginURL, err := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce)) + + if err != nil { + logger.WithField("error", err).Error("Get login url failed") + http.Error(w, "Service unavailable", 503) + return + } + http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect) logger.WithFields(logrus.Fields{ From bc243e04ab4a5f1d5255684a4c31834ba97f202f Mon Sep 17 00:00:00 2001 From: Suleiman Dibirov Date: Sat, 8 Jun 2024 17:27:13 +0300 Subject: [PATCH 9/9] Fixes --- internal/provider/generic_oauth.go | 4 ++-- internal/provider/generic_oauth_test.go | 3 ++- internal/provider/google.go | 4 ++-- internal/provider/google_test.go | 3 ++- internal/provider/oidc_test.go | 11 ++++++++--- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/internal/provider/generic_oauth.go b/internal/provider/generic_oauth.go index a6bba510..4a8c465f 100644 --- a/internal/provider/generic_oauth.go +++ b/internal/provider/generic_oauth.go @@ -52,8 +52,8 @@ func (o *GenericOAuth) Setup() error { } // GetLoginURL provides the login url for the given redirect uri and state -func (o *GenericOAuth) GetLoginURL(redirectURI, state string) string { - return o.OAuthGetLoginURL(redirectURI, state) +func (o *GenericOAuth) GetLoginURL(redirectURI, state string) (string, error) { + return o.OAuthGetLoginURL(redirectURI, state), nil } // ExchangeCode exchanges the given redirect uri and code for a token diff --git a/internal/provider/generic_oauth_test.go b/internal/provider/generic_oauth_test.go index 1c2f2899..cc0f3c9f 100644 --- a/internal/provider/generic_oauth_test.go +++ b/internal/provider/generic_oauth_test.go @@ -53,7 +53,8 @@ func TestGenericOAuthGetLoginURL(t *testing.T) { } // Check url - uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state")) + loginURL, _ := p.GetLoginURL("http://example.com/_oauth", "state") + uri, err := url.Parse(loginURL) assert.Nil(err) assert.Equal("https", uri.Scheme) assert.Equal("provider.com", uri.Host) diff --git a/internal/provider/google.go b/internal/provider/google.go index 1c0d6d10..20826e58 100644 --- a/internal/provider/google.go +++ b/internal/provider/google.go @@ -53,7 +53,7 @@ func (g *Google) Setup() error { } // GetLoginURL provides the login url for the given redirect uri and state -func (g *Google) GetLoginURL(redirectURI, state string) string { +func (g *Google) GetLoginURL(redirectURI, state string) (string, error) { q := url.Values{} q.Set("client_id", g.ClientID) q.Set("response_type", "code") @@ -68,7 +68,7 @@ func (g *Google) GetLoginURL(redirectURI, state string) string { u = *g.LoginURL u.RawQuery = q.Encode() - return u.String() + return u.String(), nil } // ExchangeCode exchanges the given redirect uri and code for a token diff --git a/internal/provider/google_test.go b/internal/provider/google_test.go index 64243fc0..d4fafda5 100644 --- a/internal/provider/google_test.go +++ b/internal/provider/google_test.go @@ -68,7 +68,8 @@ func TestGoogleGetLoginURL(t *testing.T) { } // Check url - uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state")) + loginUrl, _ := p.GetLoginURL("http://example.com/_oauth", "state") + uri, err := url.Parse(loginUrl) assert.Nil(err) assert.Equal("https", uri.Scheme) assert.Equal("google.com", uri.Host) diff --git a/internal/provider/oidc_test.go b/internal/provider/oidc_test.go index 5c9dceaa..80e75af1 100644 --- a/internal/provider/oidc_test.go +++ b/internal/provider/oidc_test.go @@ -54,7 +54,8 @@ func TestOIDCGetLoginURL(t *testing.T) { defer server.Close() // Check url - uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state")) + loginUrl, _ := provider.GetLoginURL("http://example.com/_oauth", "state") + uri, err := url.Parse(loginUrl) assert.Nil(err) assert.Equal(serverURL.Scheme, uri.Scheme) assert.Equal(serverURL.Host, uri.Host) @@ -68,6 +69,7 @@ func TestOIDCGetLoginURL(t *testing.T) { "response_type": []string{"code"}, "scope": []string{"openid profile email"}, "state": []string{"state"}, + "nonce": []string{provider.nonce}, } assert.Equal(expectedQs, qs) @@ -80,7 +82,8 @@ func TestOIDCGetLoginURL(t *testing.T) { provider.PkceRequired = true // Check url - uri, err = url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state")) + loginUrl, _ = provider.GetLoginURL("http://example.com/_oauth", "state") + uri, err = url.Parse(loginUrl) assert.Nil(err) assert.Equal(serverURL.Scheme, uri.Scheme) assert.Equal(serverURL.Host, uri.Host) @@ -96,6 +99,7 @@ func TestOIDCGetLoginURL(t *testing.T) { "response_type": []string{"code"}, "scope": []string{"openid profile email"}, "state": []string{"state"}, + "nonce": []string{provider.nonce}, } assert.Equal(expectedQs, qs) @@ -108,7 +112,8 @@ func TestOIDCGetLoginURL(t *testing.T) { provider.Resource = "resourcetest" // Check url - uri, err = url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state")) + loginUrl, _ = provider.GetLoginURL("http://example.com/_oauth", "state") + uri, err = url.Parse(loginUrl) assert.Nil(err) assert.Equal(serverURL.Scheme, uri.Scheme) assert.Equal(serverURL.Host, uri.Host)