diff --git a/pkg/jms-sdk-go/model/platform.go b/pkg/jms-sdk-go/model/platform.go index 5aa4a9c6..22aaa690 100644 --- a/pkg/jms-sdk-go/model/platform.go +++ b/pkg/jms-sdk-go/model/platform.go @@ -62,7 +62,8 @@ type ProtocolSetting struct { TelnetSuccessPrompt string `json:"success_prompt"` // for mongodb - AuthSource string `json:"auth_source"` + AuthSource string `json:"auth_source"` + ConnectionOpts string `json:"connection_options"` } type Protocol struct { diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index ce3679f5..97152acc 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -553,6 +553,7 @@ func (s *Server) getMongoDBConn(localTunnelAddr *net.TCPAddr) (srvConn *srvconn. platform := s.connOpts.authInfo.Platform protocolSetting := platform.GetProtocol("mongodb") authSource := protocolSetting.Setting.AuthSource + connectionOpts := protocolSetting.Setting.ConnectionOpts srvConn, err = srvconn.NewMongoDBConnection( srvconn.SqlHost(host), srvconn.SqlPort(port), @@ -564,6 +565,7 @@ func (s *Server) getMongoDBConn(localTunnelAddr *net.TCPAddr) (srvConn *srvconn. srvconn.SqlCertKey(asset.SecretInfo.ClientKey), srvconn.SqlAllowInvalidCert(asset.SpecInfo.AllowInvalidCert), srvconn.SqlAuthSource(authSource), + srvconn.SqlConnectionOptions(connectionOpts), srvconn.SqlPtyWin(srvconn.Windows{ Width: s.UserConn.Pty().Window.Width, Height: s.UserConn.Pty().Window.Height, diff --git a/pkg/srvconn/conn_mongodb.go b/pkg/srvconn/conn_mongodb.go index bbccfa62..44f28c11 100644 --- a/pkg/srvconn/conn_mongodb.go +++ b/pkg/srvconn/conn_mongodb.go @@ -6,6 +6,7 @@ import ( "net/url" "os" "strconv" + "strings" "time" "github.com/jumpserver/koko/pkg/logger" @@ -139,12 +140,43 @@ func (opt *sqlOption) GetAuthSource() string { return opt.AuthSource } -func (opt *sqlOption) MongoDBCommandArgs() []string { - host := net.JoinHostPort(opt.Host, strconv.Itoa(opt.Port)) - params := map[string]string{ +func (opt *sqlOption) GetConnectionOptions() map[string]string { + if opt.ConnectionOptions == "" { + return nil + } + opts := strings.Split(opt.ConnectionOptions, "&") + if len(opts) == 0 { + return nil + } + optMap := make(map[string]string, len(opts)) + for _, item := range opts { + kv := strings.Split(item, "=") + if len(kv) != 2 { + continue + } + optMap[kv[0]] = kv[1] + + } + return optMap +} + +func (opt *sqlOption) GetParams() (params map[string]string) { + params = map[string]string{ "authSource": opt.GetAuthSource(), } + connectionOpts := opt.GetConnectionOptions() + if len(connectionOpts) > 0 { + for k, v := range connectionOpts { + params[k] = v + } + } addMongoParamsWithSSL(opt, params) + return +} + +func (opt *sqlOption) MongoDBCommandArgs() []string { + host := net.JoinHostPort(opt.Host, strconv.Itoa(opt.Port)) + params := opt.GetParams() uri := BuildMongoDBURI( MongoHost(host), MongoDBName(opt.DBName), @@ -158,11 +190,7 @@ func (opt *sqlOption) MongoDBCommandArgs() []string { func checkMongoDBAccount(args *sqlOption) error { host := net.JoinHostPort(args.Host, strconv.Itoa(args.Port)) - params := map[string]string{ - "authSource": args.GetAuthSource(), - "connect": "direct", - } - addMongoParamsWithSSL(args, params) + params := args.GetParams() uri := BuildMongoDBURI( MongoHost(host), MongoAuth(args.Username, args.Password), diff --git a/pkg/srvconn/conn_sql_opt.go b/pkg/srvconn/conn_sql_opt.go index 3bb96508..97e01a8a 100644 --- a/pkg/srvconn/conn_sql_opt.go +++ b/pkg/srvconn/conn_sql_opt.go @@ -24,7 +24,8 @@ type sqlOption struct { disableMySQLAutoRehash bool - AuthSource string + AuthSource string + ConnectionOptions string } type SqlOption func(*sqlOption) @@ -101,6 +102,12 @@ func SqlAuthSource(authSource string) SqlOption { } } +func SqlConnectionOptions(options string) SqlOption { + return func(args *sqlOption) { + args.ConnectionOptions = options + } +} + const ( maxSQLConnCount = 1 maxIdleTime = time.Second * 15