go mod vendor

+ move k8s.io/apimachinery fork from go.work to go.mod
(and include it in vendor)
This commit is contained in:
2022-11-07 00:16:27 +02:00
parent d08bbf250a
commit e45bf4739b
1366 changed files with 469062 additions and 45 deletions

View File

@@ -0,0 +1,229 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"errors"
"fmt"
"net/http"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// AuthenticatorFactory constructs an authenticator.
type AuthenticatorFactory func(cred *Cred) (Authenticator, error)
var authFactories = make(map[string]AuthenticatorFactory)
func init() {
RegisterAuthenticatorFactory("", newDefaultAuthenticator)
RegisterAuthenticatorFactory(SCRAMSHA1, newScramSHA1Authenticator)
RegisterAuthenticatorFactory(SCRAMSHA256, newScramSHA256Authenticator)
RegisterAuthenticatorFactory(MONGODBCR, newMongoDBCRAuthenticator)
RegisterAuthenticatorFactory(PLAIN, newPlainAuthenticator)
RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator)
RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator)
RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator)
}
// CreateAuthenticator creates an authenticator.
func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) {
if f, ok := authFactories[name]; ok {
return f(cred)
}
return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
}
// RegisterAuthenticatorFactory registers the authenticator factory.
func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) {
authFactories[name] = factory
}
// HandshakeOptions packages options that can be passed to the Handshaker()
// function. DBUser is optional but must be of the form <dbname.username>;
// if non-empty, then the connection will do SASL mechanism negotiation.
type HandshakeOptions struct {
AppName string
Authenticator Authenticator
Compressors []string
DBUser string
PerformAuthentication func(description.Server) bool
ClusterClock *session.ClusterClock
ServerAPI *driver.ServerAPIOptions
LoadBalanced bool
HTTPClient *http.Client
}
type authHandshaker struct {
wrapped driver.Handshaker
options *HandshakeOptions
handshakeInfo driver.HandshakeInformation
conversation SpeculativeConversation
}
var _ driver.Handshaker = (*authHandshaker)(nil)
// GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided
// connection.
func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
if ah.wrapped != nil {
return ah.wrapped.GetHandshakeInformation(ctx, addr, conn)
}
op := operation.NewHello().
AppName(ah.options.AppName).
Compressors(ah.options.Compressors).
SASLSupportedMechs(ah.options.DBUser).
ClusterClock(ah.options.ClusterClock).
ServerAPI(ah.options.ServerAPI).
LoadBalanced(ah.options.LoadBalanced)
if ah.options.Authenticator != nil {
if speculativeAuth, ok := ah.options.Authenticator.(SpeculativeAuthenticator); ok {
var err error
ah.conversation, err = speculativeAuth.CreateSpeculativeConversation()
if err != nil {
return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err)
}
firstMsg, err := ah.conversation.FirstMessage()
if err != nil {
return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err)
}
op = op.SpeculativeAuthenticate(firstMsg)
}
}
var err error
ah.handshakeInfo, err = op.GetHandshakeInformation(ctx, addr, conn)
if err != nil {
return driver.HandshakeInformation{}, newAuthError("handshake failure", err)
}
return ah.handshakeInfo, nil
}
// FinishHandshake performs authentication for conn if necessary.
func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
performAuth := ah.options.PerformAuthentication
if performAuth == nil {
performAuth = func(serv description.Server) bool {
// Authentication is possible against all server types except arbiters
return serv.Kind != description.RSArbiter
}
}
desc := conn.Description()
if performAuth(desc) && ah.options.Authenticator != nil {
cfg := &Config{
Description: desc,
Connection: conn,
ClusterClock: ah.options.ClusterClock,
HandshakeInfo: ah.handshakeInfo,
ServerAPI: ah.options.ServerAPI,
HTTPClient: ah.options.HTTPClient,
}
if err := ah.authenticate(ctx, cfg); err != nil {
return newAuthError("auth error", err)
}
}
if ah.wrapped == nil {
return nil
}
return ah.wrapped.FinishHandshake(ctx, conn)
}
func (ah *authHandshaker) authenticate(ctx context.Context, cfg *Config) error {
// If the initial hello reply included a response to the speculative authentication attempt, we only need to
// conduct the remainder of the conversation.
if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil {
// Defensively ensure that the server did not include a response if speculative auth was not attempted.
if ah.conversation == nil {
return errors.New("speculative auth was not attempted but the server included a response")
}
return ah.conversation.Finish(ctx, cfg, speculativeResponse)
}
// If the server does not support speculative authentication or the first attempt was not successful, we need to
// perform authentication from scratch.
return ah.options.Authenticator.Auth(ctx, cfg)
}
// Handshaker creates a connection handshaker for the given authenticator.
func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
return &authHandshaker{
wrapped: h,
options: options,
}
}
// Config holds the information necessary to perform an authentication attempt.
type Config struct {
Description description.Server
Connection driver.Connection
ClusterClock *session.ClusterClock
HandshakeInfo driver.HandshakeInformation
ServerAPI *driver.ServerAPIOptions
HTTPClient *http.Client
}
// Authenticator handles authenticating a connection.
type Authenticator interface {
// Auth authenticates the connection.
Auth(context.Context, *Config) error
}
func newAuthError(msg string, inner error) error {
return &Error{
message: msg,
inner: inner,
}
}
func newError(err error, mech string) error {
return &Error{
message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
inner: err,
}
}
// Error is an error that occurred during authentication.
type Error struct {
message string
inner error
}
func (e *Error) Error() string {
if e.inner == nil {
return e.message
}
return fmt.Sprintf("%s: %s", e.message, e.inner)
}
// Inner returns the wrapped error.
func (e *Error) Inner() error {
return e.inner
}
// Unwrap returns the underlying error.
func (e *Error) Unwrap() error {
return e.inner
}
// Message returns the message.
func (e *Error) Message() string {
return e.message
}

View File

