From a47681f2394360c3c357c5224bc8013c421a3209 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Tue, 11 Jun 2024 18:05:35 -0400 Subject: [PATCH 01/66] GODRIVER-2911: Initial attempted to untie the Gordian not, this will not go live, I'm sure --- internal/authutil/oidc.go | 80 ++++++++++++++ internal/authutil/sasl.go | 220 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 internal/authutil/oidc.go create mode 100644 internal/authutil/sasl.go diff --git a/internal/authutil/oidc.go b/internal/authutil/oidc.go new file mode 100644 index 0000000000..71c0ac20e9 --- /dev/null +++ b/internal/authutil/oidc.go @@ -0,0 +1,80 @@ +package authutil + +import ( + "context" + "sync" + "time" + + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +const oidcMech = "MONGODB-OIDC" +const tokenResourceProp = "TOKEN_RESOURCE" +const environmentProp = "ENVIRONMENT" +const principalProp = "PRINCIPAL" +const allowedHostsProp = "ALLOWED_HOSTS" +const azureEnvironmentValue = "azure" +const gcpEnvironmentValue = "gcp" + +// OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, +// and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality +// is only for the OIDC Human flow. +type OIDCAuthenticator struct { + mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct. + + AuthMechanismProperties map[string]string + + accessToken string + refreshToken *string + idpInfo *IDPInfo +} + +type IDPInfo struct { + Issuer string `bson:"issuer"` + ClientID string `bson:"clientId"` + RequestScopes []string `bson:"requestScopes"` +} + +type OIDCCallback func(context.Context, *OIDCAargs) (*OIDCCredential, error) + +type OIDCAargs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +type OIDCCredential struct { + AccessToken string + ExpiresAt time.Time + RefreshToken *string +} + +type oidcOneStep struct { + accessToken string +} + +func (oos *oidcOneStep) Start() (string, []byte, error) { + return oidcMech, jwtStepRequest(oos.accessToken), nil +} + +func (oos *oidcOneStep) Next(context.Context, []byte) ([]byte, error) { + return nil, newAuthError("unexpected step in OIDC machine authentication", nil) +} + +func (*oidcOneStep) Completed() bool { + return true +} + +func jwtStepRequest(accessToken string) []byte { + return bsoncore.NewDocumentBuilder(). + AppendString("jwt", accessToken). + Build() +} + +func principalStepRequest(principal string) []byte { + doc := bsoncore.NewDocumentBuilder() + if principal != "" { + doc.AppendString("n", principal) + } + return doc.Build() +} diff --git a/internal/authutil/sasl.go b/internal/authutil/sasl.go new file mode 100644 index 0000000000..0f0fe8390c --- /dev/null +++ b/internal/authutil/sasl.go @@ -0,0 +1,220 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package authutil + +import ( + "context" + "fmt" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/operation" + "go.mongodb.org/mongo-driver/x/mongo/driver/session" +) + +const defaultAuthDB = "admin" + +// Error is an error that occurred during authentication. +type Error struct { + message string + inner error +} + +func (e *Error) Error() string { + if e.inner == nil { + return e.message + } + return fmt.Sprintf("%s: %s", e.message, e.inner) +} + +// Inner returns the wrapped error. +func (e *Error) Inner() error { + return e.inner +} + +// Unwrap returns the underlying error. +func (e *Error) Unwrap() error { + return e.inner +} + +// Message returns the message. +func (e *Error) Message() string { + return e.message +} + +func newError(err error, mech string) error { + return &Error{ + message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech), + inner: err, + } +} + +// SaslClient is the client piece of a sasl conversation. +type SaslClient interface { + Start() (string, []byte, error) + Next(challenge []byte) ([]byte, error) + Completed() bool +} + +// SaslClientCloser is a SaslClient that has resources to clean up. +type SaslClientCloser interface { + SaslClient + Close() +} + +// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command. +type ExtraOptionsSaslClient interface { + StartCommandOptions() bsoncore.Document +} + +// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the +// conversation can be executed in multi-step speculative fashion. +type saslConversation struct { + client SaslClient + source string + mechanism string + speculative bool +} + +func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation { + authSource := source + if authSource == "" { + authSource = defaultAuthDB + } + return &saslConversation{ + client: client, + source: authSource, + speculative: speculative, + } +} + +// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used +// for speculative authentication. +func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) { + var payload []byte + var err error + sc.mechanism, payload, err = sc.client.Start() + if err != nil { + return nil, err + } + + saslCmdElements := [][]byte{ + bsoncore.AppendInt32Element(nil, "saslStart", 1), + bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), + } + if sc.speculative { + // The "db" field is only appended for speculative auth because the hello command is executed against admin + // so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands + // will be executed against the auth source. + saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source)) + } + if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok { + optionsDoc := extraOptionsClient.StartCommandOptions() + saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc)) + } + + return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil +} + +type saslResponse struct { + ConversationID int `bson:"conversationId"` + Code int `bson:"code"` + Done bool `bson:"done"` + Payload []byte `bson:"payload"` +} + +// Finish completes the conversation based on the first server response to authenticate the given connection. +func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error { + if closer, ok := sc.client.(SaslClientCloser); ok { + defer closer.Close() + } + + var saslResp saslResponse + err := bson.Unmarshal(firstResponse, &saslResp) + if err != nil { + fullErr := fmt.Errorf("unmarshal error: %w", err) + return newError(fullErr, sc.mechanism) + } + + cid := saslResp.ConversationID + var payload []byte + var rdr bsoncore.Document + for { + if saslResp.Code != 0 { + return newError(err, sc.mechanism) + } + + if saslResp.Done && sc.client.Completed() { + return nil + } + + payload, err = sc.client.Next(saslResp.Payload) + if err != nil { + return newError(err, sc.mechanism) + } + + if saslResp.Done && sc.client.Completed() { + return nil + } + + doc := bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "saslContinue", 1), + bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), + ) + saslContinueCmd := operation.NewCommand(doc). + Database(sc.source). + Deployment(driver.SingleConnectionDeployment{cfg.Connection}). + ClusterClock(cfg.ClusterClock). + ServerAPI(cfg.ServerAPI) + + err = saslContinueCmd.Execute(ctx) + if err != nil { + return newError(err, sc.mechanism) + } + rdr = saslContinueCmd.Result() + + err = bson.Unmarshal(rdr, &saslResp) + if err != nil { + fullErr := fmt.Errorf("unmarshal error: %w", err) + return newError(fullErr, sc.mechanism) + } + } +} + +// Config holds the information necessary to perform an authentication attempt. +type Config struct { + Description description.Server + Connection driver.Connection + ClusterClock *session.ClusterClock + HandshakeInfo driver.HandshakeInformation + ServerAPI *driver.ServerAPIOptions +} + +// ConductSaslConversation runs a full SASL conversation to authenticate the given connection. +func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error { + // Create a non-speculative SASL conversation. + conversation := newSaslConversation(client, authSource, false) + + saslStartDoc, err := conversation.FirstMessage() + if err != nil { + return newError(err, conversation.mechanism) + } + saslStartCmd := operation.NewCommand(saslStartDoc). + Database(authSource). + Deployment(driver.SingleConnectionDeployment{cfg.Connection}). + ClusterClock(cfg.ClusterClock). + ServerAPI(cfg.ServerAPI) + if err := saslStartCmd.Execute(ctx); err != nil { + return newError(err, conversation.mechanism) + } + + return conversation.Finish(ctx, cfg, saslStartCmd.Result()) +} From 279635a7e4faf937598c5b61cba02b2f1fd9522f Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 11:24:42 -0400 Subject: [PATCH 02/66] GODRIVER-2911: We're going to have to go this way and implement some sort of OIDCSasl --- internal/authutil/sasl.go | 220 ------------------ internal/test/compilecheck/go.mod | 8 +- internal/test/compilecheck/go.sum | 19 +- {internal/authutil => x/mongo/driver}/oidc.go | 17 +- x/mongo/driver/operation.go | 4 + 5 files changed, 31 insertions(+), 237 deletions(-) delete mode 100644 internal/authutil/sasl.go rename {internal/authutil => x/mongo/driver}/oidc.go (82%) diff --git a/internal/authutil/sasl.go b/internal/authutil/sasl.go deleted file mode 100644 index 0f0fe8390c..0000000000 --- a/internal/authutil/sasl.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package authutil - -import ( - "context" - "fmt" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/description" - "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/operation" - "go.mongodb.org/mongo-driver/x/mongo/driver/session" -) - -const defaultAuthDB = "admin" - -// Error is an error that occurred during authentication. -type Error struct { - message string - inner error -} - -func (e *Error) Error() string { - if e.inner == nil { - return e.message - } - return fmt.Sprintf("%s: %s", e.message, e.inner) -} - -// Inner returns the wrapped error. -func (e *Error) Inner() error { - return e.inner -} - -// Unwrap returns the underlying error. -func (e *Error) Unwrap() error { - return e.inner -} - -// Message returns the message. -func (e *Error) Message() string { - return e.message -} - -func newError(err error, mech string) error { - return &Error{ - message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech), - inner: err, - } -} - -// SaslClient is the client piece of a sasl conversation. -type SaslClient interface { - Start() (string, []byte, error) - Next(challenge []byte) ([]byte, error) - Completed() bool -} - -// SaslClientCloser is a SaslClient that has resources to clean up. -type SaslClientCloser interface { - SaslClient - Close() -} - -// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command. -type ExtraOptionsSaslClient interface { - StartCommandOptions() bsoncore.Document -} - -// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the -// conversation can be executed in multi-step speculative fashion. -type saslConversation struct { - client SaslClient - source string - mechanism string - speculative bool -} - -func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation { - authSource := source - if authSource == "" { - authSource = defaultAuthDB - } - return &saslConversation{ - client: client, - source: authSource, - speculative: speculative, - } -} - -// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used -// for speculative authentication. -func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) { - var payload []byte - var err error - sc.mechanism, payload, err = sc.client.Start() - if err != nil { - return nil, err - } - - saslCmdElements := [][]byte{ - bsoncore.AppendInt32Element(nil, "saslStart", 1), - bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism), - bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), - } - if sc.speculative { - // The "db" field is only appended for speculative auth because the hello command is executed against admin - // so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands - // will be executed against the auth source. - saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source)) - } - if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok { - optionsDoc := extraOptionsClient.StartCommandOptions() - saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc)) - } - - return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil -} - -type saslResponse struct { - ConversationID int `bson:"conversationId"` - Code int `bson:"code"` - Done bool `bson:"done"` - Payload []byte `bson:"payload"` -} - -// Finish completes the conversation based on the first server response to authenticate the given connection. -func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error { - if closer, ok := sc.client.(SaslClientCloser); ok { - defer closer.Close() - } - - var saslResp saslResponse - err := bson.Unmarshal(firstResponse, &saslResp) - if err != nil { - fullErr := fmt.Errorf("unmarshal error: %w", err) - return newError(fullErr, sc.mechanism) - } - - cid := saslResp.ConversationID - var payload []byte - var rdr bsoncore.Document - for { - if saslResp.Code != 0 { - return newError(err, sc.mechanism) - } - - if saslResp.Done && sc.client.Completed() { - return nil - } - - payload, err = sc.client.Next(saslResp.Payload) - if err != nil { - return newError(err, sc.mechanism) - } - - if saslResp.Done && sc.client.Completed() { - return nil - } - - doc := bsoncore.BuildDocumentFromElements(nil, - bsoncore.AppendInt32Element(nil, "saslContinue", 1), - bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)), - bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), - ) - saslContinueCmd := operation.NewCommand(doc). - Database(sc.source). - Deployment(driver.SingleConnectionDeployment{cfg.Connection}). - ClusterClock(cfg.ClusterClock). - ServerAPI(cfg.ServerAPI) - - err = saslContinueCmd.Execute(ctx) - if err != nil { - return newError(err, sc.mechanism) - } - rdr = saslContinueCmd.Result() - - err = bson.Unmarshal(rdr, &saslResp) - if err != nil { - fullErr := fmt.Errorf("unmarshal error: %w", err) - return newError(fullErr, sc.mechanism) - } - } -} - -// Config holds the information necessary to perform an authentication attempt. -type Config struct { - Description description.Server - Connection driver.Connection - ClusterClock *session.ClusterClock - HandshakeInfo driver.HandshakeInformation - ServerAPI *driver.ServerAPIOptions -} - -// ConductSaslConversation runs a full SASL conversation to authenticate the given connection. -func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error { - // Create a non-speculative SASL conversation. - conversation := newSaslConversation(client, authSource, false) - - saslStartDoc, err := conversation.FirstMessage() - if err != nil { - return newError(err, conversation.mechanism) - } - saslStartCmd := operation.NewCommand(saslStartDoc). - Database(authSource). - Deployment(driver.SingleConnectionDeployment{cfg.Connection}). - ClusterClock(cfg.ClusterClock). - ServerAPI(cfg.ServerAPI) - if err := saslStartCmd.Execute(ctx); err != nil { - return newError(err, conversation.mechanism) - } - - return conversation.Finish(ctx, cfg, saslStartCmd.Result()) -} diff --git a/internal/test/compilecheck/go.mod b/internal/test/compilecheck/go.mod index 69d192022a..cc09124838 100644 --- a/internal/test/compilecheck/go.mod +++ b/internal/test/compilecheck/go.mod @@ -9,14 +9,14 @@ replace go.mongodb.org/mongo-driver => ../../../ require go.mongodb.org/mongo-driver v1.11.7 require ( - github.com/golang/snappy v0.0.1 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/klauspost/compress v1.13.6 // indirect - github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/montanaflynn/stats v0.7.1 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/sync v0.1.0 // indirect + golang.org/x/crypto v0.22.0 // indirect + golang.org/x/sync v0.7.0 // indirect golang.org/x/text v0.14.0 // indirect ) diff --git a/internal/test/compilecheck/go.sum b/internal/test/compilecheck/go.sum index fe79e66209..802402a881 100644 --- a/internal/test/compilecheck/go.sum +++ b/internal/test/compilecheck/go.sum @@ -1,11 +1,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= @@ -17,16 +17,16 @@ github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7Jul github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -44,4 +44,3 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/internal/authutil/oidc.go b/x/mongo/driver/oidc.go similarity index 82% rename from internal/authutil/oidc.go rename to x/mongo/driver/oidc.go index 71c0ac20e9..3dad05eafd 100644 --- a/internal/authutil/oidc.go +++ b/x/mongo/driver/oidc.go @@ -1,7 +1,8 @@ -package authutil +package driver import ( "context" + "fmt" "sync" "time" @@ -29,15 +30,21 @@ type OIDCAuthenticator struct { idpInfo *IDPInfo } +func NewOIDCAuthenticator() *OIDCAuthenticator { + return &OIDCAuthenticator{ + AuthMechanismProperties: make(map[string]string), + } +} + type IDPInfo struct { Issuer string `bson:"issuer"` ClientID string `bson:"clientId"` RequestScopes []string `bson:"requestScopes"` } -type OIDCCallback func(context.Context, *OIDCAargs) (*OIDCCredential, error) +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) -type OIDCAargs struct { +type OIDCArgs struct { Version int IDPInfo *IDPInfo RefreshToken *string @@ -57,6 +64,10 @@ func (oos *oidcOneStep) Start() (string, []byte, error) { return oidcMech, jwtStepRequest(oos.accessToken), nil } +func newAuthError(msg string, err error) error { + return fmt.Errorf("authentication error: %s: %w", msg, err) +} + func (oos *oidcOneStep) Next(context.Context, []byte) ([]byte, error) { return nil, newAuthError("unexpected step in OIDC machine authentication", nil) } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 568622d616..75c02fccca 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -912,6 +912,10 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: + if tt.Code == 391 { + x := NewOIDCAuthenticator() + fmt.Println(x) + } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { if err := op.Client.ClearPinnedResources(); err != nil { return err From 9170d505f10306681dc3b7a37eadf96158b345c9 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 11:49:04 -0400 Subject: [PATCH 03/66] GODRIVER-2911: Ok, not great, but this will work --- x/mongo/driver/auth/auth.go | 12 +-- x/mongo/driver/oidc.go | 185 ++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 10 deletions(-) diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 6eeaf0ee01..a2c7fa7a99 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -114,6 +114,8 @@ func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr addr return ah.handshakeInfo, nil } +type Config = driver.AuthConfig + // FinishHandshake performs authentication for conn if necessary. func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error { performAuth := ah.options.PerformAuthentication @@ -170,16 +172,6 @@ func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshake } } -// Config holds the information necessary to perform an authentication attempt. -type Config struct { - Description description.Server - Connection driver.Connection - ClusterClock *session.ClusterClock - HandshakeInfo driver.HandshakeInformation - ServerAPI *driver.ServerAPIOptions - HTTPClient *http.Client -} - // Authenticator handles authenticating a connection. type Authenticator interface { // Auth authenticates the connection. diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go index 3dad05eafd..53110b1e99 100644 --- a/x/mongo/driver/oidc.go +++ b/x/mongo/driver/oidc.go @@ -3,10 +3,14 @@ package driver import ( "context" "fmt" + "net/http" "sync" "time" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) const oidcMech = "MONGODB-OIDC" @@ -16,6 +20,7 @@ const principalProp = "PRINCIPAL" const allowedHostsProp = "ALLOWED_HOSTS" const azureEnvironmentValue = "azure" const gcpEnvironmentValue = "gcp" +const defaultAuthDB = "admin" // OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, // and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality @@ -89,3 +94,183 @@ func principalStepRequest(principal string) []byte { } return doc.Build() } + +//// OIDC Sasl + +func newError(err error, mechanism string) error { + return fmt.Errorf("error during %s SASL conversation: %w", oidcMech, err) +} + +// SaslClient is the client piece of a sasl conversation. +type SaslClient interface { + Start() (string, []byte, error) + Next(challenge []byte) ([]byte, error) + Completed() bool +} + +// SaslClientCloser is a SaslClient that has resources to clean up. +type SaslClientCloser interface { + SaslClient + Close() +} + +// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command. +type ExtraOptionsSaslClient interface { + StartCommandOptions() bsoncore.Document +} + +// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the +// conversation can be executed in multi-step speculative fashion. +type saslConversation struct { + client SaslClient + source string + mechanism string + speculative bool +} + +func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation { + authSource := source + if authSource == "" { + authSource = defaultAuthDB + } + return &saslConversation{ + client: client, + source: authSource, + speculative: speculative, + } +} + +// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used +// for speculative authentication. +func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) { + var payload []byte + var err error + sc.mechanism, payload, err = sc.client.Start() + if err != nil { + return nil, err + } + + saslCmdElements := [][]byte{ + bsoncore.AppendInt32Element(nil, "saslStart", 1), + bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), + } + if sc.speculative { + // The "db" field is only appended for speculative auth because the hello command is executed against admin + // so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands + // will be executed against the auth source. + saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source)) + } + if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok { + optionsDoc := extraOptionsClient.StartCommandOptions() + saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc)) + } + + return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil +} + +type saslResponse struct { + ConversationID int `bson:"conversationId"` + Code int `bson:"code"` + Done bool `bson:"done"` + Payload []byte `bson:"payload"` +} + +// AuthConfig holds the information necessary to perform an authentication attempt. +type AuthConfig struct { + Description description.Server + Connection Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions + HTTPClient *http.Client +} + +// finish completes the conversation based on the first server response to authenticate the given connection. +func (sc *saslConversation) finish(ctx context.Context, cfg *AuthConfig, firstResponse bsoncore.Document) error { + if closer, ok := sc.client.(SaslClientCloser); ok { + defer closer.Close() + } + + var saslResp saslResponse + err := bson.Unmarshal(firstResponse, &saslResp) + if err != nil { + fullErr := fmt.Errorf("unmarshal error: %w", err) + return newError(fullErr, sc.mechanism) + } + + cid := saslResp.ConversationID + var payload []byte + var rdr bsoncore.Document + for { + if saslResp.Code != 0 { + return newError(err, sc.mechanism) + } + + if saslResp.Done && sc.client.Completed() { + return nil + } + + payload, err = sc.client.Next(saslResp.Payload) + if err != nil { + return newError(err, sc.mechanism) + } + + if saslResp.Done && sc.client.Completed() { + return nil + } + + doc := bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "saslContinue", 1), + bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), + ) + + fmt.Println(rdr) + fmt.Println(doc) + + return nil + //saslContinueCmd := NewCommand(doc). + // Database(sc.source). + // Deployment(SingleConnectionDeployment{cfg.Connection}). + // ClusterClock(cfg.ClusterClock). + // ServerAPI(cfg.ServerAPI) + + // err = saslContinueCmd.Execute(ctx) + // if err != nil { + // return newError(err, sc.mechanism) + // } + // rdr = saslContinueCmd.Result() + // + // err = bson.Unmarshal(rdr, &saslResp) + // if err != nil { + // fullErr := fmt.Errorf("unmarshal error: %w", err) + // return newError(fullErr, sc.mechanism) + // } + } +} + +// conductOIDCSaslConversation runs a full SASL conversation to authenticate the given connection. +func conductOIDCSaslConversation(ctx context.Context, cfg *AuthConfig, authSource string, client SaslClient) error { + // Create a non-speculative SASL conversation. + conversation := newSaslConversation(client, authSource, false) + + saslStartDoc, err := conversation.FirstMessage() + if err != nil { + return newError(err, conversation.mechanism) + } + fmt.Println(saslStartDoc) + return nil + // saslStartCmd := NewCommand(saslStartDoc). + // + // Database(authSource). + // Deployment(SingleConnectionDeployment{cfg.Connection}). + // ClusterClock(cfg.ClusterClock). + // ServerAPI(cfg.ServerAPI) + // + // if err := saslStartCmd.Execute(ctx); err != nil { + // return newError(err, conversation.mechanism) + // } + // + // return conversation.finish(ctx, cfg, saslStartCmd.Result()) +} From 590662d28e5a5edf511ef286450e42d963f10682 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 15:17:31 -0400 Subject: [PATCH 04/66] GODRIVER-2911: Renaming oidc sasl --- x/mongo/driver/oidc.go | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go index 53110b1e99..30de5d2c00 100644 --- a/x/mongo/driver/oidc.go +++ b/x/mongo/driver/oidc.go @@ -95,40 +95,42 @@ func principalStepRequest(principal string) []byte { return doc.Build() } -//// OIDC Sasl +// OIDC Sasl. This is almost a verbatim copy of auth/sasl introduced to remove the dependency on auth package +// which causes a circular dependency when attempting to do Reauthentication in driver/operation.go. +// This could be removed with a larger refactor. func newError(err error, mechanism string) error { return fmt.Errorf("error during %s SASL conversation: %w", oidcMech, err) } -// SaslClient is the client piece of a sasl conversation. -type SaslClient interface { +// oidcSaslClient is the client piece of a sasl conversation. +type oidcSaslClient interface { Start() (string, []byte, error) Next(challenge []byte) ([]byte, error) Completed() bool } -// SaslClientCloser is a SaslClient that has resources to clean up. -type SaslClientCloser interface { - SaslClient +// oidcSaslClientCloser is a oidcSaslClient that has resources to clean up. +type oidcSaslClientCloser interface { + oidcSaslClient Close() } -// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command. -type ExtraOptionsSaslClient interface { +// extraOptionsOIDCSaslClient is a SaslClient that appends options to the saslStart command. +type extraOptionsOIDCSaslClient interface { StartCommandOptions() bsoncore.Document } // saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the // conversation can be executed in multi-step speculative fashion. type saslConversation struct { - client SaslClient + client oidcSaslClient source string mechanism string speculative bool } -func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation { +func newSaslConversation(client oidcSaslClient, source string, speculative bool) *saslConversation { authSource := source if authSource == "" { authSource = defaultAuthDB @@ -161,7 +163,7 @@ func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) { // will be executed against the auth source. saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source)) } - if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok { + if extraOptionsClient, ok := sc.client.(extraOptionsOIDCSaslClient); ok { optionsDoc := extraOptionsClient.StartCommandOptions() saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc)) } @@ -177,6 +179,8 @@ type saslResponse struct { } // AuthConfig holds the information necessary to perform an authentication attempt. +// this was moved from the auth package to avoid a circular dependency. The auth package +// reexports this under the old name to avoid breaking the public api. type AuthConfig struct { Description description.Server Connection Connection @@ -188,7 +192,7 @@ type AuthConfig struct { // finish completes the conversation based on the first server response to authenticate the given connection. func (sc *saslConversation) finish(ctx context.Context, cfg *AuthConfig, firstResponse bsoncore.Document) error { - if closer, ok := sc.client.(SaslClientCloser); ok { + if closer, ok := sc.client.(oidcSaslClientCloser); ok { defer closer.Close() } @@ -251,7 +255,7 @@ func (sc *saslConversation) finish(ctx context.Context, cfg *AuthConfig, firstRe } // conductOIDCSaslConversation runs a full SASL conversation to authenticate the given connection. -func conductOIDCSaslConversation(ctx context.Context, cfg *AuthConfig, authSource string, client SaslClient) error { +func conductOIDCSaslConversation(ctx context.Context, cfg *AuthConfig, authSource string, client oidcSaslClient) error { // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) From 171204c7151ddf027b1d587aaa345920d56b118a Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 16:01:35 -0400 Subject: [PATCH 05/66] GODRIVER-2911: Implement Operation based private sasl conversation for OIDC that is probably, maybe, possibly correct --- x/mongo/driver/oidc.go | 78 +++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go index 30de5d2c00..88f9b3d389 100644 --- a/x/mongo/driver/oidc.go +++ b/x/mongo/driver/oidc.go @@ -205,7 +205,7 @@ func (sc *saslConversation) finish(ctx context.Context, cfg *AuthConfig, firstRe cid := saslResp.ConversationID var payload []byte - var rdr bsoncore.Document + var result bsoncore.Document for { if saslResp.Code != 0 { return newError(err, sc.mechanism) @@ -230,27 +230,30 @@ func (sc *saslConversation) finish(ctx context.Context, cfg *AuthConfig, firstRe bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), ) - fmt.Println(rdr) - fmt.Println(doc) + saslOp := Operation{ + CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + return append(dst, doc[4:len(doc)-1]...), nil + }, + ProcessResponseFn: func(info ResponseInfo) error { + result = info.ServerResponse + return nil + }, + Deployment: Deployment(SingleConnectionDeployment{cfg.Connection}), + Database: sc.source, + Clock: cfg.ClusterClock, + ServerAPI: cfg.ServerAPI, + } + err = saslOp.Execute(ctx) + if err != nil { + return newError(err, sc.mechanism) + } + err = bson.Unmarshal(result, &saslResp) + if err != nil { + fullErr := fmt.Errorf("unmarshal error: %w", err) + return newError(fullErr, sc.mechanism) + } return nil - //saslContinueCmd := NewCommand(doc). - // Database(sc.source). - // Deployment(SingleConnectionDeployment{cfg.Connection}). - // ClusterClock(cfg.ClusterClock). - // ServerAPI(cfg.ServerAPI) - - // err = saslContinueCmd.Execute(ctx) - // if err != nil { - // return newError(err, sc.mechanism) - // } - // rdr = saslContinueCmd.Result() - // - // err = bson.Unmarshal(rdr, &saslResp) - // if err != nil { - // fullErr := fmt.Errorf("unmarshal error: %w", err) - // return newError(fullErr, sc.mechanism) - // } } } @@ -259,22 +262,27 @@ func conductOIDCSaslConversation(ctx context.Context, cfg *AuthConfig, authSourc // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) - saslStartDoc, err := conversation.FirstMessage() + doc, err := conversation.FirstMessage() if err != nil { return newError(err, conversation.mechanism) } - fmt.Println(saslStartDoc) - return nil - // saslStartCmd := NewCommand(saslStartDoc). - // - // Database(authSource). - // Deployment(SingleConnectionDeployment{cfg.Connection}). - // ClusterClock(cfg.ClusterClock). - // ServerAPI(cfg.ServerAPI) - // - // if err := saslStartCmd.Execute(ctx); err != nil { - // return newError(err, conversation.mechanism) - // } - // - // return conversation.finish(ctx, cfg, saslStartCmd.Result()) + var result bsoncore.Document + saslOp := Operation{ + CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + return append(dst, doc[4:len(doc)-1]...), nil + }, + ProcessResponseFn: func(info ResponseInfo) error { + result = info.ServerResponse + return nil + }, + Deployment: Deployment(SingleConnectionDeployment{cfg.Connection}), + Database: authSource, + Clock: cfg.ClusterClock, + ServerAPI: cfg.ServerAPI, + } + if err := saslOp.Execute(ctx); err != nil { + return newError(err, conversation.mechanism) + } + + return conversation.finish(ctx, cfg, result) } From dbc56993f4655e11fb1a406d535bf6eb8d03bdb4 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 16:04:23 -0400 Subject: [PATCH 06/66] GODRIVER-2911: Privitize all the oidc sasl api, move AuthConfig up so it's more clear --- x/mongo/driver/oidc.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go index 88f9b3d389..55b54323fd 100644 --- a/x/mongo/driver/oidc.go +++ b/x/mongo/driver/oidc.go @@ -99,6 +99,18 @@ func principalStepRequest(principal string) []byte { // which causes a circular dependency when attempting to do Reauthentication in driver/operation.go. // This could be removed with a larger refactor. +// AuthConfig holds the information necessary to perform an authentication attempt. +// this was moved from the auth package to avoid a circular dependency. The auth package +// reexports this under the old name to avoid breaking the public api. +type AuthConfig struct { + Description description.Server + Connection Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions + HTTPClient *http.Client +} + func newError(err error, mechanism string) error { return fmt.Errorf("error during %s SASL conversation: %w", oidcMech, err) } @@ -142,9 +154,9 @@ func newSaslConversation(client oidcSaslClient, source string, speculative bool) } } -// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used +// firstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used // for speculative authentication. -func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) { +func (sc *saslConversation) firstMessage() (bsoncore.Document, error) { var payload []byte var err error sc.mechanism, payload, err = sc.client.Start() @@ -178,18 +190,6 @@ type saslResponse struct { Payload []byte `bson:"payload"` } -// AuthConfig holds the information necessary to perform an authentication attempt. -// this was moved from the auth package to avoid a circular dependency. The auth package -// reexports this under the old name to avoid breaking the public api. -type AuthConfig struct { - Description description.Server - Connection Connection - ClusterClock *session.ClusterClock - HandshakeInfo HandshakeInformation - ServerAPI *ServerAPIOptions - HTTPClient *http.Client -} - // finish completes the conversation based on the first server response to authenticate the given connection. func (sc *saslConversation) finish(ctx context.Context, cfg *AuthConfig, firstResponse bsoncore.Document) error { if closer, ok := sc.client.(oidcSaslClientCloser); ok { @@ -262,7 +262,7 @@ func conductOIDCSaslConversation(ctx context.Context, cfg *AuthConfig, authSourc // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) - doc, err := conversation.FirstMessage() + doc, err := conversation.firstMessage() if err != nil { return newError(err, conversation.mechanism) } From ff73302e44d8300877bee69f6d21813713a308e9 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 18:04:37 -0400 Subject: [PATCH 07/66] GODRIVER-2911: Move things as necessary for authentication registration --- x/mongo/driver/auth/auth.go | 6 ++---- x/mongo/driver/auth/cred.go | 14 +++++--------- x/mongo/driver/oidc.go | 34 ++++++++++++++++++++++++++++------ x/mongo/driver/operation.go | 2 +- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index a2c7fa7a99..1265b8a146 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -33,6 +33,7 @@ func init() { RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator) RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator) RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator) + RegisterAuthenticatorFactory(driver.OIDC, driver.NewOIDCAuthenticator) } // CreateAuthenticator creates an authenticator. @@ -173,10 +174,7 @@ func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshake } // Authenticator handles authenticating a connection. -type Authenticator interface { - // Auth authenticates the connection. - Auth(context.Context, *Config) error -} +type Authenticator = driver.Authenticator func newAuthError(msg string, inner error) error { return &Error{ diff --git a/x/mongo/driver/auth/cred.go b/x/mongo/driver/auth/cred.go index 7b2b8f17d0..444a58a33f 100644 --- a/x/mongo/driver/auth/cred.go +++ b/x/mongo/driver/auth/cred.go @@ -3,14 +3,10 @@ // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - package auth -// Cred is a user's credential. -type Cred struct { - Source string - Username string - Password string - PasswordSet bool - Props map[string]string -} +import ( + "go.mongodb.org/mongo-driver/x/mongo/driver" +) + +type Cred = driver.AuthCred diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go index 55b54323fd..66287a526c 100644 --- a/x/mongo/driver/oidc.go +++ b/x/mongo/driver/oidc.go @@ -13,7 +13,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -const oidcMech = "MONGODB-OIDC" +const OIDC = "MONGODB-OIDC" const tokenResourceProp = "TOKEN_RESOURCE" const environmentProp = "ENVIRONMENT" const principalProp = "PRINCIPAL" @@ -22,6 +22,21 @@ const azureEnvironmentValue = "azure" const gcpEnvironmentValue = "gcp" const defaultAuthDB = "admin" +// Authenticator handles authenticating a connection. +type Authenticator interface { + // Auth authenticates the connection. + Auth(context.Context, *AuthConfig) error +} + +// AuthCred is a user's credential. +type AuthCred struct { + Source string + Username string + Password string + PasswordSet bool + Props map[string]string +} + // OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, // and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality // is only for the OIDC Human flow. @@ -35,10 +50,15 @@ type OIDCAuthenticator struct { idpInfo *IDPInfo } -func NewOIDCAuthenticator() *OIDCAuthenticator { - return &OIDCAuthenticator{ - AuthMechanismProperties: make(map[string]string), +func NewOIDCAuthenticator(cred *AuthCred) (Authenticator, error) { + oa := &OIDCAuthenticator{ + AuthMechanismProperties: cred.Props, } + return oa, nil +} + +func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *AuthConfig) error { + return nil } type IDPInfo struct { @@ -47,6 +67,8 @@ type IDPInfo struct { RequestScopes []string `bson:"requestScopes"` } +// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be +// nil in the OIDCArgs for the Machine flow. type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) type OIDCArgs struct { @@ -66,7 +88,7 @@ type oidcOneStep struct { } func (oos *oidcOneStep) Start() (string, []byte, error) { - return oidcMech, jwtStepRequest(oos.accessToken), nil + return OIDC, jwtStepRequest(oos.accessToken), nil } func newAuthError(msg string, err error) error { @@ -112,7 +134,7 @@ type AuthConfig struct { } func newError(err error, mechanism string) error { - return fmt.Errorf("error during %s SASL conversation: %w", oidcMech, err) + return fmt.Errorf("error during %s SASL conversation: %w", OIDC, err) } // oidcSaslClient is the client piece of a sasl conversation. diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 75c02fccca..b0d32c671c 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -913,7 +913,7 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Raw = tt.Raw case Error: if tt.Code == 391 { - x := NewOIDCAuthenticator() + x, _ := NewOIDCAuthenticator(nil) fmt.Println(x) } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { From 0db7c3ec06d6a887029bc87d4b8bc3cb35fa3ef8 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 18:06:52 -0400 Subject: [PATCH 08/66] GODRIVER-2911: Let's use a bit better naming --- x/mongo/driver/auth/cred.go | 2 +- x/mongo/driver/oidc.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x/mongo/driver/auth/cred.go b/x/mongo/driver/auth/cred.go index 444a58a33f..5f89f29018 100644 --- a/x/mongo/driver/auth/cred.go +++ b/x/mongo/driver/auth/cred.go @@ -9,4 +9,4 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver" ) -type Cred = driver.AuthCred +type Cred = driver.Cred diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go index 66287a526c..399620f4c3 100644 --- a/x/mongo/driver/oidc.go +++ b/x/mongo/driver/oidc.go @@ -28,8 +28,8 @@ type Authenticator interface { Auth(context.Context, *AuthConfig) error } -// AuthCred is a user's credential. -type AuthCred struct { +// Cred is a user's credential. +type Cred struct { Source string Username string Password string @@ -50,7 +50,7 @@ type OIDCAuthenticator struct { idpInfo *IDPInfo } -func NewOIDCAuthenticator(cred *AuthCred) (Authenticator, error) { +func NewOIDCAuthenticator(cred *Cred) (Authenticator, error) { oa := &OIDCAuthenticator{ AuthMechanismProperties: cred.Props, } From be99139fe5b800ac5682a45b3bb8991ab8b98405 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 19:20:59 -0400 Subject: [PATCH 09/66] GODRIVER-2911: Add Reauth to Authenticators --- x/mongo/driver/auth/default.go | 5 + x/mongo/driver/auth/mongodbaws.go | 5 + x/mongo/driver/auth/mongodbcr.go | 5 + x/mongo/driver/auth/plain.go | 5 + x/mongo/driver/auth/scram.go | 5 + x/mongo/driver/auth/x509.go | 5 + x/mongo/driver/oidc.go | 189 +++++++++++++++++++++++++----- x/mongo/driver/operation.go | 7 +- 8 files changed, 197 insertions(+), 29 deletions(-) diff --git a/x/mongo/driver/auth/default.go b/x/mongo/driver/auth/default.go index 6f2ca5224a..5e8c5b98ec 100644 --- a/x/mongo/driver/auth/default.go +++ b/x/mongo/driver/auth/default.go @@ -66,6 +66,11 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { return actual.Auth(ctx, cfg) } +// Reauth reauthenticates the connection. +func (a *DefaultAuthenticator) Reauth(ctx context.Context) error { + return newAuthError("DefaultAuthenticator does not support reauthentication", nil) +} + // If a server provides a list of supported mechanisms, we choose // SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1. // Otherwise, we decide based on what is supported. diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 7ae4b08998..985fc35c03 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -60,6 +60,11 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBAWSAuthenticator) Reauth(ctx context.Context) error { + return newAuthError("AWS authentication does not support reauthentication", nil) +} + type awsSaslAdapter struct { conversation *awsConversation } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 6e2c2f4dcb..7301469921 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -97,6 +97,11 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBCRAuthenticator) Reauth(ctx context.Context) error { + return newAuthError("MONGODB-CR does not support reauthentication", nil) +} + func (a *MongoDBCRAuthenticator) createKey(nonce string) string { // Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to // implement the MONGODB-CR specification. diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 532d43e39f..21ffe0465e 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -34,6 +34,11 @@ func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error { }) } +// Reauth reauthenticates the connection. +func (a *PlainAuthenticator) Reauth(ctx context.Context) error { + return newAuthError("Plain authentication does not support reauthentication", nil) +} + type plainSaslClient struct { username string password string diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index c1238cd6a9..8118e0a6ab 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -84,6 +84,11 @@ func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *ScramAuthenticator) Reauth(ctx context.Context) error { + return newAuthError("SCRAM does not support reauthentication", nil) +} + // CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { return newSaslConversation(a.createSaslClient(), a.source, true), nil diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 03a9d750e2..4cbbc246c5 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -76,3 +76,8 @@ func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error return nil } + +// Reauth reauthenticates the connection. +func (a *MongoDBX509Authenticator) Reauth(ctx context.Context) error { + return newAuthError("X509 does not support reauthentication", nil) +} diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go index 399620f4c3..e06135b33d 100644 --- a/x/mongo/driver/oidc.go +++ b/x/mongo/driver/oidc.go @@ -16,16 +16,17 @@ import ( const OIDC = "MONGODB-OIDC" const tokenResourceProp = "TOKEN_RESOURCE" const environmentProp = "ENVIRONMENT" -const principalProp = "PRINCIPAL" const allowedHostsProp = "ALLOWED_HOSTS" const azureEnvironmentValue = "azure" const gcpEnvironmentValue = "gcp" const defaultAuthDB = "admin" +const machineSleepTime = 100 * time.Millisecond // Authenticator handles authenticating a connection. type Authenticator interface { // Auth authenticates the connection. Auth(context.Context, *AuthConfig) error + Reauth(context.Context) error } // Cred is a user's credential. @@ -45,6 +46,7 @@ type OIDCAuthenticator struct { AuthMechanismProperties map[string]string + cfg *AuthConfig accessToken string refreshToken *string idpInfo *IDPInfo @@ -57,10 +59,6 @@ func NewOIDCAuthenticator(cred *Cred) (Authenticator, error) { return oa, nil } -func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *AuthConfig) error { - return nil -} - type IDPInfo struct { Issuer string `bson:"issuer"` ClientID string `bson:"clientId"` @@ -79,7 +77,7 @@ type OIDCArgs struct { type OIDCCredential struct { AccessToken string - ExpiresAt time.Time + ExpiresAt *time.Time RefreshToken *string } @@ -87,21 +85,7 @@ type oidcOneStep struct { accessToken string } -func (oos *oidcOneStep) Start() (string, []byte, error) { - return OIDC, jwtStepRequest(oos.accessToken), nil -} - -func newAuthError(msg string, err error) error { - return fmt.Errorf("authentication error: %s: %w", msg, err) -} - -func (oos *oidcOneStep) Next(context.Context, []byte) ([]byte, error) { - return nil, newAuthError("unexpected step in OIDC machine authentication", nil) -} - -func (*oidcOneStep) Completed() bool { - return true -} +var _ oidcSaslClient = (*oidcOneStep)(nil) func jwtStepRequest(accessToken string) []byte { return bsoncore.NewDocumentBuilder(). @@ -117,6 +101,155 @@ func principalStepRequest(principal string) []byte { return doc.Build() } +func (oos *oidcOneStep) Start() (string, []byte, error) { + return OIDC, jwtStepRequest(oos.accessToken), nil +} + +func (oos *oidcOneStep) Next([]byte) ([]byte, error) { + return nil, fmt.Errorf("unexpected step in OIDC machine authentication") +} + +func (*oidcOneStep) Completed() bool { + return true +} + +func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { + env, ok := oa.AuthMechanismProperties[environmentProp] + if !ok { + return nil, nil + } + + switch env { + // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider + // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider + } + + return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) +} + +// This should only be called with the Mutex held. +func (oa *OIDCAuthenticator) getAccessToken( + ctx context.Context, + args *OIDCArgs, + callback OIDCCallback, +) (string, error) { + if oa.accessToken != "" { + return oa.accessToken, nil + } + + cred, err := callback(ctx, args) + if err != nil { + return "", err + } + + oa.accessToken = cred.AccessToken + if cred.RefreshToken != nil { + oa.refreshToken = cred.RefreshToken + } + return cred.AccessToken, nil +} + +// This should only be called with the Mutex held. +func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( + ctx context.Context, + callback OIDCCallback, + refreshToken string, +) (string, error) { + + cred, err := callback(ctx, &OIDCArgs{ + Version: 1, + IDPInfo: oa.idpInfo, + RefreshToken: &refreshToken, + }) + if err != nil { + return "", err + } + + oa.accessToken = cred.AccessToken + return cred.AccessToken, nil +} + +// TODO: add invalidation algorithm from rust driver +func (oa *OIDCAuthenticator) invalidateAccessToken() { + oa.accessToken = "" +} + +func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { + oa.invalidateAccessToken() + return oa.Auth(ctx, oa.cfg) +} + +// Auth authenticates the connection. +func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *AuthConfig) error { + // the Mutex must be held during the entire Auth call so that multiple racing attempts + // to authenticate will not result in multiple callbacks. The losers on the Mutex will + // retrieve the access token from the Authenticator cache. + oa.mu.Lock() + defer oa.mu.Unlock() + + oa.cfg = cfg + + if oa.accessToken != "" { + err := conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + accessToken: oa.accessToken, + }) + if err == nil { + return nil + } + // TODO: Check error type and raise if it's not a server-side error. + oa.invalidateAccessToken() + time.Sleep(100 * time.Millisecond) + } + + if cfg.OIDCMachineCallback != nil { + accessToken, err := oa.getAccessToken(ctx, nil, cfg.OIDCMachineCallback) + if err != nil { + return err + } + + err = conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + accessToken: accessToken, + }) + if err == nil { + return nil + } + // Clear the access token if authentication failed. + oa.invalidateAccessToken() + + time.Sleep(machineSleepTime) + accessToken, err = oa.getAccessToken(ctx, &OIDCArgs{Version: 1}, cfg.OIDCMachineCallback) + if err != nil { + return err + } + return conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + accessToken: accessToken, + }) + } + + // TODO GODRIVER-3246: Handle Human callback here. + + callback, err := oa.providerCallback() + if err != nil { + return fmt.Errorf("error getting build-in OIDC provider: %w", err) + } + + accessToken, err := oa.getAccessToken(ctx, &OIDCArgs{Version: 1}, callback) + if err != nil { + return fmt.Errorf("error getting access token from built-in OIDC provider: %w", err) + } + + err = conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + accessToken: accessToken, + }) + // TODO: Check error type and raise if it's not a server-side error. + if err == nil { + return nil + } + oa.invalidateAccessToken() + + return err +} + // OIDC Sasl. This is almost a verbatim copy of auth/sasl introduced to remove the dependency on auth package // which causes a circular dependency when attempting to do Reauthentication in driver/operation.go. // This could be removed with a larger refactor. @@ -125,12 +258,14 @@ func principalStepRequest(principal string) []byte { // this was moved from the auth package to avoid a circular dependency. The auth package // reexports this under the old name to avoid breaking the public api. type AuthConfig struct { - Description description.Server - Connection Connection - ClusterClock *session.ClusterClock - HandshakeInfo HandshakeInformation - ServerAPI *ServerAPIOptions - HTTPClient *http.Client + Description description.Server + Connection Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions + HTTPClient *http.Client + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback } func newError(err error, mechanism string) error { diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index b0d32c671c..7c9c20bf66 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -315,6 +315,10 @@ type Operation struct { // [Operation.MaxTime]. OmitCSOTMaxTimeMS bool + // Authenticator is the authenticator to use for this operation when a reauthentication is + // required. + Authenticator Authenticator + // omitReadPreference is a boolean that indicates whether to omit the // read preference from the command. This omition includes the case // where a default read preference is used when the operation @@ -913,8 +917,7 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Raw = tt.Raw case Error: if tt.Code == 391 { - x, _ := NewOIDCAuthenticator(nil) - fmt.Println(x) + op.Authenticator.Reauth(ctx) } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { if err := op.Client.ClearPinnedResources(); err != nil { From f400d18f05ce9945a2dc6b9188a5fc9623bebc84 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 19:48:07 -0400 Subject: [PATCH 10/66] GODRIVER-2911: Check point --- mongo/client.go | 1 + x/mongo/driver/operation.go | 7 ++++++- x/mongo/driver/operation/command.go | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mongo/client.go b/mongo/client.go index 280749c7dd..de2a4b6ab6 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -79,6 +79,7 @@ type Client struct { metadataClientFLE *Client internalClientFLE *Client encryptedFieldsMap map[string]interface{} + authenticator driver.Authenticator } // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 7c9c20bf66..25b78893e1 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -916,8 +916,13 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: + // TODO: actually make sure this Reauths if tt.Code == 391 { - op.Authenticator.Reauth(ctx) + if op.Authenticator != nil { + if err := op.Authenticator.Reauth(ctx); err != nil { + return err + } + } } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { if err := op.Client.ClearPinnedResources(); err != nil { diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 35283794a3..14cf369340 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -38,6 +38,7 @@ type Command struct { cursorOpts driver.CursorOptions timeout *time.Duration logger *logger.Logger + authenticator driver.Authenticator } // NewCommand constructs and returns a new Command. Once the operation is executed, the result may only be accessed via From eed3dd51bb5f5da4aedca9279a5760d6c894d1f5 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 20:02:34 -0400 Subject: [PATCH 11/66] GODRIVER-2911: Initial plumbing, the Client Authenticator is going to need to get down to Handshake instead of creating the Authenticator in Handshake as we do now --- mongo/client.go | 10 ++++++++++ mongo/database.go | 2 +- x/mongo/driver/operation/command.go | 10 ++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/mongo/client.go b/mongo/client.go index de2a4b6ab6..313517a993 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -26,6 +26,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" @@ -233,6 +234,15 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return nil, fmt.Errorf("invalid logger options: %w", err) } + // Create an authenticator for the client + client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ + Source: clientOpt.Auth.AuthSource, + Username: clientOpt.Auth.Username, + Password: clientOpt.Auth.Password, + PasswordSet: clientOpt.Auth.PasswordSet, + Props: clientOpt.Auth.AuthMechanismProperties, + }) + return client, nil } diff --git a/mongo/database.go b/mongo/database.go index 57c0186eca..0147683396 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -189,7 +189,7 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, ServerSelector(readSelect).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment). Crypt(db.client.cryptFLE).ReadPreference(ro.ReadPreference).ServerAPI(db.client.serverAPI). - Timeout(db.client.timeout).Logger(db.client.logger), sess, nil + Timeout(db.client.timeout).Logger(db.client.logger).Authenticator(db.Client().authenticator), sess, nil } // RunCommand executes the given command against the database. diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 14cf369340..f1cd979e61 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -181,6 +181,16 @@ func (c *Command) ServerSelector(selector description.ServerSelector) *Command { return c } +// Authenticator sets the authenticator to use for this operation. +func (c *Command) Authenticator(authenticator driver.Authenticator) *Command { + if c == nil { + c = new(Command) + } + + c.authenticator = authenticator + return c +} + // Crypt sets the Crypt object to use for automatic encryption and decryption. func (c *Command) Crypt(crypt driver.Crypt) *Command { if c == nil { From 2ee93ccee4033691f861b4f487145972c2820b5d Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 20:27:53 -0400 Subject: [PATCH 12/66] GODRIVER-2911: Set authenticator in topology --- mongo/client.go | 20 ++++++++--------- x/mongo/driver/topology/topology_options.go | 24 ++++++++++++++++++++- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index 313517a993..67ad1d2528 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -215,7 +215,16 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return nil, err } - cfg, err := topology.NewConfig(clientOpt, client.clock) + // Create an authenticator for the client + client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ + Source: clientOpt.Auth.AuthSource, + Username: clientOpt.Auth.Username, + Password: clientOpt.Auth.Password, + PasswordSet: clientOpt.Auth.PasswordSet, + Props: clientOpt.Auth.AuthMechanismProperties, + }) + + cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator) if err != nil { return nil, err } @@ -234,15 +243,6 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return nil, fmt.Errorf("invalid logger options: %w", err) } - // Create an authenticator for the client - client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ - Source: clientOpt.Auth.AuthSource, - Username: clientOpt.Auth.Username, - Password: clientOpt.Auth.Password, - PasswordSet: clientOpt.Auth.PasswordSet, - Props: clientOpt.Auth.AuthMechanismProperties, - }) - return client, nil } diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index b5eb4a9729..820b62a7ff 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -72,8 +72,30 @@ func newLogger(opts *options.LoggerOptions) (*logger.Logger, error) { } // NewConfig will translate data from client options into a topology config for building non-default deployments. -// Server and topology options are not honored if a custom deployment is used. func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) { + // Auth & Database & Password & Username + if co.Auth != nil { + cred := &auth.Cred{ + Username: co.Auth.Username, + Password: co.Auth.Password, + PasswordSet: co.Auth.PasswordSet, + Props: co.Auth.AuthMechanismProperties, + Source: co.Auth.AuthSource, + } + mechanism := co.Auth.AuthMechanism + authenticator, err := auth.CreateAuthenticator(mechanism, cred) + if err != nil { + return nil, err + } + return NewConfigWithAuthenticator(co, clock, authenticator) + } + return NewConfigWithAuthenticator(co, clock, nil) +} + +// NewConfigWithAuthenticator will translate data from client options into a topology config for building non-default deployments. +// Server and topology options are not honored if a custom deployment is used. It uses a passed in +// authenticator to authenticate the connection. +func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { var serverAPI *driver.ServerAPIOptions if err := co.Validate(); err != nil { From f6def8d6e47e304cdcbbecc7cb36bf3cb06ac56d Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 21:22:13 -0400 Subject: [PATCH 13/66] GODRIVER-2911: Set authenticator from Command to Operation --- x/mongo/driver/operation/command.go | 1 + 1 file changed, 1 insertion(+) diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index f1cd979e61..0d0472308f 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -108,6 +108,7 @@ func (c *Command) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Logger: c.logger, + Authenticator: c.authenticator, }.Execute(ctx) } From bd5c9f29957c6aeeb42f01764302235c3b68cba3 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 21:42:07 -0400 Subject: [PATCH 14/66] GODRIVER-2911: Remove authenticator so we can readd it programmatically --- x/mongo/driver/operation/command.go | 1 - 1 file changed, 1 deletion(-) diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 0d0472308f..19f11316a3 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -38,7 +38,6 @@ type Command struct { cursorOpts driver.CursorOptions timeout *time.Duration logger *logger.Logger - authenticator driver.Authenticator } // NewCommand constructs and returns a new Command. Once the operation is executed, the result may only be accessed via From 36ba008127d734a5331f006a06a4bd47014d997d Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 21:42:33 -0400 Subject: [PATCH 15/66] GODRIVER-2911: Remove authenticator so we can readd it programmatically --- x/mongo/driver/operation/command.go | 1 - 1 file changed, 1 deletion(-) diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 19f11316a3..d8f37c6144 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -107,7 +107,6 @@ func (c *Command) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Logger: c.logger, - Authenticator: c.authenticator, }.Execute(ctx) } From a2a40293a350171af1c512a53110e249d598a9e2 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 21:53:25 -0400 Subject: [PATCH 16/66] GODRIVER-2911: Remove authenticator so we can readd it programatically --- x/mongo/driver/operation/command.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index d8f37c6144..35283794a3 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -180,16 +180,6 @@ func (c *Command) ServerSelector(selector description.ServerSelector) *Command { return c } -// Authenticator sets the authenticator to use for this operation. -func (c *Command) Authenticator(authenticator driver.Authenticator) *Command { - if c == nil { - c = new(Command) - } - - c.authenticator = authenticator - return c -} - // Crypt sets the Crypt object to use for automatic encryption and decryption. func (c *Command) Crypt(crypt driver.Crypt) *Command { if c == nil { From d2c75f17b1af051af0a5a8991b3dddc1a88b9a6b Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 12 Jun 2024 22:07:39 -0400 Subject: [PATCH 17/66] GODRIVER-2911: Add all that authenticator plumbing programmatically so I'm sure it's right --- x/mongo/driver/operation/abort_transaction.go | 12 ++++ x/mongo/driver/operation/aggregate.go | 12 ++++ x/mongo/driver/operation/command.go | 13 ++++ .../driver/operation/commit_transaction.go | 12 ++++ x/mongo/driver/operation/count.go | 25 ++++++++ x/mongo/driver/operation/create.go | 13 ++++ x/mongo/driver/operation/create_indexes.go | 53 +++++++++++----- .../driver/operation/create_search_indexes.go | 60 ++++++++++++++---- x/mongo/driver/operation/delete.go | 61 +++++++++++++------ x/mongo/driver/operation/distinct.go | 23 +++++++ x/mongo/driver/operation/drop_collection.go | 47 ++++++++++---- x/mongo/driver/operation/drop_database.go | 30 ++++++--- x/mongo/driver/operation/drop_indexes.go | 51 +++++++++++----- x/mongo/driver/operation/drop_search_index.go | 49 +++++++++++---- x/mongo/driver/operation/end_sessions.go | 30 ++++++--- x/mongo/driver/operation/find.go | 12 ++++ x/mongo/driver/operation/find_and_modify.go | 34 +++++++++++ x/mongo/driver/operation/hello.go | 24 ++++++++ x/mongo/driver/operation/insert.go | 23 +++++++ x/mongo/driver/operation/listDatabases.go | 40 +++++++++++- x/mongo/driver/operation/list_collections.go | 12 ++++ x/mongo/driver/operation/list_indexes.go | 38 ++++++++---- x/mongo/driver/operation/update.go | 38 +++++++++++- .../driver/operation/update_search_index.go | 51 +++++++++++----- 24 files changed, 627 insertions(+), 136 deletions(-) diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index 9413727130..f03949d15c 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -21,6 +21,7 @@ import ( // AbortTransaction performs an abortTransaction operation. type AbortTransaction struct { + authenticator driver.Authenticator recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -66,6 +67,7 @@ func (at *AbortTransaction) Execute(ctx context.Context) error { WriteConcern: at.writeConcern, ServerAPI: at.serverAPI, Name: driverutil.AbortTransactionOp, + Authenticator: at.authenticator, }.Execute(ctx) } @@ -199,3 +201,13 @@ func (at *AbortTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Abort at.serverAPI = serverAPI return at } + +// Authenticator sets the authenticator to use for this operation. +func (a *AbortTransaction) Authenticator(authenticator driver.Authenticator) *AbortTransaction { + if a == nil { + a = new(AbortTransaction) + } + + a.authenticator = authenticator + return a +} diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 44467df8fd..df6b8fa9dd 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -25,6 +25,7 @@ import ( // Aggregate represents an aggregate operation. type Aggregate struct { + authenticator driver.Authenticator allowDiskUse *bool batchSize *int32 bypassDocumentValidation *bool @@ -115,6 +116,7 @@ func (a *Aggregate) Execute(ctx context.Context) error { Timeout: a.timeout, Name: driverutil.AggregateOp, OmitCSOTMaxTimeMS: a.omitCSOTMaxTimeMS, + Authenticator: a.authenticator, }.Execute(ctx) } @@ -433,3 +435,13 @@ func (a *Aggregate) OmitCSOTMaxTimeMS(omit bool) *Aggregate { a.omitCSOTMaxTimeMS = omit return a } + +// Authenticator sets the authenticator to use for this operation. +func (a *Aggregate) Authenticator(authenticator driver.Authenticator) *Aggregate { + if a == nil { + a = new(Aggregate) + } + + a.authenticator = authenticator + return a +} diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 35283794a3..92fb250cf0 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -22,6 +22,7 @@ import ( // Command is used to run a generic operation. type Command struct { + authenticator driver.Authenticator command bsoncore.Document database string deployment driver.Deployment @@ -107,6 +108,7 @@ func (c *Command) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Logger: c.logger, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -219,3 +221,14 @@ func (c *Command) Logger(logger *logger.Logger) *Command { c.logger = logger return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Command) Authenticator(authenticator driver.Authenticator) *Command { + if c == nil { + c = new(Command) + } + + c.authenticator = authenticator + return c +} + diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 11c6f69ddf..b6414f60b6 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -22,6 +22,7 @@ import ( // CommitTransaction attempts to commit a transaction. type CommitTransaction struct { + authenticator driver.Authenticator maxTime *time.Duration recoveryToken bsoncore.Document session *session.Client @@ -68,6 +69,7 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { WriteConcern: ct.writeConcern, ServerAPI: ct.serverAPI, Name: driverutil.CommitTransactionOp, + Authenticator: ct.authenticator, }.Execute(ctx) } @@ -201,3 +203,13 @@ func (ct *CommitTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Comm ct.serverAPI = serverAPI return ct } + +// Authenticator sets the authenticator to use for this operation. +func (c *CommitTransaction) Authenticator(authenticator driver.Authenticator) *CommitTransaction { + if c == nil { + c = new(CommitTransaction) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 8de1e9f8d9..0729db996b 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -25,6 +25,7 @@ import ( // Count represents a count operation. type Count struct { + authenticator driver.Authenticator maxTime *time.Duration query bsoncore.Document session *session.Client @@ -46,6 +47,7 @@ type Count struct { // CountResult represents a count result returned by the server. type CountResult struct { + authenticator driver.Authenticator // The number of documents found N int64 } @@ -128,6 +130,7 @@ func (c *Count) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Name: driverutil.CountOp, + Authenticator: c.authenticator, }.Execute(ctx) // Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace @@ -311,3 +314,25 @@ func (c *Count) Timeout(timeout *time.Duration) *Count { c.timeout = timeout return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Count) Authenticator(authenticator driver.Authenticator) *Count { + if c == nil { + c = new(Count) + } + + c.authenticator = authenticator + return c +} + + +// Authenticator sets the authenticator to use for this operation. +func (c *CountResult) Authenticator(authenticator driver.Authenticator) *CountResult { + if c == nil { + c = new(CountResult) + } + + c.authenticator = authenticator + return c +} + diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 45b26cb707..394d47676b 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -20,6 +20,7 @@ import ( // Create represents a create operation. type Create struct { + authenticator driver.Authenticator capped *bool collation bsoncore.Document changeStreamPreAndPostImages bsoncore.Document @@ -77,6 +78,7 @@ func (c *Create) Execute(ctx context.Context) error { Selector: c.selector, WriteConcern: c.writeConcern, ServerAPI: c.serverAPI, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -399,3 +401,14 @@ func (c *Create) ClusteredIndex(ci bsoncore.Document) *Create { c.clusteredIndex = ci return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Create) Authenticator(authenticator driver.Authenticator) *Create { + if c == nil { + c = new(Create) + } + + c.authenticator = authenticator + return c +} + diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 77daf676a4..502182b304 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -24,25 +24,27 @@ import ( // CreateIndexes performs a createIndexes operation. type CreateIndexes struct { - commitQuorum bsoncore.Value - indexes bsoncore.Document - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result CreateIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + commitQuorum bsoncore.Value + indexes bsoncore.Document + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result CreateIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateIndexesResult represents a createIndexes result returned by the server. type CreateIndexesResult struct { + authenticator driver.Authenticator // If the collection was created automatically. CreatedCollectionAutomatically bool // The number of indexes existing after this command. @@ -119,6 +121,7 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { ServerAPI: ci.serverAPI, Timeout: ci.timeout, Name: driverutil.CreateIndexesOp, + Authenticator: ci.authenticator, }.Execute(ctx) } @@ -278,3 +281,23 @@ func (ci *CreateIndexes) Timeout(timeout *time.Duration) *CreateIndexes { ci.timeout = timeout return ci } + +// Authenticator sets the authenticator to use for this operation. +func (c *CreateIndexes) Authenticator(authenticator driver.Authenticator) *CreateIndexes { + if c == nil { + c = new(CreateIndexes) + } + + c.authenticator = authenticator + return c +} + +// Authenticator sets the authenticator to use for this operation. +func (c *CreateIndexesResult) Authenticator(authenticator driver.Authenticator) *CreateIndexesResult { + if c == nil { + c = new(CreateIndexesResult) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index cb0d807952..2300e3df4e 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -22,27 +22,30 @@ import ( // CreateSearchIndexes performs a createSearchIndexes operation. type CreateSearchIndexes struct { - indexes bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result CreateSearchIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + indexes bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result CreateSearchIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateSearchIndexResult represents a single search index result in CreateSearchIndexesResult. type CreateSearchIndexResult struct { - Name string + authenticator driver.Authenticator + Name string } // CreateSearchIndexesResult represents a createSearchIndexes result returned by the server. type CreateSearchIndexesResult struct { + authenticator driver.Authenticator IndexesCreated []CreateSearchIndexResult } @@ -116,6 +119,7 @@ func (csi *CreateSearchIndexes) Execute(ctx context.Context) error { Selector: csi.selector, ServerAPI: csi.serverAPI, Timeout: csi.timeout, + Authenticator: csi.authenticator, }.Execute(ctx) } @@ -237,3 +241,33 @@ func (csi *CreateSearchIndexes) Timeout(timeout *time.Duration) *CreateSearchInd csi.timeout = timeout return csi } + +// Authenticator sets the authenticator to use for this operation. +func (c *CreateSearchIndexes) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexes { + if c == nil { + c = new(CreateSearchIndexes) + } + + c.authenticator = authenticator + return c +} + +// Authenticator sets the authenticator to use for this operation. +func (c *CreateSearchIndexResult) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexResult { + if c == nil { + c = new(CreateSearchIndexResult) + } + + c.authenticator = authenticator + return c +} + +// Authenticator sets the authenticator to use for this operation. +func (c *CreateSearchIndexesResult) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexesResult { + if c == nil { + c = new(CreateSearchIndexesResult) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index bf95cf496d..f5a4fe7f79 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -25,29 +25,31 @@ import ( // Delete performs a delete operation type Delete struct { - comment bsoncore.Value - deletes []bsoncore.Document - ordered *bool - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - retry *driver.RetryMode - hint *bool - result DeleteResult - serverAPI *driver.ServerAPIOptions - let bsoncore.Document - timeout *time.Duration - logger *logger.Logger + authenticator driver.Authenticator + comment bsoncore.Value + deletes []bsoncore.Document + ordered *bool + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + retry *driver.RetryMode + hint *bool + result DeleteResult + serverAPI *driver.ServerAPIOptions + let bsoncore.Document + timeout *time.Duration + logger *logger.Logger } // DeleteResult represents a delete result returned by the server. type DeleteResult struct { + authenticator driver.Authenticator // Number of documents successfully deleted. N int64 } @@ -116,6 +118,7 @@ func (d *Delete) Execute(ctx context.Context) error { Timeout: d.timeout, Logger: d.logger, Name: driverutil.DeleteOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -328,3 +331,23 @@ func (d *Delete) Logger(logger *logger.Logger) *Delete { return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Delete) Authenticator(authenticator driver.Authenticator) *Delete { + if d == nil { + d = new(Delete) + } + + d.authenticator = authenticator + return d +} + +// Authenticator sets the authenticator to use for this operation. +func (d *DeleteResult) Authenticator(authenticator driver.Authenticator) *DeleteResult { + if d == nil { + d = new(DeleteResult) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index b7e675ce42..d3b2f3ce8f 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -24,6 +24,7 @@ import ( // Distinct performs a distinct operation. type Distinct struct { + authenticator driver.Authenticator collation bsoncore.Document key *string maxTime *time.Duration @@ -47,6 +48,7 @@ type Distinct struct { // DistinctResult represents a distinct result returned by the server. type DistinctResult struct { + authenticator driver.Authenticator // The distinct values for the field. Values bsoncore.Value } @@ -107,6 +109,7 @@ func (d *Distinct) Execute(ctx context.Context) error { ServerAPI: d.serverAPI, Timeout: d.timeout, Name: driverutil.DistinctOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -311,3 +314,23 @@ func (d *Distinct) Timeout(timeout *time.Duration) *Distinct { d.timeout = timeout return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Distinct) Authenticator(authenticator driver.Authenticator) *Distinct { + if d == nil { + d = new(Distinct) + } + + d.authenticator = authenticator + return d +} + +// Authenticator sets the authenticator to use for this operation. +func (d *DistinctResult) Authenticator(authenticator driver.Authenticator) *DistinctResult { + if d == nil { + d = new(DistinctResult) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index 8c65967564..ff6ffa181c 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -23,22 +23,24 @@ import ( // DropCollection performs a drop operation. type DropCollection struct { - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropCollectionResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropCollectionResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropCollectionResult represents a dropCollection result returned by the server. type DropCollectionResult struct { + authenticator driver.Authenticator // The number of indexes in the dropped collection. NIndexesWas int32 // The namespace of the dropped collection. @@ -104,6 +106,7 @@ func (dc *DropCollection) Execute(ctx context.Context) error { ServerAPI: dc.serverAPI, Timeout: dc.timeout, Name: driverutil.DropOp, + Authenticator: dc.authenticator, }.Execute(ctx) } @@ -222,3 +225,23 @@ func (dc *DropCollection) Timeout(timeout *time.Duration) *DropCollection { dc.timeout = timeout return dc } + +// Authenticator sets the authenticator to use for this operation. +func (d *DropCollection) Authenticator(authenticator driver.Authenticator) *DropCollection { + if d == nil { + d = new(DropCollection) + } + + d.authenticator = authenticator + return d +} + +// Authenticator sets the authenticator to use for this operation. +func (d *DropCollectionResult) Authenticator(authenticator driver.Authenticator) *DropCollectionResult { + if d == nil { + d = new(DropCollectionResult) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/drop_database.go b/x/mongo/driver/operation/drop_database.go index a8f9b45ba4..daecc65c5b 100644 --- a/x/mongo/driver/operation/drop_database.go +++ b/x/mongo/driver/operation/drop_database.go @@ -21,15 +21,16 @@ import ( // DropDatabase performs a dropDatabase operation type DropDatabase struct { - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + serverAPI *driver.ServerAPIOptions } // NewDropDatabase constructs and returns a new DropDatabase. @@ -55,6 +56,7 @@ func (dd *DropDatabase) Execute(ctx context.Context) error { WriteConcern: dd.writeConcern, ServerAPI: dd.serverAPI, Name: driverutil.DropDatabaseOp, + Authenticator: dd.authenticator, }.Execute(ctx) } @@ -154,3 +156,13 @@ func (dd *DropDatabase) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropDatab dd.serverAPI = serverAPI return dd } + +// Authenticator sets the authenticator to use for this operation. +func (d *DropDatabase) Authenticator(authenticator driver.Authenticator) *DropDatabase { + if d == nil { + d = new(DropDatabase) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 0c3d459707..3b5fe7bbfe 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -23,24 +23,26 @@ import ( // DropIndexes performs an dropIndexes operation. type DropIndexes struct { - index *string - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index *string + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropIndexesResult represents a dropIndexes result returned by the server. type DropIndexesResult struct { + authenticator driver.Authenticator // Number of indexes that existed before the drop was executed. NIndexesWas int32 } @@ -101,6 +103,7 @@ func (di *DropIndexes) Execute(ctx context.Context) error { ServerAPI: di.serverAPI, Timeout: di.timeout, Name: driverutil.DropIndexesOp, + Authenticator: di.authenticator, }.Execute(ctx) } @@ -242,3 +245,23 @@ func (di *DropIndexes) Timeout(timeout *time.Duration) *DropIndexes { di.timeout = timeout return di } + +// Authenticator sets the authenticator to use for this operation. +func (d *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropIndexes { + if d == nil { + d = new(DropIndexes) + } + + d.authenticator = authenticator + return d +} + +// Authenticator sets the authenticator to use for this operation. +func (d *DropIndexesResult) Authenticator(authenticator driver.Authenticator) *DropIndexesResult { + if d == nil { + d = new(DropIndexesResult) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 3992c83165..0c69d5a104 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -21,23 +21,25 @@ import ( // DropSearchIndex performs an dropSearchIndex operation. type DropSearchIndex struct { - index string - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result DropSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result DropSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropSearchIndexResult represents a dropSearchIndex result returned by the server. type DropSearchIndexResult struct { - Ok int32 + authenticator driver.Authenticator + Ok int32 } func buildDropSearchIndexResult(response bsoncore.Document) (DropSearchIndexResult, error) { @@ -93,6 +95,7 @@ func (dsi *DropSearchIndex) Execute(ctx context.Context) error { Selector: dsi.selector, ServerAPI: dsi.serverAPI, Timeout: dsi.timeout, + Authenticator: dsi.authenticator, }.Execute(ctx) } @@ -212,3 +215,23 @@ func (dsi *DropSearchIndex) Timeout(timeout *time.Duration) *DropSearchIndex { dsi.timeout = timeout return dsi } + +// Authenticator sets the authenticator to use for this operation. +func (d *DropSearchIndex) Authenticator(authenticator driver.Authenticator) *DropSearchIndex { + if d == nil { + d = new(DropSearchIndex) + } + + d.authenticator = authenticator + return d +} + +// Authenticator sets the authenticator to use for this operation. +func (d *DropSearchIndexResult) Authenticator(authenticator driver.Authenticator) *DropSearchIndexResult { + if d == nil { + d = new(DropSearchIndexResult) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index 52f300bb7f..dae9165332 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -20,15 +20,16 @@ import ( // EndSessions performs an endSessions operation. type EndSessions struct { - sessionIDs bsoncore.Document - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + sessionIDs bsoncore.Document + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + serverAPI *driver.ServerAPIOptions } // NewEndSessions constructs and returns a new EndSessions. @@ -61,6 +62,7 @@ func (es *EndSessions) Execute(ctx context.Context) error { Selector: es.selector, ServerAPI: es.serverAPI, Name: driverutil.EndSessionsOp, + Authenticator: es.authenticator, }.Execute(ctx) } @@ -161,3 +163,13 @@ func (es *EndSessions) ServerAPI(serverAPI *driver.ServerAPIOptions) *EndSession es.serverAPI = serverAPI return es } + +// Authenticator sets the authenticator to use for this operation. +func (e *EndSessions) Authenticator(authenticator driver.Authenticator) *EndSessions { + if e == nil { + e = new(EndSessions) + } + + e.authenticator = authenticator + return e +} diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 8950fde86d..c71b7d755e 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -25,6 +25,7 @@ import ( // Find performs a find operation. type Find struct { + authenticator driver.Authenticator allowDiskUse *bool allowPartialResults *bool awaitData *bool @@ -112,6 +113,7 @@ func (f *Find) Execute(ctx context.Context) error { Logger: f.logger, Name: driverutil.FindOp, OmitCSOTMaxTimeMS: f.omitCSOTMaxTimeMS, + Authenticator: f.authenticator, }.Execute(ctx) } @@ -575,3 +577,13 @@ func (f *Find) Logger(logger *logger.Logger) *Find { f.logger = logger return f } + +// Authenticator sets the authenticator to use for this operation. +func (f *Find) Authenticator(authenticator driver.Authenticator) *Find { + if f == nil { + f = new(Find) + } + + f.authenticator = authenticator + return f +} diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 7faf561135..de94abeacf 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -25,6 +25,7 @@ import ( // FindAndModify performs a findAndModify operation. type FindAndModify struct { + authenticator driver.Authenticator arrayFilters bsoncore.Array bypassDocumentValidation *bool collation bsoncore.Document @@ -57,6 +58,7 @@ type FindAndModify struct { // LastErrorObject represents information about updates and upserts returned by the server. type LastErrorObject struct { + authenticator driver.Authenticator // True if an update modified an existing document UpdatedExisting bool // Object ID of the upserted document. @@ -65,6 +67,7 @@ type LastErrorObject struct { // FindAndModifyResult represents a findAndModify result returned by the server. type FindAndModifyResult struct { + authenticator driver.Authenticator // Either the old or modified document, depending on the value of the new parameter. Value bsoncore.Document // Contains information about updates and upserts. @@ -145,6 +148,7 @@ func (fam *FindAndModify) Execute(ctx context.Context) error { ServerAPI: fam.serverAPI, Timeout: fam.timeout, Name: driverutil.FindAndModifyOp, + Authenticator: fam.authenticator, }.Execute(ctx) } @@ -477,3 +481,33 @@ func (fam *FindAndModify) Timeout(timeout *time.Duration) *FindAndModify { fam.timeout = timeout return fam } + +// Authenticator sets the authenticator to use for this operation. +func (f *FindAndModify) Authenticator(authenticator driver.Authenticator) *FindAndModify { + if f == nil { + f = new(FindAndModify) + } + + f.authenticator = authenticator + return f +} + +// Authenticator sets the authenticator to use for this operation. +func (l *LastErrorObject) Authenticator(authenticator driver.Authenticator) *LastErrorObject { + if l == nil { + l = new(LastErrorObject) + } + + l.authenticator = authenticator + return l +} + +// Authenticator sets the authenticator to use for this operation. +func (f *FindAndModifyResult) Authenticator(authenticator driver.Authenticator) *FindAndModifyResult { + if f == nil { + f = new(FindAndModifyResult) + } + + f.authenticator = authenticator + return f +} diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 16f2ebf6c0..472b5b3ce8 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -36,6 +36,7 @@ const driverName = "mongo-go-driver" // Hello is used to run the handshake operation. type Hello struct { + authenticator driver.Authenticator appname string compressors []string saslSupportedMechs string @@ -201,6 +202,7 @@ func getFaasEnvName() string { } type containerInfo struct { + authenticator driver.Authenticator runtime string orchestrator string } @@ -649,3 +651,25 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, func (h *Hello) FinishHandshake(context.Context, driver.Connection) error { return nil } + +// Authenticator sets the authenticator to use for this operation. +func (h *Hello) Authenticator(authenticator driver.Authenticator) *Hello { + if h == nil { + h = new(Hello) + } + + h.authenticator = authenticator + return h +} + + +// Authenticator sets the authenticator to use for this operation. +func (c *containerInfo) Authenticator(authenticator driver.Authenticator) *containerInfo { + if c == nil { + c = new(containerInfo) + } + + c.authenticator = authenticator + return c +} + diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index 7da4b8b0fb..a091a8ccfa 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -25,6 +25,7 @@ import ( // Insert performs an insert operation. type Insert struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value documents []bsoncore.Document @@ -47,6 +48,7 @@ type Insert struct { // InsertResult represents an insert result returned by the server. type InsertResult struct { + authenticator driver.Authenticator // Number of documents successfully inserted. N int64 } @@ -115,6 +117,7 @@ func (i *Insert) Execute(ctx context.Context) error { Timeout: i.timeout, Logger: i.logger, Name: driverutil.InsertOp, + Authenticator: i.authenticator, }.Execute(ctx) } @@ -306,3 +309,23 @@ func (i *Insert) Logger(logger *logger.Logger) *Insert { i.logger = logger return i } + +// Authenticator sets the authenticator to use for this operation. +func (i *Insert) Authenticator(authenticator driver.Authenticator) *Insert { + if i == nil { + i = new(Insert) + } + + i.authenticator = authenticator + return i +} + +// Authenticator sets the authenticator to use for this operation. +func (i *InsertResult) Authenticator(authenticator driver.Authenticator) *InsertResult { + if i == nil { + i = new(InsertResult) + } + + i.authenticator = authenticator + return i +} diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/listDatabases.go index c70248e2a9..829858a644 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/listDatabases.go @@ -24,6 +24,7 @@ import ( // ListDatabases performs a listDatabases operation. type ListDatabases struct { + authenticator driver.Authenticator filter bsoncore.Document authorizedDatabases *bool nameOnly *bool @@ -44,6 +45,7 @@ type ListDatabases struct { // ListDatabasesResult represents a listDatabases result returned by the server. type ListDatabasesResult struct { + authenticator driver.Authenticator // An array of documents, one document for each database Databases []databaseRecord // The sum of the size of all the database files on disk in bytes. @@ -51,9 +53,10 @@ type ListDatabasesResult struct { } type databaseRecord struct { - Name string - SizeOnDisk int64 `bson:"sizeOnDisk"` - Empty bool + authenticator driver.Authenticator + Name string + SizeOnDisk int64 `bson:"sizeOnDisk"` + Empty bool } func buildListDatabasesResult(response bsoncore.Document) (ListDatabasesResult, error) { @@ -165,6 +168,7 @@ func (ld *ListDatabases) Execute(ctx context.Context) error { ServerAPI: ld.serverAPI, Timeout: ld.timeout, Name: driverutil.ListDatabasesOp, + Authenticator: ld.authenticator, }.Execute(ctx) } @@ -327,3 +331,33 @@ func (ld *ListDatabases) Timeout(timeout *time.Duration) *ListDatabases { ld.timeout = timeout return ld } + +// Authenticator sets the authenticator to use for this operation. +func (l *ListDatabases) Authenticator(authenticator driver.Authenticator) *ListDatabases { + if l == nil { + l = new(ListDatabases) + } + + l.authenticator = authenticator + return l +} + +// Authenticator sets the authenticator to use for this operation. +func (l *ListDatabasesResult) Authenticator(authenticator driver.Authenticator) *ListDatabasesResult { + if l == nil { + l = new(ListDatabasesResult) + } + + l.authenticator = authenticator + return l +} + +// Authenticator sets the authenticator to use for this operation. +func (d *databaseRecord) Authenticator(authenticator driver.Authenticator) *databaseRecord { + if d == nil { + d = new(databaseRecord) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index 6fe68fa033..da65b99088 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -22,6 +22,7 @@ import ( // ListCollections performs a listCollections operation. type ListCollections struct { + authenticator driver.Authenticator filter bsoncore.Document nameOnly *bool authorizedCollections *bool @@ -83,6 +84,7 @@ func (lc *ListCollections) Execute(ctx context.Context) error { ServerAPI: lc.serverAPI, Timeout: lc.timeout, Name: driverutil.ListCollectionsOp, + Authenticator: lc.authenticator, }.Execute(ctx) } @@ -259,3 +261,13 @@ func (lc *ListCollections) Timeout(timeout *time.Duration) *ListCollections { lc.timeout = timeout return lc } + +// Authenticator sets the authenticator to use for this operation. +func (l *ListCollections) Authenticator(authenticator driver.Authenticator) *ListCollections { + if l == nil { + l = new(ListCollections) + } + + l.authenticator = authenticator + return l +} diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index 79d50eca95..dc3654c884 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -21,19 +21,20 @@ import ( // ListIndexes performs a listIndexes operation. type ListIndexes struct { - batchSize *int32 - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - database string - deployment driver.Deployment - selector description.ServerSelector - retry *driver.RetryMode - crypt driver.Crypt - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + batchSize *int32 + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + database string + deployment driver.Deployment + selector description.ServerSelector + retry *driver.RetryMode + crypt driver.Crypt + serverAPI *driver.ServerAPIOptions + timeout *time.Duration result driver.CursorResponse } @@ -85,6 +86,7 @@ func (li *ListIndexes) Execute(ctx context.Context) error { ServerAPI: li.serverAPI, Timeout: li.timeout, Name: driverutil.ListIndexesOp, + Authenticator: li.authenticator, }.Execute(ctx) } @@ -233,3 +235,13 @@ func (li *ListIndexes) Timeout(timeout *time.Duration) *ListIndexes { li.timeout = timeout return li } + +// Authenticator sets the authenticator to use for this operation. +func (l *ListIndexes) Authenticator(authenticator driver.Authenticator) *ListIndexes { + if l == nil { + l = new(ListIndexes) + } + + l.authenticator = authenticator + return l +} diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 881b1bcf7b..0da93cbf08 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -26,6 +26,7 @@ import ( // Update performs an update operation. type Update struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value ordered *bool @@ -51,12 +52,14 @@ type Update struct { // Upsert contains the information for an upsert in an Update operation. type Upsert struct { - Index int64 - ID interface{} `bson:"_id"` + authenticator driver.Authenticator + Index int64 + ID interface{} `bson:"_id"` } // UpdateResult contains information for the result of an Update operation. type UpdateResult struct { + authenticator driver.Authenticator // Number of documents matched. N int64 // Number of documents modified. @@ -167,6 +170,7 @@ func (u *Update) Execute(ctx context.Context) error { Timeout: u.timeout, Logger: u.logger, Name: driverutil.UpdateOp, + Authenticator: u.authenticator, }.Execute(ctx) } @@ -414,3 +418,33 @@ func (u *Update) Logger(logger *logger.Logger) *Update { u.logger = logger return u } + +// Authenticator sets the authenticator to use for this operation. +func (u *Update) Authenticator(authenticator driver.Authenticator) *Update { + if u == nil { + u = new(Update) + } + + u.authenticator = authenticator + return u +} + +// Authenticator sets the authenticator to use for this operation. +func (u *Upsert) Authenticator(authenticator driver.Authenticator) *Upsert { + if u == nil { + u = new(Upsert) + } + + u.authenticator = authenticator + return u +} + +// Authenticator sets the authenticator to use for this operation. +func (u *UpdateResult) Authenticator(authenticator driver.Authenticator) *UpdateResult { + if u == nil { + u = new(UpdateResult) + } + + u.authenticator = authenticator + return u +} diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index 64f2da7f6f..963ef1e1a6 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -21,24 +21,26 @@ import ( // UpdateSearchIndex performs a updateSearchIndex operation. type UpdateSearchIndex struct { - index string - definition bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result UpdateSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + definition bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result UpdateSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // UpdateSearchIndexResult represents a single index in the updateSearchIndexResult result. type UpdateSearchIndexResult struct { - Ok int32 + authenticator driver.Authenticator + Ok int32 } func buildUpdateSearchIndexResult(response bsoncore.Document) (UpdateSearchIndexResult, error) { @@ -95,6 +97,7 @@ func (usi *UpdateSearchIndex) Execute(ctx context.Context) error { Selector: usi.selector, ServerAPI: usi.serverAPI, Timeout: usi.timeout, + Authenticator: usi.authenticator, }.Execute(ctx) } @@ -225,3 +228,23 @@ func (usi *UpdateSearchIndex) Timeout(timeout *time.Duration) *UpdateSearchIndex usi.timeout = timeout return usi } + +// Authenticator sets the authenticator to use for this operation. +func (u *UpdateSearchIndex) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndex { + if u == nil { + u = new(UpdateSearchIndex) + } + + u.authenticator = authenticator + return u +} + +// Authenticator sets the authenticator to use for this operation. +func (u *UpdateSearchIndexResult) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndexResult { + if u == nil { + u = new(UpdateSearchIndexResult) + } + + u.authenticator = authenticator + return u +} From 4070d06e6e4154e1ad29afa95286cc9cb70004ce Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 13 Jun 2024 18:51:53 -0400 Subject: [PATCH 18/66] GODRIVER-2911: Thread through Authenticator --- mongo/bulk_write.go | 7 ++++--- mongo/change_stream.go | 3 ++- mongo/client.go | 2 +- mongo/collection.go | 29 +++++++++++++++++------------ mongo/database.go | 11 ++++++----- mongo/index_view.go | 7 ++++--- mongo/search_index_view.go | 6 +++--- mongo/session.go | 5 +++-- 8 files changed, 40 insertions(+), 30 deletions(-) diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 3fdb67b9a2..40f1181e0e 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -186,7 +186,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -256,7 +256,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -387,7 +387,8 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI). - Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger) + Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger). + Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 8d0a2031de..3ea8baf1f2 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -137,7 +137,8 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in ReadPreference(config.readPreference).ReadConcern(config.readConcern). Deployment(cs.client.deployment).ClusterClock(cs.client.clock). CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone). - ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout) + ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout). + Authenticator(cs.client.authenticator) if cs.options.Collation != nil { cs.aggregate.Collation(bsoncore.Document(cs.options.Collation.ToDocument())) diff --git a/mongo/client.go b/mongo/client.go index 67ad1d2528..ec72212e87 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -705,7 +705,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.cryptFLE). - ServerAPI(c.serverAPI).Timeout(c.timeout) + ServerAPI(c.serverAPI).Timeout(c.timeout).Authenticator(c.authenticator) if ldo.NameOnly != nil { op = op.NameOnly(*ldo.NameOnly) diff --git a/mongo/collection.go b/mongo/collection.go index 4cf6fd1a1a..dbe238a9e3 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -291,7 +291,8 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) imo := options.MergeInsertManyOptions(opts...) if imo.BypassDocumentValidation != nil && *imo.BypassDocumentValidation { op = op.BypassDocumentValidation(*imo.BypassDocumentValidation) @@ -471,7 +472,8 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) if do.Comment != nil { comment, err := marshalValue(do.Comment, coll.bsonOpts, coll.registry) if err != nil { @@ -588,7 +590,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Hint(uo.Hint != nil). ArrayFilters(uo.ArrayFilters != nil).Ordered(true).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).Logger(coll.client.logger) + Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) if uo.Let != nil { let, err := marshal(uo.Let, coll.bsonOpts, coll.registry) if err != nil { @@ -861,7 +863,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). Timeout(a.client.timeout). - MaxTime(ao.MaxTime) + MaxTime(ao.MaxTime). + Authenticator(a.client.authenticator) // Omit "maxTimeMS" from operations that return a user-managed cursor to // prevent confusing "cursor not found" errors. To maintain existing @@ -992,7 +995,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime) + Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime).Authenticator(coll.client.authenticator) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } @@ -1077,7 +1080,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(co.MaxTime) + Timeout(coll.client.timeout).MaxTime(co.MaxTime).Authenticator(coll.client.authenticator) if co.Comment != nil { comment, err := marshalValue(co.Comment, coll.bsonOpts, coll.registry) @@ -1144,7 +1147,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(option.MaxTime) + Timeout(coll.client.timeout).MaxTime(option.MaxTime).Authenticator(coll.client.authenticator) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) @@ -1257,7 +1260,7 @@ func (coll *Collection) find( ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger). - OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS) + OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS).Authenticator(coll.client.authenticator) cursorOpts := coll.client.createBaseCursorOptions() @@ -1521,7 +1524,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} } fod := options.MergeFindOneAndDeleteOptions(opts...) op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fod.MaxTime) + MaxTime(fod.MaxTime).Authenticator(coll.client.authenticator) if fod.Collation != nil { op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } @@ -1601,7 +1604,8 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ fo := options.MergeFindOneAndReplaceOptions(opts...) op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) + if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) } @@ -1688,7 +1692,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} fo := options.MergeFindOneAndUpdateOptions(opts...) op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fo.MaxTime) + MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { @@ -1894,7 +1898,8 @@ func (coll *Collection) drop(ctx context.Context) error { ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + Authenticator(coll.client.authenticator) err = op.Execute(ctx) // ignore namespace not found errors diff --git a/mongo/database.go b/mongo/database.go index 0147683396..5344c9641e 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -189,7 +189,7 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, ServerSelector(readSelect).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment). Crypt(db.client.cryptFLE).ReadPreference(ro.ReadPreference).ServerAPI(db.client.serverAPI). - Timeout(db.client.timeout).Logger(db.client.logger).Authenticator(db.Client().authenticator), sess, nil + Timeout(db.client.timeout).Logger(db.client.logger).Authenticator(db.client.authenticator), sess, nil } // RunCommand executes the given command against the database. @@ -308,7 +308,7 @@ func (db *Database) Drop(ctx context.Context) error { Session(sess).WriteConcern(wc).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) err = op.Execute(ctx) @@ -402,7 +402,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI).Timeout(db.client.timeout) + ServerAPI(db.client.serverAPI).Timeout(db.client.timeout).Authenticator(db.client.authenticator) cursorOpts := db.client.createBaseCursorOptions() @@ -679,7 +679,7 @@ func (db *Database) createCollection(ctx context.Context, name string, opts ...* func (db *Database) createCollectionOperation(name string, opts ...*options.CreateCollectionOptions) (*operation.Create, error) { cco := options.MergeCreateCollectionOptions(opts...) - op := operation.NewCreate(name).ServerAPI(db.client.serverAPI) + op := operation.NewCreate(name).ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) if cco.Capped != nil { op.Capped(*cco.Capped) @@ -805,7 +805,8 @@ func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pip op := operation.NewCreate(viewName). ViewOn(viewOn). Pipeline(pipelineArray). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI). + Authenticator(db.client.authenticator) cvo := options.MergeCreateViewOptions(opts...) if cvo.Collation != nil { op.Collation(bsoncore.Document(cvo.Collation.ToDocument())) diff --git a/mongo/index_view.go b/mongo/index_view.go index 8d3555d0b0..b7e7234339 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -94,7 +94,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout) + Timeout(iv.coll.client.timeout).Authenticator(iv.coll.client.authenticator) cursorOpts := iv.coll.client.createBaseCursorOptions() @@ -262,7 +262,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime) + Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime).Authenticator(iv.coll.client.authenticator) if option.CommitQuorum != nil { commitQuorum, err := marshalValue(option.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { @@ -402,7 +402,8 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime) + Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime). + Authenticator(iv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index 73fe8534ed..3253a73a2b 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -143,7 +143,7 @@ func (siv SearchIndexView) CreateMany( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { @@ -198,7 +198,7 @@ func (siv SearchIndexView) DropOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { @@ -252,7 +252,7 @@ func (siv SearchIndexView) UpdateOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) return op.Execute(ctx) } diff --git a/mongo/session.go b/mongo/session.go index 8f1e029b95..77be4ab6db 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -296,7 +296,8 @@ func (s *sessionImpl) AbortTransaction(ctx context.Context) error { _ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin"). Deployment(s.deployment).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector). Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor). - RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI).Execute(ctx) + RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI). + Authenticator(s.client.authenticator).Execute(ctx) s.clientSession.Aborting = false _ = s.clientSession.AbortTransaction() @@ -328,7 +329,7 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)). - ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct) + ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct).Authenticator(s.client.authenticator) err = op.Execute(ctx) // Return error without updating transaction state if it is a timeout, as the transaction has not From 4a4409098a626fa0e5ec2374b0062369b9a3dc1a Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 09:35:29 -0400 Subject: [PATCH 19/66] GODRIVER-2911: Move OIDC back to auth package, yay --- x/mongo/driver/auth/auth.go | 2 +- x/mongo/driver/driver.go | 54 +++++ x/mongo/driver/oidc.go | 445 ------------------------------------ 3 files changed, 55 insertions(+), 446 deletions(-) delete mode 100644 x/mongo/driver/oidc.go diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 1265b8a146..01fa082bbd 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -33,7 +33,7 @@ func init() { RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator) RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator) RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator) - RegisterAuthenticatorFactory(driver.OIDC, driver.NewOIDCAuthenticator) + RegisterAuthenticatorFactory(OIDC, newOIDCAuthenticator) } // CreateAuthenticator creates an authenticator. diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 900729bf87..2556dc06d8 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -15,6 +15,7 @@ package driver // import "go.mongodb.org/mongo-driver/x/mongo/driver" import ( "context" + "net/http" "time" "go.mongodb.org/mongo-driver/internal/csot" @@ -24,6 +25,59 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) +// AuthConfig holds the information necessary to perform an authentication attempt. +// this was moved from the auth package to avoid a circular dependency. The auth package +// reexports this under the old name to avoid breaking the public api. +type AuthConfig struct { + Description description.Server + Connection Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions + HTTPClient *http.Client + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + +// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be +// nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +type IDPInfo struct { + Issuer string `bson:"issuer"` + ClientID string `bson:"clientId"` + RequestScopes []string `bson:"requestScopes"` +} + +// Authenticator handles authenticating a connection. The implementors of this interface +// are all in the auth package. +type Authenticator interface { + // Auth authenticates the connection. + Auth(context.Context, *AuthConfig) error + Reauth(context.Context) error +} + +// Cred is a user's credential. +type Cred struct { + Source string + Username string + Password string + PasswordSet bool + Props map[string]string +} + // Deployment is implemented by types that can select a server from a deployment. type Deployment interface { SelectServer(context.Context, description.ServerSelector) (Server, error) diff --git a/x/mongo/driver/oidc.go b/x/mongo/driver/oidc.go deleted file mode 100644 index e06135b33d..0000000000 --- a/x/mongo/driver/oidc.go +++ /dev/null @@ -1,445 +0,0 @@ -package driver - -import ( - "context" - "fmt" - "net/http" - "sync" - "time" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/description" - "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/session" -) - -const OIDC = "MONGODB-OIDC" -const tokenResourceProp = "TOKEN_RESOURCE" -const environmentProp = "ENVIRONMENT" -const allowedHostsProp = "ALLOWED_HOSTS" -const azureEnvironmentValue = "azure" -const gcpEnvironmentValue = "gcp" -const defaultAuthDB = "admin" -const machineSleepTime = 100 * time.Millisecond - -// Authenticator handles authenticating a connection. -type Authenticator interface { - // Auth authenticates the connection. - Auth(context.Context, *AuthConfig) error - Reauth(context.Context) error -} - -// Cred is a user's credential. -type Cred struct { - Source string - Username string - Password string - PasswordSet bool - Props map[string]string -} - -// OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, -// and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality -// is only for the OIDC Human flow. -type OIDCAuthenticator struct { - mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct. - - AuthMechanismProperties map[string]string - - cfg *AuthConfig - accessToken string - refreshToken *string - idpInfo *IDPInfo -} - -func NewOIDCAuthenticator(cred *Cred) (Authenticator, error) { - oa := &OIDCAuthenticator{ - AuthMechanismProperties: cred.Props, - } - return oa, nil -} - -type IDPInfo struct { - Issuer string `bson:"issuer"` - ClientID string `bson:"clientId"` - RequestScopes []string `bson:"requestScopes"` -} - -// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be -// nil in the OIDCArgs for the Machine flow. -type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) - -type OIDCArgs struct { - Version int - IDPInfo *IDPInfo - RefreshToken *string -} - -type OIDCCredential struct { - AccessToken string - ExpiresAt *time.Time - RefreshToken *string -} - -type oidcOneStep struct { - accessToken string -} - -var _ oidcSaslClient = (*oidcOneStep)(nil) - -func jwtStepRequest(accessToken string) []byte { - return bsoncore.NewDocumentBuilder(). - AppendString("jwt", accessToken). - Build() -} - -func principalStepRequest(principal string) []byte { - doc := bsoncore.NewDocumentBuilder() - if principal != "" { - doc.AppendString("n", principal) - } - return doc.Build() -} - -func (oos *oidcOneStep) Start() (string, []byte, error) { - return OIDC, jwtStepRequest(oos.accessToken), nil -} - -func (oos *oidcOneStep) Next([]byte) ([]byte, error) { - return nil, fmt.Errorf("unexpected step in OIDC machine authentication") -} - -func (*oidcOneStep) Completed() bool { - return true -} - -func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { - env, ok := oa.AuthMechanismProperties[environmentProp] - if !ok { - return nil, nil - } - - switch env { - // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider - // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider - } - - return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) -} - -// This should only be called with the Mutex held. -func (oa *OIDCAuthenticator) getAccessToken( - ctx context.Context, - args *OIDCArgs, - callback OIDCCallback, -) (string, error) { - if oa.accessToken != "" { - return oa.accessToken, nil - } - - cred, err := callback(ctx, args) - if err != nil { - return "", err - } - - oa.accessToken = cred.AccessToken - if cred.RefreshToken != nil { - oa.refreshToken = cred.RefreshToken - } - return cred.AccessToken, nil -} - -// This should only be called with the Mutex held. -func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( - ctx context.Context, - callback OIDCCallback, - refreshToken string, -) (string, error) { - - cred, err := callback(ctx, &OIDCArgs{ - Version: 1, - IDPInfo: oa.idpInfo, - RefreshToken: &refreshToken, - }) - if err != nil { - return "", err - } - - oa.accessToken = cred.AccessToken - return cred.AccessToken, nil -} - -// TODO: add invalidation algorithm from rust driver -func (oa *OIDCAuthenticator) invalidateAccessToken() { - oa.accessToken = "" -} - -func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { - oa.invalidateAccessToken() - return oa.Auth(ctx, oa.cfg) -} - -// Auth authenticates the connection. -func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *AuthConfig) error { - // the Mutex must be held during the entire Auth call so that multiple racing attempts - // to authenticate will not result in multiple callbacks. The losers on the Mutex will - // retrieve the access token from the Authenticator cache. - oa.mu.Lock() - defer oa.mu.Unlock() - - oa.cfg = cfg - - if oa.accessToken != "" { - err := conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ - accessToken: oa.accessToken, - }) - if err == nil { - return nil - } - // TODO: Check error type and raise if it's not a server-side error. - oa.invalidateAccessToken() - time.Sleep(100 * time.Millisecond) - } - - if cfg.OIDCMachineCallback != nil { - accessToken, err := oa.getAccessToken(ctx, nil, cfg.OIDCMachineCallback) - if err != nil { - return err - } - - err = conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ - accessToken: accessToken, - }) - if err == nil { - return nil - } - // Clear the access token if authentication failed. - oa.invalidateAccessToken() - - time.Sleep(machineSleepTime) - accessToken, err = oa.getAccessToken(ctx, &OIDCArgs{Version: 1}, cfg.OIDCMachineCallback) - if err != nil { - return err - } - return conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ - accessToken: accessToken, - }) - } - - // TODO GODRIVER-3246: Handle Human callback here. - - callback, err := oa.providerCallback() - if err != nil { - return fmt.Errorf("error getting build-in OIDC provider: %w", err) - } - - accessToken, err := oa.getAccessToken(ctx, &OIDCArgs{Version: 1}, callback) - if err != nil { - return fmt.Errorf("error getting access token from built-in OIDC provider: %w", err) - } - - err = conductOIDCSaslConversation(ctx, cfg, "$external", &oidcOneStep{ - accessToken: accessToken, - }) - // TODO: Check error type and raise if it's not a server-side error. - if err == nil { - return nil - } - oa.invalidateAccessToken() - - return err -} - -// OIDC Sasl. This is almost a verbatim copy of auth/sasl introduced to remove the dependency on auth package -// which causes a circular dependency when attempting to do Reauthentication in driver/operation.go. -// This could be removed with a larger refactor. - -// AuthConfig holds the information necessary to perform an authentication attempt. -// this was moved from the auth package to avoid a circular dependency. The auth package -// reexports this under the old name to avoid breaking the public api. -type AuthConfig struct { - Description description.Server - Connection Connection - ClusterClock *session.ClusterClock - HandshakeInfo HandshakeInformation - ServerAPI *ServerAPIOptions - HTTPClient *http.Client - OIDCMachineCallback OIDCCallback - OIDCHumanCallback OIDCCallback -} - -func newError(err error, mechanism string) error { - return fmt.Errorf("error during %s SASL conversation: %w", OIDC, err) -} - -// oidcSaslClient is the client piece of a sasl conversation. -type oidcSaslClient interface { - Start() (string, []byte, error) - Next(challenge []byte) ([]byte, error) - Completed() bool -} - -// oidcSaslClientCloser is a oidcSaslClient that has resources to clean up. -type oidcSaslClientCloser interface { - oidcSaslClient - Close() -} - -// extraOptionsOIDCSaslClient is a SaslClient that appends options to the saslStart command. -type extraOptionsOIDCSaslClient interface { - StartCommandOptions() bsoncore.Document -} - -// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the -// conversation can be executed in multi-step speculative fashion. -type saslConversation struct { - client oidcSaslClient - source string - mechanism string - speculative bool -} - -func newSaslConversation(client oidcSaslClient, source string, speculative bool) *saslConversation { - authSource := source - if authSource == "" { - authSource = defaultAuthDB - } - return &saslConversation{ - client: client, - source: authSource, - speculative: speculative, - } -} - -// firstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used -// for speculative authentication. -func (sc *saslConversation) firstMessage() (bsoncore.Document, error) { - var payload []byte - var err error - sc.mechanism, payload, err = sc.client.Start() - if err != nil { - return nil, err - } - - saslCmdElements := [][]byte{ - bsoncore.AppendInt32Element(nil, "saslStart", 1), - bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism), - bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), - } - if sc.speculative { - // The "db" field is only appended for speculative auth because the hello command is executed against admin - // so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands - // will be executed against the auth source. - saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source)) - } - if extraOptionsClient, ok := sc.client.(extraOptionsOIDCSaslClient); ok { - optionsDoc := extraOptionsClient.StartCommandOptions() - saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc)) - } - - return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil -} - -type saslResponse struct { - ConversationID int `bson:"conversationId"` - Code int `bson:"code"` - Done bool `bson:"done"` - Payload []byte `bson:"payload"` -} - -// finish completes the conversation based on the first server response to authenticate the given connection. -func (sc *saslConversation) finish(ctx context.Context, cfg *AuthConfig, firstResponse bsoncore.Document) error { - if closer, ok := sc.client.(oidcSaslClientCloser); ok { - defer closer.Close() - } - - var saslResp saslResponse - err := bson.Unmarshal(firstResponse, &saslResp) - if err != nil { - fullErr := fmt.Errorf("unmarshal error: %w", err) - return newError(fullErr, sc.mechanism) - } - - cid := saslResp.ConversationID - var payload []byte - var result bsoncore.Document - for { - if saslResp.Code != 0 { - return newError(err, sc.mechanism) - } - - if saslResp.Done && sc.client.Completed() { - return nil - } - - payload, err = sc.client.Next(saslResp.Payload) - if err != nil { - return newError(err, sc.mechanism) - } - - if saslResp.Done && sc.client.Completed() { - return nil - } - - doc := bsoncore.BuildDocumentFromElements(nil, - bsoncore.AppendInt32Element(nil, "saslContinue", 1), - bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)), - bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload), - ) - - saslOp := Operation{ - CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { - return append(dst, doc[4:len(doc)-1]...), nil - }, - ProcessResponseFn: func(info ResponseInfo) error { - result = info.ServerResponse - return nil - }, - Deployment: Deployment(SingleConnectionDeployment{cfg.Connection}), - Database: sc.source, - Clock: cfg.ClusterClock, - ServerAPI: cfg.ServerAPI, - } - - err = saslOp.Execute(ctx) - if err != nil { - return newError(err, sc.mechanism) - } - err = bson.Unmarshal(result, &saslResp) - if err != nil { - fullErr := fmt.Errorf("unmarshal error: %w", err) - return newError(fullErr, sc.mechanism) - } - return nil - } -} - -// conductOIDCSaslConversation runs a full SASL conversation to authenticate the given connection. -func conductOIDCSaslConversation(ctx context.Context, cfg *AuthConfig, authSource string, client oidcSaslClient) error { - // Create a non-speculative SASL conversation. - conversation := newSaslConversation(client, authSource, false) - - doc, err := conversation.firstMessage() - if err != nil { - return newError(err, conversation.mechanism) - } - var result bsoncore.Document - saslOp := Operation{ - CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { - return append(dst, doc[4:len(doc)-1]...), nil - }, - ProcessResponseFn: func(info ResponseInfo) error { - result = info.ServerResponse - return nil - }, - Deployment: Deployment(SingleConnectionDeployment{cfg.Connection}), - Database: authSource, - Clock: cfg.ClusterClock, - ServerAPI: cfg.ServerAPI, - } - if err := saslOp.Execute(ctx); err != nil { - return newError(err, conversation.mechanism) - } - - return conversation.finish(ctx, cfg, result) -} From 4ea9b9c2a44b32bf9c4e31f6f380629917ca0f1f Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 09:41:02 -0400 Subject: [PATCH 20/66] GODRIVER-2911: Move Config = AuthConfig to top of the file --- x/mongo/driver/auth/auth.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 01fa082bbd..ba2d296b11 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -19,6 +19,8 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) +type Config = driver.AuthConfig + // AuthenticatorFactory constructs an authenticator. type AuthenticatorFactory func(cred *Cred) (Authenticator, error) @@ -115,8 +117,6 @@ func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr addr return ah.handshakeInfo, nil } -type Config = driver.AuthConfig - // FinishHandshake performs authentication for conn if necessary. func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error { performAuth := ah.options.PerformAuthentication From 2b5cde63e2f5d2b7f6ee02e9de128cf4b1d10cc4 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 09:43:57 -0400 Subject: [PATCH 21/66] GODRIVER-2911: Update comment --- x/mongo/driver/driver.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 2556dc06d8..299eb67142 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -62,7 +62,9 @@ type IDPInfo struct { } // Authenticator handles authenticating a connection. The implementors of this interface -// are all in the auth package. +// are all in the auth package. Most authentication mechanisms to not allow for Reauth, +// but this is included in the interface so that whenever a new mechanism is added, it +// must be explicitly considered. type Authenticator interface { // Auth authenticates the connection. Auth(context.Context, *AuthConfig) error From 368cedd5b8a8b9eadfcefb1ced195de00e1f13b3 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 10:33:30 -0400 Subject: [PATCH 22/66] GODRIVER-2911: Some implementation --- x/mongo/driver/auth/auth.go | 2 +- x/mongo/driver/driver.go | 1 + x/mongo/driver/operation.go | 12 ++++++++++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index ba2d296b11..c5bcd90f75 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -35,7 +35,7 @@ func init() { RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator) RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator) RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator) - RegisterAuthenticatorFactory(OIDC, newOIDCAuthenticator) + RegisterAuthenticatorFactory(MongoDBOIDC, newOIDCAuthenticator) } // CreateAuthenticator creates an authenticator. diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 299eb67142..28dadcb286 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -45,6 +45,7 @@ type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) type OIDCArgs struct { Version int + Timeout time.Time IDPInfo *IDPInfo RefreshToken *string } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 25b78893e1..9ce319d8c9 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -916,12 +916,20 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: - // TODO: actually make sure this Reauths + // 391 is the reauthentication required error code, so we will attempt a reauth and + // retry the operation, if it is successful. if tt.Code == 391 { if op.Authenticator != nil { if err := op.Authenticator.Reauth(ctx); err != nil { - return err + return fmt.Errorf("error reauthenticating: %w", err) } + if op.Client != nil && op.Client.Committing { + // Apply majority write concern for retries + op.Client.UpdateCommitTransactionWriteConcern() + op.WriteConcern = op.Client.CurrentWc + } + resetForRetry(tt) + continue } } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { From e00e05715726355f793201672a69262e2c105b05 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 10:51:03 -0400 Subject: [PATCH 23/66] GODRIVER-2911: Add OIDCTokenGenID to Connection interface --- mongo/integration/mtest/opmsg_deployment.go | 7 +++++++ x/mongo/driver/driver.go | 2 ++ x/mongo/driver/session/client_session.go | 2 ++ x/mongo/driver/topology/connection.go | 22 +++++++++++++++++++++ 4 files changed, 33 insertions(+) diff --git a/mongo/integration/mtest/opmsg_deployment.go b/mongo/integration/mtest/opmsg_deployment.go index 2215f84b38..2ddc23c413 100644 --- a/mongo/integration/mtest/opmsg_deployment.go +++ b/mongo/integration/mtest/opmsg_deployment.go @@ -61,6 +61,13 @@ func (c *connection) WriteWireMessage(context.Context, []byte) error { return nil } +func (c *connection) OIDCTokenGenID() uint64 { + return 0 +} + +func (c *connection) SetOIDCTokenGenID(uint64) { +} + // ReadWireMessage returns the next response in the connection's list of responses. func (c *connection) ReadWireMessage(_ context.Context) ([]byte, error) { var dst []byte diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 28dadcb286..3758f29972 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -136,6 +136,8 @@ type Connection interface { DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64. Address() address.Address Stale() bool + OIDCTokenGenID() uint64 + SetOIDCTokenGenID(uint64) } // RTTMonitor represents a round-trip-time monitor. diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index 8dac0932de..4a6be9c5e4 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -90,6 +90,8 @@ type LoadBalancedTransactionConnection interface { DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64. Address() address.Address Stale() bool + OIDCTokenGenID() uint64 + SetOIDCTokenGenID(uint64) // Functions copied over from driver.PinnedConnection that are not part of Connection or Expirable. PinToCursor() error diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 649e87b3d1..acdd7973ea 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -82,6 +82,10 @@ type connection struct { // awaitingResponse indicates that the server response was not completely // read before returning the connection to the pool. awaitingResponse bool + + // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate + // accessTokens in the OIDC authenticator cache. + oidcTokenGenID uint64 } // newConnection handles the creation of a connection. It does not connect the connection. @@ -606,6 +610,8 @@ type Connection struct { refCount int cleanupPoolFn func() + oidcTokenGenID uint64 + // cleanupServerFn resets the server state when a connection is returned to the connection pool // via Close() or expired via Expire(). cleanupServerFn func() @@ -860,6 +866,14 @@ func configureTLS(ctx context.Context, return client, nil } +func (c *Connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +func (c *Connection) SetOIDCTokenGenID(genId uint64) { + c.oidcTokenGenID = genId +} + // TODO: Naming? // cancellListener listens for context cancellation and notifies listeners via a @@ -903,3 +917,11 @@ func (c *cancellListener) StopListening() bool { c.done <- struct{}{} return c.aborted } + +func (c *connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +func (c *connection) SetOIDCTokenGenID(genId uint64) { + c.oidcTokenGenID = genId +} From 1666c6c2c8bb2846abbee4e8acd4c1bd33fdaff2 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 10:53:56 -0400 Subject: [PATCH 24/66] GODRIVER-2911: Add OIDCTokenGenID to Connection interface for types in test packages --- x/mongo/driver/drivertest/channel_conn.go | 7 +++++++ x/mongo/driver/operation_test.go | 2 ++ 2 files changed, 9 insertions(+) diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index 27be4c264d..874f100ef1 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -26,6 +26,13 @@ type ChannelConn struct { Desc description.Server } +func (c *ChannelConn) OIDCTokenGenID() uint64 { + return 0 +} + +func (c *ChannelConn) SetOIDCTokenGenID(uint64) { +} + // WriteWireMessage implements the driver.Connection interface. func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error { // Copy wm in case it came from a buffer pool. diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 6445c9d0f6..27ef3a090d 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -789,6 +789,8 @@ func (m *mockConnection) SupportsStreaming() bool { return m.rCanStream func (m *mockConnection) CurrentlyStreaming() bool { return m.rStreaming } func (m *mockConnection) SetStreaming(streaming bool) { m.rStreaming = streaming } func (m *mockConnection) Stale() bool { return false } +func (m *mockConnection) OIDCTokenGenID() uint64 { return 0 } +func (m *mockConnection) SetOIDCTokenGenID(uint64) {} // TODO:(GODRIVER-2824) replace return type with int64. func (m *mockConnection) DriverConnectionID() uint64 { return 0 } From d90ee3fe2244acf369ef782bf8e187402f2a395c Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 11:17:53 -0400 Subject: [PATCH 25/66] GODRIVER-2911: Actually add the oidc file, whoops --- x/mongo/driver/auth/oidc.go | 261 ++++++++++++++++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 x/mongo/driver/auth/oidc.go diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go new file mode 100644 index 0000000000..3d212a9bc6 --- /dev/null +++ b/x/mongo/driver/auth/oidc.go @@ -0,0 +1,261 @@ +package auth + +import ( + "context" + "fmt" + "sync" + "time" + + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" +) + +const MongoDBOIDC = "MONGODB-OIDC" +const tokenResourceProp = "TOKEN_RESOURCE" +const environmentProp = "ENVIRONMENT" +const allowedHostsProp = "ALLOWED_HOSTS" + +const azureEnvironmentValue = "azure" +const gcpEnvironmentValue = "gcp" + +const apiVersion = 1 +const invalidateSleepTimeout = 100 * time.Millisecond +const machineCallbackTimeout = 60 * time.Second + +var defaultAllowedHosts = []string{ + "*.mongodb.net", + "*.mongodb-qa.net", + "*.mongodb-dev.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + "::1", +} + +type OIDCCallback = driver.OIDCCallback +type OIDCArgs = driver.OIDCArgs +type OIDCCredential = driver.OIDCCredential +type IDPInfo = driver.IDPInfo + +var _ driver.Authenticator = (*OIDCAuthenticator)(nil) +var _ SpeculativeAuthenticator = (*OIDCAuthenticator)(nil) +var _ SaslClient = (*oidcOneStep)(nil) + +// OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, +// and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality +// is only for the OIDC Human flow. +type OIDCAuthenticator struct { + mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct. + + AuthMechanismProperties map[string]string + + cfg *Config + accessToken string + refreshToken *string + idpInfo *IDPInfo + tokenGenID uint64 +} + +func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { + oa := &OIDCAuthenticator{ + AuthMechanismProperties: cred.Props, + } + return oa, nil +} + +type oidcOneStep struct { + accessToken string +} + +func jwtStepRequest(accessToken string) []byte { + return bsoncore.NewDocumentBuilder(). + AppendString("jwt", accessToken). + Build() +} + +func principalStepRequest(principal string) []byte { + doc := bsoncore.NewDocumentBuilder() + if principal != "" { + doc.AppendString("n", principal) + } + return doc.Build() +} + +func (oos *oidcOneStep) Start() (string, []byte, error) { + return MongoDBOIDC, jwtStepRequest(oos.accessToken), nil +} + +func (oos *oidcOneStep) Next([]byte) ([]byte, error) { + return nil, newAuthError("unexpected step in OIDC authentication", nil) +} + +func (*oidcOneStep) Completed() bool { + return true +} + +func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { + env, ok := oa.AuthMechanismProperties[environmentProp] + if !ok { + return nil, nil + } + + switch env { + // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider + // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider + } + + return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) +} + +// This should only be called with the Mutex held. +func (oa *OIDCAuthenticator) getAccessToken( + ctx context.Context, + args *OIDCArgs, + callback OIDCCallback, +) (string, error) { + if oa.accessToken != "" { + return oa.accessToken, nil + } + + cred, err := callback(ctx, args) + if err != nil { + return "", err + } + + oa.accessToken = cred.AccessToken + oa.tokenGenID += 1 + oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) + if cred.RefreshToken != nil { + oa.refreshToken = cred.RefreshToken + } + return cred.AccessToken, nil +} + +// This should only be called with the Mutex held. +func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( + ctx context.Context, + callback OIDCCallback, + refreshToken string, +) (string, error) { + + cred, err := callback(ctx, &OIDCArgs{ + Version: 1, + IDPInfo: oa.idpInfo, + RefreshToken: &refreshToken, + }) + if err != nil { + return "", err + } + + oa.accessToken = cred.AccessToken + oa.tokenGenID += 1 + oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) + return cred.AccessToken, nil +} + +// invalidateAccessToken invalidates the access token, if the force flag is set to true (which is +// only on a Reauth call) or if the tokenGenID of the connection is greater than or equal to the +// tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, +// but this is a safety check, since extra invalidation is only a performance impact, not a +// correctness impact. +func (oa *OIDCAuthenticator) invalidateAccessToken(force bool) { + oa.mu.Lock() + defer oa.mu.Unlock() + tokenGenID := oa.cfg.Connection.OIDCTokenGenID() + if force || tokenGenID >= oa.tokenGenID { + oa.accessToken = "" + oa.cfg.Connection.SetOIDCTokenGenID(0) + } +} + +func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { + oa.invalidateAccessToken(true) + // it should be impossible to get a Reauth when an Auth has never occurred, + // so we assume cfg was properly set. There is nothing to enforce this, however, + // other than the current driver code flow. If cfg is nil, Auth will return an error. + return oa.Auth(ctx, oa.cfg) +} + +// Auth authenticates the connection. +func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { + // the Mutex must be held during the entire Auth call so that multiple racing attempts + // to authenticate will not result in multiple callbacks. The losers on the Mutex will + // retrieve the access token from the Authenticator cache. + oa.mu.Lock() + defer oa.mu.Unlock() + var err error + + if cfg == nil { + return newAuthError(fmt.Sprintf("config must be set for %q authentication", MongoDBOIDC), nil) + } + oa.cfg = cfg + + if oa.accessToken != "" { + err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + accessToken: oa.accessToken, + }) + if err == nil { + return nil + } + oa.invalidateAccessToken(false) + time.Sleep(invalidateSleepTimeout) + } + + if cfg.OIDCHumanCallback != nil { + return oa.doAuthHuman(ctx, cfg, cfg.OIDCHumanCallback) + } + + // Handle user provided or automatic provider machine callback. + var machineCallback OIDCCallback + if cfg.OIDCMachineCallback != nil { + machineCallback = cfg.OIDCMachineCallback + } else { + machineCallback, err = oa.providerCallback() + if err != nil { + return fmt.Errorf("error getting built-in OIDC provider: %w", err) + } + } + + if machineCallback != nil { + return oa.doAuthMachine(ctx, cfg, machineCallback) + } + return newAuthError("no OIDC callback provided", nil) +} + +func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *Config, humanCallback OIDCCallback) error { + // TODO GODRIVER-3246: Implement OIDC human flow + return newAuthError("OIDC human flow not implemented yet", nil) +} + +func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { + accessToken, err := oa.getAccessToken(ctx, + &OIDCArgs{Version: 1, + Timeout: time.Now().Add(machineCallbackTimeout), + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: nil, + RefreshToken: nil, + }, + machineCallback) + if err != nil { + return err + } + err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + accessToken: accessToken, + }) + if err == nil { + return nil + } + return nil +} + +// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. +func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + accessToken := oa.accessToken + if accessToken == "" { + return nil, nil // Skip speculative auth. + } + + return newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", true), nil +} From 19ed2613023dbbe17778c903cb015c57f36d62a4 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 11:31:58 -0400 Subject: [PATCH 26/66] GODRIVER-2911: Fix nil pointer error --- mongo/client.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index ec72212e87..cda768cf0d 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -215,14 +215,16 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return nil, err } - // Create an authenticator for the client - client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ - Source: clientOpt.Auth.AuthSource, - Username: clientOpt.Auth.Username, - Password: clientOpt.Auth.Password, - PasswordSet: clientOpt.Auth.PasswordSet, - Props: clientOpt.Auth.AuthMechanismProperties, - }) + if clientOpt.Auth != nil { + // Create an authenticator for the client + client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ + Source: clientOpt.Auth.AuthSource, + Username: clientOpt.Auth.Username, + Password: clientOpt.Auth.Password, + PasswordSet: clientOpt.Auth.PasswordSet, + Props: clientOpt.Auth.AuthMechanismProperties, + }) + } cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator) if err != nil { From 41122087ee5aa0ce41662dc6eba24c702322b131 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 11:59:28 -0400 Subject: [PATCH 27/66] GODRIVER-2911: Fix fmt --- x/mongo/driver/operation/command.go | 3 +-- x/mongo/driver/operation/count.go | 6 ++---- x/mongo/driver/operation/create.go | 5 ++--- x/mongo/driver/operation/hello.go | 8 +++----- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 92fb250cf0..9dd10f3cb0 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -22,7 +22,7 @@ import ( // Command is used to run a generic operation. type Command struct { - authenticator driver.Authenticator + authenticator driver.Authenticator command bsoncore.Document database string deployment driver.Deployment @@ -231,4 +231,3 @@ func (c *Command) Authenticator(authenticator driver.Authenticator) *Command { c.authenticator = authenticator return c } - diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 0729db996b..462236059d 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -25,7 +25,7 @@ import ( // Count represents a count operation. type Count struct { - authenticator driver.Authenticator + authenticator driver.Authenticator maxTime *time.Duration query bsoncore.Document session *session.Client @@ -130,7 +130,7 @@ func (c *Count) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Name: driverutil.CountOp, - Authenticator: c.authenticator, + Authenticator: c.authenticator, }.Execute(ctx) // Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace @@ -325,7 +325,6 @@ func (c *Count) Authenticator(authenticator driver.Authenticator) *Count { return c } - // Authenticator sets the authenticator to use for this operation. func (c *CountResult) Authenticator(authenticator driver.Authenticator) *CountResult { if c == nil { @@ -335,4 +334,3 @@ func (c *CountResult) Authenticator(authenticator driver.Authenticator) *CountRe c.authenticator = authenticator return c } - diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 394d47676b..4878e2c777 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -20,7 +20,7 @@ import ( // Create represents a create operation. type Create struct { - authenticator driver.Authenticator + authenticator driver.Authenticator capped *bool collation bsoncore.Document changeStreamPreAndPostImages bsoncore.Document @@ -78,7 +78,7 @@ func (c *Create) Execute(ctx context.Context) error { Selector: c.selector, WriteConcern: c.writeConcern, ServerAPI: c.serverAPI, - Authenticator: c.authenticator, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -411,4 +411,3 @@ func (c *Create) Authenticator(authenticator driver.Authenticator) *Create { c.authenticator = authenticator return c } - diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 472b5b3ce8..2757dc0a1f 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -36,7 +36,7 @@ const driverName = "mongo-go-driver" // Hello is used to run the handshake operation. type Hello struct { - authenticator driver.Authenticator + authenticator driver.Authenticator appname string compressors []string saslSupportedMechs string @@ -203,8 +203,8 @@ func getFaasEnvName() string { type containerInfo struct { authenticator driver.Authenticator - runtime string - orchestrator string + runtime string + orchestrator string } // getContainerEnvInfo returns runtime and orchestrator of a container. @@ -662,7 +662,6 @@ func (h *Hello) Authenticator(authenticator driver.Authenticator) *Hello { return h } - // Authenticator sets the authenticator to use for this operation. func (c *containerInfo) Authenticator(authenticator driver.Authenticator) *containerInfo { if c == nil { @@ -672,4 +671,3 @@ func (c *containerInfo) Authenticator(authenticator driver.Authenticator) *conta c.authenticator = authenticator return c } - From 03c4c0897e97263f35ecad2131b39c2ae370aedb Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 12:01:02 -0400 Subject: [PATCH 28/66] GODRIVER-2911: Fix build failure --- x/mongo/driver/auth/gssapi.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 4b860ba63f..413a2e5661 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -57,3 +57,8 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { } return ConductSaslConversation(ctx, cfg, "$external", client) } + +// Reauth reauthenticates the connection. +func (a *GSSAPIAuthenticator) Reauth(ctx context.Context, cfg *Config) error { + return newAuthError("GSSAPI does not support reauthentication", nil) +} From dac0468cfac3b2204de6ac535a4cb1af755b9d5f Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 16:00:53 -0400 Subject: [PATCH 29/66] GODRIVER-2911: well, that was silly --- x/mongo/driver/auth/gssapi.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 413a2e5661..45c0547e55 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -59,6 +59,6 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *GSSAPIAuthenticator) Reauth(ctx context.Context, cfg *Config) error { +func (a *GSSAPIAuthenticator) Reauth(ctx context.Context) error { return newAuthError("GSSAPI does not support reauthentication", nil) } From 651af6689ae3be34067cc5332e7fa42dc848d6fb Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 17:22:50 -0400 Subject: [PATCH 30/66] GODRIVER-2911: Add licenses and fix comment --- x/mongo/driver/auth/oidc.go | 6 ++++++ x/mongo/driver/driver.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 3d212a9bc6..eccee35b40 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -1,3 +1,9 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + package auth import ( diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 3758f29972..626a307b61 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -63,7 +63,7 @@ type IDPInfo struct { } // Authenticator handles authenticating a connection. The implementors of this interface -// are all in the auth package. Most authentication mechanisms to not allow for Reauth, +// are all in the auth package. Most authentication mechanisms do not allow for Reauth, // but this is included in the interface so that whenever a new mechanism is added, it // must be explicitly considered. type Authenticator interface { From 6b16e91dc23cb01d1d7b5575475cf703de0fd886 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 18:06:05 -0400 Subject: [PATCH 31/66] GODRIVER-2911: Fix receiver names and remove authenticator field from structs that were automatically updated when not needed --- x/mongo/driver/auth/oidc.go | 4 +-- x/mongo/driver/operation/abort_transaction.go | 10 +++--- .../driver/operation/commit_transaction.go | 10 +++--- x/mongo/driver/operation/count.go | 11 ------ x/mongo/driver/operation/create_indexes.go | 21 +++--------- .../driver/operation/create_search_indexes.go | 34 ++++--------------- x/mongo/driver/operation/delete.go | 11 ------ x/mongo/driver/operation/drop_collection.go | 21 +++--------- x/mongo/driver/operation/drop_database.go | 10 +++--- x/mongo/driver/operation/drop_indexes.go | 11 ------ x/mongo/driver/operation/drop_search_index.go | 23 ++++--------- x/mongo/driver/operation/end_sessions.go | 10 +++--- x/mongo/driver/operation/find_and_modify.go | 32 +++-------------- x/mongo/driver/operation/hello.go | 15 ++------ x/mongo/driver/operation/listDatabases.go | 17 ++-------- x/mongo/driver/operation/list_collections.go | 10 +++--- x/mongo/driver/operation/list_indexes.go | 10 +++--- x/mongo/driver/operation/update.go | 26 ++------------ .../driver/operation/update_search_index.go | 23 ++++--------- 19 files changed, 72 insertions(+), 237 deletions(-) diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index eccee35b40..0217ca1a96 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -129,7 +129,7 @@ func (oa *OIDCAuthenticator) getAccessToken( } oa.accessToken = cred.AccessToken - oa.tokenGenID += 1 + oa.tokenGenID++ oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) if cred.RefreshToken != nil { oa.refreshToken = cred.RefreshToken @@ -154,7 +154,7 @@ func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( } oa.accessToken = cred.AccessToken - oa.tokenGenID += 1 + oa.tokenGenID++ oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) return cred.AccessToken, nil } diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index f03949d15c..aeee533533 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -203,11 +203,11 @@ func (at *AbortTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Abort } // Authenticator sets the authenticator to use for this operation. -func (a *AbortTransaction) Authenticator(authenticator driver.Authenticator) *AbortTransaction { - if a == nil { - a = new(AbortTransaction) +func (at *AbortTransaction) Authenticator(authenticator driver.Authenticator) *AbortTransaction { + if at == nil { + at = new(AbortTransaction) } - a.authenticator = authenticator - return a + at.authenticator = authenticator + return at } diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index b6414f60b6..6b402bdf63 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -205,11 +205,11 @@ func (ct *CommitTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Comm } // Authenticator sets the authenticator to use for this operation. -func (c *CommitTransaction) Authenticator(authenticator driver.Authenticator) *CommitTransaction { - if c == nil { - c = new(CommitTransaction) +func (ct *CommitTransaction) Authenticator(authenticator driver.Authenticator) *CommitTransaction { + if ct == nil { + ct = new(CommitTransaction) } - c.authenticator = authenticator - return c + ct.authenticator = authenticator + return ct } diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 462236059d..eaafc9a244 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -47,7 +47,6 @@ type Count struct { // CountResult represents a count result returned by the server. type CountResult struct { - authenticator driver.Authenticator // The number of documents found N int64 } @@ -324,13 +323,3 @@ func (c *Count) Authenticator(authenticator driver.Authenticator) *Count { c.authenticator = authenticator return c } - -// Authenticator sets the authenticator to use for this operation. -func (c *CountResult) Authenticator(authenticator driver.Authenticator) *CountResult { - if c == nil { - c = new(CountResult) - } - - c.authenticator = authenticator - return c -} diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 502182b304..464c1762de 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -44,7 +44,6 @@ type CreateIndexes struct { // CreateIndexesResult represents a createIndexes result returned by the server. type CreateIndexesResult struct { - authenticator driver.Authenticator // If the collection was created automatically. CreatedCollectionAutomatically bool // The number of indexes existing after this command. @@ -283,21 +282,11 @@ func (ci *CreateIndexes) Timeout(timeout *time.Duration) *CreateIndexes { } // Authenticator sets the authenticator to use for this operation. -func (c *CreateIndexes) Authenticator(authenticator driver.Authenticator) *CreateIndexes { - if c == nil { - c = new(CreateIndexes) - } - - c.authenticator = authenticator - return c -} - -// Authenticator sets the authenticator to use for this operation. -func (c *CreateIndexesResult) Authenticator(authenticator driver.Authenticator) *CreateIndexesResult { - if c == nil { - c = new(CreateIndexesResult) +func (ci *CreateIndexes) Authenticator(authenticator driver.Authenticator) *CreateIndexes { + if ci == nil { + ci = new(CreateIndexes) } - c.authenticator = authenticator - return c + ci.authenticator = authenticator + return ci } diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index 2300e3df4e..8185d27fe1 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -39,13 +39,11 @@ type CreateSearchIndexes struct { // CreateSearchIndexResult represents a single search index result in CreateSearchIndexesResult. type CreateSearchIndexResult struct { - authenticator driver.Authenticator - Name string + Name string } // CreateSearchIndexesResult represents a createSearchIndexes result returned by the server. type CreateSearchIndexesResult struct { - authenticator driver.Authenticator IndexesCreated []CreateSearchIndexResult } @@ -243,31 +241,11 @@ func (csi *CreateSearchIndexes) Timeout(timeout *time.Duration) *CreateSearchInd } // Authenticator sets the authenticator to use for this operation. -func (c *CreateSearchIndexes) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexes { - if c == nil { - c = new(CreateSearchIndexes) - } - - c.authenticator = authenticator - return c -} - -// Authenticator sets the authenticator to use for this operation. -func (c *CreateSearchIndexResult) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexResult { - if c == nil { - c = new(CreateSearchIndexResult) - } - - c.authenticator = authenticator - return c -} - -// Authenticator sets the authenticator to use for this operation. -func (c *CreateSearchIndexesResult) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexesResult { - if c == nil { - c = new(CreateSearchIndexesResult) +func (csi *CreateSearchIndexes) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexes { + if csi == nil { + csi = new(CreateSearchIndexes) } - c.authenticator = authenticator - return c + csi.authenticator = authenticator + return csi } diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index f5a4fe7f79..298ec44196 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -49,7 +49,6 @@ type Delete struct { // DeleteResult represents a delete result returned by the server. type DeleteResult struct { - authenticator driver.Authenticator // Number of documents successfully deleted. N int64 } @@ -341,13 +340,3 @@ func (d *Delete) Authenticator(authenticator driver.Authenticator) *Delete { d.authenticator = authenticator return d } - -// Authenticator sets the authenticator to use for this operation. -func (d *DeleteResult) Authenticator(authenticator driver.Authenticator) *DeleteResult { - if d == nil { - d = new(DeleteResult) - } - - d.authenticator = authenticator - return d -} diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index ff6ffa181c..5a32c2f8d4 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -40,7 +40,6 @@ type DropCollection struct { // DropCollectionResult represents a dropCollection result returned by the server. type DropCollectionResult struct { - authenticator driver.Authenticator // The number of indexes in the dropped collection. NIndexesWas int32 // The namespace of the dropped collection. @@ -227,21 +226,11 @@ func (dc *DropCollection) Timeout(timeout *time.Duration) *DropCollection { } // Authenticator sets the authenticator to use for this operation. -func (d *DropCollection) Authenticator(authenticator driver.Authenticator) *DropCollection { - if d == nil { - d = new(DropCollection) - } - - d.authenticator = authenticator - return d -} - -// Authenticator sets the authenticator to use for this operation. -func (d *DropCollectionResult) Authenticator(authenticator driver.Authenticator) *DropCollectionResult { - if d == nil { - d = new(DropCollectionResult) +func (dc *DropCollection) Authenticator(authenticator driver.Authenticator) *DropCollection { + if dc == nil { + dc = new(DropCollection) } - d.authenticator = authenticator - return d + dc.authenticator = authenticator + return dc } diff --git a/x/mongo/driver/operation/drop_database.go b/x/mongo/driver/operation/drop_database.go index daecc65c5b..19956210d1 100644 --- a/x/mongo/driver/operation/drop_database.go +++ b/x/mongo/driver/operation/drop_database.go @@ -158,11 +158,11 @@ func (dd *DropDatabase) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropDatab } // Authenticator sets the authenticator to use for this operation. -func (d *DropDatabase) Authenticator(authenticator driver.Authenticator) *DropDatabase { - if d == nil { - d = new(DropDatabase) +func (dd *DropDatabase) Authenticator(authenticator driver.Authenticator) *DropDatabase { + if dd == nil { + dd = new(DropDatabase) } - d.authenticator = authenticator - return d + dd.authenticator = authenticator + return dd } diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 3b5fe7bbfe..dfa5bfc267 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -42,7 +42,6 @@ type DropIndexes struct { // DropIndexesResult represents a dropIndexes result returned by the server. type DropIndexesResult struct { - authenticator driver.Authenticator // Number of indexes that existed before the drop was executed. NIndexesWas int32 } @@ -255,13 +254,3 @@ func (d *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropInd d.authenticator = authenticator return d } - -// Authenticator sets the authenticator to use for this operation. -func (d *DropIndexesResult) Authenticator(authenticator driver.Authenticator) *DropIndexesResult { - if d == nil { - d = new(DropIndexesResult) - } - - d.authenticator = authenticator - return d -} diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 0c69d5a104..3d273434d5 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -38,8 +38,7 @@ type DropSearchIndex struct { // DropSearchIndexResult represents a dropSearchIndex result returned by the server. type DropSearchIndexResult struct { - authenticator driver.Authenticator - Ok int32 + Ok int32 } func buildDropSearchIndexResult(response bsoncore.Document) (DropSearchIndexResult, error) { @@ -217,21 +216,11 @@ func (dsi *DropSearchIndex) Timeout(timeout *time.Duration) *DropSearchIndex { } // Authenticator sets the authenticator to use for this operation. -func (d *DropSearchIndex) Authenticator(authenticator driver.Authenticator) *DropSearchIndex { - if d == nil { - d = new(DropSearchIndex) - } - - d.authenticator = authenticator - return d -} - -// Authenticator sets the authenticator to use for this operation. -func (d *DropSearchIndexResult) Authenticator(authenticator driver.Authenticator) *DropSearchIndexResult { - if d == nil { - d = new(DropSearchIndexResult) +func (dsi *DropSearchIndex) Authenticator(authenticator driver.Authenticator) *DropSearchIndex { + if dsi == nil { + dsi = new(DropSearchIndex) } - d.authenticator = authenticator - return d + dsi.authenticator = authenticator + return dsi } diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index dae9165332..8b24b3d8c2 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -165,11 +165,11 @@ func (es *EndSessions) ServerAPI(serverAPI *driver.ServerAPIOptions) *EndSession } // Authenticator sets the authenticator to use for this operation. -func (e *EndSessions) Authenticator(authenticator driver.Authenticator) *EndSessions { - if e == nil { - e = new(EndSessions) +func (es *EndSessions) Authenticator(authenticator driver.Authenticator) *EndSessions { + if es == nil { + es = new(EndSessions) } - e.authenticator = authenticator - return e + es.authenticator = authenticator + return es } diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index de94abeacf..ea365ccb23 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -58,7 +58,6 @@ type FindAndModify struct { // LastErrorObject represents information about updates and upserts returned by the server. type LastErrorObject struct { - authenticator driver.Authenticator // True if an update modified an existing document UpdatedExisting bool // Object ID of the upserted document. @@ -67,7 +66,6 @@ type LastErrorObject struct { // FindAndModifyResult represents a findAndModify result returned by the server. type FindAndModifyResult struct { - authenticator driver.Authenticator // Either the old or modified document, depending on the value of the new parameter. Value bsoncore.Document // Contains information about updates and upserts. @@ -483,31 +481,11 @@ func (fam *FindAndModify) Timeout(timeout *time.Duration) *FindAndModify { } // Authenticator sets the authenticator to use for this operation. -func (f *FindAndModify) Authenticator(authenticator driver.Authenticator) *FindAndModify { - if f == nil { - f = new(FindAndModify) - } - - f.authenticator = authenticator - return f -} - -// Authenticator sets the authenticator to use for this operation. -func (l *LastErrorObject) Authenticator(authenticator driver.Authenticator) *LastErrorObject { - if l == nil { - l = new(LastErrorObject) - } - - l.authenticator = authenticator - return l -} - -// Authenticator sets the authenticator to use for this operation. -func (f *FindAndModifyResult) Authenticator(authenticator driver.Authenticator) *FindAndModifyResult { - if f == nil { - f = new(FindAndModifyResult) +func (fam *FindAndModify) Authenticator(authenticator driver.Authenticator) *FindAndModify { + if fam == nil { + fam = new(FindAndModify) } - f.authenticator = authenticator - return f + fam.authenticator = authenticator + return fam } diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 2757dc0a1f..60c99f063d 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -202,9 +202,8 @@ func getFaasEnvName() string { } type containerInfo struct { - authenticator driver.Authenticator - runtime string - orchestrator string + runtime string + orchestrator string } // getContainerEnvInfo returns runtime and orchestrator of a container. @@ -661,13 +660,3 @@ func (h *Hello) Authenticator(authenticator driver.Authenticator) *Hello { h.authenticator = authenticator return h } - -// Authenticator sets the authenticator to use for this operation. -func (c *containerInfo) Authenticator(authenticator driver.Authenticator) *containerInfo { - if c == nil { - c = new(containerInfo) - } - - c.authenticator = authenticator - return c -} diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/listDatabases.go index 829858a644..b2b9b27cf9 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/listDatabases.go @@ -53,10 +53,9 @@ type ListDatabasesResult struct { } type databaseRecord struct { - authenticator driver.Authenticator - Name string - SizeOnDisk int64 `bson:"sizeOnDisk"` - Empty bool + Name string + SizeOnDisk int64 `bson:"sizeOnDisk"` + Empty bool } func buildListDatabasesResult(response bsoncore.Document) (ListDatabasesResult, error) { @@ -351,13 +350,3 @@ func (l *ListDatabasesResult) Authenticator(authenticator driver.Authenticator) l.authenticator = authenticator return l } - -// Authenticator sets the authenticator to use for this operation. -func (d *databaseRecord) Authenticator(authenticator driver.Authenticator) *databaseRecord { - if d == nil { - d = new(databaseRecord) - } - - d.authenticator = authenticator - return d -} diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index da65b99088..1e39f5bfbe 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -263,11 +263,11 @@ func (lc *ListCollections) Timeout(timeout *time.Duration) *ListCollections { } // Authenticator sets the authenticator to use for this operation. -func (l *ListCollections) Authenticator(authenticator driver.Authenticator) *ListCollections { - if l == nil { - l = new(ListCollections) +func (lc *ListCollections) Authenticator(authenticator driver.Authenticator) *ListCollections { + if lc == nil { + lc = new(ListCollections) } - l.authenticator = authenticator - return l + lc.authenticator = authenticator + return lc } diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index dc3654c884..433344f307 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -237,11 +237,11 @@ func (li *ListIndexes) Timeout(timeout *time.Duration) *ListIndexes { } // Authenticator sets the authenticator to use for this operation. -func (l *ListIndexes) Authenticator(authenticator driver.Authenticator) *ListIndexes { - if l == nil { - l = new(ListIndexes) +func (li *ListIndexes) Authenticator(authenticator driver.Authenticator) *ListIndexes { + if li == nil { + li = new(ListIndexes) } - l.authenticator = authenticator - return l + li.authenticator = authenticator + return li } diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 0da93cbf08..1070e7ca70 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -52,14 +52,12 @@ type Update struct { // Upsert contains the information for an upsert in an Update operation. type Upsert struct { - authenticator driver.Authenticator - Index int64 - ID interface{} `bson:"_id"` + Index int64 + ID interface{} `bson:"_id"` } // UpdateResult contains information for the result of an Update operation. type UpdateResult struct { - authenticator driver.Authenticator // Number of documents matched. N int64 // Number of documents modified. @@ -428,23 +426,3 @@ func (u *Update) Authenticator(authenticator driver.Authenticator) *Update { u.authenticator = authenticator return u } - -// Authenticator sets the authenticator to use for this operation. -func (u *Upsert) Authenticator(authenticator driver.Authenticator) *Upsert { - if u == nil { - u = new(Upsert) - } - - u.authenticator = authenticator - return u -} - -// Authenticator sets the authenticator to use for this operation. -func (u *UpdateResult) Authenticator(authenticator driver.Authenticator) *UpdateResult { - if u == nil { - u = new(UpdateResult) - } - - u.authenticator = authenticator - return u -} diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index 963ef1e1a6..4ed9946c69 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -39,8 +39,7 @@ type UpdateSearchIndex struct { // UpdateSearchIndexResult represents a single index in the updateSearchIndexResult result. type UpdateSearchIndexResult struct { - authenticator driver.Authenticator - Ok int32 + Ok int32 } func buildUpdateSearchIndexResult(response bsoncore.Document) (UpdateSearchIndexResult, error) { @@ -230,21 +229,11 @@ func (usi *UpdateSearchIndex) Timeout(timeout *time.Duration) *UpdateSearchIndex } // Authenticator sets the authenticator to use for this operation. -func (u *UpdateSearchIndex) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndex { - if u == nil { - u = new(UpdateSearchIndex) - } - - u.authenticator = authenticator - return u -} - -// Authenticator sets the authenticator to use for this operation. -func (u *UpdateSearchIndexResult) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndexResult { - if u == nil { - u = new(UpdateSearchIndexResult) +func (usi *UpdateSearchIndex) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndex { + if usi == nil { + usi = new(UpdateSearchIndex) } - u.authenticator = authenticator - return u + usi.authenticator = authenticator + return usi } From 26412ae88e5cd03ec4071a89fe1bfee50ff57f0c Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 14 Jun 2024 18:30:47 -0400 Subject: [PATCH 32/66] GODRIVER-2911: Fix many lints. Linter not running for me locally --- mongo/client.go | 3 + x/mongo/driver/auth/auth.go | 1 + x/mongo/driver/auth/cred.go | 2 + x/mongo/driver/auth/default.go | 2 +- x/mongo/driver/auth/mongodbcr.go | 2 +- x/mongo/driver/auth/oidc.go | 102 ++++++++++++-------- x/mongo/driver/auth/plain.go | 2 +- x/mongo/driver/driver.go | 3 + x/mongo/driver/drivertest/channel_conn.go | 7 +- x/mongo/driver/topology/connection.go | 10 +- x/mongo/driver/topology/topology_options.go | 5 - 11 files changed, 86 insertions(+), 53 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index cda768cf0d..1560ace918 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -224,6 +224,9 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { PasswordSet: clientOpt.Auth.PasswordSet, Props: clientOpt.Auth.AuthMechanismProperties, }) + if err != nil { + return nil, err + } } cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator) diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index c5bcd90f75..bbfdcfe935 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -19,6 +19,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) +// Config contains the configuration for an Authenticator. type Config = driver.AuthConfig // AuthenticatorFactory constructs an authenticator. diff --git a/x/mongo/driver/auth/cred.go b/x/mongo/driver/auth/cred.go index 5f89f29018..a9685f6ed8 100644 --- a/x/mongo/driver/auth/cred.go +++ b/x/mongo/driver/auth/cred.go @@ -3,10 +3,12 @@ // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + package auth import ( "go.mongodb.org/mongo-driver/x/mongo/driver" ) +// Cred is the type of user credential type Cred = driver.Cred diff --git a/x/mongo/driver/auth/default.go b/x/mongo/driver/auth/default.go index 5e8c5b98ec..a07fba1faa 100644 --- a/x/mongo/driver/auth/default.go +++ b/x/mongo/driver/auth/default.go @@ -67,7 +67,7 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *DefaultAuthenticator) Reauth(ctx context.Context) error { +func (a *DefaultAuthenticator) Reauth(_ context.Context) error { return newAuthError("DefaultAuthenticator does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 7301469921..41e21d2dea 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -98,7 +98,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *MongoDBCRAuthenticator) Reauth(ctx context.Context) error { +func (a *MongoDBCRAuthenticator) Reauth(_ context.Context) error { return newAuthError("MONGODB-CR does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 0217ca1a96..cce6c68fc7 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -16,9 +16,14 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver" ) +// MongoDBOIDC is the string constant for the MONGODB-OIDC authentication mechanism. const MongoDBOIDC = "MONGODB-OIDC" -const tokenResourceProp = "TOKEN_RESOURCE" + +// TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider +// const tokenResourceProp = "TOKEN_RESOURCE" const environmentProp = "ENVIRONMENT" + +// GODRIVER-3249 OIDC: Handle all possible OIDC configuration errors const allowedHostsProp = "ALLOWED_HOSTS" const azureEnvironmentValue = "azure" @@ -28,19 +33,27 @@ const apiVersion = 1 const invalidateSleepTimeout = 100 * time.Millisecond const machineCallbackTimeout = 60 * time.Second -var defaultAllowedHosts = []string{ - "*.mongodb.net", - "*.mongodb-qa.net", - "*.mongodb-dev.net", - "*.mongodbgov.net", - "localhost", - "127.0.0.1", - "::1", -} - +//GODRIVER-3246 OIDC: Implement Human Callback Mechanism +//var defaultAllowedHosts = []string{ +// "*.mongodb.net", +// "*.mongodb-qa.net", +// "*.mongodb-dev.net", +// "*.mongodbgov.net", +// "localhost", +// "127.0.0.1", +// "::1", +//} + +// OIDCCallback is a function that takes a context and OIDCArgs and returns an OIDCCredential. type OIDCCallback = driver.OIDCCallback + +// OIDCArgs contains the arguments for the OIDC callback. type OIDCArgs = driver.OIDCArgs + +// OIDCCredential contains the access token and refresh token. type OIDCCredential = driver.OIDCCredential + +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. type IDPInfo = driver.IDPInfo var _ driver.Authenticator = (*OIDCAuthenticator)(nil) @@ -79,13 +92,14 @@ func jwtStepRequest(accessToken string) []byte { Build() } -func principalStepRequest(principal string) []byte { - doc := bsoncore.NewDocumentBuilder() - if principal != "" { - doc.AppendString("n", principal) - } - return doc.Build() -} +// TODO GODRIVER-3246: Implement OIDC human flow +//func principalStepRequest(principal string) []byte { +// doc := bsoncore.NewDocumentBuilder() +// if principal != "" { +// doc.AppendString("n", principal) +// } +// return doc.Build() +//} func (oos *oidcOneStep) Start() (string, []byte, error) { return MongoDBOIDC, jwtStepRequest(oos.accessToken), nil @@ -108,6 +122,11 @@ func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { switch env { // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider + // This is here just to pass the linter, it will be fixed in one of the above tickets. + case azureEnvironmentValue, gcpEnvironmentValue: + return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { + return nil, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) + }, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) } return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) @@ -137,27 +156,28 @@ func (oa *OIDCAuthenticator) getAccessToken( return cred.AccessToken, nil } +// TODO GODRIVER-3246: Implement OIDC human flow // This should only be called with the Mutex held. -func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( - ctx context.Context, - callback OIDCCallback, - refreshToken string, -) (string, error) { - - cred, err := callback(ctx, &OIDCArgs{ - Version: 1, - IDPInfo: oa.idpInfo, - RefreshToken: &refreshToken, - }) - if err != nil { - return "", err - } - - oa.accessToken = cred.AccessToken - oa.tokenGenID++ - oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) - return cred.AccessToken, nil -} +//func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( +// ctx context.Context, +// callback OIDCCallback, +// refreshToken string, +//) (string, error) { +// +// cred, err := callback(ctx, &OIDCArgs{ +// Version: apiVersion, +// IDPInfo: oa.idpInfo, +// RefreshToken: &refreshToken, +// }) +// if err != nil { +// return "", err +// } +// +// oa.accessToken = cred.AccessToken +// oa.tokenGenID++ +// oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) +// return cred.AccessToken, nil +//} // invalidateAccessToken invalidates the access token, if the force flag is set to true (which is // only on a Reauth call) or if the tokenGenID of the connection is greater than or equal to the @@ -174,6 +194,8 @@ func (oa *OIDCAuthenticator) invalidateAccessToken(force bool) { } } +// Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the +// driver.Authenticator interface. func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { oa.invalidateAccessToken(true) // it should be impossible to get a Reauth when an Auth has never occurred, @@ -230,12 +252,14 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *Config, humanCallback OIDCCallback) error { // TODO GODRIVER-3246: Implement OIDC human flow + // Println is for linter + fmt.Println("OIDC human flow not implemented yet", oa.idpInfo) return newAuthError("OIDC human flow not implemented yet", nil) } func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { accessToken, err := oa.getAccessToken(ctx, - &OIDCArgs{Version: 1, + &OIDCArgs{Version: apiVersion, Timeout: time.Now().Add(machineCallbackTimeout), // idpInfo is nil for machine callbacks in the current spec. IDPInfo: nil, diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 21ffe0465e..a11f0d6e06 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -35,7 +35,7 @@ func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *PlainAuthenticator) Reauth(ctx context.Context) error { +func (a *PlainAuthenticator) Reauth(_ context.Context) error { return newAuthError("Plain authentication does not support reauthentication", nil) } diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 626a307b61..cb5bf57b21 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -43,6 +43,7 @@ type AuthConfig struct { // nil in the OIDCArgs for the Machine flow. type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) +// OIDCArgs contains the arguments for the OIDC callback. type OIDCArgs struct { Version int Timeout time.Time @@ -50,12 +51,14 @@ type OIDCArgs struct { RefreshToken *string } +// OIDCCredential contains the access token and refresh token. type OIDCCredential struct { AccessToken string ExpiresAt *time.Time RefreshToken *string } +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. type IDPInfo struct { Issuer string `bson:"issuer"` ClientID string `bson:"clientId"` diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index 874f100ef1..e2ca9642e7 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -26,12 +26,15 @@ type ChannelConn struct { Desc description.Server } +// OIDCTokenGenID implements the driver.Connection interface by returning the OIDCToken generation +// (which is always 0) func (c *ChannelConn) OIDCTokenGenID() uint64 { return 0 } -func (c *ChannelConn) SetOIDCTokenGenID(uint64) { -} +// OIDCTokenGenID implements the driver.Connection interface by setting the OIDCToken generation +// (which is always 0) +func (c *ChannelConn) SetOIDCTokenGenID(uint64) {} // WriteWireMessage implements the driver.Connection interface. func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error { diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index acdd7973ea..49a613aef8 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -866,12 +866,14 @@ func configureTLS(ctx context.Context, return client, nil } +// OIDCTokenGenID returns the OIDC token generation ID. func (c *Connection) OIDCTokenGenID() uint64 { return c.oidcTokenGenID } -func (c *Connection) SetOIDCTokenGenID(genId uint64) { - c.oidcTokenGenID = genId +// SetOIDCTokenGenID sets the OIDC token generation ID. +func (c *Connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID } // TODO: Naming? @@ -922,6 +924,6 @@ func (c *connection) OIDCTokenGenID() uint64 { return c.oidcTokenGenID } -func (c *connection) SetOIDCTokenGenID(genId uint64) { - c.oidcTokenGenID = genId +func (c *connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID } diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 820b62a7ff..7d8dd62dc0 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -202,11 +202,6 @@ func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.Cluste } } - authenticator, err := auth.CreateAuthenticator(mechanism, cred) - if err != nil { - return nil, err - } - handshakeOpts := &auth.HandshakeOptions{ AppName: appName, Authenticator: authenticator, From 98e8cbe75160300ef1abbbc12560aad19d4248b3 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Sat, 15 Jun 2024 16:19:00 -0400 Subject: [PATCH 33/66] GODRIVER-2911: Fix lints --- x/mongo/driver/auth/mongodbaws.go | 2 +- x/mongo/driver/auth/oidc.go | 4 ++-- x/mongo/driver/auth/scram.go | 2 +- x/mongo/driver/auth/x509.go | 2 +- x/mongo/driver/drivertest/channel_conn.go | 2 +- x/mongo/driver/operation/drop_indexes.go | 10 +++++----- x/mongo/driver/operation/listDatabases.go | 21 +++++---------------- 7 files changed, 16 insertions(+), 27 deletions(-) diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 985fc35c03..af9f0f0b18 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -61,7 +61,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *MongoDBAWSAuthenticator) Reauth(ctx context.Context) error { +func (a *MongoDBAWSAuthenticator) Reauth(_ context.Context) error { return newAuthError("AWS authentication does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index cce6c68fc7..b17a687b76 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -24,7 +24,7 @@ const MongoDBOIDC = "MONGODB-OIDC" const environmentProp = "ENVIRONMENT" // GODRIVER-3249 OIDC: Handle all possible OIDC configuration errors -const allowedHostsProp = "ALLOWED_HOSTS" +//const allowedHostsProp = "ALLOWED_HOSTS" const azureEnvironmentValue = "azure" const gcpEnvironmentValue = "gcp" @@ -250,7 +250,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { return newAuthError("no OIDC callback provided", nil) } -func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *Config, humanCallback OIDCCallback) error { +func (oa *OIDCAuthenticator) doAuthHuman(_ context.Context, _ *Config, _ OIDCCallback) error { // TODO GODRIVER-3246: Implement OIDC human flow // Println is for linter fmt.Println("OIDC human flow not implemented yet", oa.idpInfo) diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index 8118e0a6ab..96bd0b89be 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -85,7 +85,7 @@ func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *ScramAuthenticator) Reauth(ctx context.Context) error { +func (a *ScramAuthenticator) Reauth(_ context.Context) error { return newAuthError("SCRAM does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 4cbbc246c5..06773551a7 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -78,6 +78,6 @@ func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error } // Reauth reauthenticates the connection. -func (a *MongoDBX509Authenticator) Reauth(ctx context.Context) error { +func (a *MongoDBX509Authenticator) Reauth(_ context.Context) error { return newAuthError("X509 does not support reauthentication", nil) } diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index e2ca9642e7..d002398a5b 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -32,7 +32,7 @@ func (c *ChannelConn) OIDCTokenGenID() uint64 { return 0 } -// OIDCTokenGenID implements the driver.Connection interface by setting the OIDCToken generation +// SetOIDCTokenGenID implements the driver.Connection interface by setting the OIDCToken generation // (which is always 0) func (c *ChannelConn) SetOIDCTokenGenID(uint64) {} diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index dfa5bfc267..e4f924e4e1 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -246,11 +246,11 @@ func (di *DropIndexes) Timeout(timeout *time.Duration) *DropIndexes { } // Authenticator sets the authenticator to use for this operation. -func (d *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropIndexes { - if d == nil { - d = new(DropIndexes) +func (di *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropIndexes { + if di == nil { + di = new(DropIndexes) } - d.authenticator = authenticator - return d + di.authenticator = authenticator + return di } diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/listDatabases.go index b2b9b27cf9..3df171e37a 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/listDatabases.go @@ -45,7 +45,6 @@ type ListDatabases struct { // ListDatabasesResult represents a listDatabases result returned by the server. type ListDatabasesResult struct { - authenticator driver.Authenticator // An array of documents, one document for each database Databases []databaseRecord // The sum of the size of all the database files on disk in bytes. @@ -332,21 +331,11 @@ func (ld *ListDatabases) Timeout(timeout *time.Duration) *ListDatabases { } // Authenticator sets the authenticator to use for this operation. -func (l *ListDatabases) Authenticator(authenticator driver.Authenticator) *ListDatabases { - if l == nil { - l = new(ListDatabases) - } - - l.authenticator = authenticator - return l -} - -// Authenticator sets the authenticator to use for this operation. -func (l *ListDatabasesResult) Authenticator(authenticator driver.Authenticator) *ListDatabasesResult { - if l == nil { - l = new(ListDatabasesResult) +func (ld *ListDatabases) Authenticator(authenticator driver.Authenticator) *ListDatabases { + if ld == nil { + ld = new(ListDatabases) } - l.authenticator = authenticator - return l + ld.authenticator = authenticator + return ld } From 78fa2175b4ca06e79ecb1f053a7e4a7e7d5e5b89 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Sat, 15 Jun 2024 16:33:30 -0400 Subject: [PATCH 34/66] GODRIVER-2911: Fix spelling error --- x/mongo/driver/driver.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index cb5bf57b21..1ae6588f0e 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -65,7 +65,7 @@ type IDPInfo struct { RequestScopes []string `bson:"requestScopes"` } -// Authenticator handles authenticating a connection. The implementors of this interface +// Authenticator handles authenticating a connection. The implementers of this interface // are all in the auth package. Most authentication mechanisms do not allow for Reauth, // but this is included in the interface so that whenever a new mechanism is added, it // must be explicitly considered. From 46fa6f3c525b5b65c59d6c3db616f2325c18944a Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Tue, 18 Jun 2024 16:32:34 -0400 Subject: [PATCH 35/66] GODRIVER-2911: Testing checkpoint --- .evergreen/config.yml | 56 +++++++++++++++++++++++++ Makefile | 4 ++ cmd/testoidcauth/main.go | 76 ++++++++++++++++++++++++++++++++++ etc/run-oidc-test.sh | 33 +++++++++++++++ mongo/client.go | 12 +++--- mongo/options/clientoptions.go | 2 + x/mongo/driver/auth/oidc.go | 12 ++++-- x/mongo/driver/driver.go | 26 ++++++------ 8 files changed, 199 insertions(+), 22 deletions(-) create mode 100644 cmd/testoidcauth/main.go create mode 100644 etc/run-oidc-test.sh diff --git a/.evergreen/config.yml b/.evergreen/config.yml index b078af8066..b9ba851124 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -350,6 +350,23 @@ functions: chmod +x $i done + assume-ec2-role: + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + + run-oidc-auth-test-with-test-credentials: + - command: shell.exec + type: test + params: + working_dir: src/go.mongodb.org/mongo-driver + shell: bash + include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + export OIDC="oidc" + bash ${PROJECT_DIRECTORY}/etc/run-oidc-test.sh + run-make: - command: shell.exec type: test @@ -1954,6 +1971,10 @@ tasks: popd ./.evergreen/run-deployed-lambda-aws-tests.sh + - name: "oidc-auth-test-latest" + commands: + - func: "run-oidc-auth-test-with-test-credentials" + - name: "test-search-index" commands: - func: "bootstrap-mongo-orchestration" @@ -2247,6 +2268,31 @@ task_groups: tasks: - testazurekms-task + - name: testoidc_task_group + setup_group: + - func: fetch-source + - func: prepare-resources + - func: fix-absolute-paths + - func: make-files-executable + - func: assume-ec2-role + - command: shell.exec + params: + shell: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-latest + - name: test-aws-lambda-task-group setup_group: - func: fetch-source @@ -2561,3 +2607,13 @@ buildvariants: - name: testazurekms_task_group batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README - testazurekms-fail-task + + - name: testoidc-variant + display_name: "OIDC" + run_on: + - rhel8.7-small + expansions: + GO_DIST: "/opt/golang/go1.20" + tasks: + - name: testoidc_task_group + batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README diff --git a/Makefile b/Makefile index 88bc756390..3ee9c4b912 100644 --- a/Makefile +++ b/Makefile @@ -132,6 +132,10 @@ evg-test-atlas-data-lake: evg-test-enterprise-auth: go run -tags gssapi ./cmd/testentauth/main.go +.PHONY: evg-test-oidc-auth +evg-test-oidc-auth: + go run -tags oidc ./cmd/testoidcauth/main.go + .PHONY: evg-test-kmip evg-test-kmip: go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH) DYLD_LIBRARY_PATH=$(MACOS_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec/kmipKMS >> test.suite diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go new file mode 100644 index 0000000000..a8d114ca79 --- /dev/null +++ b/cmd/testoidcauth/main.go @@ -0,0 +1,76 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "os" + "path" + "sync" + "time" + + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver" +) + +var uriAdmin = os.Getenv("MONGODB_URI") +var uriSingle = os.Getenv("MONGODB_URI_SINGLE") +var uriMulti = os.Getenv("MONGODB_URI_MULTI") +var oidcTokenDir = path.Join(os.Getenv("OIDC_TOKEN_DIR"), "tmp", "tokens") +var noUserTokenFile = os.Getenv("OIDC_TOKEN_FILE") +var oidcDomain = os.Getenv("OIDC_DOMAIN") + +func explicitUser(user string) string { + return fmt.Sprintf("%s@%s", user, oidcDomain) +} + +func main() { + machine_1_1_callbackIsCalled() +} + +func machine_1_1_callbackIsCalled() { + callbackCount := 0 + countMutex := sync.Mutex{} + + opts := options.Client().ApplyURI(uriSingle) + fmt.Println("machine_1_1_callbackIsCalled: uriSingle: ", uriSingle) + opts.Auth.OIDCMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + accessToken, err := os.ReadFile(noUserTokenFile) + if err != nil { + log.Fatalf("machine_1_1_callbackIsCalled: failed reading token file: %v", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + + client, err := mongo.Connect( + context.Background(), + options.Client().ApplyURI(uriSingle)) + if err != nil { + log.Fatalf("Error connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + res := coll.FindOne(context.Background(), nil) + if res == nil || res.Err() != nil { + log.Fatalf("machine_1_1_callbackIsCalled: failed executing FindOne: %v", err) + } + if callbackCount != 1 { + log.Fatalf("machine_1_1_callbackIsCalled: expected callback count to be 1, got %d", callbackCount) + } +} diff --git a/etc/run-oidc-test.sh b/etc/run-oidc-test.sh new file mode 100644 index 0000000000..755e162dd1 --- /dev/null +++ b/etc/run-oidc-test.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# run-enterprise-gssapi-test +# Runs the enterprise auth tests with gssapi credentials. +set -eu + +echo "Running MONGODB-OIDC authentication tests" + +OIDC_ENV="${OIDC_ENV:-"test"}" + +if [ $OIDC_ENV == "test" ]; then + # Make sure DRIVERS_TOOLS is set. + if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" + exit 1 + fi + source ${DRIVERS_TOOLS}/.evergreen/auth_oidc/secrets-export.sh + +elif [ $OIDC_ENV == "azure" ]; then + source ./env.sh + +elif [ $OIDC_ENV == "gcp" ]; then + source ./secrets-export.sh + +else + echo "Unrecognized OIDC_ENV $OIDC_ENV" + exit 1 +fi + +export TEST_AUTH_OIDC=1 +export COVERAGE=1 +export AUTH="auth" + +make -s evg-test-oidc-auth diff --git a/mongo/client.go b/mongo/client.go index 1560ace918..082554adbb 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -218,11 +218,13 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { if clientOpt.Auth != nil { // Create an authenticator for the client client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ - Source: clientOpt.Auth.AuthSource, - Username: clientOpt.Auth.Username, - Password: clientOpt.Auth.Password, - PasswordSet: clientOpt.Auth.PasswordSet, - Props: clientOpt.Auth.AuthMechanismProperties, + Source: clientOpt.Auth.AuthSource, + Username: clientOpt.Auth.Username, + Password: clientOpt.Auth.Password, + PasswordSet: clientOpt.Auth.PasswordSet, + Props: clientOpt.Auth.AuthMechanismProperties, + OIDCMachineCallback: clientOpt.Auth.OIDCMachineCallback, + OIDCHumanCallback: clientOpt.Auth.OIDCHumanCallback, }) if err != nil { return nil, err diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index db56745919..1fd7765a30 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -110,6 +110,8 @@ type Credential struct { Username string Password string PasswordSet bool + OIDCMachineCallback driver.OIDCCallback + OIDCHumanCallback driver.OIDCCallback } // BSONOptions are optional BSON marshaling and unmarshaling behaviors. diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index b17a687b76..e65e966500 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -67,6 +67,8 @@ type OIDCAuthenticator struct { mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct. AuthMechanismProperties map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback cfg *Config accessToken string @@ -78,6 +80,8 @@ type OIDCAuthenticator struct { func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { oa := &OIDCAuthenticator{ AuthMechanismProperties: cred.Props, + OIDCMachineCallback: cred.OIDCMachineCallback, + OIDCHumanCallback: cred.OIDCHumanCallback, } return oa, nil } @@ -229,14 +233,14 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { time.Sleep(invalidateSleepTimeout) } - if cfg.OIDCHumanCallback != nil { - return oa.doAuthHuman(ctx, cfg, cfg.OIDCHumanCallback) + if oa.OIDCHumanCallback != nil { + return oa.doAuthHuman(ctx, cfg, oa.OIDCHumanCallback) } // Handle user provided or automatic provider machine callback. var machineCallback OIDCCallback - if cfg.OIDCMachineCallback != nil { - machineCallback = cfg.OIDCMachineCallback + if oa.OIDCMachineCallback != nil { + machineCallback = oa.OIDCMachineCallback } else { machineCallback, err = oa.providerCallback() if err != nil { diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 1ae6588f0e..59f6f85ab4 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -29,14 +29,12 @@ import ( // this was moved from the auth package to avoid a circular dependency. The auth package // reexports this under the old name to avoid breaking the public api. type AuthConfig struct { - Description description.Server - Connection Connection - ClusterClock *session.ClusterClock - HandshakeInfo HandshakeInformation - ServerAPI *ServerAPIOptions - HTTPClient *http.Client - OIDCMachineCallback OIDCCallback - OIDCHumanCallback OIDCCallback + Description description.Server + Connection Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions + HTTPClient *http.Client } // OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be @@ -77,11 +75,13 @@ type Authenticator interface { // Cred is a user's credential. type Cred struct { - Source string - Username string - Password string - PasswordSet bool - Props map[string]string + Source string + Username string + Password string + PasswordSet bool + Props map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback } // Deployment is implemented by types that can select a server from a deployment. From c6d23de5bbd68cde47e669131d2bba5b40c9ca15 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Tue, 18 Jun 2024 18:24:30 -0400 Subject: [PATCH 36/66] GODRIVER-2911: Fix config, fix spec auth --- Makefile | 2 +- cmd/testoidcauth/main.go | 18 +++++++++++------- mongo/collection.go | 1 + x/mongo/driver/auth/auth.go | 15 ++++++++++----- x/mongo/driver/auth/oidc.go | 1 + x/mongo/driver/connstring/connstring.go | 24 ++++++++++++++++++++++++ 6 files changed, 48 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 3ee9c4b912..0c144db30e 100644 --- a/Makefile +++ b/Makefile @@ -134,7 +134,7 @@ evg-test-enterprise-auth: .PHONY: evg-test-oidc-auth evg-test-oidc-auth: - go run -tags oidc ./cmd/testoidcauth/main.go + go run ./cmd/testoidcauth/main.go .PHONY: evg-test-kmip evg-test-kmip: diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index a8d114ca79..8c5173e2c3 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -24,13 +25,16 @@ var uriAdmin = os.Getenv("MONGODB_URI") var uriSingle = os.Getenv("MONGODB_URI_SINGLE") var uriMulti = os.Getenv("MONGODB_URI_MULTI") var oidcTokenDir = path.Join(os.Getenv("OIDC_TOKEN_DIR"), "tmp", "tokens") -var noUserTokenFile = os.Getenv("OIDC_TOKEN_FILE") var oidcDomain = os.Getenv("OIDC_DOMAIN") func explicitUser(user string) string { return fmt.Sprintf("%s@%s", user, oidcDomain) } +func tokenFile(user string) string { + return path.Join(oidcTokenDir, user) +} + func main() { machine_1_1_callbackIsCalled() } @@ -40,13 +44,15 @@ func machine_1_1_callbackIsCalled() { countMutex := sync.Mutex{} opts := options.Client().ApplyURI(uriSingle) - fmt.Println("machine_1_1_callbackIsCalled: uriSingle: ", uriSingle) + opts.Auth.OIDCMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ t := time.Now().Add(time.Hour) - accessToken, err := os.ReadFile(noUserTokenFile) + tokenFile := tokenFile("test_user1") + fmt.Println(tokenFile) + accessToken, err := os.ReadFile(tokenFile) if err != nil { log.Fatalf("machine_1_1_callbackIsCalled: failed reading token file: %v", err) } @@ -57,16 +63,14 @@ func machine_1_1_callbackIsCalled() { }, nil } - client, err := mongo.Connect( - context.Background(), - options.Client().ApplyURI(uriSingle)) + client, err := mongo.Connect(context.Background(), opts) if err != nil { log.Fatalf("Error connecting client: %v", err) } coll := client.Database("test").Collection("test") - res := coll.FindOne(context.Background(), nil) + res := coll.FindOne(context.Background(), bson.D{}) if res == nil || res.Err() != nil { log.Fatalf("machine_1_1_callbackIsCalled: failed executing FindOne: %v", err) } diff --git a/mongo/collection.go b/mongo/collection.go index dbe238a9e3..8a0a054d5e 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1227,6 +1227,7 @@ func (coll *Collection) find( f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { + fmt.Println(err) return nil, err } diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index bbfdcfe935..c9e49756f2 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -101,12 +101,17 @@ func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr addr return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err) } - firstMsg, err := ah.conversation.FirstMessage() - if err != nil { - return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) + // It is possible for the speculative conversation to be nil even without error if the authenticator + // cannot perform speculative authentication. An example of this is MONGODB-OIDC when there is + // no AccessToken in the cache. + if ah.conversation != nil { + firstMsg, err := ah.conversation.FirstMessage() + if err != nil { + return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) + } + + op = op.SpeculativeAuthenticate(firstMsg) } - - op = op.SpeculativeAuthenticate(firstMsg) } } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index e65e966500..7d968b642a 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -210,6 +210,7 @@ func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { // Auth authenticates the connection. func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { + fmt.Println("OIDC Auth!!!") // the Mutex must be held during the entire Auth call so that multiple racing attempts // to authenticate will not result in multiple callbacks. The losers on the Mutex will // retrieve the access token from the Authenticator cache. diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 686458e292..8562214a74 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -302,6 +302,13 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthSource = "admin" } } + case "mongodb-oidc": + if u.AuthSource == "" { + u.AuthSource = dbName + if u.AuthSource == "" { + u.AuthSource = "$external" + } + } case "": // Only set auth source if there is a request for authentication via non-empty credentials. if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) { @@ -781,6 +788,23 @@ func (u *ConnString) validateAuth() error { if u.AuthMechanismProperties != nil { return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties") } + case "mongodb-oidc": + if u.Password != "" { + return fmt.Errorf("password cannot be specified for MONGODB-OIDC") + } + if u.AuthMechanismProperties != nil { + if env, ok := u.AuthMechanismProperties["ENVIRONMENT"]; ok { + switch strings.ToLower(env) { + case "azure": + fallthrough + case "gcp": + if _, ok := u.AuthMechanismProperties["DOMAIN"]; !ok { + return fmt.Errorf("DOMAIN must be specified for %s environment", env) + } + } + } + } + case "": if u.UsernameSet && u.Username == "" { return fmt.Errorf("username required if URI contains user info") From c1373999147930c47c19697b2f3064e15d08f929 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Tue, 18 Jun 2024 21:24:50 -0400 Subject: [PATCH 37/66] GODRIVER-2911: Checkpoint --- cmd/testoidcauth/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 8c5173e2c3..68989114a4 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -24,7 +24,7 @@ import ( var uriAdmin = os.Getenv("MONGODB_URI") var uriSingle = os.Getenv("MONGODB_URI_SINGLE") var uriMulti = os.Getenv("MONGODB_URI_MULTI") -var oidcTokenDir = path.Join(os.Getenv("OIDC_TOKEN_DIR"), "tmp", "tokens") +var oidcTokenDir = os.Getenv("OIDC_TOKEN_DIR") var oidcDomain = os.Getenv("OIDC_DOMAIN") func explicitUser(user string) string { From 1be9498cbc796cfa0666f851407c3ce85a9e294a Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 19 Jun 2024 12:45:08 -0400 Subject: [PATCH 38/66] GODRIVER-2911: OIDC working --- .evergreen/config.yml | 2 +- cmd/testoidcauth/main.go | 6 ++++-- x/mongo/driver/auth/oidc.go | 16 ++++++++-------- x/mongo/driver/auth/sasl.go | 11 +++++++++-- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index b9ba851124..8cbfcb6c3d 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2611,7 +2611,7 @@ buildvariants: - name: testoidc-variant display_name: "OIDC" run_on: - - rhel8.7-small + - ubuntu2204-large expansions: GO_DIST: "/opt/golang/go1.20" tasks: diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 68989114a4..d5f62f36a9 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -70,10 +70,12 @@ func machine_1_1_callbackIsCalled() { coll := client.Database("test").Collection("test") - res := coll.FindOne(context.Background(), bson.D{}) - if res == nil || res.Err() != nil { + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { log.Fatalf("machine_1_1_callbackIsCalled: failed executing FindOne: %v", err) } + countMutex.Lock() + defer countMutex.Unlock() if callbackCount != 1 { log.Fatalf("machine_1_1_callbackIsCalled: expected callback count to be 1, got %d", callbackCount) } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 7d968b642a..12d4c914bf 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -70,6 +70,7 @@ type OIDCAuthenticator struct { OIDCMachineCallback OIDCCallback OIDCHumanCallback OIDCCallback + userName string cfg *Config accessToken string refreshToken *string @@ -79,6 +80,7 @@ type OIDCAuthenticator struct { func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { oa := &OIDCAuthenticator{ + userName: cred.Username, AuthMechanismProperties: cred.Props, OIDCMachineCallback: cred.OIDCMachineCallback, OIDCHumanCallback: cred.OIDCHumanCallback, @@ -87,6 +89,7 @@ func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { } type oidcOneStep struct { + userName string accessToken string } @@ -210,7 +213,6 @@ func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { // Auth authenticates the connection. func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { - fmt.Println("OIDC Auth!!!") // the Mutex must be held during the entire Auth call so that multiple racing attempts // to authenticate will not result in multiple callbacks. The losers on the Mutex will // retrieve the access token from the Authenticator cache. @@ -225,6 +227,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { if oa.accessToken != "" { err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + userName: oa.userName, accessToken: oa.accessToken, }) if err == nil { @@ -274,13 +277,10 @@ func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, mac if err != nil { return err } - err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ - accessToken: accessToken, - }) - if err == nil { - return nil - } - return nil + return runSaslConversation(ctx, + cfg, + newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", false), + ) } // CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 2a84b53a64..18a6fb3f42 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -105,6 +105,7 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon fullErr := fmt.Errorf("unmarshal error: %w", err) return newError(fullErr, sc.mechanism) } + fmt.Println("resp", saslResp) cid := saslResp.ConversationID var payload []byte @@ -152,17 +153,23 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon } } -// ConductSaslConversation runs a full SASL conversation to authenticate the given connection. +// ConductSaslConversation runs a full SASL conversation to authenticate the given connection, given +// sasl arguments. func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error { // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) + return runSaslConversation(ctx, cfg, conversation) +} +// runSaslConversation runs a SASL conversation to authenticate the given connection, given a +// pre-built saslConversation. +func runSaslConversation(ctx context.Context, cfg *Config, conversation *saslConversation) error { saslStartDoc, err := conversation.FirstMessage() if err != nil { return newError(err, conversation.mechanism) } saslStartCmd := operation.NewCommand(saslStartDoc). - Database(authSource). + Database(conversation.source). Deployment(driver.SingleConnectionDeployment{cfg.Connection}). ClusterClock(cfg.ClusterClock). ServerAPI(cfg.ServerAPI) From 8542f76a9ddb1ad247f99d1806266d7d9f177d4c Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 19 Jun 2024 12:54:42 -0400 Subject: [PATCH 39/66] GODRIVER-2911: add machine_1_2 --- cmd/testoidcauth/main.go | 70 +++++++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 11 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index d5f62f36a9..cff72df265 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -35,6 +35,17 @@ func tokenFile(user string) string { return path.Join(oidcTokenDir, user) } +func connectWithMachineCB(uri string, cb driver.OIDCCallback) *mongo.Client { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCMachineCallback = cb + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + log.Fatalf("Error connecting client: %v", err) + } + return client +} + func main() { machine_1_1_callbackIsCalled() } @@ -43,9 +54,7 @@ func machine_1_1_callbackIsCalled() { callbackCount := 0 countMutex := sync.Mutex{} - opts := options.Client().ApplyURI(uriSingle) - - opts.Auth.OIDCMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -61,22 +70,61 @@ func machine_1_1_callbackIsCalled() { ExpiresAt: &t, RefreshToken: nil, }, nil - } + }) - client, err := mongo.Connect(context.Background(), opts) + coll := client.Database("test").Collection("test") + + _, err := coll.Find(context.Background(), bson.D{}) if err != nil { - log.Fatalf("Error connecting client: %v", err) + log.Fatalf("machine_1_1: failed executing FindOne: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + log.Fatalf("machine_1_1: expected callback count to be 1, got %d", callbackCount) } +} - coll := client.Database("test").Collection("test") +func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() { + callbackCount := 0 + countMutex := sync.Mutex{} - _, err = coll.Find(context.Background(), bson.D{}) - if err != nil { - log.Fatalf("machine_1_1_callbackIsCalled: failed executing FindOne: %v", err) + client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + fmt.Println(tokenFile) + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + log.Fatalf("machine_1_2: failed reading token file: %v", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + var wg sync.WaitGroup + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + coll := client.Database("test").Collection("test") + _, err := coll.Find(context.Background(), bson.D{}) + if err != nil { + log.Fatalf("machine_1_2: failed executing FindOne: %v", err) + } + }() } + + wg.Wait() countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - log.Fatalf("machine_1_1_callbackIsCalled: expected callback count to be 1, got %d", callbackCount) + log.Fatalf("machine_1_2: expected callback count to be 1, got %d", callbackCount) } } From 286525f8b41bce96347585f82a69117e594991db Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 19 Jun 2024 12:55:21 -0400 Subject: [PATCH 40/66] GODRIVER-2911: add machine_1_2, actually helps to call it --- cmd/testoidcauth/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index cff72df265..7ba5d49c1e 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -48,6 +48,7 @@ func connectWithMachineCB(uri string, cb driver.OIDCCallback) *mongo.Client { func main() { machine_1_1_callbackIsCalled() + machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() } func machine_1_1_callbackIsCalled() { From 4013ccbfdd3241e198037322f31b6fb9f4dd63cf Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 19 Jun 2024 13:15:19 -0400 Subject: [PATCH 41/66] GODRIVER-2911: Remove unneeded debugging --- cmd/testoidcauth/main.go | 2 -- x/mongo/driver/auth/sasl.go | 1 - 2 files changed, 3 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 7ba5d49c1e..a3806c8222 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -61,7 +61,6 @@ func machine_1_1_callbackIsCalled() { callbackCount++ t := time.Now().Add(time.Hour) tokenFile := tokenFile("test_user1") - fmt.Println(tokenFile) accessToken, err := os.ReadFile(tokenFile) if err != nil { log.Fatalf("machine_1_1_callbackIsCalled: failed reading token file: %v", err) @@ -96,7 +95,6 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() { callbackCount++ t := time.Now().Add(time.Hour) tokenFile := tokenFile("test_user1") - fmt.Println(tokenFile) accessToken, err := os.ReadFile(tokenFile) if err != nil { log.Fatalf("machine_1_2: failed reading token file: %v", err) diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 18a6fb3f42..2b9fe386d7 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -105,7 +105,6 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon fullErr := fmt.Errorf("unmarshal error: %w", err) return newError(fullErr, sc.mechanism) } - fmt.Println("resp", saslResp) cid := saslResp.ConversationID var payload []byte From 83ffaa7241673cc3c160708284b04206bcf038d1 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 12:21:22 -0400 Subject: [PATCH 42/66] GODRIVER-2911: Add more tests --- cmd/testoidcauth/main.go | 188 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 177 insertions(+), 11 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index a3806c8222..8a7cd21c45 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -41,18 +41,47 @@ func connectWithMachineCB(uri string, cb driver.OIDCCallback) *mongo.Client { opts.Auth.OIDCMachineCallback = cb client, err := mongo.Connect(context.Background(), opts) if err != nil { - log.Fatalf("Error connecting client: %v", err) + fmt.Printf("Error connecting client: %v", err) + } + return client +} + +func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props map[string]string) *mongo.Client { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCMachineCallback = cb + opts.Auth.AuthMechanismProperties = props + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + fmt.Printf("Error connecting client: %v", err) } return client } func main() { - machine_1_1_callbackIsCalled() - machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() + hasError := false + aux := func(test_name string, f func() bool) { + fmt.Printf("%s...", test_name) + testResult := f() + if testResult { + fmt.Println("...Failed") + } else { + fmt.Println("...Ok") + } + hasError = hasError || testResult + } + aux("machine_1_1_callbackIsCalled", machine_1_1_callbackIsCalled) + aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine_1_2_callbackIsCalledOnlyOneForMultipleConnections) + aux("machine_2_1_validCallbackInputs", machine_2_1_validCallbackInputs) + aux("machine_2_3_oidcCallbackReturnMissingData", machine_2_3_oidcCallbackReturnMissingData) + if hasError { + log.Fatal("One or more tests failed") + } } -func machine_1_1_callbackIsCalled() { +func machine_1_1_callbackIsCalled() bool { callbackCount := 0 + callbackFailed := false countMutex := sync.Mutex{} client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { @@ -63,7 +92,8 @@ func machine_1_1_callbackIsCalled() { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - log.Fatalf("machine_1_1_callbackIsCalled: failed reading token file: %v", err) + fmt.Printf("machine_1_1: failed reading token file: %v\n", err) + callbackFailed = true } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -76,17 +106,21 @@ func machine_1_1_callbackIsCalled() { _, err := coll.Find(context.Background(), bson.D{}) if err != nil { - log.Fatalf("machine_1_1: failed executing FindOne: %v", err) + fmt.Printf("machine_1_1: failed executing Find: %v", err) + return true } countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - log.Fatalf("machine_1_1: expected callback count to be 1, got %d", callbackCount) + fmt.Printf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount) + return true } + return callbackFailed } -func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() { +func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { callbackCount := 0 + callbackFailed := false countMutex := sync.Mutex{} client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { @@ -97,7 +131,8 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - log.Fatalf("machine_1_2: failed reading token file: %v", err) + fmt.Printf("machine_1_2: failed reading token file: %v\n", err) + callbackFailed = true } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -108,6 +143,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() { var wg sync.WaitGroup + findFailed := false for i := 0; i < 10; i++ { wg.Add(1) go func() { @@ -115,7 +151,8 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() { coll := client.Database("test").Collection("test") _, err := coll.Find(context.Background(), bson.D{}) if err != nil { - log.Fatalf("machine_1_2: failed executing FindOne: %v", err) + fmt.Printf("machine_1_2: failed executing Find: %v\n", err) + findFailed = true } }() } @@ -124,6 +161,135 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - log.Fatalf("machine_1_2: expected callback count to be 1, got %d", callbackCount) + fmt.Printf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount) + return true + } + return callbackFailed || findFailed +} + +func machine_2_1_validCallbackInputs() bool { + callbackCount := 0 + callbackFailed := false + countMutex := sync.Mutex{} + + client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + if args.RefreshToken != nil { + fmt.Printf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken) + callbackFailed = true + } + if args.Timeout.Before(time.Now()) { + fmt.Printf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout) + callbackFailed = true + } + if args.Version < 1 { + fmt.Printf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version) + callbackFailed = true + } + if args.IDPInfo != nil { + fmt.Printf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo) + callbackFailed = true + } + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + fmt.Printf("machine_2_1: failed reading token file: %v\n", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + coll := client.Database("test").Collection("test") + + _, err := coll.Find(context.Background(), bson.D{}) + if err != nil { + fmt.Printf("machine_2_1: failed executing Find: %v", err) + return true + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + fmt.Printf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount) + return true + } + return callbackFailed +} + +func machine_2_3_oidcCallbackReturnMissingData() bool { + callbackCount := 0 + countMutex := sync.Mutex{} + + client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &driver.OIDCCredential{ + AccessToken: "", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + coll := client.Database("test").Collection("test") + + _, err := coll.Find(context.Background(), bson.D{}) + if err == nil { + fmt.Println("machine_2_3: should have failed to executed Find, but succeeded") + return true + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + fmt.Printf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount) + return true + } + return true +} + +func machine_2_4_invalidClientConfigurationWithCallback() bool { + callbackCount := 0 + callbackFailed := false + countMutex := sync.Mutex{} + + client := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + fmt.Printf("machine_2_4: failed reading token file: %v\n", err) + callbackFailed = true + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }, + map[string]string{"ENVIRONMENT": "test"}, + ) + + coll := client.Database("test").Collection("test") + + _, err := coll.Find(context.Background(), bson.D{}) + if err == nil { + fmt.Println("machine_2_4: succeeded executing Find when it should fail") + return true + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + fmt.Printf("machine_2_4: expected callback count to be 1, got %d\n", callbackCount) + return true } + return callbackFailed } From f33dca7dc09934035b8b455b8d1334c1b924faf8 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 13:03:35 -0400 Subject: [PATCH 43/66] GODRIVER-2911: Updates --- cmd/testoidcauth/main.go | 41 +++++-------------------- x/mongo/driver/auth/oidc.go | 20 ++++++++++++ x/mongo/driver/connstring/connstring.go | 13 -------- 3 files changed, 28 insertions(+), 46 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 8a7cd21c45..b12acfc371 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -46,16 +46,12 @@ func connectWithMachineCB(uri string, cb driver.OIDCCallback) *mongo.Client { return client } -func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props map[string]string) *mongo.Client { +func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props map[string]string) (*mongo.Client, error) { opts := options.Client().ApplyURI(uri) opts.Auth.OIDCMachineCallback = cb opts.Auth.AuthMechanismProperties = props - client, err := mongo.Connect(context.Background(), opts) - if err != nil { - fmt.Printf("Error connecting client: %v", err) - } - return client + return mongo.Connect(context.Background(), opts) } func main() { @@ -68,12 +64,14 @@ func main() { } else { fmt.Println("...Ok") } + fmt.Println("hasError: ", hasError, "testResult: ", testResult) hasError = hasError || testResult } aux("machine_1_1_callbackIsCalled", machine_1_1_callbackIsCalled) aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine_1_2_callbackIsCalledOnlyOneForMultipleConnections) aux("machine_2_1_validCallbackInputs", machine_2_1_validCallbackInputs) aux("machine_2_3_oidcCallbackReturnMissingData", machine_2_3_oidcCallbackReturnMissingData) + aux("machine_2_4_invalidClientConfigurationWithCallback", machine_2_4_invalidClientConfigurationWithCallback) if hasError { log.Fatal("One or more tests failed") } @@ -254,42 +252,19 @@ func machine_2_3_oidcCallbackReturnMissingData() bool { } func machine_2_4_invalidClientConfigurationWithCallback() bool { - callbackCount := 0 - callbackFailed := false - countMutex := sync.Mutex{} - - client := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { - countMutex.Lock() - defer countMutex.Unlock() - callbackCount++ + _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { t := time.Now().Add(time.Hour) - tokenFile := tokenFile("test_user1") - accessToken, err := os.ReadFile(tokenFile) - if err != nil { - fmt.Printf("machine_2_4: failed reading token file: %v\n", err) - callbackFailed = true - } return &driver.OIDCCredential{ - AccessToken: string(accessToken), + AccessToken: "", ExpiresAt: &t, RefreshToken: nil, }, nil }, map[string]string{"ENVIRONMENT": "test"}, ) - - coll := client.Database("test").Collection("test") - - _, err := coll.Find(context.Background(), bson.D{}) if err == nil { - fmt.Println("machine_2_4: succeeded executing Find when it should fail") + fmt.Println("machine_2_4: succeeded building client when it should fail") return true } - countMutex.Lock() - defer countMutex.Unlock() - if callbackCount != 1 { - fmt.Printf("machine_2_4: expected callback count to be 1, got %d\n", callbackCount) - return true - } - return callbackFailed + return false } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 12d4c914bf..16aa613101 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -9,6 +9,7 @@ package auth import ( "context" "fmt" + "strings" "sync" "time" @@ -23,6 +24,8 @@ const MongoDBOIDC = "MONGODB-OIDC" // const tokenResourceProp = "TOKEN_RESOURCE" const environmentProp = "ENVIRONMENT" +const resourceProp = "TOKEN_RESOURCE" + // GODRIVER-3249 OIDC: Handle all possible OIDC configuration errors //const allowedHostsProp = "ALLOWED_HOSTS" @@ -79,6 +82,23 @@ type OIDCAuthenticator struct { } func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { + if cred.Props != nil { + if env, ok := cred.Props[environmentProp]; ok { + switch strings.ToLower(env) { + case "azure": + fallthrough + case "gcp": + if _, ok := cred.Props[resourceProp]; !ok { + return nil, fmt.Errorf("%s must be specified for %s %s", resourceProp, env, environmentProp) + } + fallthrough + case "test": + if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil { + return nil, fmt.Errorf("OIDC callbacks are not allowed for %s %s", env, environmentProp) + } + } + } + } oa := &OIDCAuthenticator{ userName: cred.Username, AuthMechanismProperties: cred.Props, diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 8562214a74..a8adafb8f8 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -792,19 +792,6 @@ func (u *ConnString) validateAuth() error { if u.Password != "" { return fmt.Errorf("password cannot be specified for MONGODB-OIDC") } - if u.AuthMechanismProperties != nil { - if env, ok := u.AuthMechanismProperties["ENVIRONMENT"]; ok { - switch strings.ToLower(env) { - case "azure": - fallthrough - case "gcp": - if _, ok := u.AuthMechanismProperties["DOMAIN"]; !ok { - return fmt.Errorf("DOMAIN must be specified for %s environment", env) - } - } - } - } - case "": if u.UsernameSet && u.Username == "" { return fmt.Errorf("username required if URI contains user info") From 3c00307d4f2ed6303eef0d6387a44753f59170cd Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 13:52:43 -0400 Subject: [PATCH 44/66] GODRIVER-2911: Change to using errors --- cmd/testoidcauth/main.go | 120 ++++++++++++++++++------------------ x/mongo/driver/auth/oidc.go | 22 +++---- 2 files changed, 71 insertions(+), 71 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index b12acfc371..5459dda410 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -35,15 +35,11 @@ func tokenFile(user string) string { return path.Join(oidcTokenDir, user) } -func connectWithMachineCB(uri string, cb driver.OIDCCallback) *mongo.Client { +func connectWithMachineCB(uri string, cb driver.OIDCCallback) (*mongo.Client, error) { opts := options.Client().ApplyURI(uri) opts.Auth.OIDCMachineCallback = cb - client, err := mongo.Connect(context.Background(), opts) - if err != nil { - fmt.Printf("Error connecting client: %v", err) - } - return client + return mongo.Connect(context.Background(), opts) } func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props map[string]string) (*mongo.Client, error) { @@ -56,16 +52,16 @@ func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props func main() { hasError := false - aux := func(test_name string, f func() bool) { + aux := func(test_name string, f func() error) { fmt.Printf("%s...", test_name) - testResult := f() - if testResult { + err := f() + if err != nil { + fmt.Println("Test Error: ", err) fmt.Println("...Failed") + hasError = true } else { fmt.Println("...Ok") } - fmt.Println("hasError: ", hasError, "testResult: ", testResult) - hasError = hasError || testResult } aux("machine_1_1_callbackIsCalled", machine_1_1_callbackIsCalled) aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine_1_2_callbackIsCalledOnlyOneForMultipleConnections) @@ -77,12 +73,12 @@ func main() { } } -func machine_1_1_callbackIsCalled() bool { +func machine_1_1_callbackIsCalled() error { callbackCount := 0 - callbackFailed := false + var callbackFailed error = nil countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -90,8 +86,7 @@ func machine_1_1_callbackIsCalled() bool { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - fmt.Printf("machine_1_1: failed reading token file: %v\n", err) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v\n", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -100,28 +95,30 @@ func machine_1_1_callbackIsCalled() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_1_1: failed connecting client: %v", err) + } + coll := client.Database("test").Collection("test") - _, err := coll.Find(context.Background(), bson.D{}) + _, err = coll.Find(context.Background(), bson.D{}) if err != nil { - fmt.Printf("machine_1_1: failed executing Find: %v", err) - return true + return fmt.Errorf("machine_1_1: failed executing Find: %v", err) } countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount) } return callbackFailed } -func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { +func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { callbackCount := 0 - callbackFailed := false + var callbackFailed error = nil countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -129,8 +126,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - fmt.Printf("machine_1_2: failed reading token file: %v\n", err) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v\n", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -139,9 +135,13 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_1_2: failed connecting client: %v", err) + } + var wg sync.WaitGroup - findFailed := false + var findFailed error = nil for i := 0; i < 10; i++ { wg.Add(1) go func() { @@ -149,8 +149,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { coll := client.Database("test").Collection("test") _, err := coll.Find(context.Background(), bson.D{}) if err != nil { - fmt.Printf("machine_1_2: failed executing Find: %v\n", err) - findFailed = true + findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v\n", err) } }() } @@ -159,33 +158,31 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount) + } + if callbackFailed != nil { + return callbackFailed } - return callbackFailed || findFailed + return findFailed } -func machine_2_1_validCallbackInputs() bool { +func machine_2_1_validCallbackInputs() error { callbackCount := 0 - callbackFailed := false + var callbackFailed error = nil countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { if args.RefreshToken != nil { - fmt.Printf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken) } if args.Timeout.Before(time.Now()) { - fmt.Printf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout) } if args.Version < 1 { - fmt.Printf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version) } if args.IDPInfo != nil { - fmt.Printf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo) } countMutex.Lock() defer countMutex.Unlock() @@ -203,27 +200,29 @@ func machine_2_1_validCallbackInputs() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_2_1: failed connecting client: %v", err) + } + coll := client.Database("test").Collection("test") - _, err := coll.Find(context.Background(), bson.D{}) + _, err = coll.Find(context.Background(), bson.D{}) if err != nil { - fmt.Printf("machine_2_1: failed executing Find: %v", err) - return true + return fmt.Errorf("machine_2_1: failed executing Find: %v", err) } countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount) } return callbackFailed } -func machine_2_3_oidcCallbackReturnMissingData() bool { +func machine_2_3_oidcCallbackReturnMissingData() error { callbackCount := 0 countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -235,23 +234,25 @@ func machine_2_3_oidcCallbackReturnMissingData() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_2_3: failed connecting client: %v\n", err) + } + coll := client.Database("test").Collection("test") - _, err := coll.Find(context.Background(), bson.D{}) + _, err = coll.Find(context.Background(), bson.D{}) if err == nil { - fmt.Println("machine_2_3: should have failed to executed Find, but succeeded") - return true + return fmt.Errorf("machine_2_3: should have failed to executed Find, but succeeded") } countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount) } - return true + return nil } -func machine_2_4_invalidClientConfigurationWithCallback() bool { +func machine_2_4_invalidClientConfigurationWithCallback() error { _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { t := time.Now().Add(time.Hour) return &driver.OIDCCredential{ @@ -263,8 +264,7 @@ func machine_2_4_invalidClientConfigurationWithCallback() bool { map[string]string{"ENVIRONMENT": "test"}, ) if err == nil { - fmt.Println("machine_2_4: succeeded building client when it should fail") - return true + return fmt.Errorf("machine_2_4: succeeded building client when it should fail") } - return false + return nil } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 16aa613101..5138ae4402 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -31,6 +31,7 @@ const resourceProp = "TOKEN_RESOURCE" const azureEnvironmentValue = "azure" const gcpEnvironmentValue = "gcp" +const testEnvironmentValue = "test" const apiVersion = 1 const invalidateSleepTimeout = 100 * time.Millisecond @@ -82,19 +83,22 @@ type OIDCAuthenticator struct { } func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { + if cred.Password != "" { + return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) + } if cred.Props != nil { if env, ok := cred.Props[environmentProp]; ok { switch strings.ToLower(env) { - case "azure": + case azureEnvironmentValue: fallthrough - case "gcp": + case gcpEnvironmentValue: if _, ok := cred.Props[resourceProp]; !ok { - return nil, fmt.Errorf("%s must be specified for %s %s", resourceProp, env, environmentProp) + return nil, fmt.Errorf("%q must be specified for %q %q", resourceProp, env, environmentProp) } fallthrough - case "test": + case testEnvironmentValue: if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil { - return nil, fmt.Errorf("OIDC callbacks are not allowed for %s %s", env, environmentProp) + return nil, fmt.Errorf("OIDC callbacks are not allowed for %q %q", env, environmentProp) } } } @@ -146,15 +150,11 @@ func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { return nil, nil } - switch env { + //switch env { // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider // This is here just to pass the linter, it will be fixed in one of the above tickets. - case azureEnvironmentValue, gcpEnvironmentValue: - return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { - return nil, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) - }, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) - } + //} return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) } From 590a3c8bfc28bef326d55186e46c0155f0ce742e Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 14:13:34 -0400 Subject: [PATCH 45/66] GODRIVER-2911: Add more tests that do not require fail points --- cmd/testoidcauth/main.go | 80 +++++++++++++++++++++++++++++++++++++ mongo/client.go | 5 +++ x/mongo/driver/auth/oidc.go | 8 ++++ 3 files changed, 93 insertions(+) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 5459dda410..71f67bf41b 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -19,6 +19,7 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" ) var uriAdmin = os.Getenv("MONGODB_URI") @@ -68,6 +69,7 @@ func main() { aux("machine_2_1_validCallbackInputs", machine_2_1_validCallbackInputs) aux("machine_2_3_oidcCallbackReturnMissingData", machine_2_3_oidcCallbackReturnMissingData) aux("machine_2_4_invalidClientConfigurationWithCallback", machine_2_4_invalidClientConfigurationWithCallback) + aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth) if hasError { log.Fatal("One or more tests failed") } @@ -268,3 +270,81 @@ func machine_2_4_invalidClientConfigurationWithCallback() error { } return nil } + +func machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth() error { + callbackCount := 0 + var callbackFailed error = nil + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v\n", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + if err != nil { + return fmt.Errorf("machine_3_1: failed connecting client: %v", err) + } + + // Poison the cache with a random token + client.GetAuthenticator().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_1: expected callback count to be 1, got %d\n", callbackCount) + } + return callbackFailed +} + +func machine_3_2_authFailuresWithoutCachedTokensReturnsAnError() error { + callbackCount := 0 + var callbackFailed error = nil + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &driver.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + if err != nil { + return fmt.Errorf("machine_3_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_2: failed succeeded Find when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_2: expected callback count to be 1, got %d\n", callbackCount) + } + return callbackFailed +} diff --git a/mongo/client.go b/mongo/client.go index 082554adbb..fec2c45287 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -83,6 +83,11 @@ type Client struct { authenticator driver.Authenticator } +// GetAuthenticator returns the authenticator for the client, used for testing purposes. +func (c *Client) GetAuthenticator() driver.Authenticator { + return c.authenticator +} + // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling // NewClient followed by Client.Connect. // diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 5138ae4402..6414d238ef 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -82,6 +82,14 @@ type OIDCAuthenticator struct { tokenGenID uint64 } +// SetAccessToken allows for manually setting the access token for the OIDCAuthenticator, this is +// only for testing purposes. +func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { + oa.mu.Lock() + defer oa.mu.Unlock() + oa.accessToken = accessToken +} + func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { if cred.Password != "" { return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) From e88ebe7e98b9fbce7a1224cb3e898fed25dcb4f6 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 15:26:35 -0400 Subject: [PATCH 46/66] GODRIVER-2911: See if it fails with 10 tries --- cmd/testoidcauth/main.go | 96 ++++++++++++++++++++++++++++++++++++- x/mongo/driver/auth/oidc.go | 5 +- 2 files changed, 97 insertions(+), 4 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 71f67bf41b..e69a3a3ef6 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -17,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/integration/mtest" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" @@ -36,6 +37,10 @@ func tokenFile(user string) string { return path.Join(oidcTokenDir, user) } +func connectAdminClinet() (*mongo.Client, error) { + return mongo.Connect(context.Background(), options.Client().ApplyURI(uriAdmin)) +} + func connectWithMachineCB(uri string, cb driver.OIDCCallback) (*mongo.Client, error) { opts := options.Client().ApplyURI(uri) @@ -70,6 +75,8 @@ func main() { aux("machine_2_3_oidcCallbackReturnMissingData", machine_2_3_oidcCallbackReturnMissingData) aux("machine_2_4_invalidClientConfigurationWithCallback", machine_2_4_invalidClientConfigurationWithCallback) aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth) + aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine_3_2_authFailuresWithoutCachedTokensReturnsAnError) + aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache) if hasError { log.Fatal("One or more tests failed") } @@ -97,6 +104,8 @@ func machine_1_1_callbackIsCalled() error { }, nil }) + defer client.Disconnect(context.Background()) + if err != nil { return fmt.Errorf("machine_1_1: failed connecting client: %v", err) } @@ -137,6 +146,8 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { }, nil }) + defer client.Disconnect(context.Background()) + if err != nil { return fmt.Errorf("machine_1_2: failed connecting client: %v", err) } @@ -202,6 +213,8 @@ func machine_2_1_validCallbackInputs() error { }, nil }) + defer client.Disconnect(context.Background()) + if err != nil { return fmt.Errorf("machine_2_1: failed connecting client: %v", err) } @@ -236,6 +249,8 @@ func machine_2_3_oidcCallbackReturnMissingData() error { }, nil }) + defer client.Disconnect(context.Background()) + if err != nil { return fmt.Errorf("machine_2_3: failed connecting client: %v\n", err) } @@ -293,6 +308,8 @@ func machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth() error { }, nil }) + defer client.Disconnect(context.Background()) + if err != nil { return fmt.Errorf("machine_3_1: failed connecting client: %v", err) } @@ -331,15 +348,16 @@ func machine_3_2_authFailuresWithoutCachedTokensReturnsAnError() error { }, nil }) + defer client.Disconnect(context.Background()) + if err != nil { return fmt.Errorf("machine_3_2: failed connecting client: %v", err) } coll := client.Database("test").Collection("test") - _, err = coll.Find(context.Background(), bson.D{}) if err == nil { - return fmt.Errorf("machine_3_2: failed succeeded Find when it should fail") + return fmt.Errorf("machine_3_2: Find ucceeded when it should fail") } countMutex.Lock() defer countMutex.Unlock() @@ -348,3 +366,77 @@ func machine_3_2_authFailuresWithoutCachedTokensReturnsAnError() error { } return callbackFailed } + +func machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache() error { + callbackCount := 0 + var callbackFailed error = nil + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_3: failed reading token file: %v\n", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + err = mtest.SetFailPoint( + mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 10, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"saslStart"}, + AppName: "go-oidc", + ErrorCode: 20, + }, + }, + adminClient, + ) + + if err != nil { + return fmt.Errorf("machine_3_3: failed setting failpoint: %v", err) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_3: Find succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d\n", callbackCount) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_3: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d\n", callbackCount) + } + return callbackFailed +} diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 6414d238ef..b2c271a6d0 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -219,9 +219,8 @@ func (oa *OIDCAuthenticator) getAccessToken( // tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, // but this is a safety check, since extra invalidation is only a performance impact, not a // correctness impact. +// This must only be called with the lock held func (oa *OIDCAuthenticator) invalidateAccessToken(force bool) { - oa.mu.Lock() - defer oa.mu.Unlock() tokenGenID := oa.cfg.Connection.OIDCTokenGenID() if force || tokenGenID >= oa.tokenGenID { oa.accessToken = "" @@ -232,7 +231,9 @@ func (oa *OIDCAuthenticator) invalidateAccessToken(force bool) { // Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the // driver.Authenticator interface. func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { + oa.mu.Lock() oa.invalidateAccessToken(true) + oa.mu.Unlock() // it should be impossible to get a Reauth when an Auth has never occurred, // so we assume cfg was properly set. There is nothing to enforce this, however, // other than the current driver code flow. If cfg is nil, Auth will return an error. From 58f0f420354c979e9ea6e9decc66090b5bac9d19 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 15:32:33 -0400 Subject: [PATCH 47/66] GODRIVER-2911: Not sure how to get fail points working --- cmd/testoidcauth/main.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index e69a3a3ef6..d65d69771d 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -76,7 +76,8 @@ func main() { aux("machine_2_4_invalidClientConfigurationWithCallback", machine_2_4_invalidClientConfigurationWithCallback) aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth) aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine_3_2_authFailuresWithoutCachedTokensReturnsAnError) - aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache) + // fail points do not seem to be working, or I'm using them wrongly + //aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache) if hasError { log.Fatal("One or more tests failed") } @@ -403,7 +404,7 @@ func machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache() error { mtest.FailPoint{ ConfigureFailPoint: "failCommand", Mode: mtest.FailPointMode{ - Times: 10, + Times: 1, }, Data: mtest.FailPointData{ FailCommands: []string{"saslStart"}, From 1be1e137c0af29c3ded2c33aac22c0c7c5934797 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 15:42:07 -0400 Subject: [PATCH 48/66] GODRIVER-2911: Appease linter --- cmd/testoidcauth/main.go | 97 +++++++++++++++++++------------------ x/mongo/driver/auth/oidc.go | 8 ++- 2 files changed, 57 insertions(+), 48 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index d65d69771d..abc25a9746 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -25,13 +25,15 @@ import ( var uriAdmin = os.Getenv("MONGODB_URI") var uriSingle = os.Getenv("MONGODB_URI_SINGLE") -var uriMulti = os.Getenv("MONGODB_URI_MULTI") + +// var uriMulti = os.Getenv("MONGODB_URI_MULTI") var oidcTokenDir = os.Getenv("OIDC_TOKEN_DIR") -var oidcDomain = os.Getenv("OIDC_DOMAIN") -func explicitUser(user string) string { - return fmt.Sprintf("%s@%s", user, oidcDomain) -} +//var oidcDomain = os.Getenv("OIDC_DOMAIN") + +//func explicitUser(user string) string { +// return fmt.Sprintf("%s@%s", user, oidcDomain) +//} func tokenFile(user string) string { return path.Join(oidcTokenDir, user) @@ -57,6 +59,9 @@ func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props } func main() { + // be quiet linter + _ = tokenFile("test_user2") + hasError := false aux := func(test_name string, f func() error) { fmt.Printf("%s...", test_name) @@ -69,23 +74,23 @@ func main() { fmt.Println("...Ok") } } - aux("machine_1_1_callbackIsCalled", machine_1_1_callbackIsCalled) - aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine_1_2_callbackIsCalledOnlyOneForMultipleConnections) - aux("machine_2_1_validCallbackInputs", machine_2_1_validCallbackInputs) - aux("machine_2_3_oidcCallbackReturnMissingData", machine_2_3_oidcCallbackReturnMissingData) - aux("machine_2_4_invalidClientConfigurationWithCallback", machine_2_4_invalidClientConfigurationWithCallback) - aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth) - aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine_3_2_authFailuresWithoutCachedTokensReturnsAnError) + aux("machine_1_1_callbackIsCalled", machine11callbackIsCalled) + aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine12callbackIsCalledOnlyOneForMultipleConnections) + aux("machine_2_1_validCallbackInputs", machine21validCallbackInputs) + aux("machine_2_3_oidcCallbackReturnMissingData", machine23oidcCallbackReturnMissingData) + aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback) + aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth) + aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError) // fail points do not seem to be working, or I'm using them wrongly - //aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache) + aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) if hasError { log.Fatal("One or more tests failed") } } -func machine_1_1_callbackIsCalled() error { +func machine11callbackIsCalled() error { callbackCount := 0 - var callbackFailed error = nil + var callbackFailed error countMutex := sync.Mutex{} client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { @@ -96,7 +101,7 @@ func machine_1_1_callbackIsCalled() error { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v\n", err) + callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -120,14 +125,14 @@ func machine_1_1_callbackIsCalled() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d", callbackCount) } return callbackFailed } -func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { +func machine12callbackIsCalledOnlyOneForMultipleConnections() error { callbackCount := 0 - var callbackFailed error = nil + var callbackFailed error countMutex := sync.Mutex{} client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { @@ -138,7 +143,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v\n", err) + callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -163,7 +168,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { coll := client.Database("test").Collection("test") _, err := coll.Find(context.Background(), bson.D{}) if err != nil { - findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v\n", err) + findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v", err) } }() } @@ -172,7 +177,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d", callbackCount) } if callbackFailed != nil { return callbackFailed @@ -180,23 +185,23 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { return findFailed } -func machine_2_1_validCallbackInputs() error { +func machine21validCallbackInputs() error { callbackCount := 0 - var callbackFailed error = nil + var callbackFailed error countMutex := sync.Mutex{} client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { if args.RefreshToken != nil { - callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken) + callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v", args.RefreshToken) } if args.Timeout.Before(time.Now()) { - callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout) + callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v", args.Timeout) } if args.Version < 1 { - callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version) + callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d", args.Version) } if args.IDPInfo != nil { - callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo) + callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v", args.IDPInfo) } countMutex.Lock() defer countMutex.Unlock() @@ -205,7 +210,7 @@ func machine_2_1_validCallbackInputs() error { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - fmt.Printf("machine_2_1: failed reading token file: %v\n", err) + fmt.Printf("machine_2_1: failed reading token file: %v", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -229,12 +234,12 @@ func machine_2_1_validCallbackInputs() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d", callbackCount) } return callbackFailed } -func machine_2_3_oidcCallbackReturnMissingData() error { +func machine23oidcCallbackReturnMissingData() error { callbackCount := 0 countMutex := sync.Mutex{} @@ -253,7 +258,7 @@ func machine_2_3_oidcCallbackReturnMissingData() error { defer client.Disconnect(context.Background()) if err != nil { - return fmt.Errorf("machine_2_3: failed connecting client: %v\n", err) + return fmt.Errorf("machine_2_3: failed connecting client: %v", err) } coll := client.Database("test").Collection("test") @@ -265,12 +270,12 @@ func machine_2_3_oidcCallbackReturnMissingData() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d", callbackCount) } return nil } -func machine_2_4_invalidClientConfigurationWithCallback() error { +func machine24invalidClientConfigurationWithCallback() error { _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { t := time.Now().Add(time.Hour) return &driver.OIDCCredential{ @@ -287,9 +292,9 @@ func machine_2_4_invalidClientConfigurationWithCallback() error { return nil } -func machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth() error { +func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { callbackCount := 0 - var callbackFailed error = nil + var callbackFailed error countMutex := sync.Mutex{} client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { @@ -300,7 +305,7 @@ func machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth() error { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v\n", err) + callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -327,14 +332,14 @@ func machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_3_1: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_3_1: expected callback count to be 1, got %d", callbackCount) } return callbackFailed } -func machine_3_2_authFailuresWithoutCachedTokensReturnsAnError() error { +func machine32authFailuresWithoutCachedTokensReturnsAnError() error { callbackCount := 0 - var callbackFailed error = nil + var callbackFailed error countMutex := sync.Mutex{} client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { @@ -363,14 +368,14 @@ func machine_3_2_authFailuresWithoutCachedTokensReturnsAnError() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_3_2: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_3_2: expected callback count to be 1, got %d", callbackCount) } return callbackFailed } -func machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache() error { +func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { callbackCount := 0 - var callbackFailed error = nil + var callbackFailed error countMutex := sync.Mutex{} adminClient, err := connectAdminClinet() @@ -383,7 +388,7 @@ func machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache() error { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - callbackFailed = fmt.Errorf("machine_3_3: failed reading token file: %v\n", err) + callbackFailed = fmt.Errorf("machine_3_3: failed reading token file: %v", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -427,7 +432,7 @@ func machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) } _, err = coll.Find(context.Background(), bson.D{}) @@ -437,7 +442,7 @@ func machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache() error { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d\n", callbackCount) + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) } return callbackFailed } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index b2c271a6d0..f25f2a4850 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -158,11 +158,15 @@ func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { return nil, nil } - //switch env { + switch env { // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider // This is here just to pass the linter, it will be fixed in one of the above tickets. - //} + case azureEnvironmentValue, gcpEnvironmentValue: + return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { + return nil, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) + }, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) + } return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) } From 6e1fd3ab3f3052c4b68c2dea6cfc645ad2183f84 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 16:04:11 -0400 Subject: [PATCH 49/66] GODRIVER-2911: Appease linter --- cmd/testoidcauth/main.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index abc25a9746..f63804e3de 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -160,7 +160,7 @@ func machine12callbackIsCalledOnlyOneForMultipleConnections() error { var wg sync.WaitGroup - var findFailed error = nil + var findFailed error for i := 0; i < 10; i++ { wg.Add(1) go func() { @@ -379,6 +379,11 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { countMutex := sync.Mutex{} adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting admin client: %v", err) + } client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() From 640907d47479c06a5c38ef1cee13622c57127bc9 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Fri, 21 Jun 2024 16:02:14 -0400 Subject: [PATCH 50/66] GODRIVER-2911: Change 3_3 to use fail on find, add 4_1 --- cmd/testoidcauth/main.go | 73 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index f63804e3de..3dfb2a9530 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -81,8 +81,8 @@ func main() { aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback) aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth) aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError) - // fail points do not seem to be working, or I'm using them wrongly aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) + aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds) if hasError { log.Fatal("One or more tests failed") } @@ -417,7 +417,8 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { Times: 1, }, Data: mtest.FailPointData{ - FailCommands: []string{"saslStart"}, + // saslStart failPoint is not causing find to fail, using find failPoint instead + FailCommands: []string{"find"}, AppName: "go-oidc", ErrorCode: 20, }, @@ -451,3 +452,71 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { } return callbackFailed } + +func machine41ReauthenticationSucceeds() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_1: failed reading token file: %v", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + err = mtest.SetFailPoint( + mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + AppName: "go-oidc", + ErrorCode: 391, + }, + }, + adminClient, + ) + + if err != nil { + return fmt.Errorf("machine_4_1: failed setting failpoint: %v", err) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} From 4d30705d6ec2ad02105747c7a00286425b22c54c Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Sun, 23 Jun 2024 14:37:40 -0400 Subject: [PATCH 51/66] GODRIVER-2911: Manually create fail points --- cmd/testoidcauth/main.go | 102 +++++++++++++++++++++++++----------- x/mongo/driver/operation.go | 3 ++ 2 files changed, 73 insertions(+), 32 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 3dfb2a9530..ac190e6b97 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -17,7 +17,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/integration/mtest" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" @@ -88,6 +87,47 @@ func main() { } } +func test() error { + + aclient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("failed connecting admin client: %v", err) + } + defer aclient.Disconnect(context.Background()) + + client, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("failed connecting admin client: %v", err) + } + defer client.Disconnect(context.Background()) + + db := aclient.Database("admin") + res := db.RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, // IllegalOperation + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("test: failed setting failpoint: %v", res.Err()) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("failed executing Find: %v", err) + } + return nil +} + func machine11callbackIsCalled() error { callbackCount := 0 var callbackFailed error @@ -410,24 +450,21 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { coll := client.Database("test").Collection("test") - err = mtest.SetFailPoint( - mtest.FailPoint{ - ConfigureFailPoint: "failCommand", - Mode: mtest.FailPointMode{ - Times: 1, - }, - Data: mtest.FailPointData{ - // saslStart failPoint is not causing find to fail, using find failPoint instead - FailCommands: []string{"find"}, - AppName: "go-oidc", - ErrorCode: 20, - }, - }, - adminClient, - ) + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 20}, + }}, + }) - if err != nil { - return fmt.Errorf("machine_3_3: failed setting failpoint: %v", err) + if res.Err() != nil { + return fmt.Errorf("machine_3_3: failed setting failpoint: %v", res.Err()) } _, err = coll.Find(context.Background(), bson.D{}) @@ -489,21 +526,22 @@ func machine41ReauthenticationSucceeds() error { } coll := client.Database("test").Collection("test") + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) - err = mtest.SetFailPoint( - mtest.FailPoint{ - ConfigureFailPoint: "failCommand", - Mode: mtest.FailPointMode{ - Times: 1, - }, - Data: mtest.FailPointData{ - FailCommands: []string{"find"}, - AppName: "go-oidc", - ErrorCode: 391, - }, - }, - adminClient, - ) + if res.Err() != nil { + return fmt.Errorf("machine_4_1: failed setting failpoint: %v", res.Err()) + } if err != nil { return fmt.Errorf("machine_4_1: failed setting failpoint: %v", err) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 9ce319d8c9..af5be3c2be 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -916,9 +916,12 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: + fmt.Println("!!!!") // 391 is the reauthentication required error code, so we will attempt a reauth and // retry the operation, if it is successful. + fmt.Println("code", tt.Code) if tt.Code == 391 { + fmt.Println("!!!!") if op.Authenticator != nil { if err := op.Authenticator.Reauth(ctx); err != nil { return fmt.Errorf("error reauthenticating: %w", err) From 9dd40c9776c7070af3095a60fededb027abf6f13 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Sun, 23 Jun 2024 15:06:45 -0400 Subject: [PATCH 52/66] GODRIVER-2911: This is working except 3_3 seems to be hanging --- cmd/testoidcauth/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index ac190e6b97..fda2e4946f 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -80,7 +80,7 @@ func main() { aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback) aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth) aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError) - aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) + //aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds) if hasError { log.Fatal("One or more tests failed") From b343ebbcdd454d1b81d52895ba935282ffec4069 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Sun, 23 Jun 2024 17:59:57 -0400 Subject: [PATCH 53/66] GODRIVER-2911: Tests all passing --- cmd/testoidcauth/main.go | 170 ++++++++++++++++++++++++++++++++++-- x/mongo/driver/operation.go | 3 - 2 files changed, 164 insertions(+), 9 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index fda2e4946f..540568b996 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -80,8 +80,10 @@ func main() { aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback) aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth) aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError) - //aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) + aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds) + aux("machine_4_2_readCommandsFailIfReauthenticationFails", machine42ReadCommandsFailIfReauthenticationFails) + aux("machine_4_3_writeCommandsFailIfReauthenticationFails", machine43WriteCommandsFailIfReauthenticationFails) if hasError { log.Fatal("One or more tests failed") } @@ -482,8 +484,6 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { if err != nil { return fmt.Errorf("machine_3_3: failed executing Find: %v", err) } - countMutex.Lock() - defer countMutex.Unlock() if callbackCount != 1 { return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) } @@ -543,18 +543,176 @@ func machine41ReauthenticationSucceeds() error { return fmt.Errorf("machine_4_1: failed setting failpoint: %v", res.Err()) } + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine42ReadCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_2: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_2: failed reading token file: %v", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } else { + return &driver.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + }) + + defer client.Disconnect(context.Background()) + if err != nil { - return fmt.Errorf("machine_4_1: failed setting failpoint: %v", err) + return fmt.Errorf("machine_4_2: failed connecting client: %v", err) } + coll := client.Database("test").Collection("test") _, err = coll.Find(context.Background(), bson.D{}) if err != nil { - return fmt.Errorf("machine_4_1: failed executing Find: %v", err) + return fmt.Errorf("machine_4_2: failed executing Find: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_2: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_2: Find succeeded when it should fail") } + countMutex.Lock() defer countMutex.Unlock() if callbackCount != 2 { - return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount) + return fmt.Errorf("machine_4_2: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine43WriteCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_3: failed reading token file: %v", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } else { + return &driver.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_3: failed executing Insert: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "insert", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_3: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_3: Insert succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_3: expected callback count to be 2, got %d", callbackCount) } return callbackFailed } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index af5be3c2be..9ce319d8c9 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -916,12 +916,9 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: - fmt.Println("!!!!") // 391 is the reauthentication required error code, so we will attempt a reauth and // retry the operation, if it is successful. - fmt.Println("code", tt.Code) if tt.Code == 391 { - fmt.Println("!!!!") if op.Authenticator != nil { if err := op.Authenticator.Reauth(ctx); err != nil { return fmt.Errorf("error reauthenticating: %w", err) From 5240a91dedafd3d7e20b13ca2d5d166da0e45309 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Sun, 23 Jun 2024 18:20:17 -0400 Subject: [PATCH 54/66] GODRIVER-2911: Appease linter --- cmd/testoidcauth/main.go | 23 +++++++++++------------ x/mongo/driver/auth/oidc.go | 8 ++++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 540568b996..68927ab220 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -585,13 +585,13 @@ func machine42ReadCommandsFailIfReauthenticationFails() error { ExpiresAt: &t, RefreshToken: nil, }, nil - } else { - return &driver.OIDCCredential{ - AccessToken: "this is a bad, bad token", - ExpiresAt: &t, - RefreshToken: nil, - }, nil } + return &driver.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) defer client.Disconnect(context.Background()) @@ -666,13 +666,12 @@ func machine43WriteCommandsFailIfReauthenticationFails() error { ExpiresAt: &t, RefreshToken: nil, }, nil - } else { - return &driver.OIDCCredential{ - AccessToken: "this is a bad, bad token", - ExpiresAt: &t, - RefreshToken: nil, - }, nil } + return &driver.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil }) defer client.Disconnect(context.Background()) diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index f25f2a4850..cbffbc2b0c 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -224,9 +224,9 @@ func (oa *OIDCAuthenticator) getAccessToken( // but this is a safety check, since extra invalidation is only a performance impact, not a // correctness impact. // This must only be called with the lock held -func (oa *OIDCAuthenticator) invalidateAccessToken(force bool) { +func (oa *OIDCAuthenticator) invalidateAccessToken() { tokenGenID := oa.cfg.Connection.OIDCTokenGenID() - if force || tokenGenID >= oa.tokenGenID { + if tokenGenID >= oa.tokenGenID { oa.accessToken = "" oa.cfg.Connection.SetOIDCTokenGenID(0) } @@ -236,7 +236,7 @@ func (oa *OIDCAuthenticator) invalidateAccessToken(force bool) { // driver.Authenticator interface. func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { oa.mu.Lock() - oa.invalidateAccessToken(true) + oa.invalidateAccessToken() oa.mu.Unlock() // it should be impossible to get a Reauth when an Auth has never occurred, // so we assume cfg was properly set. There is nothing to enforce this, however, @@ -266,7 +266,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { if err == nil { return nil } - oa.invalidateAccessToken(false) + oa.invalidateAccessToken() time.Sleep(invalidateSleepTimeout) } From 0cdd7a2b3621f3b78a1211bba3eacd731b78418f Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Sun, 23 Jun 2024 18:20:43 -0400 Subject: [PATCH 55/66] GODRIVER-2911: Remove test func that is unneeded --- cmd/testoidcauth/main.go | 41 ---------------------------------------- 1 file changed, 41 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 68927ab220..79929fbf29 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -89,47 +89,6 @@ func main() { } } -func test() error { - - aclient, err := connectAdminClinet() - if err != nil { - return fmt.Errorf("failed connecting admin client: %v", err) - } - defer aclient.Disconnect(context.Background()) - - client, err := connectAdminClinet() - if err != nil { - return fmt.Errorf("failed connecting admin client: %v", err) - } - defer client.Disconnect(context.Background()) - - db := aclient.Database("admin") - res := db.RunCommand(context.Background(), bson.D{ - {Key: "configureFailPoint", Value: "failCommand"}, - {Key: "mode", Value: bson.D{ - {Key: "times", Value: 1}, - }}, - {Key: "data", Value: bson.D{ - {Key: "failCommands", Value: bson.A{ - "find", - }}, - {Key: "errorCode", Value: 391}, // IllegalOperation - }}, - }) - - if res.Err() != nil { - return fmt.Errorf("test: failed setting failpoint: %v", res.Err()) - } - - coll := client.Database("test").Collection("test") - - _, err = coll.Find(context.Background(), bson.D{}) - if err != nil { - return fmt.Errorf("failed executing Find: %v", err) - } - return nil -} - func machine11callbackIsCalled() error { callbackCount := 0 var callbackFailed error From 4613c5f85ea673140b2b83b64463e5881ee3dd3c Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 26 Jun 2024 11:07:29 -0400 Subject: [PATCH 56/66] Update x/mongo/driver/auth/oidc.go Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- x/mongo/driver/auth/oidc.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index cbffbc2b0c..5c972ed70c 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -300,7 +300,8 @@ func (oa *OIDCAuthenticator) doAuthHuman(_ context.Context, _ *Config, _ OIDCCal func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { accessToken, err := oa.getAccessToken(ctx, - &OIDCArgs{Version: apiVersion, + &OIDCArgs{ + Version: apiVersion, Timeout: time.Now().Add(machineCallbackTimeout), // idpInfo is nil for machine callbacks in the current spec. IDPInfo: nil, From 40998b618af2a30c14f7d2b3f90fedcbeef730ef Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 26 Jun 2024 13:10:57 -0400 Subject: [PATCH 57/66] SQL-1937: Remove spurious authenticators, move mutex --- x/mongo/driver/auth/oidc.go | 14 +++++--------- x/mongo/driver/operation/distinct.go | 11 ----------- x/mongo/driver/operation/insert.go | 11 ----------- 3 files changed, 5 insertions(+), 31 deletions(-) diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 5c972ed70c..c9ba94f434 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -171,12 +171,14 @@ func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) } -// This should only be called with the Mutex held. func (oa *OIDCAuthenticator) getAccessToken( ctx context.Context, args *OIDCArgs, callback OIDCCallback, ) (string, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + if oa.accessToken != "" { return oa.accessToken, nil } @@ -223,8 +225,9 @@ func (oa *OIDCAuthenticator) getAccessToken( // tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, // but this is a safety check, since extra invalidation is only a performance impact, not a // correctness impact. -// This must only be called with the lock held func (oa *OIDCAuthenticator) invalidateAccessToken() { + oa.mu.Lock() + defer oa.mu.Unlock() tokenGenID := oa.cfg.Connection.OIDCTokenGenID() if tokenGenID >= oa.tokenGenID { oa.accessToken = "" @@ -235,9 +238,7 @@ func (oa *OIDCAuthenticator) invalidateAccessToken() { // Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the // driver.Authenticator interface. func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { - oa.mu.Lock() oa.invalidateAccessToken() - oa.mu.Unlock() // it should be impossible to get a Reauth when an Auth has never occurred, // so we assume cfg was properly set. There is nothing to enforce this, however, // other than the current driver code flow. If cfg is nil, Auth will return an error. @@ -246,11 +247,6 @@ func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { // Auth authenticates the connection. func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { - // the Mutex must be held during the entire Auth call so that multiple racing attempts - // to authenticate will not result in multiple callbacks. The losers on the Mutex will - // retrieve the access token from the Authenticator cache. - oa.mu.Lock() - defer oa.mu.Unlock() var err error if cfg == nil { diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index d3b2f3ce8f..484d96b66b 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -48,7 +48,6 @@ type Distinct struct { // DistinctResult represents a distinct result returned by the server. type DistinctResult struct { - authenticator driver.Authenticator // The distinct values for the field. Values bsoncore.Value } @@ -324,13 +323,3 @@ func (d *Distinct) Authenticator(authenticator driver.Authenticator) *Distinct { d.authenticator = authenticator return d } - -// Authenticator sets the authenticator to use for this operation. -func (d *DistinctResult) Authenticator(authenticator driver.Authenticator) *DistinctResult { - if d == nil { - d = new(DistinctResult) - } - - d.authenticator = authenticator - return d -} diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index a091a8ccfa..f5afe31169 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -48,7 +48,6 @@ type Insert struct { // InsertResult represents an insert result returned by the server. type InsertResult struct { - authenticator driver.Authenticator // Number of documents successfully inserted. N int64 } @@ -319,13 +318,3 @@ func (i *Insert) Authenticator(authenticator driver.Authenticator) *Insert { i.authenticator = authenticator return i } - -// Authenticator sets the authenticator to use for this operation. -func (i *InsertResult) Authenticator(authenticator driver.Authenticator) *InsertResult { - if i == nil { - i = new(InsertResult) - } - - i.authenticator = authenticator - return i -} From 2d09cc577fa93639cbcb8cb889ccf81abb8b5a13 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 26 Jun 2024 14:49:34 -0400 Subject: [PATCH 58/66] SQL-1937: Change Reauth interface --- cmd/testoidcauth/main.go | 17 +++++++++++--- mongo/client.go | 5 ---- x/mongo/driver/auth/default.go | 4 +++- x/mongo/driver/auth/gssapi.go | 3 ++- x/mongo/driver/auth/mongodbaws.go | 3 ++- x/mongo/driver/auth/mongodbcr.go | 2 +- x/mongo/driver/auth/oidc.go | 38 ++++++++++++++++++++----------- x/mongo/driver/auth/plain.go | 4 +++- x/mongo/driver/auth/scram.go | 3 ++- x/mongo/driver/auth/x509.go | 2 +- x/mongo/driver/driver.go | 3 +-- x/mongo/driver/operation.go | 8 ++++++- 12 files changed, 61 insertions(+), 31 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 79929fbf29..c86161a131 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -12,8 +12,10 @@ import ( "log" "os" "path" + "reflect" "sync" "time" + "unsafe" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -195,8 +197,12 @@ func machine21validCallbackInputs() error { if args.RefreshToken != nil { callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v", args.RefreshToken) } - if args.Timeout.Before(time.Now()) { - callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v", args.Timeout) + timeout, ok := ctx.Deadline() + if !ok { + callbackFailed = fmt.Errorf("machine_2_1: expected context to have deadline, got %v", ctx) + } + if timeout.Before(time.Now()) { + callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v", timeout) } if args.Version < 1 { callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d", args.Version) @@ -322,7 +328,12 @@ func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { } // Poison the cache with a random token - client.GetAuthenticator().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") + clientElem := reflect.ValueOf(client).Elem() + authenticatorField := clientElem.FieldByName("authenticator") + authenticatorField = reflect.NewAt( + authenticatorField.Type(), + unsafe.Pointer(authenticatorField.UnsafeAddr())).Elem() + authenticatorField.Interface().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") coll := client.Database("test").Collection("test") diff --git a/mongo/client.go b/mongo/client.go index fec2c45287..082554adbb 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -83,11 +83,6 @@ type Client struct { authenticator driver.Authenticator } -// GetAuthenticator returns the authenticator for the client, used for testing purposes. -func (c *Client) GetAuthenticator() driver.Authenticator { - return c.authenticator -} - // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling // NewClient followed by Client.Connect. // diff --git a/x/mongo/driver/auth/default.go b/x/mongo/driver/auth/default.go index a07fba1faa..0813675c66 100644 --- a/x/mongo/driver/auth/default.go +++ b/x/mongo/driver/auth/default.go @@ -9,6 +9,8 @@ package auth import ( "context" "fmt" + + "go.mongodb.org/mongo-driver/x/mongo/driver" ) func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { @@ -67,7 +69,7 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *DefaultAuthenticator) Reauth(_ context.Context) error { +func (a *DefaultAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { return newAuthError("DefaultAuthenticator does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 45c0547e55..84bdf49e78 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -15,6 +15,7 @@ import ( "fmt" "net" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi" ) @@ -59,6 +60,6 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *GSSAPIAuthenticator) Reauth(ctx context.Context) error { +func (a *GSSAPIAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { return newAuthError("GSSAPI does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index af9f0f0b18..832378e89c 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -12,6 +12,7 @@ import ( "go.mongodb.org/mongo-driver/internal/aws/credentials" "go.mongodb.org/mongo-driver/internal/credproviders" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds" ) @@ -61,7 +62,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *MongoDBAWSAuthenticator) Reauth(_ context.Context) error { +func (a *MongoDBAWSAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { return newAuthError("AWS authentication does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 41e21d2dea..416010f638 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -98,7 +98,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *MongoDBCRAuthenticator) Reauth(_ context.Context) error { +func (a *MongoDBCRAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { return newAuthError("MONGODB-CR does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index c9ba94f434..a083a45198 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -9,6 +9,7 @@ package auth import ( "context" "fmt" + "net/http" "strings" "sync" "time" @@ -75,7 +76,7 @@ type OIDCAuthenticator struct { OIDCHumanCallback OIDCCallback userName string - cfg *Config + httpClient *http.Client accessToken string refreshToken *string idpInfo *IDPInfo @@ -173,6 +174,7 @@ func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { func (oa *OIDCAuthenticator) getAccessToken( ctx context.Context, + conn driver.Connection, args *OIDCArgs, callback OIDCCallback, ) (string, error) { @@ -190,7 +192,7 @@ func (oa *OIDCAuthenticator) getAccessToken( oa.accessToken = cred.AccessToken oa.tokenGenID++ - oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) + conn.SetOIDCTokenGenID(oa.tokenGenID) if cred.RefreshToken != nil { oa.refreshToken = cred.RefreshToken } @@ -225,24 +227,31 @@ func (oa *OIDCAuthenticator) getAccessToken( // tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, // but this is a safety check, since extra invalidation is only a performance impact, not a // correctness impact. -func (oa *OIDCAuthenticator) invalidateAccessToken() { +func (oa *OIDCAuthenticator) invalidateAccessToken(conn driver.Connection) { oa.mu.Lock() defer oa.mu.Unlock() - tokenGenID := oa.cfg.Connection.OIDCTokenGenID() - if tokenGenID >= oa.tokenGenID { + tokenGenID := conn.OIDCTokenGenID() + // If the connection used in a Reauth is a new connection it will not have a correct tokenGenID, + // it will instead be set to 0. In the absence of information, the only safe thing to do is to + // invalidate the cached accessToken. + if tokenGenID == 0 || tokenGenID >= oa.tokenGenID { oa.accessToken = "" - oa.cfg.Connection.SetOIDCTokenGenID(0) + conn.SetOIDCTokenGenID(0) } } // Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the // driver.Authenticator interface. -func (oa *OIDCAuthenticator) Reauth(ctx context.Context) error { - oa.invalidateAccessToken() +func (oa *OIDCAuthenticator) Reauth(ctx context.Context, cfg *Config) error { + oa.invalidateAccessToken(cfg.Connection) + // The HTTPClient argument of the cfg will be nil on a Reauth call, so we populate + // it from the one stored in the Authenticator at Auth time, since the HTTPClient is only + // configured on driver startup. The HTTPClient will be needed for builtin provider callbacks + cfg.HTTPClient = oa.httpClient // it should be impossible to get a Reauth when an Auth has never occurred, // so we assume cfg was properly set. There is nothing to enforce this, however, // other than the current driver code flow. If cfg is nil, Auth will return an error. - return oa.Auth(ctx, oa.cfg) + return oa.Auth(ctx, cfg) } // Auth authenticates the connection. @@ -252,7 +261,8 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { if cfg == nil { return newAuthError(fmt.Sprintf("config must be set for %q authentication", MongoDBOIDC), nil) } - oa.cfg = cfg + oa.httpClient = cfg.HTTPClient + conn := cfg.Connection if oa.accessToken != "" { err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ @@ -262,7 +272,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { if err == nil { return nil } - oa.invalidateAccessToken() + oa.invalidateAccessToken(conn) time.Sleep(invalidateSleepTimeout) } @@ -295,15 +305,17 @@ func (oa *OIDCAuthenticator) doAuthHuman(_ context.Context, _ *Config, _ OIDCCal } func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { - accessToken, err := oa.getAccessToken(ctx, + subCtx, cancel := context.WithTimeout(ctx, machineCallbackTimeout) + accessToken, err := oa.getAccessToken(subCtx, + cfg.Connection, &OIDCArgs{ Version: apiVersion, - Timeout: time.Now().Add(machineCallbackTimeout), // idpInfo is nil for machine callbacks in the current spec. IDPInfo: nil, RefreshToken: nil, }, machineCallback) + cancel() if err != nil { return err } diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index a11f0d6e06..15067dc8f3 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -8,6 +8,8 @@ package auth import ( "context" + + "go.mongodb.org/mongo-driver/x/mongo/driver" ) // PLAIN is the mechanism name for PLAIN. @@ -35,7 +37,7 @@ func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *PlainAuthenticator) Reauth(_ context.Context) error { +func (a *PlainAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { return newAuthError("Plain authentication does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index 96bd0b89be..c9fbf89f3a 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -18,6 +18,7 @@ import ( "github.com/xdg-go/scram" "github.com/xdg-go/stringprep" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" ) const ( @@ -85,7 +86,7 @@ func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error { } // Reauth reauthenticates the connection. -func (a *ScramAuthenticator) Reauth(_ context.Context) error { +func (a *ScramAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { return newAuthError("SCRAM does not support reauthentication", nil) } diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 06773551a7..a624b4b53f 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -78,6 +78,6 @@ func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error } // Reauth reauthenticates the connection. -func (a *MongoDBX509Authenticator) Reauth(_ context.Context) error { +func (a *MongoDBX509Authenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { return newAuthError("X509 does not support reauthentication", nil) } diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 59f6f85ab4..60bbd375e0 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -44,7 +44,6 @@ type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) // OIDCArgs contains the arguments for the OIDC callback. type OIDCArgs struct { Version int - Timeout time.Time IDPInfo *IDPInfo RefreshToken *string } @@ -70,7 +69,7 @@ type IDPInfo struct { type Authenticator interface { // Auth authenticates the connection. Auth(context.Context, *AuthConfig) error - Reauth(context.Context) error + Reauth(context.Context, *AuthConfig) error } // Cred is a user's credential. diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 9ce319d8c9..52298ee63d 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -920,7 +920,13 @@ func (op Operation) Execute(ctx context.Context) error { // retry the operation, if it is successful. if tt.Code == 391 { if op.Authenticator != nil { - if err := op.Authenticator.Reauth(ctx); err != nil { + cfg := AuthConfig{ + Description: conn.Description(), + Connection: conn, + ClusterClock: op.Clock, + ServerAPI: op.ServerAPI, + } + if err := op.Authenticator.Reauth(ctx, &cfg); err != nil { return fmt.Errorf("error reauthenticating: %w", err) } if op.Client != nil && op.Client.Committing { From d45c7e4e5869585a6ff36a28f6f7715924514081 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Mon, 1 Jul 2024 16:46:20 -0400 Subject: [PATCH 59/66] Update Makefile Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 0c144db30e..b38bb4b6f0 100644 --- a/Makefile +++ b/Makefile @@ -135,6 +135,7 @@ evg-test-enterprise-auth: .PHONY: evg-test-oidc-auth evg-test-oidc-auth: go run ./cmd/testoidcauth/main.go + go run -race ./cmd/testoidcauth/main.go .PHONY: evg-test-kmip evg-test-kmip: From ae9c34fc787d79500386ffc66594ec083c8d8daa Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Mon, 1 Jul 2024 16:49:12 -0400 Subject: [PATCH 60/66] GODRIVER-2911: Apply httpclient patch --- mongo/client.go | 2 +- x/mongo/driver/auth/auth.go | 8 +++----- x/mongo/driver/auth/auth_test.go | 3 ++- x/mongo/driver/auth/default.go | 14 +++++++++----- x/mongo/driver/auth/gssapi.go | 3 ++- x/mongo/driver/auth/gssapi_not_enabled.go | 4 +++- x/mongo/driver/auth/gssapi_not_supported.go | 3 ++- x/mongo/driver/auth/mongodbaws.go | 14 ++++++++------ x/mongo/driver/auth/mongodbcr.go | 3 ++- x/mongo/driver/auth/oidc.go | 11 ++--------- x/mongo/driver/auth/plain.go | 3 ++- x/mongo/driver/auth/scram.go | 5 +++-- x/mongo/driver/auth/scram_test.go | 15 +++++++++------ x/mongo/driver/auth/speculative_scram_test.go | 5 +++-- x/mongo/driver/auth/speculative_x509_test.go | 5 +++-- x/mongo/driver/auth/x509.go | 3 ++- x/mongo/driver/driver.go | 2 -- x/mongo/driver/topology/topology_options.go | 3 +-- 18 files changed, 57 insertions(+), 49 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index 082554adbb..3dbcf13eb3 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -225,7 +225,7 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { Props: clientOpt.Auth.AuthMechanismProperties, OIDCMachineCallback: clientOpt.Auth.OIDCMachineCallback, OIDCHumanCallback: clientOpt.Auth.OIDCHumanCallback, - }) + }, clientOpt.HTTPClient) if err != nil { return nil, err } diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index c9e49756f2..f6471cea26 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -23,7 +23,7 @@ import ( type Config = driver.AuthConfig // AuthenticatorFactory constructs an authenticator. -type AuthenticatorFactory func(cred *Cred) (Authenticator, error) +type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error) var authFactories = make(map[string]AuthenticatorFactory) @@ -40,9 +40,9 @@ func init() { } // CreateAuthenticator creates an authenticator. -func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) { +func CreateAuthenticator(name string, cred *Cred, httpClient *http.Client) (Authenticator, error) { if f, ok := authFactories[name]; ok { - return f(cred) + return f(cred, httpClient) } return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil) @@ -65,7 +65,6 @@ type HandshakeOptions struct { ClusterClock *session.ClusterClock ServerAPI *driver.ServerAPIOptions LoadBalanced bool - HTTPClient *http.Client } type authHandshaker struct { @@ -141,7 +140,6 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Conne ClusterClock: ah.options.ClusterClock, HandshakeInfo: ah.handshakeInfo, ServerAPI: ah.options.ServerAPI, - HTTPClient: ah.options.HTTPClient, } if err := ah.authenticate(ctx, cfg); err != nil { diff --git a/x/mongo/driver/auth/auth_test.go b/x/mongo/driver/auth/auth_test.go index 9145a21595..3c07ed2cd8 100644 --- a/x/mongo/driver/auth/auth_test.go +++ b/x/mongo/driver/auth/auth_test.go @@ -7,6 +7,7 @@ package auth_test import ( + "net/http" "testing" "github.com/google/go-cmp/cmp" @@ -39,7 +40,7 @@ func TestCreateAuthenticator(t *testing.T) { PasswordSet: true, } - a, err := CreateAuthenticator(test.name, cred) + a, err := CreateAuthenticator(test.name, cred, &http.Client{}) require.NoError(t, err) require.IsType(t, test.auth, a) }) diff --git a/x/mongo/driver/auth/default.go b/x/mongo/driver/auth/default.go index 0813675c66..785a41951d 100644 --- a/x/mongo/driver/auth/default.go +++ b/x/mongo/driver/auth/default.go @@ -9,12 +9,13 @@ package auth import ( "context" "fmt" + "net/http" "go.mongodb.org/mongo-driver/x/mongo/driver" ) -func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { - scram, err := newScramSHA256Authenticator(cred) +func newDefaultAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + scram, err := newScramSHA256Authenticator(cred, httpClient) if err != nil { return nil, newAuthError("failed to create internal authenticator", err) } @@ -27,6 +28,7 @@ func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { return &DefaultAuthenticator{ Cred: cred, speculativeAuthenticator: speculative, + httpClient: httpClient, }, nil } @@ -38,6 +40,8 @@ type DefaultAuthenticator struct { // The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing // the initial hello, SCRAM-SHA-256 is used for the speculative attempt. speculativeAuthenticator SpeculativeAuthenticator + + httpClient *http.Client } var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil) @@ -54,11 +58,11 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { switch chooseAuthMechanism(cfg) { case SCRAMSHA256: - actual, err = newScramSHA256Authenticator(a.Cred) + actual, err = newScramSHA256Authenticator(a.Cred, a.httpClient) case SCRAMSHA1: - actual, err = newScramSHA1Authenticator(a.Cred) + actual, err = newScramSHA1Authenticator(a.Cred, a.httpClient) default: - actual, err = newMongoDBCRAuthenticator(a.Cred) + actual, err = newMongoDBCRAuthenticator(a.Cred, a.httpClient) } if err != nil { diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 84bdf49e78..037c944eb7 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -14,6 +14,7 @@ import ( "context" "fmt" "net" + "net/http" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi" @@ -22,7 +23,7 @@ import ( // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("GSSAPI source must be empty or $external", nil) } diff --git a/x/mongo/driver/auth/gssapi_not_enabled.go b/x/mongo/driver/auth/gssapi_not_enabled.go index 7ba5fe860c..e50553c7a1 100644 --- a/x/mongo/driver/auth/gssapi_not_enabled.go +++ b/x/mongo/driver/auth/gssapi_not_enabled.go @@ -9,9 +9,11 @@ package auth +import "net/http" + // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(*Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil) } diff --git a/x/mongo/driver/auth/gssapi_not_supported.go b/x/mongo/driver/auth/gssapi_not_supported.go index 10312c228e..12046ff67c 100644 --- a/x/mongo/driver/auth/gssapi_not_supported.go +++ b/x/mongo/driver/auth/gssapi_not_supported.go @@ -11,12 +11,13 @@ package auth import ( "fmt" + "net/http" "runtime" ) // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil) } diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 832378e89c..2245bdb6fe 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -9,6 +9,7 @@ package auth import ( "context" "errors" + "net/http" "go.mongodb.org/mongo-driver/internal/aws/credentials" "go.mongodb.org/mongo-driver/internal/credproviders" @@ -19,10 +20,13 @@ import ( // MongoDBAWS is the mechanism name for MongoDBAWS. const MongoDBAWS = "MONGODB-AWS" -func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil) } + if httpClient == nil { + return nil, errors.New("httpClient must not be nil") + } return &MongoDBAWSAuthenticator{ source: cred.Source, credentials: &credproviders.StaticProvider{ @@ -33,6 +37,7 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { SessionToken: cred.Props["AWS_SESSION_TOKEN"], }, }, + httpClient: httpClient, }, nil } @@ -40,15 +45,12 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { type MongoDBAWSAuthenticator struct { source string credentials *credproviders.StaticProvider + httpClient *http.Client } // Auth authenticates the connection. func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { - httpClient := cfg.HTTPClient - if httpClient == nil { - return errors.New("cfg.HTTPClient must not be nil") - } - providers := creds.NewAWSCredentialProvider(httpClient, a.credentials) + providers := creds.NewAWSCredentialProvider(a.httpClient, a.credentials) adapter := &awsSaslAdapter{ conversation: &awsConversation{ credentials: providers.Cred, diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 416010f638..a988011b36 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "io" + "net/http" // Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need // to use MD5 here to implement the MONGODB-CR specification. @@ -28,7 +29,7 @@ import ( // MongoDB 4.0. const MONGODBCR = "MONGODB-CR" -func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBCRAuthenticator{ DB: cred.Source, Username: cred.Username, diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index a083a45198..f117bf8e63 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -91,7 +91,7 @@ func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { oa.accessToken = accessToken } -func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { +func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { if cred.Password != "" { return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) } @@ -114,6 +114,7 @@ func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { } oa := &OIDCAuthenticator{ userName: cred.Username, + httpClient: httpClient, AuthMechanismProperties: cred.Props, OIDCMachineCallback: cred.OIDCMachineCallback, OIDCHumanCallback: cred.OIDCHumanCallback, @@ -244,13 +245,6 @@ func (oa *OIDCAuthenticator) invalidateAccessToken(conn driver.Connection) { // driver.Authenticator interface. func (oa *OIDCAuthenticator) Reauth(ctx context.Context, cfg *Config) error { oa.invalidateAccessToken(cfg.Connection) - // The HTTPClient argument of the cfg will be nil on a Reauth call, so we populate - // it from the one stored in the Authenticator at Auth time, since the HTTPClient is only - // configured on driver startup. The HTTPClient will be needed for builtin provider callbacks - cfg.HTTPClient = oa.httpClient - // it should be impossible to get a Reauth when an Auth has never occurred, - // so we assume cfg was properly set. There is nothing to enforce this, however, - // other than the current driver code flow. If cfg is nil, Auth will return an error. return oa.Auth(ctx, cfg) } @@ -261,7 +255,6 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { if cfg == nil { return newAuthError(fmt.Sprintf("config must be set for %q authentication", MongoDBOIDC), nil) } - oa.httpClient = cfg.HTTPClient conn := cfg.Connection if oa.accessToken != "" { diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 15067dc8f3..3e4c5b4eb3 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "go.mongodb.org/mongo-driver/x/mongo/driver" ) @@ -15,7 +16,7 @@ import ( // PLAIN is the mechanism name for PLAIN. const PLAIN = "PLAIN" -func newPlainAuthenticator(cred *Cred) (Authenticator, error) { +func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &PlainAuthenticator{ Username: cred.Username, Password: cred.Password, diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index c9fbf89f3a..291492e6ff 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -14,6 +14,7 @@ package auth import ( "context" + "net/http" "github.com/xdg-go/scram" "github.com/xdg-go/stringprep" @@ -36,7 +37,7 @@ var ( ) ) -func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passdigest := mongoPasswordDigest(cred.Username, cred.Password) client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "") if err != nil { @@ -50,7 +51,7 @@ func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { }, nil } -func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passprep, err := stringprep.SASLprep.Prepare(cred.Password) if err != nil { return nil, newAuthError("error SASLprepping password", err) diff --git a/x/mongo/driver/auth/scram_test.go b/x/mongo/driver/auth/scram_test.go index ef30a07364..0a745885ee 100644 --- a/x/mongo/driver/auth/scram_test.go +++ b/x/mongo/driver/auth/scram_test.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/internal/assert" @@ -38,7 +39,7 @@ func TestSCRAM(t *testing.T) { t.Run("conversation", func(t *testing.T) { testCases := []struct { name string - createAuthenticatorFn func(*Cred) (Authenticator, error) + createAuthenticatorFn func(*Cred, *http.Client) (Authenticator, error) payloads [][]byte nonce string }{ @@ -49,11 +50,13 @@ func TestSCRAM(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - authenticator, err := tc.createAuthenticatorFn(&Cred{ - Username: "user", - Password: "pencil", - Source: "admin", - }) + authenticator, err := tc.createAuthenticatorFn( + &Cred{ + Username: "user", + Password: "pencil", + Source: "admin", + }, + &http.Client{}) assert.Nil(t, err, "error creating authenticator: %v", err) sa, _ := authenticator.(*ScramAuthenticator) sa.client = sa.client.WithNonceGenerator(func() string { diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index a159891adc..9108fe1d21 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/bson" @@ -63,7 +64,7 @@ func TestSpeculativeSCRAM(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Create a SCRAM authenticator and overwrite the nonce generator to make the conversation // deterministic. - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) @@ -148,7 +149,7 @@ func TestSpeculativeSCRAM(t *testing.T) { for _, tc := range testCases { t.Run(tc.mechanism, func(t *testing.T) { - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 85bd93191b..e26b448e79 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/bson" @@ -32,7 +33,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response contains a reply to the speculative authentication attempt. The // driver should not send any more commands after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, @@ -76,7 +77,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response does not contain a reply to the speculative authentication attempt. // The driver should send an authenticate command after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index a624b4b53f..3e84f516f8 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -17,7 +18,7 @@ import ( // MongoDBX509 is the mechanism name for MongoDBX509. const MongoDBX509 = "MONGODB-X509" -func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) { +func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBX509Authenticator{User: cred.Username}, nil } diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 60bbd375e0..363f4d6be3 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -15,7 +15,6 @@ package driver // import "go.mongodb.org/mongo-driver/x/mongo/driver" import ( "context" - "net/http" "time" "go.mongodb.org/mongo-driver/internal/csot" @@ -34,7 +33,6 @@ type AuthConfig struct { ClusterClock *session.ClusterClock HandshakeInfo HandshakeInformation ServerAPI *ServerAPIOptions - HTTPClient *http.Client } // OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 7d8dd62dc0..0563e5524e 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -83,7 +83,7 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, Source: co.Auth.AuthSource, } mechanism := co.Auth.AuthMechanism - authenticator, err := auth.CreateAuthenticator(mechanism, cred) + authenticator, err := auth.CreateAuthenticator(mechanism, cred, co.HTTPClient) if err != nil { return nil, err } @@ -209,7 +209,6 @@ func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.Cluste ServerAPI: serverAPI, LoadBalanced: loadBalanced, ClusterClock: clock, - HTTPClient: co.HTTPClient, } if mechanism == "" { From 1d8691477aff3733093cf0a82dd6cdcefc1bbd09 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Mon, 1 Jul 2024 17:05:45 -0400 Subject: [PATCH 61/66] GODRIVER-2911: Fix races --- x/mongo/driver/auth/oidc.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index f117bf8e63..af935f9516 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -257,14 +257,21 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { } conn := cfg.Connection - if oa.accessToken != "" { + oa.mu.Lock() + cachedAccessToken := oa.accessToken + oa.mu.Unlock() + + if cachedAccessToken != "" { err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ userName: oa.userName, - accessToken: oa.accessToken, + accessToken: cachedAccessToken, }) if err == nil { return nil } + // this seems like it could be incorrect since we could be inavlidating an access token that + // has already been replaced by a different auth attempt, but the TokenGenID will prevernt + // that from happening. oa.invalidateAccessToken(conn) time.Sleep(invalidateSleepTimeout) } From 30ed4c4b70639d2d41068a6dadc0fe6f3f41300a Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Tue, 2 Jul 2024 10:53:15 -0400 Subject: [PATCH 62/66] GODRIVER-2911: Back out changes to sasl, add comment, remove Println in favor of using oa.idpInfo in error message --- x/mongo/driver/auth/oidc.go | 15 ++++++++++----- x/mongo/driver/auth/sasl.go | 11 ++--------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index af935f9516..91748598d3 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -36,6 +36,11 @@ const testEnvironmentValue = "test" const apiVersion = 1 const invalidateSleepTimeout = 100 * time.Millisecond + +// The CSOT specification says to apply a 1-minute timeout if "CSOT is not applied". That's +// ambiguous for the v1.x Go Driver because it could mean either "no timeout provided" or "CSOT not +// enabled". Always use a maximum timeout duration of 1 minute, allowing us to ignore the ambiguity. +// Contexts with a shorter timeout are unaffected. const machineCallbackTimeout = 60 * time.Second //GODRIVER-3246 OIDC: Implement Human Callback Mechanism @@ -299,9 +304,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { func (oa *OIDCAuthenticator) doAuthHuman(_ context.Context, _ *Config, _ OIDCCallback) error { // TODO GODRIVER-3246: Implement OIDC human flow - // Println is for linter - fmt.Println("OIDC human flow not implemented yet", oa.idpInfo) - return newAuthError("OIDC human flow not implemented yet", nil) + return newAuthError("OIDC", fmt.Errorf("human flow not implemented yet, %v", oa.idpInfo)) } func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { @@ -319,9 +322,11 @@ func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, mac if err != nil { return err } - return runSaslConversation(ctx, + return ConductSaslConversation( + ctx, cfg, - newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", false), + "$external", + &oidcOneStep{accessToken: accessToken}, ) } diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 2b9fe386d7..75f0c411bf 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -152,23 +152,16 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon } } -// ConductSaslConversation runs a full SASL conversation to authenticate the given connection, given -// sasl arguments. +// ConductSaslConversation runs a full SASL conversation to authenticate the given connection. func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error { // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) - return runSaslConversation(ctx, cfg, conversation) -} - -// runSaslConversation runs a SASL conversation to authenticate the given connection, given a -// pre-built saslConversation. -func runSaslConversation(ctx context.Context, cfg *Config, conversation *saslConversation) error { saslStartDoc, err := conversation.FirstMessage() if err != nil { return newError(err, conversation.mechanism) } saslStartCmd := operation.NewCommand(saslStartDoc). - Database(conversation.source). + Database(authSource). Deployment(driver.SingleConnectionDeployment{cfg.Connection}). ClusterClock(cfg.ClusterClock). ServerAPI(cfg.ServerAPI) From 519205c784d9d20c61d73c102db84a2dd11a6616 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Tue, 2 Jul 2024 17:17:01 -0400 Subject: [PATCH 63/66] GODRIVER-2911: Move public OIDC configuration types into public, non-experimental options package --- cmd/testoidcauth/main.go | 55 +++++++++++++++++----------------- mongo/options/clientoptions.go | 19 ++++++++++-- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index c86161a131..82e95f1db1 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -20,7 +20,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" ) @@ -44,14 +43,14 @@ func connectAdminClinet() (*mongo.Client, error) { return mongo.Connect(context.Background(), options.Client().ApplyURI(uriAdmin)) } -func connectWithMachineCB(uri string, cb driver.OIDCCallback) (*mongo.Client, error) { +func connectWithMachineCB(uri string, cb options.OIDCCallback) (*mongo.Client, error) { opts := options.Client().ApplyURI(uri) opts.Auth.OIDCMachineCallback = cb return mongo.Connect(context.Background(), opts) } -func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props map[string]string) (*mongo.Client, error) { +func connectWithMachineCBAndProperties(uri string, cb options.OIDCCallback, props map[string]string) (*mongo.Client, error) { opts := options.Client().ApplyURI(uri) opts.Auth.OIDCMachineCallback = cb @@ -96,7 +95,7 @@ func machine11callbackIsCalled() error { var callbackFailed error countMutex := sync.Mutex{} - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -106,7 +105,7 @@ func machine11callbackIsCalled() error { if err != nil { callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, @@ -138,7 +137,7 @@ func machine12callbackIsCalledOnlyOneForMultipleConnections() error { var callbackFailed error countMutex := sync.Mutex{} - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -148,7 +147,7 @@ func machine12callbackIsCalledOnlyOneForMultipleConnections() error { if err != nil { callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, @@ -193,7 +192,7 @@ func machine21validCallbackInputs() error { var callbackFailed error countMutex := sync.Mutex{} - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { if args.RefreshToken != nil { callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v", args.RefreshToken) } @@ -219,7 +218,7 @@ func machine21validCallbackInputs() error { if err != nil { fmt.Printf("machine_2_1: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, @@ -250,12 +249,12 @@ func machine23oidcCallbackReturnMissingData() error { callbackCount := 0 countMutex := sync.Mutex{} - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ t := time.Now().Add(time.Hour) - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: "", ExpiresAt: &t, RefreshToken: nil, @@ -283,9 +282,9 @@ func machine23oidcCallbackReturnMissingData() error { } func machine24invalidClientConfigurationWithCallback() error { - _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { t := time.Now().Add(time.Hour) - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: "", ExpiresAt: &t, RefreshToken: nil, @@ -304,7 +303,7 @@ func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { var callbackFailed error countMutex := sync.Mutex{} - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -314,7 +313,7 @@ func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { if err != nil { callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, @@ -333,6 +332,8 @@ func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { authenticatorField = reflect.NewAt( authenticatorField.Type(), unsafe.Pointer(authenticatorField.UnsafeAddr())).Elem() + // this is the only usage of the x packages in the test, showing the the public interface is + // correct. authenticatorField.Interface().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") coll := client.Database("test").Collection("test") @@ -354,12 +355,12 @@ func machine32authFailuresWithoutCachedTokensReturnsAnError() error { var callbackFailed error countMutex := sync.Mutex{} - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ t := time.Now().Add(time.Hour) - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: "this is a bad, bad token", ExpiresAt: &t, RefreshToken: nil, @@ -397,7 +398,7 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { return fmt.Errorf("machine_3_3: failed connecting admin client: %v", err) } - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -407,7 +408,7 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { if err != nil { callbackFailed = fmt.Errorf("machine_3_3: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, @@ -472,7 +473,7 @@ func machine41ReauthenticationSucceeds() error { return fmt.Errorf("machine_4_1: failed connecting admin client: %v", err) } - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -482,7 +483,7 @@ func machine41ReauthenticationSucceeds() error { if err != nil { callbackFailed = fmt.Errorf("machine_4_1: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, @@ -538,7 +539,7 @@ func machine42ReadCommandsFailIfReauthenticationFails() error { return fmt.Errorf("machine_4_2: failed connecting admin client: %v", err) } - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -550,13 +551,13 @@ func machine42ReadCommandsFailIfReauthenticationFails() error { if err != nil { callbackFailed = fmt.Errorf("machine_4_2: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, }, nil } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: "this is a bad, bad token", ExpiresAt: &t, RefreshToken: nil, @@ -619,7 +620,7 @@ func machine43WriteCommandsFailIfReauthenticationFails() error { return fmt.Errorf("machine_4_3: failed connecting admin client: %v", err) } - client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -631,13 +632,13 @@ func machine43WriteCommandsFailIfReauthenticationFails() error { if err != nil { callbackFailed = fmt.Errorf("machine_4_3: failed reading token file: %v", err) } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: string(accessToken), ExpiresAt: &t, RefreshToken: nil, }, nil } - return &driver.OIDCCredential{ + return &options.OIDCCredential{ AccessToken: "this is a bad, bad token", ExpiresAt: &t, RefreshToken: nil, diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 1fd7765a30..838e38f9fa 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -110,10 +110,25 @@ type Credential struct { Username string Password string PasswordSet bool - OIDCMachineCallback driver.OIDCCallback - OIDCHumanCallback driver.OIDCCallback + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback } +// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be +// nil in the OIDCArgs for the Machine flow. +type OIDCCallback = driver.OIDCCallback + +// OIDCArgs contains the arguments for OIDC authentication. +type OIDCArgs = driver.OIDCArgs + +// OIDCCredential contains the OIDC access and refresh tokens, and is the type returned by an +// OIDCCallback. +type OIDCCredential = driver.OIDCCredential + +// IDPInfo contains the information needed to perform OIDC human flow authentication with an Idetity +// Provider (IDP). +type IDPInfo = driver.IDPInfo + // BSONOptions are optional BSON marshaling and unmarshaling behaviors. type BSONOptions struct { // UseJSONStructTags causes the driver to fall back to using the "json" From 5f0c68d313ba7e05eb707a3f2290f992760f23fa Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Tue, 2 Jul 2024 17:19:35 -0400 Subject: [PATCH 64/66] GODRIVER-2911: Improve comment --- mongo/options/clientoptions.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 838e38f9fa..9cfaad86e9 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -118,7 +118,8 @@ type Credential struct { // nil in the OIDCArgs for the Machine flow. type OIDCCallback = driver.OIDCCallback -// OIDCArgs contains the arguments for OIDC authentication. +// OIDCArgs contains the arguments for OIDC authentication that are passed to the OIDCCallback +// function. type OIDCArgs = driver.OIDCArgs // OIDCCredential contains the OIDC access and refresh tokens, and is the type returned by an From 6a3af5a86c3f54298409267b3ffc74b92e2ab8b6 Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 3 Jul 2024 18:12:49 -0400 Subject: [PATCH 65/66] GODRIVER-2911: Update script comment --- etc/run-oidc-test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/etc/run-oidc-test.sh b/etc/run-oidc-test.sh index 755e162dd1..bc5eb99758 100644 --- a/etc/run-oidc-test.sh +++ b/etc/run-oidc-test.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -# run-enterprise-gssapi-test -# Runs the enterprise auth tests with gssapi credentials. +# run-oidc-test +# Runs oidc auth tests. set -eu echo "Running MONGODB-OIDC authentication tests" From 7340dddaf4a7c90b48b8e420ef9dce712ade5fde Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Wed, 3 Jul 2024 21:10:10 -0400 Subject: [PATCH 66/66] GODRIVER-2911: Use conversion functions instead of type redeclarations --- mongo/client.go | 33 ++++++++++++++- mongo/client_test.go | 76 ++++++++++++++++++++++++++++++++++ mongo/options/clientoptions.go | 36 ++++++++++------ 3 files changed, 130 insertions(+), 15 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index 3dbcf13eb3..1f79c18507 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -216,6 +216,22 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { } if clientOpt.Auth != nil { + var oidcMachineCallback auth.OIDCCallback + if clientOpt.Auth.OIDCMachineCallback != nil { + oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := clientOpt.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + + var oidcHumanCallback auth.OIDCCallback + if clientOpt.Auth.OIDCHumanCallback != nil { + oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := clientOpt.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + // Create an authenticator for the client client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ Source: clientOpt.Auth.AuthSource, @@ -223,8 +239,8 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { Password: clientOpt.Auth.Password, PasswordSet: clientOpt.Auth.PasswordSet, Props: clientOpt.Auth.AuthMechanismProperties, - OIDCMachineCallback: clientOpt.Auth.OIDCMachineCallback, - OIDCHumanCallback: clientOpt.Auth.OIDCHumanCallback, + OIDCMachineCallback: oidcMachineCallback, + OIDCHumanCallback: oidcHumanCallback, }, clientOpt.HTTPClient) if err != nil { return nil, err @@ -253,6 +269,19 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return client, nil } +// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent +// public type *options.OIDCArgs. +func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { + if args == nil { + return nil + } + return &options.OIDCArgs{ + Version: args.Version, + IDPInfo: (*options.IDPInfo)(args.IDPInfo), + RefreshToken: args.RefreshToken, + } +} + // Connect initializes the Client by starting background monitoring goroutines. // If the Client was created using the NewClient function, this method must be called before a Client can be used. // diff --git a/mongo/client_test.go b/mongo/client_test.go index 013c1ae6bb..0a96e54501 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -11,6 +11,7 @@ import ( "errors" "math" "os" + "reflect" "testing" "time" @@ -18,11 +19,13 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/integtest" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" @@ -502,3 +505,76 @@ func TestClient(t *testing.T) { } }) } + +// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs +// into an options.OIDCArgs. +func TestConvertOIDCArgs(t *testing.T) { + refreshToken := "test refresh token" + + testCases := []struct { + desc string + args *driver.OIDCArgs + }{ + { + desc: "populated args", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: &driver.IDPInfo{ + Issuer: "test issuer", + ClientID: "test client ID", + RequestScopes: []string{"test scope 1", "test scope 2"}, + }, + RefreshToken: &refreshToken, + }, + }, + { + desc: "nil", + args: nil, + }, + { + desc: "nil IDPInfo and RefreshToken", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: nil, + RefreshToken: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + got := convertOIDCArgs(tc.args) + + if tc.args == nil { + assert.Nil(t, got, "expected nil when input is nil") + return + } + + require.Equal(t, + 3, + reflect.ValueOf(*tc.args).NumField(), + "expected the driver.OIDCArgs struct to have exactly 3 fields") + require.Equal(t, + 3, + reflect.ValueOf(*got).NumField(), + "expected the options.OIDCArgs struct to have exactly 3 fields") + + assert.Equal(t, + tc.args.Version, + got.Version, + "expected Version field to be equal") + assert.EqualValues(t, + tc.args.IDPInfo, + got.IDPInfo, + "expected IDPInfo field to be convertible to equal values") + assert.Equal(t, + tc.args.RefreshToken, + got.RefreshToken, + "expected RefreshToken field to be equal") + }) + } +} diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 9cfaad86e9..19322e5dcd 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -114,21 +114,31 @@ type Credential struct { OIDCHumanCallback OIDCCallback } -// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be -// nil in the OIDCArgs for the Machine flow. -type OIDCCallback = driver.OIDCCallback - -// OIDCArgs contains the arguments for OIDC authentication that are passed to the OIDCCallback -// function. -type OIDCArgs = driver.OIDCArgs +// OIDCCallback is the type for both Human and Machine Callback flows. +// RefreshToken will always be nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} -// OIDCCredential contains the OIDC access and refresh tokens, and is the type returned by an -// OIDCCallback. -type OIDCCredential = driver.OIDCCredential +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} -// IDPInfo contains the information needed to perform OIDC human flow authentication with an Idetity -// Provider (IDP). -type IDPInfo = driver.IDPInfo +// IDPInfo contains the information needed to perform OIDC authentication with +// an Identity Provider. +type IDPInfo struct { + Issuer string + ClientID string + RequestScopes []string +} // BSONOptions are optional BSON marshaling and unmarshaling behaviors. type BSONOptions struct {