logmower-shipper/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/aws_conv.go

349 lines
9.3 KiB
Go

// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4"
)
type clientState int
const (
clientStarting clientState = iota
clientFirst
clientFinal
clientDone
)
type awsConversation struct {
state clientState
valid bool
nonce []byte
username string
password string
token string
httpClient *http.Client
}
type serverMessage struct {
Nonce primitive.Binary `bson:"s"`
Host string `bson:"h"`
}
type ecsResponse struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"Token"`
}
const (
amzDateFormat = "20060102T150405Z"
awsRelativeURI = "http://169.254.170.2/"
awsEC2URI = "http://169.254.169.254/"
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
awsEC2TokenPath = "latest/api/token"
defaultRegion = "us-east-1"
maxHostLength = 255
defaultHTTPTimeout = 10 * time.Second
responceNonceLength = 64
)
// Step takes a string provided from a server (or just an empty string for the
// very first conversation step) and attempts to move the authentication
// conversation forward. It returns a string to be sent to the server or an
// error if the server message is invalid. Calling Step after a conversation
// completes is also an error.
func (ac *awsConversation) Step(challenge []byte) (response []byte, err error) {
switch ac.state {
case clientStarting:
ac.state = clientFirst
response = ac.firstMsg()
case clientFirst:
ac.state = clientFinal
response, err = ac.finalMsg(challenge)
case clientFinal:
ac.state = clientDone
ac.valid = true
default:
response, err = nil, errors.New("Conversation already completed")
}
return
}
// Done returns true if the conversation is completed or has errored.
func (ac *awsConversation) Done() bool {
return ac.state == clientDone
}
// Valid returns true if the conversation successfully authenticated with the
// server, including counter-validation that the server actually has the
// user's stored credentials.
func (ac *awsConversation) Valid() bool {
return ac.valid
}
func getRegion(host string) (string, error) {
region := defaultRegion
if len(host) == 0 {
return "", errors.New("invalid STS host: empty")
}
if len(host) > maxHostLength {
return "", errors.New("invalid STS host: too large")
}
// The implicit region for sts.amazonaws.com is us-east-1
if host == "sts.amazonaws.com" {
return region, nil
}
if strings.HasPrefix(host, ".") || strings.HasSuffix(host, ".") || strings.Contains(host, "..") {
return "", errors.New("invalid STS host: empty part")
}
// If the host has multiple parts, the second part is the region
parts := strings.Split(host, ".")
if len(parts) >= 2 {
region = parts[1]
}
return region, nil
}
func (ac *awsConversation) validateAndMakeCredentials() (*awsv4.StaticProvider, error) {
if ac.username != "" && ac.password == "" {
return nil, errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
}
if ac.username == "" && ac.password != "" {
return nil, errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
}
if ac.username == "" && ac.password == "" && ac.token != "" {
return nil, errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
}
if ac.username != "" || ac.password != "" || ac.token != "" {
return &awsv4.StaticProvider{Value: awsv4.Value{
AccessKeyID: ac.username,
SecretAccessKey: ac.password,
SessionToken: ac.token,
}}, nil
}
return nil, nil
}
func executeAWSHTTPRequest(httpClient *http.Client, req *http.Request) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultHTTPTimeout)
defer cancel()
resp, err := httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
}
func (ac *awsConversation) getEC2Credentials() (*awsv4.StaticProvider, error) {
// get token
req, err := http.NewRequest("PUT", awsEC2URI+awsEC2TokenPath, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "30")
token, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
if len(token) == 0 {
return nil, errors.New("unable to retrieve token from EC2 metadata")
}
tokenStr := string(token)
// get role name
req, err = http.NewRequest("GET", awsEC2URI+awsEC2RolePath, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
role, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
if len(role) == 0 {
return nil, errors.New("unable to retrieve role_name from EC2 metadata")
}
// get credentials
pathWithRole := awsEC2URI + awsEC2RolePath + string(role)
req, err = http.NewRequest("GET", pathWithRole, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
creds, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
var es2Resp ecsResponse
err = json.Unmarshal(creds, &es2Resp)
if err != nil {
return nil, err
}
ac.username = es2Resp.AccessKeyID
ac.password = es2Resp.SecretAccessKey
ac.token = es2Resp.Token
return ac.validateAndMakeCredentials()
}
func (ac *awsConversation) getCredentials() (*awsv4.StaticProvider, error) {
// Credentials passed through URI
creds, err := ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
// Credentials from environment variables
ac.username = os.Getenv("AWS_ACCESS_KEY_ID")
ac.password = os.Getenv("AWS_SECRET_ACCESS_KEY")
ac.token = os.Getenv("AWS_SESSION_TOKEN")
creds, err = ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
// Credentials from ECS metadata
relativeEcsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
if len(relativeEcsURI) > 0 {
fullURI := awsRelativeURI + relativeEcsURI
req, err := http.NewRequest("GET", fullURI, nil)
if err != nil {
return nil, err
}
body, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
var espResp ecsResponse
err = json.Unmarshal(body, &espResp)
if err != nil {
return nil, err
}
ac.username = espResp.AccessKeyID
ac.password = espResp.SecretAccessKey
ac.token = espResp.Token
creds, err = ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
}
// Credentials from EC2 metadata
creds, err = ac.getEC2Credentials()
if creds == nil && err == nil {
return nil, errors.New("unable to get credentials")
}
return creds, err
}
func (ac *awsConversation) firstMsg() []byte {
// Values are cached for use in final message parameters
ac.nonce = make([]byte, 32)
_, _ = rand.Read(ac.nonce)
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendInt32Element(msg, "p", 110)
msg = bsoncore.AppendBinaryElement(msg, "r", 0x00, ac.nonce)
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg
}
func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
var sm serverMessage
err := bson.Unmarshal(s1, &sm)
if err != nil {
return nil, err
}
// Check nonce prefix
if sm.Nonce.Subtype != 0x00 {
return nil, errors.New("server reply contained unexpected binary subtype")
}
if len(sm.Nonce.Data) != responceNonceLength {
return nil, fmt.Errorf("server reply nonce was not %v bytes", responceNonceLength)
}
if !bytes.HasPrefix(sm.Nonce.Data, ac.nonce) {
return nil, errors.New("server nonce did not extend client nonce")
}
region, err := getRegion(sm.Host)
if err != nil {
return nil, err
}
creds, err := ac.getCredentials()
if err != nil {
return nil, err
}
currentTime := time.Now().UTC()
body := "Action=GetCallerIdentity&Version=2011-06-15"
// Create http.Request
req, _ := http.NewRequest("POST", "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Length", "43")
req.Host = sm.Host
req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
if len(ac.token) > 0 {
req.Header.Set("X-Amz-Security-Token", ac.token)
}
req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
// Create signer with credentials
signer := awsv4.NewSigner(creds)
// Get signed header
_, err = signer.Sign(req, strings.NewReader(body), "sts", region, currentTime)
if err != nil {
return nil, err
}
// create message
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
if len(ac.token) > 0 {
msg = bsoncore.AppendStringElement(msg, "t", ac.token)
}
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg, nil
}