diff --git a/cmd/saml2aws/commands/configure.go b/cmd/saml2aws/commands/configure.go index a8f9fc7c2..29bf3a684 100644 --- a/cmd/saml2aws/commands/configure.go +++ b/cmd/saml2aws/commands/configure.go @@ -6,7 +6,7 @@ import ( "path" "github.com/pkg/errors" - "github.com/versent/saml2aws/v2" + saml2aws "github.com/versent/saml2aws/v2" "github.com/versent/saml2aws/v2/helper/credentials" "github.com/versent/saml2aws/v2/pkg/cfg" "github.com/versent/saml2aws/v2/pkg/flags" @@ -68,14 +68,14 @@ func storeCredentials(configFlags *flags.CommonFlags, account *cfg.IDPAccount) e return nil } if configFlags.Password != "" { - if err := credentials.SaveCredentials(account.URL, account.Username, configFlags.Password); err != nil { + if err := credentials.SaveCredentials(account.Name, account.URL, account.Username, configFlags.Password); err != nil { return errors.Wrap(err, "error storing password in keychain") } } else { password := prompter.Password("Password") if password != "" { if confirmPassword := prompter.Password("Confirm"); confirmPassword == password { - if err := credentials.SaveCredentials(account.URL, account.Username, password); err != nil { + if err := credentials.SaveCredentials(account.Name, account.URL, account.Username, password); err != nil { return errors.Wrap(err, "error storing password in keychain") } } else { @@ -91,7 +91,8 @@ func storeCredentials(configFlags *flags.CommonFlags, account *cfg.IDPAccount) e log.Println("OneLogin provider requires --client-id and --client-secret flags to be set.") os.Exit(1) } - if err := credentials.SaveCredentials(path.Join(account.URL, OneLoginOAuthPath), configFlags.ClientID, configFlags.ClientSecret); err != nil { + // we store the OneLogin token in a different secret (idpName + the one login suffix) + if err := credentials.SaveCredentials(account.Name+credentials.OneLoginTokenSuffix, path.Join(account.URL, OneLoginOAuthPath), configFlags.ClientID, configFlags.ClientSecret); err != nil { return errors.Wrap(err, "error storing client_id and client_secret in keychain") } } diff --git a/cmd/saml2aws/commands/list_roles.go b/cmd/saml2aws/commands/list_roles.go index a8a20b925..5fe5fd0a3 100644 --- a/cmd/saml2aws/commands/list_roles.go +++ b/cmd/saml2aws/commands/list_roles.go @@ -8,7 +8,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/versent/saml2aws/v2" + saml2aws "github.com/versent/saml2aws/v2" "github.com/versent/saml2aws/v2/helper/credentials" "github.com/versent/saml2aws/v2/pkg/flags" "github.com/versent/saml2aws/v2/pkg/samlcache" @@ -83,7 +83,7 @@ func ListRoles(loginFlags *flags.LoginExecFlags) error { } if !loginFlags.CommonFlags.DisableKeychain { - err = credentials.SaveCredentials(loginDetails.URL, loginDetails.Username, loginDetails.Password) + err = credentials.SaveCredentials(loginDetails.IdpName, loginDetails.URL, loginDetails.Username, loginDetails.Password) if err != nil { return errors.Wrap(err, "error storing password in keychain") } diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index 0a7766a4e..981063bea 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -14,7 +14,7 @@ import ( "github.com/aws/aws-sdk-go/service/sts" "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/versent/saml2aws/v2" + saml2aws "github.com/versent/saml2aws/v2" "github.com/versent/saml2aws/v2/helper/credentials" "github.com/versent/saml2aws/v2/pkg/awsconfig" "github.com/versent/saml2aws/v2/pkg/cfg" @@ -122,7 +122,7 @@ func Login(loginFlags *flags.LoginExecFlags) error { } if !loginFlags.CommonFlags.DisableKeychain { - err = credentials.SaveCredentials(loginDetails.URL, loginDetails.Username, loginDetails.Password) + err = credentials.SaveCredentials(loginDetails.IdpName, loginDetails.URL, loginDetails.Username, loginDetails.Password) if err != nil { return errors.Wrap(err, "Error storing password in keychain.") } @@ -174,15 +174,20 @@ func buildIdpAccount(loginFlags *flags.LoginExecFlags) (*cfg.IDPAccount, error) func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFlags) (*creds.LoginDetails, error) { - // log.Printf("loginFlags %+v", loginFlags) - - loginDetails := &creds.LoginDetails{URL: account.URL, Username: account.Username, MFAToken: loginFlags.CommonFlags.MFAToken, DuoMFAOption: loginFlags.DuoMFAOption} + loginDetails := &creds.LoginDetails{ + URL: account.URL, + Username: account.Username, + MFAToken: loginFlags.CommonFlags.MFAToken, + DuoMFAOption: loginFlags.DuoMFAOption, + IdpName: account.Name, + IdpProvider: account.Provider, + } log.Printf("Using IdP Account %s to access %s %s", loginFlags.CommonFlags.IdpAccount, account.Provider, account.URL) var err error if !loginFlags.CommonFlags.DisableKeychain { - err = credentials.LookupCredentials(loginDetails, account.Provider) + err = credentials.LookupCredentials(loginDetails) if err != nil { if !credentials.IsErrCredentialsNotFound(err) { return nil, errors.Wrap(err, "Error loading saved password.") diff --git a/cmd/saml2aws/commands/login_test.go b/cmd/saml2aws/commands/login_test.go index bca0442cf..2a9446318 100644 --- a/cmd/saml2aws/commands/login_test.go +++ b/cmd/saml2aws/commands/login_test.go @@ -6,7 +6,8 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/versent/saml2aws/v2" + + saml2aws "github.com/versent/saml2aws/v2" "github.com/versent/saml2aws/v2/pkg/awsconfig" "github.com/versent/saml2aws/v2/pkg/cfg" "github.com/versent/saml2aws/v2/pkg/creds" @@ -14,11 +15,18 @@ import ( ) func TestResolveLoginDetailsWithFlags(t *testing.T) { + commonFlags := &flags.CommonFlags{ + URL: "https://id.example.com", + Username: "wolfeidau", + Password: "testtestlol", + MFAToken: "123456", + SkipPrompt: true, + } - commonFlags := &flags.CommonFlags{URL: "https://id.example.com", Username: "wolfeidau", Password: "testtestlol", MFAIPAddress: "127.0.0.1", MFAToken: "123456", SkipPrompt: true} loginFlags := &flags.LoginExecFlags{CommonFlags: commonFlags} idpa := &cfg.IDPAccount{ + Name: "AccountName", URL: "https://id.example.com", MFA: "none", Provider: "Ping", @@ -27,16 +35,30 @@ func TestResolveLoginDetailsWithFlags(t *testing.T) { loginDetails, err := resolveLoginDetails(idpa, loginFlags) assert.Empty(t, err) - assert.Equal(t, &creds.LoginDetails{Username: "wolfeidau", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456", MFAIPAddress: "127.0.0.1"}, loginDetails) + assert.Equal(t, + &creds.LoginDetails{ + IdpName: "AccountName", + IdpProvider: "Ping", + Username: "wolfeidau", + Password: "testtestlol", + URL: "https://id.example.com", + MFAToken: "123456", + }, loginDetails) } func TestOktaResolveLoginDetailsWithFlags(t *testing.T) { - // Default state - user did not supply values for DisableSessions and DisableSessions - commonFlags := &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true} + commonFlags := &flags.CommonFlags{ + URL: "https://id.example.com", + Username: "testuser", + Password: "testtestlol", + MFAToken: "123456", + SkipPrompt: true, + } loginFlags := &flags.LoginExecFlags{CommonFlags: commonFlags} idpa := &cfg.IDPAccount{ + Name: "AnotherAccountName", URL: "https://id.example.com", MFA: "none", Provider: "Okta", @@ -47,11 +69,26 @@ func TestOktaResolveLoginDetailsWithFlags(t *testing.T) { assert.Nil(t, err) assert.False(t, idpa.DisableSessions, fmt.Errorf("default state, DisableSessions should be false")) assert.False(t, idpa.DisableRememberDevice, fmt.Errorf("default state, DisableRememberDevice should be false")) - assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) + assert.Equal(t, + &creds.LoginDetails{ + IdpName: "AnotherAccountName", + IdpProvider: "Okta", + Username: "testuser", + Password: "testtestlol", + URL: "https://id.example.com", + MFAToken: "123456", + }, loginDetails) // User disabled keychain, resolveLoginDetails should set the account's DisableSessions and DisableSessions fields to true - commonFlags = &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true, DisableKeychain: true} + commonFlags = &flags.CommonFlags{ + URL: "https://id.example.com", + Username: "testuser", + Password: "testtestlol", + MFAToken: "123456", + SkipPrompt: true, + DisableKeychain: true, + } loginFlags = &flags.LoginExecFlags{CommonFlags: commonFlags} loginDetails, err = resolveLoginDetails(idpa, loginFlags) @@ -59,12 +96,18 @@ func TestOktaResolveLoginDetailsWithFlags(t *testing.T) { assert.Nil(t, err) assert.True(t, idpa.DisableSessions, fmt.Errorf("user disabled keychain, DisableSessions should be true")) assert.True(t, idpa.DisableRememberDevice, fmt.Errorf("user disabled keychain, DisableRememberDevice should be true")) - assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) - + assert.Equal(t, + &creds.LoginDetails{ + IdpName: "AnotherAccountName", + IdpProvider: "Okta", + Username: "testuser", + Password: "testtestlol", + URL: "https://id.example.com", + MFAToken: "123456", + }, loginDetails) } func TestResolveRoleSingleEntry(t *testing.T) { - adminRole := &saml2aws.AWSRole{ Name: "admin", RoleARN: "arn:aws:iam::456456456456:saml-provider/example-idp,arn:aws:iam::456456456456:role/admin", @@ -81,7 +124,6 @@ func TestResolveRoleSingleEntry(t *testing.T) { } func TestCredentialsToCredentialProcess(t *testing.T) { - aws_creds := &awsconfig.AWSCredentials{ AWSAccessKey: "someawsaccesskey", AWSSecretKey: "somesecretkey", diff --git a/helper/credentials/credentials.go b/helper/credentials/credentials.go index bebce84f7..847baca26 100644 --- a/helper/credentials/credentials.go +++ b/helper/credentials/credentials.go @@ -2,6 +2,7 @@ package credentials import ( "errors" + "fmt" ) var ( @@ -14,25 +15,38 @@ var ( // Credentials holds the information shared between saml2aws and the credentials store. type Credentials struct { + IdpName string ServerURL string Username string Secret string } -// CredsLabel saml2aws credentials should be labeled as such in credentials stores that allow labelling. -// That label allows to filter out non-Docker credentials too at lookup/search in macOS keychain, -// Windows credentials manager and Linux libsecret. Default value is "saml2aws Credentials" -var CredsLabel = "saml2aws Credentials" +const ( + // CredsLabel saml2aws credentials should be labeled as such in credentials stores that allow labelling. + // That label allows to filter out non-Docker credentials too at lookup/search in macOS keychain, + // Windows credentials manager and Linux libsecret. Default value is "saml2aws Credentials" + CredsLabel = "saml2aws Credentials" + CredsKeyPrefix = "saml2aws_credentials" + OktaSessionCookieSuffix = "_okta_session" + OneLoginTokenSuffix = "_onelogin_token" +) + +func GetKeyFromAccount(accountName string) string { + return fmt.Sprintf("%s_%s", CredsKeyPrefix, accountName) +} // Helper is the interface a credentials store helper must implement. type Helper interface { // Add appends credentials to the store. Add(*Credentials) error // Delete removes credentials from the store. - Delete(serverURL string) error + Delete(keyName string) error // Get retrieves credentials from the store. // It returns username and secret as strings. - Get(serverURL string) (string, string, error) + Get(keyName string) (string, string, error) + // Legacy Get retrieves previously stored credentials + // this function is preserved for backward compatibility + LegacyGet(serverURL string) (string, string, error) // SupportsCredentialStorage returns true or false if there is credential storage. SupportsCredentialStorage() bool } @@ -49,11 +63,15 @@ func (defaultHelper) Add(*Credentials) error { return nil } -func (defaultHelper) Delete(serverURL string) error { +func (defaultHelper) Delete(keyName string) error { return nil } -func (defaultHelper) Get(serverURL string) (string, string, error) { +func (defaultHelper) Get(keyName string) (string, string, error) { + return "", "", ErrCredentialsNotFound +} + +func (defaultHelper) LegacyGet(serverURL string) (string, string, error) { return "", "", ErrCredentialsNotFound } diff --git a/helper/credentials/saml.go b/helper/credentials/saml.go index 0f3ba65a8..74c599898 100644 --- a/helper/credentials/saml.go +++ b/helper/credentials/saml.go @@ -1,34 +1,57 @@ package credentials import ( + "errors" "path" "github.com/versent/saml2aws/v2/pkg/creds" ) // LookupCredentials lookup an existing set of credentials and validate it. -func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error { +func LookupCredentials(loginDetails *creds.LoginDetails) error { + var username, password string + var err, err2 error - username, password, err := CurrentHelper.Get(loginDetails.URL) + username, password, err = CurrentHelper.Get(GetKeyFromAccount(loginDetails.IdpName)) if err != nil { - return err + // the credential keyname has changed from server URL to Identity Provider (#762) + // Falling back to old key name to preserve backward compatibility + username, password, err2 = CurrentHelper.LegacyGet(loginDetails.URL) + if err2 != nil { + // return the error from the current key name, not the historical one + return err + } } loginDetails.Username = username loginDetails.Password = password // If the provider is Okta, check for existing Okta Session Cookie (sid) - if provider == "Okta" { - _, oktaSessionCookie, err := CurrentHelper.Get(loginDetails.URL + "/sessionCookie") - if err == nil { - loginDetails.OktaSessionCookie = oktaSessionCookie + if loginDetails.IdpProvider == "Okta" { + // load up the Okta token from a different secret (idp name + Okta suffix) + var oktaSessionCookie string + + _, oktaSessionCookie, err = CurrentHelper.Get(GetKeyFromAccount(loginDetails.IdpName + OktaSessionCookieSuffix)) + if err != nil { + // the credential keyname has changed from server URL to Identity Provider (#762) + // Falling back to old key name to preserve backward compatibility + _, oktaSessionCookie, _ = CurrentHelper.LegacyGet(loginDetails.URL + "/sessionCookie") } + loginDetails.OktaSessionCookie = oktaSessionCookie } - if provider == "OneLogin" { - id, secret, err := CurrentHelper.Get(path.Join(loginDetails.URL, "/auth/oauth2/v2/token")) + if loginDetails.IdpProvider == "OneLogin" { + var id, secret string + + // load up the one login token from a different secret (idp name + one login suffix) + id, secret, err = CurrentHelper.Get(GetKeyFromAccount(loginDetails.IdpName + OneLoginTokenSuffix)) if err != nil { - return err + // the credential keyname has changed from server URL to Identity Provider (#762) + // Falling back to old key name to preserve backward compatibility + id, secret, err2 = CurrentHelper.LegacyGet(path.Join(loginDetails.URL, "/auth/oauth2/v2/token")) + if err2 != nil { + return err + } } loginDetails.ClientID = id loginDetails.ClientSecret = secret @@ -37,14 +60,17 @@ func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error } // SaveCredentials save the user credentials. -func SaveCredentials(url, username, password string) error { - +func SaveCredentials(idpName, url, username, password string) error { creds := &Credentials{ + IdpName: idpName, ServerURL: url, Username: username, Secret: password, } + if idpName == "" { + return errors.New("idpName is empty") + } return CurrentHelper.Add(creds) } diff --git a/helper/credentials/saml_test.go b/helper/credentials/saml_test.go new file mode 100644 index 000000000..3cbdc99df --- /dev/null +++ b/helper/credentials/saml_test.go @@ -0,0 +1,310 @@ +package credentials + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/versent/saml2aws/v2/pkg/creds" +) + +type MockHelper struct { + Credentials map[string]*Credentials + AddFailError error + DeleteFailError error +} + +func NewMockHelper() *MockHelper { + return &MockHelper{ + Credentials: make(map[string]*Credentials), + } +} + +func (m *MockHelper) Add(c *Credentials) error { + if m.AddFailError != nil { + return m.AddFailError + } + m.Credentials[GetKeyFromAccount(c.IdpName)] = c + return nil +} + +func (m *MockHelper) Delete(keyname string) error { + if m.DeleteFailError != nil { + return m.DeleteFailError + } + if _, ok := m.Credentials[keyname]; !ok { + return fmt.Errorf("%s not found in credential set (keychain mock)", keyname) + } + delete(m.Credentials, keyname) + return nil +} + +func (m *MockHelper) Get(keyName string) (string, string, error) { + if _, ok := m.Credentials[keyName]; !ok { + return "", "", fmt.Errorf("%s not found in credential set (keychain mock)", keyName) + } + return m.Credentials[keyName].Username, m.Credentials[keyName].Secret, nil +} + +func (m *MockHelper) LegacyGet(serverURL string) (string, string, error) { + return m.Get(serverURL) +} + +func (m *MockHelper) SupportsCredentialStorage() bool { + return true +} + +func TestLookupCredentials(t *testing.T) { + oldHelper := CurrentHelper + + testCases := []struct { + CaseName string + initialCredentials map[string]*Credentials + loginDetails creds.LoginDetails + expectedError bool + expectedUsername string + expectedPassword string + expectedOktaCookie string + expectedClientID string + expectedClientSecret string + }{ + { + CaseName: "CredentialsFound", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "ADFS", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "saml2aws_credentials_test": { + Username: "user1", + Secret: "password1", + }, + }, + expectedUsername: "user1", + expectedPassword: "password1", + }, + { + CaseName: "CredentialsNotFound", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "ADFS", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{}, + expectedError: true, + }, + { + CaseName: "CredentialsFoundButFallBack", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "ADFS", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "https://someurl.com/": { + Username: "user1", + Secret: "password1", + }, + }, + expectedUsername: "user1", + expectedPassword: "password1", + }, + // for Okta + { + CaseName: "CredentialsWorkForOkta", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "Okta", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "saml2aws_credentials_test": { + Username: "user1", + Secret: "password1", + }, + "saml2aws_credentials_test_okta_session": { + Secret: "cookie1", + }, + }, + expectedUsername: "user1", + expectedPassword: "password1", + expectedOktaCookie: "cookie1", + }, + { + CaseName: "CredentialsFallbackWorkForOkta", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "Okta", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "https://someurl.com/": { + Username: "user2", + Secret: "password2", + }, + "https://someurl.com//sessionCookie": { + Secret: "cookie2", + }, + }, + expectedUsername: "user2", + expectedPassword: "password2", + expectedOktaCookie: "cookie2", + }, + { + CaseName: "CredentialsWorkForOktaButCookieFails", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "Okta", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "saml2aws_credentials_test": { + Username: "user3", + Secret: "password3", + }, + }, + expectedUsername: "user3", + expectedPassword: "password3", + expectedOktaCookie: "", + expectedError: false, + }, + + // For OneLogin + { + CaseName: "CredentialsWorkForOneLogin", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "OneLogin", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "saml2aws_credentials_test": { + Username: "user4", + Secret: "password4", + }, + "saml2aws_credentials_test_onelogin_token": { + Username: "clientId1", + Secret: "clientSecret1", + }, + }, + expectedUsername: "user4", + expectedPassword: "password4", + expectedClientID: "clientId1", + expectedClientSecret: "clientSecret1", + }, + { + CaseName: "CredentialsWorksButFailForOneLoginToken", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "OneLogin", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "saml2aws_credentials_test": { + Username: "user5", + Secret: "password5", + }, + }, + expectedError: true, + }, + { + CaseName: "CredentialsFallbackWorkForOneLogin", + loginDetails: creds.LoginDetails{ + IdpName: "test", + IdpProvider: "OneLogin", + URL: "https://someurl.com/", + }, + initialCredentials: map[string]*Credentials{ + "https://someurl.com/": { + Username: "user6", + Secret: "password6", + }, + "https:/someurl.com/auth/oauth2/v2/token": { + Username: "clientId2", + Secret: "clientSecret2", + }, + }, + expectedUsername: "user6", + expectedPassword: "password6", + expectedClientID: "clientId2", + expectedClientSecret: "clientSecret2", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.CaseName, func(t *testing.T) { + t.Log(testCase.CaseName) + m := NewMockHelper() + CurrentHelper = m + m.Credentials = testCase.initialCredentials + t.Log(testCase.initialCredentials) + err := LookupCredentials(&testCase.loginDetails) + if testCase.expectedError { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + assert.EqualValues(t, testCase.expectedUsername, testCase.loginDetails.Username) + assert.EqualValues(t, testCase.expectedPassword, testCase.loginDetails.Password) + assert.EqualValues(t, testCase.expectedOktaCookie, testCase.loginDetails.OktaSessionCookie) + assert.EqualValues(t, testCase.expectedClientID, testCase.loginDetails.ClientID) + assert.EqualValues(t, testCase.expectedClientSecret, testCase.loginDetails.ClientSecret) + } + }) + } + + // restoring the old Helper + CurrentHelper = oldHelper +} + +func TestSaveCredentials(t *testing.T) { + oldHelper := CurrentHelper + + testCases := []struct { + CaseName string + IdpName string + URL string + Username string + Password string + expectedCredentialKeyName string + expectedError bool + }{ + { + CaseName: "SaveCredentials", + IdpName: "test", + URL: "http://test.com/", + Username: "user1", + Password: "password1", + expectedCredentialKeyName: "saml2aws_credentials_test", + }, + { + CaseName: "EmptyIdpNameRaisesError", + IdpName: "", + URL: "http://test.com/", + Username: "user2", + Password: "password2", + expectedError: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.CaseName, func(t *testing.T) { + m := NewMockHelper() + CurrentHelper = m + err := SaveCredentials(testCase.IdpName, testCase.URL, testCase.Username, testCase.Password) + if testCase.expectedError { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + _, ok := m.Credentials[testCase.expectedCredentialKeyName] + assert.True(t, ok) + assert.EqualValues(t, testCase.Username, m.Credentials[testCase.expectedCredentialKeyName].Username) + assert.EqualValues(t, testCase.Password, m.Credentials[testCase.expectedCredentialKeyName].Secret) + } + }) + } + + // restoring the old Helper + CurrentHelper = oldHelper +} diff --git a/helper/linuxkeyring/linuxkeyring_linux.go b/helper/linuxkeyring/linuxkeyring_linux.go index 123e7c787..0224e3bc7 100644 --- a/helper/linuxkeyring/linuxkeyring_linux.go +++ b/helper/linuxkeyring/linuxkeyring_linux.go @@ -5,6 +5,7 @@ import ( "github.com/99designs/keyring" "github.com/sirupsen/logrus" + "github.com/versent/saml2aws/v2/helper/credentials" ) @@ -35,7 +36,6 @@ func NewKeyringHelper(config Configuration) (*KeyringHelper, error) { } kr, err := keyring.Open(c) - if err != nil { return nil, err } @@ -52,19 +52,19 @@ func (kr *KeyringHelper) Add(creds *credentials.Credentials) error { } return kr.keyring.Set(keyring.Item{ - Key: creds.ServerURL, + Key: credentials.GetKeyFromAccount(creds.IdpName), Label: credentials.CredsLabel, Data: encoded, KeychainNotTrustApplication: false, }) } -func (kr *KeyringHelper) Delete(serverURL string) error { - return kr.keyring.Remove(serverURL) +func (kr *KeyringHelper) Delete(keyName string) error { + return kr.keyring.Remove(keyName) } -func (kr *KeyringHelper) Get(serverURL string) (string, string, error) { - item, err := kr.keyring.Get(serverURL) +func (kr *KeyringHelper) Get(keyName string) (string, string, error) { + item, err := kr.keyring.Get(keyName) if err != nil { logger.WithField("err", err).Error("keychain Get returned error") return "", "", credentials.ErrCredentialsNotFound @@ -78,6 +78,11 @@ func (kr *KeyringHelper) Get(serverURL string) (string, string, error) { return creds.Username, creds.Secret, nil } +// this function is preserved for backward compatibility reasons +func (kr *KeyringHelper) LegacyGet(serverURL string) (string, string, error) { + return kr.Get(serverURL) +} + func (KeyringHelper) SupportsCredentialStorage() bool { return true } diff --git a/helper/osxkeychain/osxkeychain.go b/helper/osxkeychain/osxkeychain.go index 93e2bcdf3..2f3a6698f 100644 --- a/helper/osxkeychain/osxkeychain.go +++ b/helper/osxkeychain/osxkeychain.go @@ -9,6 +9,7 @@ import ( "github.com/keybase/go-keychain" "github.com/sirupsen/logrus" + "github.com/versent/saml2aws/v2/helper/credentials" ) @@ -19,14 +20,15 @@ type Osxkeychain struct{} // Add adds new credentials to the keychain. func (h Osxkeychain) Add(creds *credentials.Credentials) error { - err := h.Delete(creds.ServerURL) + err := h.Delete(creds.IdpName) if err != nil { logger.WithError(err).Debug("delete of existing keychain entry failed") } item := keychain.NewItem() item.SetSecClass(keychain.SecClassInternetPassword) - item.SetLabel(credentials.CredsLabel) + item.SetLabel(credentials.GetKeyFromAccount(creds.IdpName)) + item.SetString("Purpose", credentials.CredsLabel) item.SetAccount(creds.Username) item.SetData([]byte(creds.Secret)) err = splitServer3(creds.ServerURL, item) @@ -44,26 +46,43 @@ func (h Osxkeychain) Add(creds *credentials.Credentials) error { } // Delete removes credentials from the keychain. -func (h Osxkeychain) Delete(serverURL string) error { - +func (h Osxkeychain) Delete(keyName string) error { item := keychain.NewItem() item.SetSecClass(keychain.SecClassInternetPassword) - err := splitServer3(serverURL, item) + item.SetLabel(keyName) + return keychain.DeleteItem(item) +} + +// Get returns the username and secret to use for a given keyName +func (h Osxkeychain) Get(keyName string) (string, string, error) { + logger.WithField("Credential Key", keyName).Debug("Get credentials") + + query := keychain.NewItem() + query.SetSecClass(keychain.SecClassInternetPassword) + + // only search on the idp name + query.SetLabel(keyName) + query.SetMatchLimit(keychain.MatchLimitOne) + query.SetReturnAttributes(true) + query.SetReturnData(true) + + results, err := keychain.QueryItem(query) if err != nil { - return err + return "", "", err } - err = keychain.DeleteItem(item) - if err != nil { - return err + if len(results) == 0 { + return "", "", credentials.ErrCredentialsNotFound } - return nil -} + logger.WithField("user", results[0].Account).Debug("Get credentials") -// Get returns the username and secret to use for a given registry server URL. -func (h Osxkeychain) Get(serverURL string) (string, string, error) { + return results[0].Account, string(results[0].Data), nil +} +// Legacy Get returns the username and secret to use for a given registry server URL. +// this function is preserved for backward compatibility reasons +func (h Osxkeychain) LegacyGet(serverURL string) (string, string, error) { logger.WithField("serverURL", serverURL).Debug("Get credentials") query := keychain.NewItem() diff --git a/helper/osxkeychain/osxkeychain_test.go b/helper/osxkeychain/osxkeychain_test.go index 78b85337c..84311c23f 100644 --- a/helper/osxkeychain/osxkeychain_test.go +++ b/helper/osxkeychain/osxkeychain_test.go @@ -29,16 +29,19 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/versent/saml2aws/v2/helper/credentials" ) func TestOSXKeychainHelper(t *testing.T) { creds := &credentials.Credentials{ + IdpName: "creds_idpName", ServerURL: "https://foobar.docker.io:2376/v1", Username: "foobar", Secret: "foobarbaz", } creds1 := &credentials.Credentials{ + IdpName: "creds1_idpName", ServerURL: "https://foobar.docker.io:2376/v2", Username: "foobarbaz", Secret: "foobar", @@ -48,7 +51,9 @@ func TestOSXKeychainHelper(t *testing.T) { t.Fatal(err) } - username, secret, err := helper.Get(creds.ServerURL) + credsKey, creds1Key := credentials.GetKeyFromAccount(creds.IdpName), credentials.GetKeyFromAccount(creds1.IdpName) + + username, secret, err := helper.Get(credsKey) if err != nil { t.Fatal(err) } @@ -65,10 +70,10 @@ func TestOSXKeychainHelper(t *testing.T) { require.Nil(t, err) defer func() { - _ = helper.Delete(creds1.ServerURL) + _ = helper.Delete(creds1Key) }() - if err := helper.Delete(creds.ServerURL); err != nil { + if err := helper.Delete(credsKey); err != nil { t.Fatal(err) } } diff --git a/helper/wincred/wincred_windows.go b/helper/wincred/wincred_windows.go index 38d707521..421053f86 100644 --- a/helper/wincred/wincred_windows.go +++ b/helper/wincred/wincred_windows.go @@ -27,6 +27,7 @@ import ( "strings" winc "github.com/danieljoos/wincred" + "github.com/versent/saml2aws/v2/helper/credentials" ) @@ -35,7 +36,7 @@ type Wincred struct{} // Add adds new credentials to the windows credentials manager. func (h Wincred) Add(creds *credentials.Credentials) error { - g := winc.NewGenericCredential(creds.ServerURL) + g := winc.NewGenericCredential(credentials.GetKeyFromAccount(creds.IdpName)) g.UserName = creds.Username g.CredentialBlob = []byte(creds.Secret) g.Persist = winc.PersistLocalMachine @@ -45,8 +46,8 @@ func (h Wincred) Add(creds *credentials.Credentials) error { } // Delete removes credentials from the windows credentials manager. -func (h Wincred) Delete(serverURL string) error { - g, err := winc.GetGenericCredential(serverURL) +func (h Wincred) Delete(keyName string) error { + g, err := winc.GetGenericCredential(keyName) if g == nil { return nil } @@ -57,8 +58,8 @@ func (h Wincred) Delete(serverURL string) error { } // Get retrieves credentials from the windows credentials manager. -func (h Wincred) Get(serverURL string) (string, string, error) { - g, _ := winc.GetGenericCredential(serverURL) +func (h Wincred) Get(keyname string) (string, string, error) { + g, _ := winc.GetGenericCredential(keyname) if g == nil { return "", "", credentials.ErrCredentialsNotFound } @@ -72,6 +73,12 @@ func (h Wincred) Get(serverURL string) (string, string, error) { return "", "", credentials.ErrCredentialsNotFound } +// Legacy Get retrieves credentials from the windows credentials manager. +// this function is preserved for backward compatibility reasons +func (h Wincred) LegacyGet(serverURL string) (string, string, error) { + return h.Get(serverURL) +} + // List returns the stored URLs and corresponding usernames for a given credentials label. func (h Wincred) List() (map[string]string, error) { creds, err := winc.List() diff --git a/pkg/creds/creds.go b/pkg/creds/creds.go index 5aa697103..c86724128 100644 --- a/pkg/creds/creds.go +++ b/pkg/creds/creds.go @@ -2,6 +2,8 @@ package creds // LoginDetails used to authenticate type LoginDetails struct { + IdpName string // the IDP name for those login Details, required for the credential + IdpProvider string // the IDP provider, required to populate Okta and OneLogin ClientID string // used by OneLogin ClientSecret string // used by OneLogin DownloadBrowser bool // used by Browser diff --git a/pkg/provider/okta/okta.go b/pkg/provider/okta/okta.go index 52e9ba808..31f2b9cfc 100644 --- a/pkg/provider/okta/okta.go +++ b/pkg/provider/okta/okta.go @@ -201,7 +201,7 @@ func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken s oktaSessionCookie := gjson.Get(resp, "id").String() - err = credentials.SaveCredentials(loginDetails.URL+"/sessionCookie", loginDetails.Username, oktaSessionCookie) + err = credentials.SaveCredentials(loginDetails.IdpName+credentials.OktaSessionCookieSuffix, loginDetails.URL+"/sessionCookie", loginDetails.Username, oktaSessionCookie) if err != nil { return "", "", fmt.Errorf("error storing okta session token | err: %v", err) } @@ -570,7 +570,7 @@ func (oc *Client) follow(ctx context.Context, req *http.Request, loginDetails *c if handler == nil { html, _ := doc.Selection.Html() logger.WithField("doc", html).Debug("Unknown document type") - return "", fmt.Errorf("Unknown document type") + return "", fmt.Errorf("unknown document type") } ctx, req, err = handler(ctx, doc) diff --git a/pkg/provider/okta/okta_test.go b/pkg/provider/okta/okta_test.go index e2d8f1fa4..318f24a85 100644 --- a/pkg/provider/okta/okta_test.go +++ b/pkg/provider/okta/okta_test.go @@ -209,6 +209,7 @@ func setupTestClient(t *testing.T, ts *httptest.Server, mfa string) (*Client, *c func TestSetDeviceTokenCookie(t *testing.T) { idpAccount := cfg.NewIDPAccount() + idpAccount.Name = "myOktaProvider" idpAccount.URL = "https://idp.example.com/abcd" idpAccount.Username = "user@example.com"