@@ -0,0 +1,348 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4"
)
type clientState int
const (
clientStarting clientState = iota
clientFirst
clientFinal
clientDone
)
type awsConversation struct {
state clientState
valid bool
nonce []byte
username string
password string
token string
httpClient *http.Client
}
type serverMessage struct {
Nonce primitive.Binary `bson:"s"`
Host string `bson:"h"`
}
type ecsResponse struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"Token"`
}
const (
amzDateFormat = "20060102T150405Z"
awsRelativeURI = "http://169.254.170.2/"
awsEC2URI = "http://169.254.169.254/"
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
awsEC2TokenPath = "latest/api/token"
defaultRegion = "us-east-1"
maxHostLength = 255
defaultHTTPTimeout = 10 * time.Second
responceNonceLength = 64
)
// Step takes a string provided from a server (or just an empty string for the
// very first conversation step) and attempts to move the authentication
// conversation forward. It returns a string to be sent to the server or an
// error if the server message is invalid. Calling Step after a conversation
// completes is also an error.
func (ac *awsConversation) Step(challenge []byte) (response []byte, err error) {
switch ac.state {
case clientStarting:
ac.state = clientFirst
response = ac.firstMsg()
case clientFirst:
ac.state = clientFinal
response, err = ac.finalMsg(challenge)
case clientFinal:
ac.state = clientDone
ac.valid = true
default:
response, err = nil, errors.New("Conversation already completed")
}
return
}
// Done returns true if the conversation is completed or has errored.
func (ac *awsConversation) Done() bool {
return ac.state == clientDone
}
// Valid returns true if the conversation successfully authenticated with the
// server, including counter-validation that the server actually has the
// user's stored credentials.
func (ac *awsConversation) Valid() bool {
return ac.valid
}
func getRegion(host string) (string, error) {
region := defaultRegion
if len(host) == 0 {
return "", errors.New("invalid STS host: empty")
}
if len(host) > maxHostLength {
return "", errors.New("invalid STS host: too large")
}
// The implicit region for sts.amazonaws.com is us-east-1
if host == "sts.amazonaws.com" {
return region, nil
}
if strings.HasPrefix(host, ".") || strings.HasSuffix(host, ".") || strings.Contains(host, "..") {
return "", errors.New("invalid STS host: empty part")
}
// If the host has multiple parts, the second part is the region
parts := strings.Split(host, ".")
if len(parts) >= 2 {
region = parts[1]
}
return region, nil
}
func (ac *awsConversation) validateAndMakeCredentials() (*awsv4.StaticProvider, error) {
if ac.username != "" && ac.password == "" {
return nil, errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
}
if ac.username == "" && ac.password != "" {
return nil, errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
}
if ac.username == "" && ac.password == "" && ac.token != "" {
return nil, errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
}
if ac.username != "" || ac.password != "" || ac.token != "" {
return &awsv4.StaticProvider{Value: awsv4.Value{
AccessKeyID: ac.username,
SecretAccessKey: ac.password,
SessionToken: ac.token,
}}, nil
}
return nil, nil
}
func executeAWSHTTPRequest(httpClient *http.Client, req *http.Request) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultHTTPTimeout)
defer cancel()
resp, err := httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
}
func (ac *awsConversation) getEC2Credentials() (*awsv4.StaticProvider, error) {
// get token
req, err := http.NewRequest("PUT", awsEC2URI+awsEC2TokenPath, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "30")
token, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
if len(token) == 0 {
return nil, errors.New("unable to retrieve token from EC2 metadata")
}
tokenStr := string(token)
// get role name
req, err = http.NewRequest("GET", awsEC2URI+awsEC2RolePath, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
role, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
if len(role) == 0 {
return nil, errors.New("unable to retrieve role_name from EC2 metadata")
}
// get credentials
pathWithRole := awsEC2URI + awsEC2RolePath + string(role)
req, err = http.NewRequest("GET", pathWithRole, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
creds, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
var es2Resp ecsResponse
err = json.Unmarshal(creds, &es2Resp)
if err != nil {
return nil, err
}
ac.username = es2Resp.AccessKeyID
ac.password = es2Resp.SecretAccessKey
ac.token = es2Resp.Token
return ac.validateAndMakeCredentials()
}
func (ac *awsConversation) getCredentials() (*awsv4.StaticProvider, error) {
// Credentials passed through URI
creds, err := ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
// Credentials from environment variables
ac.username = os.Getenv("AWS_ACCESS_KEY_ID")
ac.password = os.Getenv("AWS_SECRET_ACCESS_KEY")
ac.token = os.Getenv("AWS_SESSION_TOKEN")
creds, err = ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
// Credentials from ECS metadata
relativeEcsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
if len(relativeEcsURI) > 0 {
fullURI := awsRelativeURI + relativeEcsURI
req, err := http.NewRequest("GET", fullURI, nil)
if err != nil {
return nil, err
}
body, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
var espResp ecsResponse
err = json.Unmarshal(body, &espResp)
if err != nil {
return nil, err
}
ac.username = espResp.AccessKeyID
ac.password = espResp.SecretAccessKey
ac.token = espResp.Token
creds, err = ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
}
// Credentials from EC2 metadata
creds, err = ac.getEC2Credentials()
if creds == nil && err == nil {
return nil, errors.New("unable to get credentials")
}
return creds, err
}
func (ac *awsConversation) firstMsg() []byte {
// Values are cached for use in final message parameters
ac.nonce = make([]byte, 32)
_, _ = rand.Read(ac.nonce)
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendInt32Element(msg, "p", 110)
msg = bsoncore.AppendBinaryElement(msg, "r", 0x00, ac.nonce)
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg
}
func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
var sm serverMessage
err := bson.Unmarshal(s1, &sm)
if err != nil {
return nil, err
}
// Check nonce prefix
if sm.Nonce.Subtype != 0x00 {
return nil, errors.New("server reply contained unexpected binary subtype")
}
if len(sm.Nonce.Data) != responceNonceLength {
return nil, fmt.Errorf("server reply nonce was not %v bytes", responceNonceLength)
}
if !bytes.HasPrefix(sm.Nonce.Data, ac.nonce) {
return nil, errors.New("server nonce did not extend client nonce")
}
region, err := getRegion(sm.Host)
if err != nil {
return nil, err
}
creds, err := ac.getCredentials()
if err != nil {
return nil, err
}
currentTime := time.Now().UTC()
body := "Action=GetCallerIdentity&Version=2011-06-15"
// Create http.Request
req, _ := http.NewRequest("POST", "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Length", "43")
req.Host = sm.Host
req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
if len(ac.token) > 0 {
req.Header.Set("X-Amz-Security-Token", ac.token)
}
req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
// Create signer with credentials
signer := awsv4.NewSigner(creds)
// Get signed header
_, err = signer.Sign(req, strings.NewReader(body), "sts", region, currentTime)
if err != nil {
return nil, err
}
// create message
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
if len(ac.token) > 0 {
msg = bsoncore.AppendStringElement(msg, "t", ac.token)
}
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg, nil
}

View File

@@ -0,0 +1,31 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// SpeculativeConversation represents an authentication conversation that can be merged with the initial connection
// handshake.
//
// FirstMessage method returns the first message to be sent to the server. This message will be included in the initial
// hello command.
//
// Finish takes the server response to the initial message and conducts the remainder of the conversation to
// authenticate the provided connection.
type SpeculativeConversation interface {
FirstMessage() (bsoncore.Document, error)
Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error
}
// SpeculativeAuthenticator represents an authenticator that supports speculative authentication.
type SpeculativeAuthenticator interface {
CreateSpeculativeConversation() (SpeculativeConversation, error)
}

View File

@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
// Cred is a user's credential.
type Cred struct {
Source string
Username string
Password string
PasswordSet bool
Props map[string]string
}

View File

@@ -0,0 +1,98 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo/description"
)
func newDefaultAuthenticator(cred *Cred) (Authenticator, error) {
scram, err := newScramSHA256Authenticator(cred)
if err != nil {
return nil, newAuthError("failed to create internal authenticator", err)
}
speculative, ok := scram.(SpeculativeAuthenticator)
if !ok {
typeErr := fmt.Errorf("expected SCRAM authenticator to be SpeculativeAuthenticator but got %T", scram)
return nil, newAuthError("failed to create internal authenticator", typeErr)
}
return &DefaultAuthenticator{
Cred: cred,
speculativeAuthenticator: speculative,
}, nil
}
// DefaultAuthenticator uses SCRAM-SHA-1 or MONGODB-CR depending
// on the server version.
type DefaultAuthenticator struct {
Cred *Cred
// The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing
// the initial hello, SCRAM-SHA-256 is used for the speculative attempt.
speculativeAuthenticator SpeculativeAuthenticator
}
var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil)
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
func (a *DefaultAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return a.speculativeAuthenticator.CreateSpeculativeConversation()
}
// Auth authenticates the connection.
func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error {
var actual Authenticator
var err error
switch chooseAuthMechanism(cfg) {
case SCRAMSHA256:
actual, err = newScramSHA256Authenticator(a.Cred)
case SCRAMSHA1:
actual, err = newScramSHA1Authenticator(a.Cred)
default:
actual, err = newMongoDBCRAuthenticator(a.Cred)
}
if err != nil {
return newAuthError("error creating authenticator", err)
}
return actual.Auth(ctx, cfg)
}
// If a server provides a list of supported mechanisms, we choose
// SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1.
// Otherwise, we decide based on what is supported.
func chooseAuthMechanism(cfg *Config) string {
if saslSupportedMechs := cfg.HandshakeInfo.SaslSupportedMechs; saslSupportedMechs != nil {
for _, v := range saslSupportedMechs {
if v == SCRAMSHA256 {
return v
}
}
return SCRAMSHA1
}
if err := scramSHA1Supported(cfg.HandshakeInfo.Description.WireVersion); err == nil {
return SCRAMSHA1
}
return MONGODBCR
}
// scramSHA1Supported returns an error if the given server version does not support scram-sha-1.
func scramSHA1Supported(wireVersion *description.VersionRange) error {
if wireVersion != nil && wireVersion.Max < 3 {
return fmt.Errorf("SCRAM-SHA-1 is only supported for servers 3.0 or newer")
}
return nil
}

View File

@@ -0,0 +1,23 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package auth is not for public use.
//
// The API for packages in the 'private' directory have no stability
// guarantee.
//
// The packages within the 'private' directory would normally be put into an
// 'internal' directory to prohibit their use outside the 'mongo' directory.
// However, some MongoDB tools require very low-level access to the building
// blocks of a driver, so we have placed them under 'private' to allow these
// packages to be imported by projects that need them.
//
// These package APIs may be modified in backwards-incompatible ways at any
// time.
//
// You are strongly discouraged from directly using any packages
// under 'private'.
package auth

View File

@@ -0,0 +1,59 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && (windows || linux || darwin)
// +build gssapi
// +build windows linux darwin
package auth
import (
"context"
"fmt"
"net"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
}
return &GSSAPIAuthenticator{
Username: cred.Username,
Password: cred.Password,
PasswordSet: cred.PasswordSet,
Props: cred.Props,
}, nil
}
// GSSAPIAuthenticator uses the GSSAPI algorithm over SASL to authenticate a connection.
type GSSAPIAuthenticator struct {
Username string
Password string
PasswordSet bool
Props map[string]string
}
// Auth authenticates the connection.
func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error {
target := cfg.Description.Addr.String()
hostname, _, err := net.SplitHostPort(target)
if err != nil {
return newAuthError(fmt.Sprintf("invalid endpoint (%s) specified: %s", target, err), nil)
}
client, err := gssapi.New(hostname, a.Username, a.Password, a.PasswordSet, a.Props)
if err != nil {
return newAuthError("error creating gssapi", err)
}
return ConductSaslConversation(ctx, cfg, "$external", client)
}

View File

@@ -0,0 +1,17 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !gssapi
// +build !gssapi
package auth
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil)
}

View File

@@ -0,0 +1,22 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && !windows && !linux && !darwin
// +build gssapi,!windows,!linux,!darwin
package auth
import (
"fmt"
"runtime"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil)
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/credentials/static_provider.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/credentials/credentials.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"errors"
)
// StaticProviderName provides a name of Static provider
const StaticProviderName = "StaticProvider"
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
ErrStaticCredentialsEmpty = errors.New("EmptyStaticCreds: static credentials are empty")
)
// A Value is the AWS credentials value for individual credential fields.
type Value struct {
// AWS Access key ID
AccessKeyID string
// AWS Secret Access Key
SecretAccessKey string
// AWS Session Token
SessionToken string
// Provider used to get credentials
ProviderName string
}
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A StaticProvider is a set of credentials which are set programmatically,
// and will never expire.
type StaticProvider struct {
Value
}
// Retrieve returns the credentials or error if the credentials are invalid.
func (s *StaticProvider) Retrieve() (Value, error) {
if s.AccessKeyID == "" || s.SecretAccessKey == "" {
return Value{ProviderName: StaticProviderName}, ErrStaticCredentialsEmpty
}
if len(s.Value.ProviderName) == 0 {
s.Value.ProviderName = StaticProviderName
}
return s.Value, nil
}

View File

@@ -0,0 +1,15 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go v1.34.28 by Amazon.com, Inc.
// See THIRD-PARTY-NOTICES for original license terms
// Package awsv4 implements signing for AWS V4 signer with static credentials,
// and is based on and modified from code in the package aws-sdk-go. The
// modifications remove non-static credentials, support for non-sts services,
// and the options for v4.Signer. They also reduce the number of non-Go
// library dependencies.
package awsv4

View File

@@ -0,0 +1,80 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"net/http"
"strings"
)
// Returns host from request
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
if r.URL == nil {
return ""
}
return r.URL.Host
}
// Hostname returns u.Host, without any port number.
//
// If Host is an IPv6 literal with a port number, Hostname returns the
// IPv6 literal without the square brackets. IPv6 literals may include
// a zone identifier.
//
// Copied from the Go 1.8 standard library (net/url)
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// Port returns the port part of u.Host, without the leading colon.
// If u.Host doesn't contain a port, Port returns an empty string.
//
// Copied from the Go 1.8 standard library (net/url)
func portOnly(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return ""
}
if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[i+len("]:"):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+len(":"):]
}
// Returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
func isDefaultPort(scheme, port string) bool {
if port == "" {
return true
}
lowerCaseScheme := strings.ToLower(scheme)
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
return true
}
return false
}

View File

@@ -0,0 +1,46 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/private/protocol/rest/build.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"bytes"
"fmt"
)
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
func init() {
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}
}
// EscapePath escapes part of a URL path in Amazon style
func EscapePath(path string, encodeSep bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
if noEscape[c] || (c == '/' && !encodeSep) {
buf.WriteByte(c)
} else {
fmt.Fprintf(&buf, "%%%02X", c)
}
}
return buf.String()
}

View File

@@ -0,0 +1,98 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/header_rules.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/internal/strings/strings.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"strings"
)
// validator houses a set of rule needed for validation of a
// string value
type rules []rule
// rule interface allows for more flexible rules and just simply
// checks whether or not a value adheres to that rule
type rule interface {
IsValid(value string) bool
}
// IsValid will iterate through all rules and see if any rules
// apply to the value and supports nested rules
func (r rules) IsValid(value string) bool {
for _, rule := range r {
if rule.IsValid(value) {
return true
}
}
return false
}
// mapRule generic rule for maps
type mapRule map[string]struct{}
// IsValid for the map rule satisfies whether it exists in the map
func (m mapRule) IsValid(value string) bool {
_, ok := m[value]
return ok
}
// allowlist is a generic rule for allowlisting
type allowlist struct {
rule
}
// IsValid for allowlist checks if the value is within the allowlist
func (a allowlist) IsValid(value string) bool {
return a.rule.IsValid(value)
}
// denylist is a generic rule for denylisting
type denylist struct {
rule
}
// IsValid for allowlist checks if the value is within the allowlist
func (d denylist) IsValid(value string) bool {
return !d.rule.IsValid(value)
}
type patterns []string
// hasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func hasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}
// IsValid for patterns checks each pattern and returns if a match has
// been found
func (p patterns) IsValid(value string) bool {
for _, pattern := range p {
if hasPrefixFold(value, pattern) {
return true
}
}
return false
}
// inclusiveRules rules allow for rules to depend on one another
type inclusiveRules []rule
// IsValid will return true if all rules are true
func (r inclusiveRules) IsValid(value string) bool {
for _, rule := range r {
if !rule.IsValid(value) {
return false
}
}
return true
}

View File

@@ -0,0 +1,472 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/v4.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/uri_path.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/types.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
const (
authorizationHeader = "Authorization"
authHeaderSignatureElem = "Signature="
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
awsV4Request = "aws4_request"
// emptyStringSHA256 is a SHA256 of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
)
var ignoredHeaders = rules{
denylist{
mapRule{
authorizationHeader: struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// Signer applies AWS v4 signing to given request. Use this to sign requests
// that need to be signed with AWS V4 Signatures.
type Signer struct {
Credentials *StaticProvider
}
// NewSigner returns a Signer pointer configured with the credentials and optional
// option values provided. If not options are provided the Signer will use its
// default configuration.
func NewSigner(credentials *StaticProvider) *Signer {
v4 := &Signer{
Credentials: credentials,
}
return v4
}
type signingCtx struct {
ServiceName string
Region string
Request *http.Request
Body io.ReadSeeker
Query url.Values
Time time.Time
SignedHeaderVals http.Header
credValues Value
bodyDigest string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
authorization string
}
// Sign signs AWS v4 requests with the provided body, service name, region the
// request is made to, and time the request is signed at. The signTime allows
// you to specify that a request is signed for the future, and cannot be
// used until then.
//
// Returns a list of HTTP headers that were included in the signature or an
// error if signing the request failed. Generally for signed requests this value
// is not needed as the full request context will be captured by the http.Request
// value. It is included for reference though.
//
// Sign will set the request's Body to be the `body` parameter passed in. If
// the body is not already an io.ReadCloser, it will be wrapped within one. If
// a `nil` body parameter passed to Sign, the request's Body field will be
// also set to nil. Its important to note that this functionality will not
// change the request's ContentLength of the request.
//
// Sign differs from Presign in that it will sign the request using HTTP
// header values. This type of signing is intended for http.Request values that
// will not be shared, or are shared in a way the header values on the request
// will not be lost.
//
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
// generated. To bypass the signer computing the hash you can set the
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
// only compute the hash if the request header value is empty.
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, signTime)
}
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
ctx := &signingCtx{
Request: r,
Body: body,
Query: r.URL.Query(),
Time: signTime,
ServiceName: service,
Region: region,
}
for key := range ctx.Query {
sort.Strings(ctx.Query[key])
}
if ctx.isRequestSigned() {
ctx.Time = time.Now()
}
var err error
ctx.credValues, err = v4.Credentials.Retrieve()
if err != nil {
return http.Header{}, err
}
ctx.sanitizeHostForHeader()
ctx.assignAmzQueryValues()
if err := ctx.build(); err != nil {
return nil, err
}
var reader io.ReadCloser
if body != nil {
var ok bool
if reader, ok = body.(io.ReadCloser); !ok {
reader = ioutil.NopCloser(body)
}
}
r.Body = reader
return ctx.SignedHeaderVals, nil
}
// sanitizeHostForHeader removes default port from host and updates request.Host
func (ctx *signingCtx) sanitizeHostForHeader() {
r := ctx.Request
host := getHost(r)
port := portOnly(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
r.Host = stripPort(host)
}
}
func (ctx *signingCtx) assignAmzQueryValues() {
if ctx.credValues.SessionToken != "" {
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
}
}
func (ctx *signingCtx) build() error {
ctx.buildTime() // no depends
ctx.buildCredentialString() // no depends
if err := ctx.buildBodyDigest(); err != nil {
return err
}
unsignedHeaders := ctx.Request.Header
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
ctx.buildCanonicalString() // depends on canon headers / signed headers
ctx.buildStringToSign() // depends on canon string
ctx.buildSignature() // depends on string to sign
parts := []string{
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
"SignedHeaders=" + ctx.signedHeaders,
authHeaderSignatureElem + ctx.signature,
}
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
return nil
}
// GetSignedRequestSignature attempts to extract the signature of the request.
// Returning an error if the request is unsigned, or unable to extract the
// signature.
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
ps := strings.Split(auth, ", ")
for _, p := range ps {
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
sig := p[len(authHeaderSignatureElem):]
if len(sig) == 0 {
return nil, fmt.Errorf("invalid request signature authorization header")
}
return hex.DecodeString(sig)
}
}
}
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
return hex.DecodeString(sig)
}
return nil, fmt.Errorf("request not signed")
}
func (ctx *signingCtx) buildTime() {
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
}
func (ctx *signingCtx) buildCredentialString() {
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
}
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
headers := make([]string, 0, len(header))
headers = append(headers, "host")
for k, v := range header {
if !r.IsValid(k) {
continue // ignored header
}
if ctx.SignedHeaderVals == nil {
ctx.SignedHeaderVals = make(http.Header)
}
lowerCaseKey := strings.ToLower(k)
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
// include additional values
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
ctx.SignedHeaderVals[lowerCaseKey] = v
}
sort.Strings(headers)
ctx.signedHeaders = strings.Join(headers, ";")
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
if ctx.Request.Host != "" {
headerValues[i] = "host:" + ctx.Request.Host
} else {
headerValues[i] = "host:" + ctx.Request.URL.Host
}
} else {
headerValues[i] = k + ":" +
strings.Join(ctx.SignedHeaderVals[k], ",")
}
}
stripExcessSpaces(headerValues)
ctx.canonicalHeaders = strings.Join(headerValues, "\n")
}
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}
func (ctx *signingCtx) buildCanonicalString() {
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
uri := getURIPath(ctx.Request.URL)
uri = EscapePath(uri, false)
ctx.canonicalString = strings.Join([]string{
ctx.Request.Method,
uri,
ctx.Request.URL.RawQuery,
ctx.canonicalHeaders + "\n",
ctx.signedHeaders,
ctx.bodyDigest,
}, "\n")
}
func (ctx *signingCtx) buildStringToSign() {
ctx.stringToSign = strings.Join([]string{
authHeaderPrefix,
formatTime(ctx.Time),
ctx.credentialString,
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
}, "\n")
}
func (ctx *signingCtx) buildSignature() {
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
ctx.signature = hex.EncodeToString(signature)
}
func (ctx *signingCtx) buildBodyDigest() error {
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if ctx.Body == nil {
hash = emptyStringSHA256
} else {
hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
}
}
ctx.bodyDigest = hash
return nil
}
// isRequestSigned returns if the request is currently signed or presigned
func (ctx *signingCtx) isRequestSigned() bool {
return ctx.Request.Header.Get("Authorization") != ""
}
func hmacSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func hashSHA256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
// seekerLen attempts to get the number of bytes remaining at the seeker's
// current position. Returns the number of bytes remaining or error.
func seekerLen(s io.Seeker) (int64, error) {
curOffset, err := s.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}
endOffset, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, err
}
_, err = s.Seek(curOffset, io.SeekStart)
if err != nil {
return 0, err
}
return endOffset - curOffset, nil
}
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New()
start, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, io.SeekStart)
}()
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
size, err := seekerLen(reader)
if err != nil {
_, _ = io.Copy(hash, reader)
} else {
_, _ = io.CopyN(hash, reader, size)
}
return hash.Sum(nil), nil
}
const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not
// contain multiple side-by-side spaces.
func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int
for i, str := range vals {
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
vals[i] = str
continue
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
vals[i] = string(buf[:m])
}
}
func buildSigningScope(region, service string, dt time.Time) string {
return strings.Join([]string{
formatShortTime(dt),
region,
service,
awsV4Request,
}, "/")
}
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
keyRegion := hmacSHA256(keyDate, []byte(region))
keyService := hmacSHA256(keyRegion, []byte(service))
signingKey := hmacSHA256(keyService, []byte(awsV4Request))
return signingKey
}
func formatShortTime(dt time.Time) string {
return dt.UTC().Format(shortTimeFormat)
}
func formatTime(dt time.Time) string {
return dt.UTC().Format(timeFormat)
}

View File

@@ -0,0 +1,167 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && (linux || darwin)
// +build gssapi
// +build linux darwin
package gssapi
/*
#cgo linux CFLAGS: -DGOOS_linux
#cgo linux LDFLAGS: -lgssapi_krb5 -lkrb5
#cgo darwin CFLAGS: -DGOOS_darwin
#cgo darwin LDFLAGS: -framework GSS
#include "gss_wrapper.h"
*/
import "C"
import (
"fmt"
"runtime"
"strings"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
serviceName := "mongodb"
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_REALM":
return nil, fmt.Errorf("SERVICE_REALM is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
target = value
default:
return nil, fmt.Errorf("unknown mechanism property %s", key)
}
}
servicePrincipalName := fmt.Sprintf("%s@%s", serviceName, target)
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.gssapi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.gssapi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.gssapi_client_init(&sc.state, cservicePrincipalName, cusername, cpassword)
if status != C.GSSAPI_OK {
return mechName, nil, sc.getError("unable to initialize client")
}
payload, err := sc.Next(nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
var buf unsafe.Pointer
var bufLen C.size_t
var outBuf unsafe.Pointer
var outBufLen C.size_t
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.gssapi_client_username(&sc.state, &cusername)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf = unsafe.Pointer(&bytes[0])
bufLen = C.size_t(len(bytes))
status := C.gssapi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
if len(challenge) > 0 {
buf = unsafe.Pointer(&challenge[0])
bufLen = C.size_t(len(challenge))
}
status := C.gssapi_client_negotiate(&sc.state, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.GSSAPI_OK:
sc.contextComplete = true
case C.GSSAPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != nil {
defer C.free(outBuf)
}
return C.GoBytes(outBuf, C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
var desc *C.char
status := C.gssapi_error_desc(sc.state.maj_stat, sc.state.min_stat, &desc)
if status != C.GSSAPI_OK {
if desc != nil {
C.free(unsafe.Pointer(desc))
}
return fmt.Errorf("%s: (%v, %v)", prefix, sc.state.maj_stat, sc.state.min_stat)
}
defer C.free(unsafe.Pointer(desc))
return fmt.Errorf("%s: %v(%v,%v)", prefix, C.GoString(desc), int32(sc.state.maj_stat), int32(sc.state.min_stat))
}

View File

@@ -0,0 +1,254 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build linux darwin
#include <string.h>
#include <stdio.h>
#include "gss_wrapper.h"
OM_uint32 gssapi_canonicalize_name(
OM_uint32* minor_status,
char *input_name,
gss_OID input_name_type,
gss_name_t *output_name
)
{
OM_uint32 major_status;
gss_name_t imported_name = GSS_C_NO_NAME;
gss_buffer_desc buffer = GSS_C_EMPTY_BUFFER;
buffer.value = input_name;
buffer.length = strlen(input_name);
major_status = gss_import_name(minor_status, &buffer, input_name_type, &imported_name);
if (GSS_ERROR(major_status)) {
return major_status;
}
major_status = gss_canonicalize_name(minor_status, imported_name, (gss_OID)gss_mech_krb5, output_name);
if (imported_name != GSS_C_NO_NAME) {
OM_uint32 ignored;
gss_release_name(&ignored, &imported_name);
}
return major_status;
}
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
)
{
OM_uint32 stat = maj_stat;
int stat_type = GSS_C_GSS_CODE;
if (min_stat != 0) {
stat = min_stat;
stat_type = GSS_C_MECH_CODE;
}
OM_uint32 local_maj_stat, local_min_stat;
OM_uint32 msg_ctx = 0;
gss_buffer_desc desc_buffer;
do
{
local_maj_stat = gss_display_status(
&local_min_stat,
stat,
stat_type,
GSS_C_NO_OID,
&msg_ctx,
&desc_buffer
);
if (GSS_ERROR(local_maj_stat)) {
return GSSAPI_ERROR;
}
if (*desc) {
free(*desc);
}
*desc = malloc(desc_buffer.length+1);
memcpy(*desc, desc_buffer.value, desc_buffer.length+1);
gss_release_buffer(&local_min_stat, &desc_buffer);
}
while(msg_ctx != 0);
return GSSAPI_OK;
}
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
)
{
client->cred = GSS_C_NO_CREDENTIAL;
client->ctx = GSS_C_NO_CONTEXT;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, spn, GSS_C_NT_HOSTBASED_SERVICE, &client->spn);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (username) {
gss_name_t name;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, username, GSS_C_NT_USER_NAME, &name);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (password) {
gss_buffer_desc password_buffer;
password_buffer.value = password;
password_buffer.length = strlen(password);
client->maj_stat = gss_acquire_cred_with_password(&client->min_stat, name, &password_buffer, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
} else {
client->maj_stat = gss_acquire_cred(&client->min_stat, name, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
OM_uint32 ignored;
gss_release_name(&ignored, &name);
}
return GSSAPI_OK;
}
int gssapi_client_username(
gssapi_client_state *client,
char** username
)
{
OM_uint32 ignored;
gss_name_t name = GSS_C_NO_NAME;
client->maj_stat = gss_inquire_context(&client->min_stat, client->ctx, &name, NULL, NULL, NULL, NULL, NULL, NULL);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
gss_buffer_desc name_buffer;
client->maj_stat = gss_display_name(&client->min_stat, name, &name_buffer, NULL);
if (GSS_ERROR(client->maj_stat)) {
gss_release_name(&ignored, &name);
return GSSAPI_ERROR;
}
*username = malloc(name_buffer.length+1);
memcpy(*username, name_buffer.value, name_buffer.length+1);
gss_release_buffer(&ignored, &name_buffer);
gss_release_name(&ignored, &name);
return GSSAPI_OK;
}
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
if (input) {
input_buffer.value = input;
input_buffer.length = input_length;
}
client->maj_stat = gss_init_sec_context(
&client->min_stat,
client->cred,
&client->ctx,
client->spn,
GSS_C_NO_OID,
GSS_C_MUTUAL_FLAG | GSS_C_SEQUENCE_FLAG,
0,
GSS_C_NO_CHANNEL_BINDINGS,
&input_buffer,
NULL,
&output_buffer,
NULL,
NULL
);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
OM_uint32 ignored;
gss_release_buffer(&ignored, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
} else if (client->maj_stat == GSS_S_CONTINUE_NEEDED) {
return GSSAPI_CONTINUE;
}
return GSSAPI_OK;
}
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
input_buffer.value = input;
input_buffer.length = input_length;
client->maj_stat = gss_wrap(&client->min_stat, client->ctx, 0, GSS_C_QOP_DEFAULT, &input_buffer, NULL, &output_buffer);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
gss_release_buffer(&client->min_stat, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
return GSSAPI_OK;
}
int gssapi_client_destroy(
gssapi_client_state *client
)
{
OM_uint32 ignored;
if (client->ctx != GSS_C_NO_CONTEXT) {
gss_delete_sec_context(&ignored, &client->ctx, GSS_C_NO_BUFFER);
}
if (client->spn != GSS_C_NO_NAME) {
gss_release_name(&ignored, &client->spn);
}
if (client->cred != GSS_C_NO_CREDENTIAL) {
gss_release_cred(&ignored, &client->cred);
}
return GSSAPI_OK;
}

View File

@@ -0,0 +1,72 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build linux darwin
#ifndef GSS_WRAPPER_H
#define GSS_WRAPPER_H
#include <stdlib.h>
#ifdef GOOS_linux
#include <gssapi/gssapi.h>
#include <gssapi/gssapi_krb5.h>
#endif
#ifdef GOOS_darwin
#include <GSS/GSS.h>
#endif
#define GSSAPI_OK 0
#define GSSAPI_CONTINUE 1
#define GSSAPI_ERROR 2
typedef struct {
gss_name_t spn;
gss_cred_id_t cred;
gss_ctx_id_t ctx;
OM_uint32 maj_stat;
OM_uint32 min_stat;
} gssapi_client_state;
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
);
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
);
int gssapi_client_username(
gssapi_client_state *client,
char** username
);
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_destroy(
gssapi_client_state *client
);
#endif

View File

@@ -0,0 +1,353 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && windows
// +build gssapi,windows
package gssapi
// #include "sspi_wrapper.h"
import "C"
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
initOnce.Do(initSSPI)
if initError != nil {
return nil, initError
}
var err error
serviceName := "mongodb"
serviceRealm := ""
canonicalizeHostName := false
var serviceHostSet bool
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
canonicalizeHostName, err = strconv.ParseBool(value)
if err != nil {
return nil, fmt.Errorf("%s must be a boolean (true, false, 0, 1) but got '%s'", key, value)
}
case "SERVICE_REALM":
serviceRealm = value
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
serviceHostSet = true
target = value
}
}
if canonicalizeHostName {
// Should not canonicalize the SERVICE_HOST
if serviceHostSet {
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME and SERVICE_HOST canonot both be specified")
}
names, err := net.LookupAddr(target)
if err != nil || len(names) == 0 {
return nil, fmt.Errorf("unable to canonicalize hostname: %s", err)
}
target = names[0]
if target[len(target)-1] == '.' {
target = target[:len(target)-1]
}
}
servicePrincipalName := fmt.Sprintf("%s/%s", serviceName, target)
if serviceRealm != "" {
servicePrincipalName += "@" + serviceRealm
}
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.sspi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.sspi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.sspi_client_init(&sc.state, cusername, cpassword)
if status != C.SSPI_OK {
return mechName, nil, sc.getError("unable to intitialize client")
}
payload, err := sc.Next(nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
var outBuf C.PVOID
var outBufLen C.ULONG
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.sspi_client_username(&sc.state, &cusername)
if status != C.SSPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf := (C.PVOID)(unsafe.Pointer(&bytes[0]))
bufLen := C.ULONG(len(bytes))
status := C.sspi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.SSPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
var buf C.PVOID
var bufLen C.ULONG
if len(challenge) > 0 {
buf = (C.PVOID)(unsafe.Pointer(&challenge[0]))
bufLen = C.ULONG(len(challenge))
}
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
status := C.sspi_client_negotiate(&sc.state, cservicePrincipalName, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.SSPI_OK:
sc.contextComplete = true
case C.SSPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != C.PVOID(nil) {
defer C.free(unsafe.Pointer(outBuf))
}
return C.GoBytes(unsafe.Pointer(outBuf), C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
return getError(prefix, sc.state.status)
}
var initOnce sync.Once
var initError error
func initSSPI() {
rc := C.sspi_init()
if rc != 0 {
initError = fmt.Errorf("error initializing sspi: %v", rc)
}
}
func getError(prefix string, status C.SECURITY_STATUS) error {
var s string
switch status {
case C.SEC_E_ALGORITHM_MISMATCH:
s = "The client and server cannot communicate because they do not possess a common algorithm."
case C.SEC_E_BAD_BINDINGS:
s = "The SSPI channel bindings supplied by the client are incorrect."
case C.SEC_E_BAD_PKGID:
s = "The requested package identifier does not exist."
case C.SEC_E_BUFFER_TOO_SMALL:
s = "The buffers supplied to the function are not large enough to contain the information."
case C.SEC_E_CANNOT_INSTALL:
s = "The security package cannot initialize successfully and should not be installed."
case C.SEC_E_CANNOT_PACK:
s = "The package is unable to pack the context."
case C.SEC_E_CERT_EXPIRED:
s = "The received certificate has expired."
case C.SEC_E_CERT_UNKNOWN:
s = "An unknown error occurred while processing the certificate."
case C.SEC_E_CERT_WRONG_USAGE:
s = "The certificate is not valid for the requested usage."
case C.SEC_E_CONTEXT_EXPIRED:
s = "The application is referencing a context that has already been closed. A properly written application should not receive this error."
case C.SEC_E_CROSSREALM_DELEGATION_FAILURE:
s = "The server attempted to make a Kerberos-constrained delegation request for a target outside the server's realm."
case C.SEC_E_CRYPTO_SYSTEM_INVALID:
s = "The cryptographic system or checksum function is not valid because a required function is unavailable."
case C.SEC_E_DECRYPT_FAILURE:
s = "The specified data could not be decrypted."
case C.SEC_E_DELEGATION_REQUIRED:
s = "The requested operation cannot be completed. The computer must be trusted for delegation"
case C.SEC_E_DOWNGRADE_DETECTED:
s = "The system detected a possible attempt to compromise security. Verify that the server that authenticated you can be contacted."
case C.SEC_E_ENCRYPT_FAILURE:
s = "The specified data could not be encrypted."
case C.SEC_E_ILLEGAL_MESSAGE:
s = "The message received was unexpected or badly formatted."
case C.SEC_E_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. The context could not be initialized."
case C.SEC_E_INCOMPLETE_MESSAGE:
s = "The message supplied was incomplete. The signature was not verified."
case C.SEC_E_INSUFFICIENT_MEMORY:
s = "Not enough memory is available to complete the request."
case C.SEC_E_INTERNAL_ERROR:
s = "An error occurred that did not map to an SSPI error code."
case C.SEC_E_INVALID_HANDLE:
s = "The handle passed to the function is not valid."
case C.SEC_E_INVALID_TOKEN:
s = "The token passed to the function is not valid."
case C.SEC_E_ISSUING_CA_UNTRUSTED:
s = "An untrusted certification authority (CA) was detected while processing the smart card certificate used for authentication."
case C.SEC_E_ISSUING_CA_UNTRUSTED_KDC:
s = "An untrusted CA was detected while processing the domain controller certificate used for authentication. The system event log contains additional information."
case C.SEC_E_KDC_CERT_EXPIRED:
s = "The domain controller certificate used for smart card logon has expired."
case C.SEC_E_KDC_CERT_REVOKED:
s = "The domain controller certificate used for smart card logon has been revoked."
case C.SEC_E_KDC_INVALID_REQUEST:
s = "A request that is not valid was sent to the KDC."
case C.SEC_E_KDC_UNABLE_TO_REFER:
s = "The KDC was unable to generate a referral for the service requested."
case C.SEC_E_KDC_UNKNOWN_ETYPE:
s = "The requested encryption type is not supported by the KDC."
case C.SEC_E_LOGON_DENIED:
s = "The logon has been denied"
case C.SEC_E_MAX_REFERRALS_EXCEEDED:
s = "The number of maximum ticket referrals has been exceeded."
case C.SEC_E_MESSAGE_ALTERED:
s = "The message supplied for verification has been altered."
case C.SEC_E_MULTIPLE_ACCOUNTS:
s = "The received certificate was mapped to multiple accounts."
case C.SEC_E_MUST_BE_KDC:
s = "The local computer must be a Kerberos domain controller (KDC)"
case C.SEC_E_NO_AUTHENTICATING_AUTHORITY:
s = "No authority could be contacted for authentication."
case C.SEC_E_NO_CREDENTIALS:
s = "No credentials are available."
case C.SEC_E_NO_IMPERSONATION:
s = "No impersonation is allowed for this context."
case C.SEC_E_NO_IP_ADDRESSES:
s = "Unable to accomplish the requested task because the local computer does not have any IP addresses."
case C.SEC_E_NO_KERB_KEY:
s = "No Kerberos key was found."
case C.SEC_E_NO_PA_DATA:
s = "Policy administrator (PA) data is needed to determine the encryption type"
case C.SEC_E_NO_S4U_PROT_SUPPORT:
s = "The Kerberos subsystem encountered an error. A service for user protocol request was made against a domain controller which does not support service for a user."
case C.SEC_E_NO_TGT_REPLY:
s = "The client is trying to negotiate a context and the server requires a user-to-user connection"
case C.SEC_E_NOT_OWNER:
s = "The caller of the function does not own the credentials."
case C.SEC_E_OK:
s = "The operation completed successfully."
case C.SEC_E_OUT_OF_SEQUENCE:
s = "The message supplied for verification is out of sequence."
case C.SEC_E_PKINIT_CLIENT_FAILURE:
s = "The smart card certificate used for authentication is not trusted."
case C.SEC_E_PKINIT_NAME_MISMATCH:
s = "The client certificate does not contain a valid UPN or does not match the client name in the logon request."
case C.SEC_E_QOP_NOT_SUPPORTED:
s = "The quality of protection attribute is not supported by this package."
case C.SEC_E_REVOCATION_OFFLINE_C:
s = "The revocation status of the smart card certificate used for authentication could not be determined."
case C.SEC_E_REVOCATION_OFFLINE_KDC:
s = "The revocation status of the domain controller certificate used for smart card authentication could not be determined. The system event log contains additional information."
case C.SEC_E_SECPKG_NOT_FOUND:
s = "The security package was not recognized."
case C.SEC_E_SECURITY_QOS_FAILED:
s = "The security context could not be established due to a failure in the requested quality of service (for example"
case C.SEC_E_SHUTDOWN_IN_PROGRESS:
s = "A system shutdown is in progress."
case C.SEC_E_SMARTCARD_CERT_EXPIRED:
s = "The smart card certificate used for authentication has expired."
case C.SEC_E_SMARTCARD_CERT_REVOKED:
s = "The smart card certificate used for authentication has been revoked. Additional information may exist in the event log."
case C.SEC_E_SMARTCARD_LOGON_REQUIRED:
s = "Smart card logon is required and was not used."
case C.SEC_E_STRONG_CRYPTO_NOT_SUPPORTED:
s = "The other end of the security negotiation requires strong cryptography"
case C.SEC_E_TARGET_UNKNOWN:
s = "The target was not recognized."
case C.SEC_E_TIME_SKEW:
s = "The clocks on the client and server computers do not match."
case C.SEC_E_TOO_MANY_PRINCIPALS:
s = "The KDC reply contained more than one principal name."
case C.SEC_E_UNFINISHED_CONTEXT_DELETED:
s = "A security context was deleted before the context was completed. This is considered a logon failure."
case C.SEC_E_UNKNOWN_CREDENTIALS:
s = "The credentials provided were not recognized."
case C.SEC_E_UNSUPPORTED_FUNCTION:
s = "The requested function is not supported."
case C.SEC_E_UNSUPPORTED_PREAUTH:
s = "An unsupported preauthentication mechanism was presented to the Kerberos package."
case C.SEC_E_UNTRUSTED_ROOT:
s = "The certificate chain was issued by an authority that is not trusted."
case C.SEC_E_WRONG_CREDENTIAL_HANDLE:
s = "The supplied credential handle does not match the credential associated with the security context."
case C.SEC_E_WRONG_PRINCIPAL:
s = "The target principal name is incorrect."
case C.SEC_I_COMPLETE_AND_CONTINUE:
s = "The function completed successfully"
case C.SEC_I_COMPLETE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_CONTEXT_EXPIRED:
s = "The message sender has finished using the connection and has initiated a shutdown. For information about initiating or recognizing a shutdown"
case C.SEC_I_CONTINUE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. Additional information can be returned from the context."
case C.SEC_I_LOCAL_LOGON:
s = "The logon was completed"
case C.SEC_I_NO_LSA_CONTEXT:
s = "There is no LSA mode context associated with this context."
case C.SEC_I_RENEGOTIATE:
s = "The context data must be renegotiated with the peer."
default:
return fmt.Errorf("%s: 0x%x", prefix, uint32(status))
}
return fmt.Errorf("%s: %s(0x%x)", prefix, s, uint32(status))
}

View File

@@ -0,0 +1,249 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,windows
#include "sspi_wrapper.h"
static HINSTANCE sspi_secur32_dll = NULL;
static PSecurityFunctionTable sspi_functions = NULL;
static const LPSTR SSPI_PACKAGE_NAME = "kerberos";
int sspi_init(
)
{
// Load the secur32.dll library using its exact path. Passing the exact DLL path rather than allowing LoadLibrary to
// search in different locations removes the possibility of DLL preloading attacks. We use GetSystemDirectoryA and
// LoadLibraryA rather than the GetSystemDirectory/LoadLibrary aliases to ensure the ANSI versions are used so we
// don't have to account for variations in char sizes if UNICODE is enabled.
// Passing a 0 size will return the required buffer length to hold the path, including the null terminator.
int requiredLen = GetSystemDirectoryA(NULL, 0);
if (!requiredLen) {
return GetLastError();
}
// Allocate a buffer to hold the system directory + "\secur32.dll" (length 12, not including null terminator).
int actualLen = requiredLen + 12;
char *directoryBuffer = (char *) calloc(1, actualLen);
int directoryLen = GetSystemDirectoryA(directoryBuffer, actualLen);
if (!directoryLen) {
free(directoryBuffer);
return GetLastError();
}
// Append the DLL name to the buffer.
char *dllName = "\\secur32.dll";
strcpy_s(&(directoryBuffer[directoryLen]), actualLen - directoryLen, dllName);
sspi_secur32_dll = LoadLibraryA(directoryBuffer);
free(directoryBuffer);
if (!sspi_secur32_dll) {
return GetLastError();
}
INIT_SECURITY_INTERFACE init_security_interface = (INIT_SECURITY_INTERFACE)GetProcAddress(sspi_secur32_dll, SECURITY_ENTRYPOINT);
if (!init_security_interface) {
return -1;
}
sspi_functions = (*init_security_interface)();
if (!sspi_functions) {
return -2;
}
return SSPI_OK;
}
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
)
{
TimeStamp timestamp;
if (username) {
if (password) {
SEC_WINNT_AUTH_IDENTITY auth_identity;
#ifdef _UNICODE
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
#else
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
#endif
auth_identity.User = (LPSTR) username;
auth_identity.UserLength = strlen(username);
auth_identity.Password = (LPSTR) password;
auth_identity.PasswordLength = strlen(password);
auth_identity.Domain = NULL;
auth_identity.DomainLength = 0;
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, &client->cred, &timestamp);
} else {
client->status = sspi_functions->AcquireCredentialsHandle(username, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
} else {
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
return SSPI_OK;
}
int sspi_client_username(
sspi_client_state *client,
char** username
)
{
SecPkgCredentials_Names names;
client->status = sspi_functions->QueryCredentialsAttributes(&client->cred, SECPKG_CRED_ATTR_NAMES, &names);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
int len = strlen(names.sUserName) + 1;
*username = malloc(len);
memcpy(*username, names.sUserName, len);
sspi_functions->FreeContextBuffer(names.sUserName);
return SSPI_OK;
}
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecBufferDesc inbuf;
SecBuffer in_bufs[1];
SecBufferDesc outbuf;
SecBuffer out_bufs[1];
if (client->has_ctx > 0) {
inbuf.ulVersion = SECBUFFER_VERSION;
inbuf.cBuffers = 1;
inbuf.pBuffers = in_bufs;
in_bufs[0].pvBuffer = input;
in_bufs[0].cbBuffer = input_length;
in_bufs[0].BufferType = SECBUFFER_TOKEN;
}
outbuf.ulVersion = SECBUFFER_VERSION;
outbuf.cBuffers = 1;
outbuf.pBuffers = out_bufs;
out_bufs[0].pvBuffer = NULL;
out_bufs[0].cbBuffer = 0;
out_bufs[0].BufferType = SECBUFFER_TOKEN;
ULONG context_attr = 0;
client->status = sspi_functions->InitializeSecurityContext(
&client->cred,
client->has_ctx > 0 ? &client->ctx : NULL,
(LPSTR) spn,
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH,
0,
SECURITY_NETWORK_DREP,
client->has_ctx > 0 ? &inbuf : NULL,
0,
&client->ctx,
&outbuf,
&context_attr,
NULL);
if (client->status != SEC_E_OK && client->status != SEC_I_CONTINUE_NEEDED) {
return SSPI_ERROR;
}
client->has_ctx = 1;
*output = malloc(out_bufs[0].cbBuffer);
*output_length = out_bufs[0].cbBuffer;
memcpy(*output, out_bufs[0].pvBuffer, *output_length);
sspi_functions->FreeContextBuffer(out_bufs[0].pvBuffer);
if (client->status == SEC_I_CONTINUE_NEEDED) {
return SSPI_CONTINUE;
}
return SSPI_OK;
}
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecPkgContext_Sizes sizes;
client->status = sspi_functions->QueryContextAttributes(&client->ctx, SECPKG_ATTR_SIZES, &sizes);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
char *msg = malloc((sizes.cbSecurityTrailer + input_length + sizes.cbBlockSize) * sizeof(char));
memcpy(&msg[sizes.cbSecurityTrailer], input, input_length);
SecBuffer wrap_bufs[3];
SecBufferDesc wrap_buf_desc;
wrap_buf_desc.cBuffers = 3;
wrap_buf_desc.pBuffers = wrap_bufs;
wrap_buf_desc.ulVersion = SECBUFFER_VERSION;
wrap_bufs[0].cbBuffer = sizes.cbSecurityTrailer;
wrap_bufs[0].BufferType = SECBUFFER_TOKEN;
wrap_bufs[0].pvBuffer = msg;
wrap_bufs[1].cbBuffer = input_length;
wrap_bufs[1].BufferType = SECBUFFER_DATA;
wrap_bufs[1].pvBuffer = msg + sizes.cbSecurityTrailer;
wrap_bufs[2].cbBuffer = sizes.cbBlockSize;
wrap_bufs[2].BufferType = SECBUFFER_PADDING;
wrap_bufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + input_length;
client->status = sspi_functions->EncryptMessage(&client->ctx, SECQOP_WRAP_NO_ENCRYPT, &wrap_buf_desc, 0);
if (client->status != SEC_E_OK) {
free(msg);
return SSPI_ERROR;
}
*output_length = wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer + wrap_bufs[2].cbBuffer;
*output = malloc(*output_length);
memcpy(*output, wrap_bufs[0].pvBuffer, wrap_bufs[0].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer, wrap_bufs[1].pvBuffer, wrap_bufs[1].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer, wrap_bufs[2].pvBuffer, wrap_bufs[2].cbBuffer);
free(msg);
return SSPI_OK;
}
int sspi_client_destroy(
sspi_client_state *client
)
{
if (client->has_ctx > 0) {
sspi_functions->DeleteSecurityContext(&client->ctx);
}
sspi_functions->FreeCredentialsHandle(&client->cred);
return SSPI_OK;
}

View File

@@ -0,0 +1,64 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,windows
#ifndef SSPI_WRAPPER_H
#define SSPI_WRAPPER_H
#define SECURITY_WIN32 1 /* Required for SSPI */
#include <windows.h>
#include <sspi.h>
#define SSPI_OK 0
#define SSPI_CONTINUE 1
#define SSPI_ERROR 2
typedef struct {
CredHandle cred;
CtxtHandle ctx;
int has_ctx;
SECURITY_STATUS status;
} sspi_client_state;
int sspi_init();
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
);
int sspi_client_username(
sspi_client_state *client,
char** username
);
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_destroy(
sspi_client_state *client
);
#endif

View File

@@ -0,0 +1,82 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"errors"
)
// MongoDBAWS is the mechanism name for MongoDBAWS.
const MongoDBAWS = "MONGODB-AWS"
func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil)
}
return &MongoDBAWSAuthenticator{
source: cred.Source,
username: cred.Username,
password: cred.Password,
sessionToken: cred.Props["AWS_SESSION_TOKEN"],
}, nil
}
// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
type MongoDBAWSAuthenticator struct {
source string
username string
password string
sessionToken string
}
// Auth authenticates the connection.
func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error {
httpClient := cfg.HTTPClient
if httpClient == nil {
return errors.New("cfg.HTTPClient must not be nil")
}
adapter := &awsSaslAdapter{
conversation: &awsConversation{
username: a.username,
password: a.password,
token: a.sessionToken,
httpClient: httpClient,
},
}
err := ConductSaslConversation(ctx, cfg, a.source, adapter)
if err != nil {
return newAuthError("sasl conversation error", err)
}
return nil
}
type awsSaslAdapter struct {
conversation *awsConversation
}
var _ SaslClient = (*awsSaslAdapter)(nil)
func (a *awsSaslAdapter) Start() (string, []byte, error) {
step, err := a.conversation.Step(nil)
if err != nil {
return MongoDBAWS, nil, err
}
return MongoDBAWS, step, nil
}
func (a *awsSaslAdapter) Next(challenge []byte) ([]byte, error) {
step, err := a.conversation.Step(challenge)
if err != nil {
return nil, err
}
return step, nil
}
func (a *awsSaslAdapter) Completed() bool {
return a.conversation.Done()
}

View File

@@ -0,0 +1,110 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"io"
// Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need
// to use MD5 here to implement the MONGODB-CR specification.
/* #nosec G501 */
"crypto/md5"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// MONGODBCR is the mechanism name for MONGODB-CR.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
// MongoDB 4.0.
const MONGODBCR = "MONGODB-CR"
func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) {
return &MongoDBCRAuthenticator{
DB: cred.Source,
Username: cred.Username,
Password: cred.Password,
}, nil
}
// MongoDBCRAuthenticator uses the MONGODB-CR algorithm to authenticate a connection.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
// MongoDB 4.0.
type MongoDBCRAuthenticator struct {
DB string
Username string
Password string
}
// Auth authenticates the connection.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
// MongoDB 4.0.
func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error {
db := a.DB
if db == "" {
db = defaultAuthDB
}
doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1))
cmd := operation.NewCommand(doc).
Database(db).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err := cmd.Execute(ctx)
if err != nil {
return newError(err, MONGODBCR)
}
rdr := cmd.Result()
var getNonceResult struct {
Nonce string `bson:"nonce"`
}
err = bson.Unmarshal(rdr, &getNonceResult)
if err != nil {
return newAuthError("unmarshal error", err)
}
doc = bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "user", a.Username),
bsoncore.AppendStringElement(nil, "nonce", getNonceResult.Nonce),
bsoncore.AppendStringElement(nil, "key", a.createKey(getNonceResult.Nonce)),
)
cmd = operation.NewCommand(doc).
Database(db).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err = cmd.Execute(ctx)
if err != nil {
return newError(err, MONGODBCR)
}
return nil
}
func (a *MongoDBCRAuthenticator) createKey(nonce string) string {
// Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to
// implement the MONGODB-CR specification.
/* #nosec G401 */
h := md5.New()
_, _ = io.WriteString(h, nonce)
_, _ = io.WriteString(h, a.Username)
_, _ = io.WriteString(h, mongoPasswordDigest(a.Username, a.Password))
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -0,0 +1,55 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
)
// PLAIN is the mechanism name for PLAIN.
const PLAIN = "PLAIN"
func newPlainAuthenticator(cred *Cred) (Authenticator, error) {
return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
}, nil
}
// PlainAuthenticator uses the PLAIN algorithm over SASL to authenticate a connection.
type PlainAuthenticator struct {
Username string
Password string
}
// Auth authenticates the connection.
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error {
return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{
username: a.Username,
password: a.Password,
})
}
type plainSaslClient struct {
username string
password string
}
var _ SaslClient = (*plainSaslClient)(nil)
func (c *plainSaslClient) Start() (string, []byte, error) {
b := []byte("\x00" + c.username + "\x00" + c.password)
return PLAIN, b, nil
}
func (c *plainSaslClient) Next(challenge []byte) ([]byte, error) {
return nil, newAuthError("unexpected server challenge", nil)
}
func (c *plainSaslClient) Completed() bool {
return true
}

View File

@@ -0,0 +1,174 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// SaslClient is the client piece of a sasl conversation.
type SaslClient interface {
Start() (string, []byte, error)
Next(challenge []byte) ([]byte, error)
Completed() bool
}
// SaslClientCloser is a SaslClient that has resources to clean up.
type SaslClientCloser interface {
SaslClient
Close()
}
// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command.
type ExtraOptionsSaslClient interface {
StartCommandOptions() bsoncore.Document
}
// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the
// conversation can be executed in multi-step speculative fashion.
type saslConversation struct {
client SaslClient
source string
mechanism string
speculative bool
}
var _ SpeculativeConversation = (*saslConversation)(nil)
func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation {
authSource := source
if authSource == "" {
authSource = defaultAuthDB
}
return &saslConversation{
client: client,
source: authSource,
speculative: speculative,
}
}
// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used
// for speculative authentication.
func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) {
var payload []byte
var err error
sc.mechanism, payload, err = sc.client.Start()
if err != nil {
return nil, err
}
saslCmdElements := [][]byte{
bsoncore.AppendInt32Element(nil, "saslStart", 1),
bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
}
if sc.speculative {
// The "db" field is only appended for speculative auth because the hello command is executed against admin
// so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands
// will be executed against the auth source.
saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source))
}
if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok {
optionsDoc := extraOptionsClient.StartCommandOptions()
saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc))
}
return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil
}
type saslResponse struct {
ConversationID int `bson:"conversationId"`
Code int `bson:"code"`
Done bool `bson:"done"`
Payload []byte `bson:"payload"`
}
// Finish completes the conversation based on the first server response to authenticate the given connection.
func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error {
if closer, ok := sc.client.(SaslClientCloser); ok {
defer closer.Close()
}
var saslResp saslResponse
err := bson.Unmarshal(firstResponse, &saslResp)
if err != nil {
fullErr := fmt.Errorf("unmarshal error: %v", err)
return newError(fullErr, sc.mechanism)
}
cid := saslResp.ConversationID
var payload []byte
var rdr bsoncore.Document
for {
if saslResp.Code != 0 {
return newError(err, sc.mechanism)
}
if saslResp.Done && sc.client.Completed() {
return nil
}
payload, err = sc.client.Next(saslResp.Payload)
if err != nil {
return newError(err, sc.mechanism)
}
if saslResp.Done && sc.client.Completed() {
return nil
}
doc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "saslContinue", 1),
bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
)
saslContinueCmd := operation.NewCommand(doc).
Database(sc.source).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err = saslContinueCmd.Execute(ctx)
if err != nil {
return newError(err, sc.mechanism)
}
rdr = saslContinueCmd.Result()
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
fullErr := fmt.Errorf("unmarshal error: %v", err)
return newError(fullErr, sc.mechanism)
}
}
}
// ConductSaslConversation runs a full SASL conversation to authenticate the given connection.
func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error {
// Create a non-speculative SASL conversation.
conversation := newSaslConversation(client, authSource, false)
saslStartDoc, err := conversation.FirstMessage()
if err != nil {
return newError(err, conversation.mechanism)
}
saslStartCmd := operation.NewCommand(saslStartDoc).
Database(authSource).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
if err := saslStartCmd.Execute(ctx); err != nil {
return newError(err, conversation.mechanism)
}
return conversation.Finish(ctx, cfg, saslStartCmd.Result())
}

View File

@@ -0,0 +1,130 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Copyright (C) MongoDB, Inc. 2018-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"github.com/xdg-go/scram"
"github.com/xdg-go/stringprep"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
const (
// SCRAMSHA1 holds the mechanism name "SCRAM-SHA-1"
SCRAMSHA1 = "SCRAM-SHA-1"
// SCRAMSHA256 holds the mechanism name "SCRAM-SHA-256"
SCRAMSHA256 = "SCRAM-SHA-256"
)
var (
// Additional options for the saslStart command to enable a shorter SCRAM conversation
scramStartOptions bsoncore.Document = bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
)
)
func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) {
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-1 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA1,
source: cred.Source,
client: client,
}, nil
}
func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) {
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
if err != nil {
return nil, newAuthError(fmt.Sprintf("error SASLprepping password '%s'", cred.Password), err)
}
client, err := scram.SHA256.NewClientUnprepped(cred.Username, passprep, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-256 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA256,
source: cred.Source,
client: client,
}, nil
}
// ScramAuthenticator uses the SCRAM algorithm over SASL to authenticate a connection.
type ScramAuthenticator struct {
mechanism string
source string
client *scram.Client
}
var _ SpeculativeAuthenticator = (*ScramAuthenticator)(nil)
// Auth authenticates the provided connection by conducting a full SASL conversation.
func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error {
err := ConductSaslConversation(ctx, cfg, a.source, a.createSaslClient())
if err != nil {
return newAuthError("sasl conversation error", err)
}
return nil
}
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return newSaslConversation(a.createSaslClient(), a.source, true), nil
}
func (a *ScramAuthenticator) createSaslClient() SaslClient {
return &scramSaslAdapter{
conversation: a.client.NewConversation(),
mechanism: a.mechanism,
}
}
type scramSaslAdapter struct {
mechanism string
conversation *scram.ClientConversation
}
var _ SaslClient = (*scramSaslAdapter)(nil)
var _ ExtraOptionsSaslClient = (*scramSaslAdapter)(nil)
func (a *scramSaslAdapter) Start() (string, []byte, error) {
step, err := a.conversation.Step("")
if err != nil {
return a.mechanism, nil, err
}
return a.mechanism, []byte(step), nil
}
func (a *scramSaslAdapter) Next(challenge []byte) ([]byte, error) {
step, err := a.conversation.Step(string(challenge))
if err != nil {
return nil, err
}
return []byte(step), nil
}
func (a *scramSaslAdapter) Completed() bool {
return a.conversation.Done()
}
func (*scramSaslAdapter) StartCommandOptions() bsoncore.Document {
return scramStartOptions
}

View File

@@ -0,0 +1,30 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"fmt"
"io"
// Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need
// to use MD5 here to implement the SCRAM specification.
/* #nosec G501 */
"crypto/md5"
)
const defaultAuthDB = "admin"
func mongoPasswordDigest(username, password string) string {
// Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to
// implement the SCRAM specification.
/* #nosec G401 */
h := md5.New()
_, _ = io.WriteString(h, username)
_, _ = io.WriteString(h, ":mongo:")
_, _ = io.WriteString(h, password)
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -0,0 +1,85 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// MongoDBX509 is the mechanism name for MongoDBX509.
const MongoDBX509 = "MONGODB-X509"
func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) {
return &MongoDBX509Authenticator{User: cred.Username}, nil
}
// MongoDBX509Authenticator uses X.509 certificates over TLS to authenticate a connection.
type MongoDBX509Authenticator struct {
User string
}
var _ SpeculativeAuthenticator = (*MongoDBX509Authenticator)(nil)
// x509 represents a X509 authentication conversation. This type implements the SpeculativeConversation interface so the
// conversation can be executed in multi-step speculative fashion.
type x509Conversation struct{}
var _ SpeculativeConversation = (*x509Conversation)(nil)
// FirstMessage returns the first message to be sent to the server.
func (c *x509Conversation) FirstMessage() (bsoncore.Document, error) {
return createFirstX509Message(description.Server{}, ""), nil
}
// createFirstX509Message creates the first message for the X509 conversation.
func createFirstX509Message(desc description.Server, user string) bsoncore.Document {
elements := [][]byte{
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "mechanism", MongoDBX509),
}
// Server versions < 3.4 require the username to be included in the message. Versions >= 3.4 will extract the
// username from the certificate.
if desc.WireVersion != nil && desc.WireVersion.Max < 5 {
elements = append(elements, bsoncore.AppendStringElement(nil, "user", user))
}
return bsoncore.BuildDocument(nil, elements...)
}
// Finish implements the SpeculativeConversation interface and is a no-op because an X509 conversation only has one
// step.
func (c *x509Conversation) Finish(context.Context, *Config, bsoncore.Document) error {
return nil
}
// CreateSpeculativeConversation creates a speculative conversation for X509 authentication.
func (a *MongoDBX509Authenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return &x509Conversation{}, nil
}
// Auth authenticates the provided connection by conducting an X509 authentication conversation.
func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error {
requestDoc := createFirstX509Message(cfg.Description, a.User)
authCmd := operation.
NewCommand(requestDoc).
Database("$external").
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err := authCmd.Execute(ctx)
if err != nil {
return newAuthError("round trip error", err)
}
return nil
}