diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..f8786ef5 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,71 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +name: "CodeQL" + +on: + push: + branches: [master] + pull_request: + # The branches below must be a subset of the branches above + branches: [master] + schedule: + - cron: '0 10 * * 2' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + # Override automatic language detection by changing the below list + # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python'] + language: ['go'] + # Learn more... + # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + with: + # We must fetch at least the immediate parents so that if this is + # a pull request then we can checkout the head. + fetch-depth: 2 + + # If this run was triggered by a pull request event, then checkout + # the head of the pull request instead of the merge commit. + - run: git checkout HEAD^2 + if: ${{ github.event_name == 'pull_request' }} + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏ī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/README.md b/README.md index 54cb5798..5825103b 100644 --- a/README.md +++ b/README.md @@ -174,6 +174,7 @@ OIDC Provider: --providers.oidc.client-id= Client ID [$PROVIDERS_OIDC_CLIENT_ID] --providers.oidc.client-secret= Client Secret [$PROVIDERS_OIDC_CLIENT_SECRET] --providers.oidc.resource= Optional resource indicator [$PROVIDERS_OIDC_RESOURCE] + --providers.oidc.pkce_required= Optional pkce required indicator [$PROVIDERS_OIDC_PKCE_REQUIRED] Generic OAuth2 Provider: --providers.generic-oauth.auth-url= Auth/Login URL [$PROVIDERS_GENERIC_OAUTH_AUTH_URL] @@ -321,6 +322,7 @@ All options can be supplied in any of the following ways, in the following prece - `action` - same usage as [`default-action`](#default-action), supported values: - `auth` (default) - `allow` + - `domains` - optional, same usage as [`domain`](#domain) - `provider` - same usage as [`default-provider`](#default-provider), supported values: - `google` - `oidc` @@ -333,6 +335,7 @@ All options can be supplied in any of the following ways, in the following prece - ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)`` - ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)`` - ``Query(`foo=bar`, `bar=baz`)`` + - `whitelist` - optional, same usage as whitelist`](#whitelist) For example: ``` @@ -348,6 +351,11 @@ All options can be supplied in any of the following ways, in the following prece rule.oidc.action = auth rule.oidc.provider = oidc rule.oidc.rule = PathPrefix(`/github`) + + # Allow jane@example.com to `/janes-eyes-only` + rule.two.action = allow + rule.two.rule = Path(`/janes-eyes-only`) + rule.two.whitelist = jane@example.com ``` Note: It is possible to break your redirect flow with rules, please be careful not to create an `allow` rule that matches your redirect_uri unless you know what you're doing. This limitation is being tracked in in #101 and the behaviour will change in future releases. @@ -361,7 +369,7 @@ You can restrict who can login with the following parameters: * `domain` - Use this to limit logins to a specific domain, e.g. test.com only * `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only -Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). +Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rule, the global configuration is ignored. ### Forwarded Headers diff --git a/internal/auth.go b/internal/auth.go index cefd2315..0b0a9676 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -59,18 +59,28 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { // ValidateEmail checks if the given email address matches either a whitelisted // email address, as defined by the "whitelist" config parameter. Or is part of // a permitted domain, as defined by the "domains" config parameter -func ValidateEmail(email string) bool { +func ValidateEmail(email, ruleName string) bool { + // Use global config by default + whitelist := config.Whitelist + domains := config.Domains + + if rule, ok := config.Rules[ruleName]; ok { + // Override with rule config if found + if len(rule.Whitelist) > 0 || len(rule.Domains) > 0 { + whitelist = rule.Whitelist + domains = rule.Domains + } + } + // Do we have any validation to perform? - if len(config.Whitelist) == 0 && len(config.Domains) == 0 { + if len(whitelist) == 0 && len(domains) == 0 { return true } // Email whitelist validation - if len(config.Whitelist) > 0 { - for _, whitelist := range config.Whitelist { - if email == whitelist { - return true - } + if len(whitelist) > 0 { + if ValidateWhitelist(email, whitelist) { + return true } // If we're not matching *either*, stop here @@ -80,18 +90,34 @@ func ValidateEmail(email string) bool { } // Domain validation - if len(config.Domains) > 0 { - parts := strings.Split(email, "@") - if len(parts) < 2 { - return false - } - for _, domain := range config.Domains { - if domain == parts[1] { - return true - } + if len(domains) > 0 && ValidateDomains(email, domains) { + return true + } + + return false +} + +// ValidateWhitelist checks if the email is in whitelist +func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool { + for _, whitelist := range whitelist { + if email == whitelist { + return true } } + return false +} +// ValidateDomains checks if the email matches a whitelisted domain +func ValidateDomains(email string, domains CommaSeparatedList) bool { + parts := strings.Split(email, "@") + if len(parts) < 2 { + return false + } + for _, domain := range domains { + if domain == parts[1] { + return true + } + } return false } @@ -170,23 +196,31 @@ func ClearCookie(r *http.Request) *http.Cookie { } } +func buildCSRFCookieName(nonce string) string { + return config.CSRFCookieName + "_" + nonce[:6] +} + // MakeCSRFCookie makes a csrf cookie (used during login only) +// +// Note, CSRF cookies live shorter than auth cookies, a fixed 1h. +// That's because some CSRF cookies may belong to auth flows that don't complete +// and thus may not get cleared by ClearCookie. func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: buildCSRFCookieName(nonce), Value: nonce, Path: "/", Domain: csrfCookieDomain(r), HttpOnly: true, Secure: !config.InsecureCookie, - Expires: cookieExpiry(), + Expires: time.Now().Local().Add(time.Hour * 1), } } // ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie -func ClearCSRFCookie(r *http.Request) *http.Cookie { +func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: c.Name, Value: "", Path: "/", Domain: csrfCookieDomain(r), @@ -196,18 +230,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie { } } -// ValidateCSRFCookie validates the csrf cookie against state -func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) { - state := r.URL.Query().Get("state") +// FindCSRFCookie extracts the CSRF cookie from the request based on state. +func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) { + // Check for CSRF cookie + return r.Cookie(buildCSRFCookieName(state)) +} +// ValidateCSRFCookie validates the csrf cookie against state +func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) { if len(c.Value) != 32 { return false, "", "", errors.New("Invalid CSRF cookie value") } - if len(state) < 34 { - return false, "", "", errors.New("Invalid CSRF state value") - } - // Check nonce match if c.Value != state[:32] { return false, "", "", errors.New("CSRF cookie does not match state") @@ -229,6 +263,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string { return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r)) } +// ValidateState checks whether the state is of right length. +func ValidateState(state string) error { + if len(state) < 34 { + return errors.New("Invalid CSRF state value") + } + return nil +} + // Nonce generates a random nonce func Nonce() (error, string) { nonce := make([]byte, 16) diff --git a/internal/auth_test.go b/internal/auth_test.go index 14ee1ce6..5b0bedaf 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -1,7 +1,6 @@ package tfa import ( - "fmt" "net/http" "net/url" "strings" @@ -66,32 +65,25 @@ func TestAuthValidateEmail(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) - // Should allow any - v := ValidateEmail("test@test.com") + // Should allow any with no whitelist/domain is specified + v := ValidateEmail("test@test.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - // Should block non matching domain - config.Domains = []string{"test.com"} - v = ValidateEmail("one@two.com") - assert.False(v, "should not allow user from another domain") - // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("one@two.com", "default") + assert.False(v, "should not allow user from another domain") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user from allowed domain") - // Should block non whitelisted email address - config.Domains = []string{} - config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("one@two.com") - assert.False(v, "should not allow user not in whitelist") - // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("test@test.com") + v = ValidateEmail("one@two.com", "default") + assert.False(v, "should not allow user not in whitelist") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user in whitelist") // Should allow only matching email address when @@ -99,24 +91,106 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("test@test.com") - assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com") - assert.False(v, "should not allow user from valid domain") - v = ValidateEmail("one@two.com") + v = ValidateEmail("one@two.com", "default") assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@example.com", "default") + assert.False(v, "should not allow user from allowed domain") + v = ValidateEmail("test@test.com", "default") + assert.True(v, "should allow user in whitelist") // Should allow either matching domain or email address when // MatchWhitelistOrDomain is enabled config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = true - v = ValidateEmail("test@test.com") + v = ValidateEmail("one@two.com", "default") + assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@example.com", "default") + assert.True(v, "should allow user from allowed domain") + v = ValidateEmail("test@test.com", "default") assert.True(v, "should allow user in whitelist") - v = ValidateEmail("test@example.com") - assert.True(v, "should allow user from valid domain") - v = ValidateEmail("one@two.com") + + // Rule testing + + // Should use global whitelist/domain when not specified on rule + config.Domains = []string{"example.com"} + config.Whitelist = []string{"test@test.com"} + config.Rules = map[string]*Rule{"test": NewRule()} + config.MatchWhitelistOrDomain = true + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@example.com", "test") + assert.True(v, "should allow user from allowed global domain") + v = ValidateEmail("test@test.com", "test") + assert.True(v, "should allow user in global whitelist") + + // Should allow matching domain in rule + config.Domains = []string{"testglobal.com"} + config.Whitelist = []string{} + rule := NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user from another domain") + v = ValidateEmail("one@testglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user from allowed domain") + + // Should allow matching whitelist in rule + config.Domains = []string{} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user from another domain") + v = ValidateEmail("test@testglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user from allowed domain") + + // Should allow only matching email address when + // MatchWhitelistOrDomain is disabled + config.Domains = []string{"exampleglobal.com"} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"examplerule.com"} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = false + v = ValidateEmail("one@two.com", "test") assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@testglobal.com", "test") + assert.False(v, "should not allow user in global whitelist") + v = ValidateEmail("test@exampleglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@examplerule.com", "test") + assert.False(v, "should not allow user from allowed domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user in whitelist") + + // Should allow either matching domain or email address when + // MatchWhitelistOrDomain is enabled + config.Domains = []string{"exampleglobal.com"} + config.Whitelist = []string{"test@testglobal.com"} + rule = NewRule() + config.Rules = map[string]*Rule{"test": rule} + rule.Domains = []string{"examplerule.com"} + rule.Whitelist = []string{"test@testrule.com"} + config.MatchWhitelistOrDomain = true + v = ValidateEmail("one@two.com", "test") + assert.False(v, "should not allow user not in either") + v = ValidateEmail("test@testglobal.com", "test") + assert.False(v, "should not allow user in global whitelist") + v = ValidateEmail("test@exampleglobal.com", "test") + assert.False(v, "should not allow user from global domain") + v = ValidateEmail("test@examplerule.com", "test") + assert.True(v, "should allow user from allowed domain") + v = ValidateEmail("test@testrule.com", "test") + assert.True(v, "should allow user in whitelist") } func TestRedirectUri(t *testing.T) { @@ -217,29 +291,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) { // No cookie domain or auth url c := MakeCSRFCookie(r, "12345678901234567890123456789012") + assert.Equal("_forward_auth_csrf_123456", c.Name) assert.Equal("app.example.com", c.Domain) // With cookie domain but no auth url - config = &Config{ - CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, - } - c = MakeCSRFCookie(r, "12345678901234567890123456789012") + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} + c = MakeCSRFCookie(r, "12222278901234567890123456789012") + assert.Equal("_forward_auth_csrf_122222", c.Name) assert.Equal("app.example.com", c.Domain) // With cookie domain and auth url - config = &Config{ - AuthHost: "auth.example.com", - CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, - } - c = MakeCSRFCookie(r, "12345678901234567890123456789012") + config.AuthHost = "auth.example.com" + config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} + c = MakeCSRFCookie(r, "12333378901234567890123456789012") + assert.Equal("_forward_auth_csrf_123333", c.Name) assert.Equal("example.com", c.Domain) } func TestAuthClearCSRFCookie(t *testing.T) { + assert := assert.New(t) config, _ = NewConfig([]string{}) r, _ := http.NewRequest("GET", "http://example.com", nil) - c := ClearCSRFCookie(r) + c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"}) + assert.Equal("someCsrfCookie", c.Name) if c.Value != "" { t.Error("ClearCSRFCookie should create cookie with empty value") } @@ -249,56 +324,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) c := &http.Cookie{} - - newCsrfRequest := func(state string) *http.Request { - u := fmt.Sprintf("http://example.com?state=%s", state) - r, _ := http.NewRequest("GET", u, nil) - return r - } + state := "" // Should require 32 char string - r := newCsrfRequest("") + state = "" c.Value = "" - valid, _, _, err := ValidateCSRFCookie(r, c) + valid, _, _, err := ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } c.Value = "123456789012345678901234567890123" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF cookie value", err.Error()) } - // Should require valid state - r = newCsrfRequest("12345678901234567890123456789012:") - c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) - assert.False(valid) - if assert.Error(err) { - assert.Equal("Invalid CSRF state value", err.Error()) - } - // Should require provider - r = newCsrfRequest("12345678901234567890123456789012:99") + state = "12345678901234567890123456789012:99" c.Value = "12345678901234567890123456789012" - valid, _, _, err = ValidateCSRFCookie(r, c) + valid, _, _, err = ValidateCSRFCookie(c, state) assert.False(valid) if assert.Error(err) { assert.Equal("Invalid CSRF state format", err.Error()) } // Should allow valid state - r = newCsrfRequest("12345678901234567890123456789012:p99:url123") + state = "12345678901234567890123456789012:p99:url123" c.Value = "12345678901234567890123456789012" - valid, provider, redirect, err := ValidateCSRFCookie(r, c) + valid, provider, redirect, err := ValidateCSRFCookie(c, state) assert.True(valid, "valid request should return valid") assert.Nil(err, "valid request should not return an error") assert.Equal("p99", provider, "valid request should return correct provider") assert.Equal("url123", redirect, "valid request should return correct redirect") } +func TestValidateState(t *testing.T) { + assert := assert.New(t) + + // Should require valid state + state := "12345678901234567890123456789012:" + err := ValidateState(state) + if assert.Error(err) { + assert.Equal("Invalid CSRF state value", err.Error()) + } + // Should pass this state + state = "12345678901234567890123456789012:p99:url123" + err = ValidateState(state) + assert.Nil(err, "valid request should not return an error") +} + func TestMakeState(t *testing.T) { assert := assert.New(t) diff --git a/internal/config.go b/internal/config.go index e35a732c..40546121 100644 --- a/internal/config.go +++ b/internal/config.go @@ -210,6 +210,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [ rule.Rule = val case "provider": rule.Provider = val + case "whitelist": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Whitelist = list + case "domains": + list := CommaSeparatedList{} + list.UnmarshalFlag(val) + rule.Domains = list default: return args, fmt.Errorf("invalid route param: %v", option) } @@ -327,9 +335,11 @@ func (c *Config) setupProvider(name string) error { // Rule holds defined rules type Rule struct { - Action string - Rule string - Provider string + Action string + Rule string + Provider string + Whitelist CommaSeparatedList + Domains CommaSeparatedList } // NewRule creates a new rule object diff --git a/internal/pkce/verifier.go b/internal/pkce/verifier.go new file mode 100644 index 00000000..4bf8d98d --- /dev/null +++ b/internal/pkce/verifier.go @@ -0,0 +1,52 @@ +package pkce + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" +) + +type CodeVerifier struct { + Value string +} + +func CreateCodeVerifier() (*CodeVerifier, error) { + secureRandomString, err := generateSecureRandomString(32) + if err != nil { + return nil, err + } + return &CodeVerifier{ + Value: secureRandomString, + }, nil +} + +func (v *CodeVerifier) String() string { + return v.Value +} + +func (v *CodeVerifier) CodeChallengeS256() string { + h := sha256.New() + h.Write([]byte(v.Value)) + hash := h.Sum(nil) + + return encode(hash) +} + +func GenerateNonce() (string, error) { + return generateSecureRandomString(32) +} + +func generateSecureRandomString(length int) (string, error) { + bytes := make([]byte, length) + if _, err := io.ReadFull(rand.Reader, bytes); err != nil { + return "", fmt.Errorf("failed to generate secure random string: %w", err) + } + return base64.RawURLEncoding.EncodeToString(bytes), nil +} + +func encode(msg []byte) string { + encoded := base64.RawURLEncoding.EncodeToString(msg) + return encoded +} diff --git a/internal/provider/generic_oauth.go b/internal/provider/generic_oauth.go index a6bba510..4a8c465f 100644 --- a/internal/provider/generic_oauth.go +++ b/internal/provider/generic_oauth.go @@ -52,8 +52,8 @@ func (o *GenericOAuth) Setup() error { } // GetLoginURL provides the login url for the given redirect uri and state -func (o *GenericOAuth) GetLoginURL(redirectURI, state string) string { - return o.OAuthGetLoginURL(redirectURI, state) +func (o *GenericOAuth) GetLoginURL(redirectURI, state string) (string, error) { + return o.OAuthGetLoginURL(redirectURI, state), nil } // ExchangeCode exchanges the given redirect uri and code for a token diff --git a/internal/provider/generic_oauth_test.go b/internal/provider/generic_oauth_test.go index 1c2f2899..cc0f3c9f 100644 --- a/internal/provider/generic_oauth_test.go +++ b/internal/provider/generic_oauth_test.go @@ -53,7 +53,8 @@ func TestGenericOAuthGetLoginURL(t *testing.T) { } // Check url - uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state")) + loginURL, _ := p.GetLoginURL("http://example.com/_oauth", "state") + uri, err := url.Parse(loginURL) assert.Nil(err) assert.Equal("https", uri.Scheme) assert.Equal("provider.com", uri.Host) diff --git a/internal/provider/google.go b/internal/provider/google.go index 1c0d6d10..20826e58 100644 --- a/internal/provider/google.go +++ b/internal/provider/google.go @@ -53,7 +53,7 @@ func (g *Google) Setup() error { } // GetLoginURL provides the login url for the given redirect uri and state -func (g *Google) GetLoginURL(redirectURI, state string) string { +func (g *Google) GetLoginURL(redirectURI, state string) (string, error) { q := url.Values{} q.Set("client_id", g.ClientID) q.Set("response_type", "code") @@ -68,7 +68,7 @@ func (g *Google) GetLoginURL(redirectURI, state string) string { u = *g.LoginURL u.RawQuery = q.Encode() - return u.String() + return u.String(), nil } // ExchangeCode exchanges the given redirect uri and code for a token diff --git a/internal/provider/google_test.go b/internal/provider/google_test.go index 64243fc0..d4fafda5 100644 --- a/internal/provider/google_test.go +++ b/internal/provider/google_test.go @@ -68,7 +68,8 @@ func TestGoogleGetLoginURL(t *testing.T) { } // Check url - uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state")) + loginUrl, _ := p.GetLoginURL("http://example.com/_oauth", "state") + uri, err := url.Parse(loginUrl) assert.Nil(err) assert.Equal("https", uri.Scheme) assert.Equal("google.com", uri.Host) diff --git a/internal/provider/oidc.go b/internal/provider/oidc.go index 5e17a580..147c9e34 100644 --- a/internal/provider/oidc.go +++ b/internal/provider/oidc.go @@ -3,8 +3,10 @@ package provider import ( "context" "errors" + "strings" "github.com/coreos/go-oidc" + "github.com/thomseddon/traefik-forward-auth/internal/pkce" "golang.org/x/oauth2" ) @@ -13,11 +15,14 @@ type OIDC struct { IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"` ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"` ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"` + PkceRequired bool `long:"pkce-required" env:"PKCE_REQUIRED" description:"Optional pkce required indicator"` OAuthProvider - provider *oidc.Provider - verifier *oidc.IDTokenVerifier + provider *oidc.Provider + verifier *oidc.IDTokenVerifier + pkceVerifier *pkce.CodeVerifier + nonce string } // Name returns the name of the provider @@ -27,9 +32,9 @@ func (o *OIDC) Name() string { // Setup performs validation and setup func (o *OIDC) Setup() error { - // Check parms - if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" { - return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set") + // Check params + if err := o.checkParams(); err != nil { + return err } var err error @@ -60,13 +65,38 @@ func (o *OIDC) Setup() error { } // GetLoginURL provides the login url for the given redirect uri and state -func (o *OIDC) GetLoginURL(redirectURI, state string) string { - return o.OAuthGetLoginURL(redirectURI, state) +func (o *OIDC) GetLoginURL(redirectURI, state string) (string, error) { + var opts []oauth2.AuthCodeOption + + // Generate and store nonce + var err error + o.nonce, err = pkce.GenerateNonce() + if err != nil { + return "", err + } + + opts = append(opts, oauth2.SetAuthURLParam("nonce", o.nonce)) + + if o.PkceRequired { + o.pkceVerifier, err = pkce.CreateCodeVerifier() + if err != nil { + return "", err + } + opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", "S256")) + opts = append(opts, oauth2.SetAuthURLParam("code_challenge", o.pkceVerifier.CodeChallengeS256())) + } + return o.OAuthGetLoginURL(redirectURI, state, opts...), nil } // ExchangeCode exchanges the given redirect uri and code for a token func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) { - token, err := o.OAuthExchangeCode(redirectURI, code) + var opts []oauth2.AuthCodeOption + + if o.PkceRequired { + opts = append(opts, oauth2.SetAuthURLParam("code_verifier", o.pkceVerifier.String())) + } + + token, err := o.OAuthExchangeCode(redirectURI, code, opts...) if err != nil { return "", err } @@ -77,6 +107,15 @@ func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) { return "", errors.New("Missing id_token") } + // Verify nonce + idToken, err := o.verifier.Verify(o.ctx, rawIDToken) + if err != nil { + return "", err + } + if idToken.Nonce != o.nonce { + return "", errors.New("nonce verification failed") + } + return rawIDToken, nil } @@ -97,3 +136,25 @@ func (o *OIDC) GetUser(token string) (User, error) { return user, nil } + +func (o *OIDC) checkParams() error { + if o.IssuerURL == "" || o.ClientID == "" || (o.ClientSecret == "" && !o.PkceRequired) { + var emptyFields []string + + if o.IssuerURL == "" { + emptyFields = append(emptyFields, "providers.oidc.issuer-url") + } + + if o.ClientID == "" { + emptyFields = append(emptyFields, "providers.oidc.client-id") + } + + if o.ClientSecret == "" && !o.PkceRequired { + emptyFields = append(emptyFields, "providers.oidc.client-secret") + } + + return errors.New(strings.Join(emptyFields, ", ") + " must be set") + } + + return nil +} diff --git a/internal/provider/oidc_test.go b/internal/provider/oidc_test.go index d514d37c..80e75af1 100644 --- a/internal/provider/oidc_test.go +++ b/internal/provider/oidc_test.go @@ -31,6 +31,20 @@ func TestOIDCSetup(t *testing.T) { if assert.Error(err) { assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error()) } + + p.IssuerURL = "url" + + err = p.Setup() + if assert.Error(err) { + assert.Equal("providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error()) + } + + p.ClientID = "id" + + err = p.Setup() + if assert.Error(err) { + assert.Equal("providers.oidc.client-secret must be set", err.Error()) + } } func TestOIDCGetLoginURL(t *testing.T) { @@ -40,7 +54,8 @@ func TestOIDCGetLoginURL(t *testing.T) { defer server.Close() // Check url - uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state")) + loginUrl, _ := provider.GetLoginURL("http://example.com/_oauth", "state") + uri, err := url.Parse(loginUrl) assert.Nil(err) assert.Equal(serverURL.Scheme, uri.Scheme) assert.Equal(serverURL.Host, uri.Host) @@ -54,6 +69,37 @@ func TestOIDCGetLoginURL(t *testing.T) { "response_type": []string{"code"}, "scope": []string{"openid profile email"}, "state": []string{"state"}, + "nonce": []string{provider.nonce}, + } + assert.Equal(expectedQs, qs) + + // Calling the method should not modify the underlying config + assert.Equal("", provider.Config.RedirectURL) + + // + // Test with PkceRequired config option + // + provider.PkceRequired = true + + // Check url + loginUrl, _ = provider.GetLoginURL("http://example.com/_oauth", "state") + uri, err = url.Parse(loginUrl) + assert.Nil(err) + assert.Equal(serverURL.Scheme, uri.Scheme) + assert.Equal(serverURL.Host, uri.Host) + assert.Equal("/auth", uri.Path) + + // Check query string + qs = uri.Query() + expectedQs = url.Values{ + "client_id": []string{"idtest"}, + "code_challenge": []string{provider.pkceVerifier.CodeChallengeS256()}, + "code_challenge_method": []string{"S256"}, + "redirect_uri": []string{"http://example.com/_oauth"}, + "response_type": []string{"code"}, + "scope": []string{"openid profile email"}, + "state": []string{"state"}, + "nonce": []string{provider.nonce}, } assert.Equal(expectedQs, qs) @@ -66,7 +112,8 @@ func TestOIDCGetLoginURL(t *testing.T) { provider.Resource = "resourcetest" // Check url - uri, err = url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state")) + loginUrl, _ = provider.GetLoginURL("http://example.com/_oauth", "state") + uri, err = url.Parse(loginUrl) assert.Nil(err) assert.Equal(serverURL.Scheme, uri.Scheme) assert.Equal(serverURL.Host, uri.Host) diff --git a/internal/provider/providers.go b/internal/provider/providers.go index ac863df3..d139688b 100644 --- a/internal/provider/providers.go +++ b/internal/provider/providers.go @@ -17,7 +17,7 @@ type Providers struct { // Provider is used to authenticate users type Provider interface { Name() string - GetLoginURL(redirectURI, state string) string + GetLoginURL(redirectURI, state string) (string, error) ExchangeCode(redirectURI, code string) (string, error) GetUser(token string) (User, error) Setup() error @@ -49,18 +49,18 @@ func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config { } // OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2 -func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string) string { +func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string, opts ...oauth2.AuthCodeOption) string { config := p.ConfigCopy(redirectURI) if p.Resource != "" { return config.AuthCodeURL(state, oauth2.SetAuthURLParam("resource", p.Resource)) } - return config.AuthCodeURL(state) + return config.AuthCodeURL(state, opts...) } // OAuthExchangeCode provides a base "ExchangeCode" for proiders using OAauth2 -func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string) (*oauth2.Token, error) { +func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { config := p.ConfigCopy(redirectURI) - return config.Exchange(p.ctx, code) + return config.Exchange(p.ctx, code, opts...) } diff --git a/internal/server.go b/internal/server.go index 8ac03131..d2d38c09 100644 --- a/internal/server.go +++ b/internal/server.go @@ -101,7 +101,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { } // Validate user - valid := ValidateEmail(email) + valid := ValidateEmail(email, rule) if !valid { logger.WithField("email", email).Warn("Invalid email") http.Error(w, "Not authorized", 401) @@ -121,16 +121,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { // Logging setup logger := s.logger(r, "AuthCallback", "default", "Handling callback") + // Check state + state := r.URL.Query().Get("state") + if err := ValidateState(state); err != nil { + logger.WithFields(logrus.Fields{ + "error": err, + }).Warn("Error validating state") + http.Error(w, "Not authorized", 401) + return + } + // Check for CSRF cookie - c, err := r.Cookie(config.CSRFCookieName) + c, err := FindCSRFCookie(r, state) if err != nil { logger.Info("Missing csrf cookie") http.Error(w, "Not authorized", 401) return } - // Validate state - valid, providerName, redirect, err := ValidateCSRFCookie(r, c) + // Validate CSRF cookie against state + valid, providerName, redirect, err := ValidateCSRFCookie(c, state) if !valid { logger.WithFields(logrus.Fields{ "error": err, @@ -153,7 +163,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Clear CSRF cookie - http.SetCookie(w, ClearCSRFCookie(r)) + http.SetCookie(w, ClearCSRFCookie(r, c)) // Exchange code for token token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code")) @@ -221,7 +231,14 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht } // Forward them on - loginURL := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce)) + loginURL, err := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce)) + + if err != nil { + logger.WithField("error", err).Error("Get login url failed") + http.Error(w, "Service unavailable", 503) + return + } + http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect) logger.WithFields(logrus.Fields{ diff --git a/internal/server_test.go b/internal/server_test.go index 2e543400..8ec0f01d 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) { // Check for CSRF cookie var cookie *http.Cookie for _, c := range res.Cookies() { - if c.Name == config.CSRFCookieName { + if strings.HasPrefix(c.Name, config.CSRFCookieName) { cookie = c } }