go mod vendor

+ move k8s.io/apimachinery fork from go.work to go.mod
(and include it in vendor)
This commit is contained in:
2022-11-07 00:16:27 +02:00
parent d08bbf250a
commit e45bf4739b
1366 changed files with 469062 additions and 45 deletions

View File

@@ -0,0 +1,40 @@
# Topology Package Design
This document outlines the design for this package.
## Topology
The `Topology` type handles monitoring the state of a MongoDB deployment and selecting servers.
Updating the description is handled by finite state machine which implements the server discovery
and monitoring specification. A `Topology` can be connected and fully disconnected, which enables
saving resources. The `Topology` type also handles server selection following the server selection
specification.
## Server
The `Server` type handles heartbeating a MongoDB server and holds a pool of connections.
## Connection
Connections are handled by two main types and an auxiliary type. The two main types are `connection`
and `Connection`. The first holds most of the logic required to actually read and write wire
messages. Instances can be created with the `newConnection` method. Inside the `newConnection`
method the auxiliary type, `initConnection` is used to perform the connection handshake. This is
required because the `connection` type does not fully implement `driver.Connection` which is
required during handshaking. The `Connection` type is what is actually returned to a consumer of the
`topology` package. This type does implement the `driver.Connection` type, holds a reference to a
`connection` instance, and exists mainly to prevent accidental continued usage of a connection after
closing it.
The connection implementations in this package are conduits for wire messages but they have no
ability to encode, decode, or validate wire messages. That must be handled by consumers.
## Pool
The `pool` type implements a connection pool. It handles caching idle connections and dialing
new ones, but it does not track a maximum number of connections. That is the responsibility of a
wrapping type, like `Server`.
The `pool` type has no concept of closing, instead it has concepts of connecting and disconnecting.
This allows a `Topology` to be disconnected,but keeping the memory around to be reconnected later.
There is a `close` method, but this is used to close a connection.
There are three methods related to getting and putting connections: `get`, `close`, and `put`. The
`get` method will either retrieve a connection from the cache or it will dial a new `connection`.
The `close` method will close the underlying socket of a `connection`. The `put` method will put a
connection into the pool, placing it in the cahce if there is space, otherwise it will close it.

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import "context"
type cancellationListener interface {
Listen(context.Context, func())
StopListening() bool
}

View File

@@ -0,0 +1,825 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// Connection state constants.
const (
connDisconnected int64 = iota
connConnected
connInitialized
)
var globalConnectionID uint64 = 1
var (
defaultMaxMessageSize uint32 = 48000000
errResponseTooLarge = errors.New("length of read message too large")
errLoadBalancedStateMismatch = errors.New("driver attempted to initialize in load balancing mode, but the server does not support this mode")
)
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
type connection struct {
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
id string
nc net.Conn // When nil, the connection is closed.
addr address.Address
idleTimeout time.Duration
idleDeadline atomic.Value // Stores a time.Time
readTimeout time.Duration
writeTimeout time.Duration
desc description.Server
helloRTT time.Duration
compressor wiremessage.CompressorID
zliblevel int
zstdLevel int
connectDone chan struct{}
config *connectionConfig
cancelConnectContext context.CancelFunc
connectContextMade chan struct{}
canStream bool
currentlyStreaming bool
connectContextMutex sync.Mutex
cancellationListener cancellationListener
serverConnectionID *int32 // the server's ID for this client's connection
// pool related fields
pool *pool
poolID uint64
generation uint64
}
// newConnection handles the creation of a connection. It does not connect the connection.
func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
cfg := newConnectionConfig(opts...)
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
c := &connection{
id: id,
addr: addr,
idleTimeout: cfg.idleTimeout,
readTimeout: cfg.readTimeout,
writeTimeout: cfg.writeTimeout,
connectDone: make(chan struct{}),
config: cfg,
connectContextMade: make(chan struct{}),
cancellationListener: internal.NewCancellationListener(),
}
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
// at any point during connection establishment can be processed without the connection being considered stale.
if !c.config.loadBalanced {
c.setGenerationNumber()
}
atomic.StoreInt64(&c.state, connInitialized)
return c
}
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
if c.config.getGenerationFn != nil {
c.generation = c.config.getGenerationFn(c.desc.ServiceID)
}
}
// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
// generationNumberFn provided via the connection options has been called exactly once.
func (c *connection) hasGenerationNumber() bool {
if !c.config.loadBalanced {
// The generation is known for all non-LB clusters once the connection object has been created.
return true
}
// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
// description has been updated to reflect that it's behind an LB.
return c.desc.LoadBalanced()
}
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
// handshakes. All errors returned by connect are considered "before the handshake completes" and
// must be handled by calling the appropriate SDAM handshake error handler.
func (c *connection) connect(ctx context.Context) (err error) {
if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
return nil
}
defer close(c.connectDone)
// If connect returns an error, set the connection status as disconnected and close the
// underlying net.Conn if it was created.
defer func() {
if err != nil {
atomic.StoreInt64(&c.state, connDisconnected)
if c.nc != nil {
_ = c.nc.Close()
}
}
}()
// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
//
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
//
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
// holding the lock longer than necessary.
c.connectContextMutex.Lock()
var handshakeCtx context.Context
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
c.connectContextMutex.Unlock()
dialCtx := handshakeCtx
var dialCancel context.CancelFunc
if c.config.connectTimeout != 0 {
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
defer dialCancel()
}
defer func() {
var cancelFn context.CancelFunc
c.connectContextMutex.Lock()
cancelFn = c.cancelConnectContext
c.cancelConnectContext = nil
c.connectContextMutex.Unlock()
if cancelFn != nil {
cancelFn()
}
}()
close(c.connectContextMade)
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
if err != nil {
return ConnectionError{Wrapped: err, init: true}
}
c.nc = tempNc
if c.config.tlsConfig != nil {
tlsConfig := c.config.tlsConfig.Clone()
// store the result of configureTLS in a separate variable than c.nc to avoid overwriting c.nc with nil in
// error cases.
ocspOpts := &ocsp.VerifyOptions{
Cache: c.config.ocspCache,
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
HTTPClient: c.config.httpClient,
}
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
if err != nil {
return ConnectionError{Wrapped: err, init: true}
}
c.nc = tlsNc
}
// running hello and authentication is handled by a handshaker on the configuration instance.
handshaker := c.config.handshaker
if handshaker == nil {
return nil
}
var handshakeInfo driver.HandshakeInformation
handshakeStartTime := time.Now()
handshakeConn := initConnection{c}
handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
if err == nil {
// We only need to retain the Description field as the connection's description. The authentication-related
// fields in handshakeInfo are tracked by the handshaker if necessary.
c.desc = handshakeInfo.Description
c.serverConnectionID = handshakeInfo.ServerConnectionID
c.helloRTT = time.Since(handshakeStartTime)
// If the application has indicated that the cluster is load balanced, ensure the server has included serviceId
// in its handshake response to signal that it knows it's behind an LB as well.
if c.config.loadBalanced && c.desc.ServiceID == nil {
err = errLoadBalancedStateMismatch
}
}
if err == nil {
// For load-balanced connections, the generation number depends on the service ID, which isn't known until the
// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
// number unless GetHandshakeInformation succeeds.
if c.config.loadBalanced {
c.setGenerationNumber()
}
// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
// the handshake.
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
}
// We have a failed handshake here
if err != nil {
return ConnectionError{Wrapped: err, init: true}
}
if len(c.desc.Compression) > 0 {
clientMethodLoop:
for _, method := range c.config.compressors {
for _, serverMethod := range c.desc.Compression {
if method != serverMethod {
continue
}
switch strings.ToLower(method) {
case "snappy":
c.compressor = wiremessage.CompressorSnappy
case "zlib":
c.compressor = wiremessage.CompressorZLib
c.zliblevel = wiremessage.DefaultZlibLevel
if c.config.zlibLevel != nil {
c.zliblevel = *c.config.zlibLevel
}
case "zstd":
c.compressor = wiremessage.CompressorZstd
c.zstdLevel = wiremessage.DefaultZstdLevel
if c.config.zstdLevel != nil {
c.zstdLevel = *c.config.zstdLevel
}
}
break clientMethodLoop
}
}
}
return nil
}
func (c *connection) wait() {
if c.connectDone != nil {
<-c.connectDone
}
}
func (c *connection) closeConnectContext() {
<-c.connectContextMade
var cancelFn context.CancelFunc
c.connectContextMutex.Lock()
cancelFn = c.cancelConnectContext
c.cancelConnectContext = nil
c.connectContextMutex.Unlock()
if cancelFn != nil {
cancelFn()
}
}
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
if originalError == nil {
return nil
}
// If there was an error and the context was cancelled, we assume it happened due to the cancellation.
if ctx.Err() == context.Canceled {
return context.Canceled
}
// If there was a timeout error and the context deadline was used, we convert the error into
// context.DeadlineExceeded.
if !contextDeadlineUsed {
return originalError
}
if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
return context.DeadlineExceeded
}
return originalError
}
func (c *connection) cancellationListenerCallback() {
_ = c.close()
}
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if atomic.LoadInt64(&c.state) != connConnected {
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
var deadline time.Time
if c.writeTimeout != 0 {
deadline = time.Now().Add(c.writeTimeout)
}
var contextDeadlineUsed bool
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
contextDeadlineUsed = true
deadline = dl
}
if err := c.nc.SetWriteDeadline(deadline); err != nil {
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
}
err = c.write(ctx, wm)
if err != nil {
c.close()
return ConnectionError{
ConnectionID: c.id,
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
message: "unable to write wire message to network",
}
}
return nil
}
func (c *connection) write(ctx context.Context, wm []byte) (err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
// There is a race condition between Write and StopListening. If the context is cancelled after c.nc.Write
// succeeds, the cancellation listener could fire and close the connection. In this case, the connection has
// been invalidated but the error is nil. To account for this, overwrite the error to context.Cancelled if
// the abortedForCancellation flag was set.
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
err = context.Canceled
}
}()
_, err = c.nc.Write(wm)
return err
}
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
if atomic.LoadInt64(&c.state) != connConnected {
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
var deadline time.Time
if c.readTimeout != 0 {
deadline = time.Now().Add(c.readTimeout)
}
var contextDeadlineUsed bool
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
contextDeadlineUsed = true
deadline = dl
}
if err := c.nc.SetReadDeadline(deadline); err != nil {
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
}
dst, errMsg, err := c.read(ctx, dst)
if err != nil {
// We closeConnection the connection because we don't know if there are other bytes left to read.
c.close()
message := errMsg
if err == io.EOF {
message = "socket was unexpectedly closed"
}
return dst, ConnectionError{
ConnectionID: c.id,
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
message: message,
}
}
return dst, nil
}
func (c *connection) read(ctx context.Context, dst []byte) (bytesRead []byte, errMsg string, err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
// If the context is cancelled after we finish reading the server response, the cancellation listener could fire
// even though the socket reads succeed. To account for this, we overwrite err to be context.Canceled if the
// abortedForCancellation flag is set.
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
errMsg = "unable to read server response"
err = context.Canceled
}
}()
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err = io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
return dst, "incomplete read of message header", err
}
// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return dst, errResponseTooLarge.Error(), errResponseTooLarge
}
if int(size) > cap(dst) {
// Since we can't grow this slice without allocating, just allocate an entirely new slice.
dst = make([]byte, 0, size)
}
// We need to ensure we don't accidentally read into a subsequent wire message, so we set the
// size to read exactly this wire message.
dst = dst[:size]
copy(dst, sizeBuf[:])
_, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
return dst, "incomplete read of full message", err
}
return dst, "", nil
}
func (c *connection) close() error {
// Overwrite the connection state as the first step so only the first close call will execute.
if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
return nil
}
var err error
if c.nc != nil {
err = c.nc.Close()
}
return err
}
func (c *connection) closed() bool {
return atomic.LoadInt64(&c.state) == connDisconnected
}
func (c *connection) idleTimeoutExpired() bool {
now := time.Now()
if c.idleTimeout > 0 {
idleDeadline, ok := c.idleDeadline.Load().(time.Time)
if ok && now.After(idleDeadline) {
return true
}
}
return false
}
func (c *connection) bumpIdleDeadline() {
if c.idleTimeout > 0 {
c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
}
}
func (c *connection) setCanStream(canStream bool) {
c.canStream = canStream
}
func (c initConnection) supportsStreaming() bool {
return c.canStream
}
func (c *connection) setStreaming(streaming bool) {
c.currentlyStreaming = streaming
}
func (c *connection) getCurrentlyStreaming() bool {
return c.currentlyStreaming
}
func (c *connection) setSocketTimeout(timeout time.Duration) {
c.readTimeout = timeout
c.writeTimeout = timeout
}
func (c *connection) ID() string {
return c.id
}
func (c *connection) ServerConnectionID() *int32 {
return c.serverConnectionID
}
// initConnection is an adapter used during connection initialization. It has the minimum
// functionality necessary to implement the driver.Connection interface, which is required to pass a
// *connection to a Handshaker.
type initConnection struct{ *connection }
var _ driver.Connection = initConnection{}
var _ driver.StreamerConnection = initConnection{}
func (c initConnection) Description() description.Server {
if c.connection == nil {
return description.Server{}
}
return c.connection.desc
}
func (c initConnection) Close() error { return nil }
func (c initConnection) ID() string { return c.id }
func (c initConnection) Address() address.Address { return c.addr }
func (c initConnection) Stale() bool { return false }
func (c initConnection) LocalAddress() address.Address {
if c.connection == nil || c.nc == nil {
return address.Address("0.0.0.0")
}
return address.Address(c.nc.LocalAddr().String())
}
func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
return c.writeWireMessage(ctx, wm)
}
func (c initConnection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
return c.readWireMessage(ctx, dst)
}
func (c initConnection) SetStreaming(streaming bool) {
c.setStreaming(streaming)
}
func (c initConnection) CurrentlyStreaming() bool {
return c.getCurrentlyStreaming()
}
func (c initConnection) SupportsStreaming() bool {
return c.supportsStreaming()
}
// Connection implements the driver.Connection interface to allow reading and writing wire
// messages and the driver.Expirable interface to allow expiring.
type Connection struct {
*connection
refCount int
cleanupPoolFn func()
// cleanupServerFn resets the server state when a connection is returned to the connection pool
// via Close() or expired via Expire().
cleanupServerFn func()
mu sync.RWMutex
}
var _ driver.Connection = (*Connection)(nil)
var _ driver.Expirable = (*Connection)(nil)
var _ driver.PinnedConnection = (*Connection)(nil)
// WriteWireMessage handles writing a wire message to the underlying connection.
func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return ErrConnectionClosed
}
return c.writeWireMessage(ctx, wm)
}
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
// will be overwritten with the new wire message.
func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return dst, ErrConnectionClosed
}
return c.readWireMessage(ctx, dst)
}
// CompressWireMessage handles compressing the provided wire message using the underlying
// connection's compressor. The dst parameter will be overwritten with the new wire message. If
// there is no compressor set on the underlying connection, then no compression will be performed.
func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return dst, ErrConnectionClosed
}
if c.connection.compressor == wiremessage.CompressorNoOp {
return append(dst, src...), nil
}
_, reqid, respto, origcode, rem, ok := wiremessage.ReadHeader(src)
if !ok {
return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
}
idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode)
dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem)))
dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor)
opts := driver.CompressionOpts{
Compressor: c.connection.compressor,
ZlibLevel: c.connection.zliblevel,
ZstdLevel: c.connection.zstdLevel,
}
compressed, err := driver.CompressPayload(rem, opts)
if err != nil {
return nil, err
}
dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed)
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
}
// Description returns the server description of the server this connection is connected to.
func (c *Connection) Description() description.Server {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return description.Server{}
}
return c.desc
}
// Close returns this connection to the connection pool. This method may not closeConnection the underlying
// socket.
func (c *Connection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil || c.refCount > 0 {
return nil
}
return c.cleanupReferences()
}
// Expire closes this connection and will closeConnection the underlying socket.
func (c *Connection) Expire() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
return nil
}
_ = c.close()
return c.cleanupReferences()
}
func (c *Connection) cleanupReferences() error {
err := c.pool.checkIn(c.connection)
if c.cleanupPoolFn != nil {
c.cleanupPoolFn()
c.cleanupPoolFn = nil
}
if c.cleanupServerFn != nil {
c.cleanupServerFn()
c.cleanupServerFn = nil
}
c.connection = nil
return err
}
// Alive returns if the connection is still alive.
func (c *Connection) Alive() bool {
return c.connection != nil
}
// ID returns the ID of this connection.
func (c *Connection) ID() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return "<closed>"
}
return c.id
}
// Stale returns if the connection is stale.
func (c *Connection) Stale() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.pool.stale(c.connection)
}
// Address returns the address of this connection.
func (c *Connection) Address() address.Address {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return address.Address("0.0.0.0")
}
return c.addr
}
// LocalAddress returns the local address of the connection
func (c *Connection) LocalAddress() address.Address {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil || c.nc == nil {
return address.Address("0.0.0.0")
}
return address.Address(c.nc.LocalAddr().String())
}
// PinToCursor updates this connection to reflect that it is pinned to a cursor.
func (c *Connection) PinToCursor() error {
return c.pin("cursor", c.pool.pinConnectionToCursor, c.pool.unpinConnectionFromCursor)
}
// PinToTransaction updates this connection to reflect that it is pinned to a transaction.
func (c *Connection) PinToTransaction() error {
return c.pin("transaction", c.pool.pinConnectionToTransaction, c.pool.unpinConnectionFromTransaction)
}
func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
return fmt.Errorf("attempted to pin a connection for a %s, but the connection has already been returned to the pool", reason)
}
// Only use the provided callbacks for the first reference to avoid double-counting pinned connection statistics
// in the pool.
if c.refCount == 0 {
updatePoolFn()
c.cleanupPoolFn = cleanupPoolFn
}
c.refCount++
return nil
}
// UnpinFromCursor updates this connection to reflect that it is no longer pinned to a cursor.
func (c *Connection) UnpinFromCursor() error {
return c.unpin("cursor")
}
// UnpinFromTransaction updates this connection to reflect that it is no longer pinned to a transaction.
func (c *Connection) UnpinFromTransaction() error {
return c.unpin("transaction")
}
func (c *Connection) unpin(reason string) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
// We don't error here because the resource could have been forcefully closed via Expire.
return nil
}
if c.refCount == 0 {
return fmt.Errorf("attempted to unpin a connection from a %s, but the connection is not pinned by any resources", reason)
}
c.refCount--
return nil
}
func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}
hostname = hostname[:colonPos]
config.ServerName = hostname
}
client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}
// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

View File

@@ -0,0 +1,7 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology

View File

@@ -0,0 +1,214 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
)
// Dialer is used to make network connections.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// DialerFunc is a type implemented by functions that can be used as a Dialer.
type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
// DialContext implements the Dialer interface.
func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return df(ctx, network, address)
}
// DefaultDialer is the Dialer implementation that is used by this package. Changing this
// will also change the Dialer used for this package. This should only be changed why all
// of the connections being made need to use a different Dialer. Most of the time, using a
// WithDialer option is more appropriate than changing this variable.
var DefaultDialer Dialer = &net.Dialer{}
// Handshaker is the interface implemented by types that can perform a MongoDB
// handshake over a provided driver.Connection. This is used during connection
// initialization. Implementations must be goroutine safe.
type Handshaker = driver.Handshaker
// generationNumberFn is a callback type used by a connection to fetch its generation number given its service ID.
type generationNumberFn func(serviceID *primitive.ObjectID) uint64
type connectionConfig struct {
connectTimeout time.Duration
dialer Dialer
handshaker Handshaker
idleTimeout time.Duration
cmdMonitor *event.CommandMonitor
readTimeout time.Duration
writeTimeout time.Duration
tlsConfig *tls.Config
httpClient *http.Client
compressors []string
zlibLevel *int
zstdLevel *int
ocspCache ocsp.Cache
disableOCSPEndpointCheck bool
tlsConnectionSource tlsConnectionSource
loadBalanced bool
getGenerationFn generationNumberFn
}
func newConnectionConfig(opts ...ConnectionOption) *connectionConfig {
cfg := &connectionConfig{
connectTimeout: 30 * time.Second,
dialer: nil,
tlsConnectionSource: defaultTLSConnectionSource,
httpClient: internal.DefaultHTTPClient,
}
for _, opt := range opts {
if opt == nil {
continue
}
opt(cfg)
}
if cfg.dialer == nil {
cfg.dialer = &net.Dialer{}
}
return cfg
}
// ConnectionOption is used to configure a connection.
type ConnectionOption func(*connectionConfig)
func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) ConnectionOption {
return func(c *connectionConfig) {
c.tlsConnectionSource = fn(c.tlsConnectionSource)
}
}
// WithCompressors sets the compressors that can be used for communication.
func WithCompressors(fn func([]string) []string) ConnectionOption {
return func(c *connectionConfig) {
c.compressors = fn(c.compressors)
}
}
// WithConnectTimeout configures the maximum amount of time a dial will wait for a
// Connect to complete. The default is 30 seconds.
func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) {
c.connectTimeout = fn(c.connectTimeout)
}
}
// WithDialer configures the Dialer to use when making a new connection to MongoDB.
func WithDialer(fn func(Dialer) Dialer) ConnectionOption {
return func(c *connectionConfig) {
c.dialer = fn(c.dialer)
}
}
// WithHandshaker configures the Handshaker that wll be used to initialize newly
// dialed connections.
func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption {
return func(c *connectionConfig) {
c.handshaker = fn(c.handshaker)
}
}
// WithIdleTimeout configures the maximum idle time to allow for a connection.
func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) {
c.idleTimeout = fn(c.idleTimeout)
}
}
// WithReadTimeout configures the maximum read time for a connection.
func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) {
c.readTimeout = fn(c.readTimeout)
}
}
// WithWriteTimeout configures the maximum write time for a connection.
func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) {
c.writeTimeout = fn(c.writeTimeout)
}
}
// WithTLSConfig configures the TLS options for a connection.
func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption {
return func(c *connectionConfig) {
c.tlsConfig = fn(c.tlsConfig)
}
}
// WithHTTPClient configures the HTTP client for a connection.
func WithHTTPClient(fn func(*http.Client) *http.Client) ConnectionOption {
return func(c *connectionConfig) {
c.httpClient = fn(c.httpClient)
}
}
// WithMonitor configures a event for command monitoring.
func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) ConnectionOption {
return func(c *connectionConfig) {
c.cmdMonitor = fn(c.cmdMonitor)
}
}
// WithZlibLevel sets the zLib compression level.
func WithZlibLevel(fn func(*int) *int) ConnectionOption {
return func(c *connectionConfig) {
c.zlibLevel = fn(c.zlibLevel)
}
}
// WithZstdLevel sets the zstd compression level.
func WithZstdLevel(fn func(*int) *int) ConnectionOption {
return func(c *connectionConfig) {
c.zstdLevel = fn(c.zstdLevel)
}
}
// WithOCSPCache specifies a cache to use for OCSP verification.
func WithOCSPCache(fn func(ocsp.Cache) ocsp.Cache) ConnectionOption {
return func(c *connectionConfig) {
c.ocspCache = fn(c.ocspCache)
}
}
// WithDisableOCSPEndpointCheck specifies whether or the driver should perform non-stapled OCSP verification. If set
// to true, the driver will only check stapled responses and will continue the connection without reaching out to
// OCSP responders.
func WithDisableOCSPEndpointCheck(fn func(bool) bool) ConnectionOption {
return func(c *connectionConfig) {
c.disableOCSPEndpointCheck = fn(c.disableOCSPEndpointCheck)
}
}
// WithConnectionLoadBalanced specifies whether or not the connection is to a server behind a load balancer.
func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption {
return func(c *connectionConfig) {
c.loadBalanced = fn(c.loadBalanced)
}
}
func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption {
return func(c *connectionConfig) {
c.getGenerationFn = fn(c.getGenerationFn)
}
}

View File

@@ -0,0 +1,73 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import "go.mongodb.org/mongo-driver/mongo/description"
// hostlistDiff is the difference between a topology and a host list.
type hostlistDiff struct {
Added []string
Removed []string
}
// diffHostList compares the topology description and host list and returns the difference.
func diffHostList(t description.Topology, hostlist []string) hostlistDiff {
var diff hostlistDiff
oldServers := make(map[string]bool)
for _, s := range t.Servers {
oldServers[s.Addr.String()] = true
}
for _, addr := range hostlist {
if oldServers[addr] {
delete(oldServers, addr)
} else {
diff.Added = append(diff.Added, addr)
}
}
for addr := range oldServers {
diff.Removed = append(diff.Removed, addr)
}
return diff
}
// topologyDiff is the difference between two different topology descriptions.
type topologyDiff struct {
Added []description.Server
Removed []description.Server
}
// diffTopology compares the two topology descriptions and returns the difference.
func diffTopology(old, new description.Topology) topologyDiff {
var diff topologyDiff
oldServers := make(map[string]bool)
for _, s := range old.Servers {
oldServers[s.Addr.String()] = true
}
for _, s := range new.Servers {
addr := s.Addr.String()
if oldServers[addr] {
delete(oldServers, addr)
} else {
diff.Added = append(diff.Added, s)
}
}
for _, s := range old.Servers {
addr := s.Addr.String()
if oldServers[addr] {
diff.Removed = append(diff.Removed, s)
}
}
return diff
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo/description"
)
// ConnectionError represents a connection error.
type ConnectionError struct {
ConnectionID string
Wrapped error
// init will be set to true if this error occurred during connection initialization or
// during a connection handshake.
init bool
message string
}
// Error implements the error interface.
func (e ConnectionError) Error() string {
message := e.message
if e.init {
fullMsg := "error occurred during connection handshake"
if message != "" {
fullMsg = fmt.Sprintf("%s: %s", fullMsg, message)
}
message = fullMsg
}
if e.Wrapped != nil && message != "" {
return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, message, e.Wrapped.Error())
}
if e.Wrapped != nil {
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.Wrapped.Error())
}
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, message)
}
// Unwrap returns the underlying error.
func (e ConnectionError) Unwrap() error {
return e.Wrapped
}
// ServerSelectionError represents a Server Selection error.
type ServerSelectionError struct {
Desc description.Topology
Wrapped error
}
// Error implements the error interface.
func (e ServerSelectionError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("server selection error: %s, current topology: { %s }", e.Wrapped.Error(), e.Desc.String())
}
return fmt.Sprintf("server selection error: current topology: { %s }", e.Desc.String())
}
// Unwrap returns the underlying error.
func (e ServerSelectionError) Unwrap() error {
return e.Wrapped
}
// WaitQueueTimeoutError represents a timeout when requesting a connection from the pool
type WaitQueueTimeoutError struct {
Wrapped error
PinnedCursorConnections uint64
PinnedTransactionConnections uint64
maxPoolSize uint64
totalConnectionCount int
}
// Error implements the error interface.
func (w WaitQueueTimeoutError) Error() string {
errorMsg := "timed out while checking out a connection from connection pool"
switch w.Wrapped {
case nil:
case context.Canceled:
errorMsg = fmt.Sprintf(
"%s: %s",
"canceled while checking out a connection from connection pool",
w.Wrapped.Error(),
)
default:
errorMsg = fmt.Sprintf(
"%s: %s",
errorMsg,
w.Wrapped.Error(),
)
}
return fmt.Sprintf(
"%s; maxPoolSize: %d, connections in use by cursors: %d"+
", connections in use by transactions: %d, connections in use by other operations: %d",
errorMsg,
w.maxPoolSize,
w.PinnedCursorConnections,
w.PinnedTransactionConnections,
uint64(w.totalConnectionCount)-w.PinnedCursorConnections-w.PinnedTransactionConnections)
}
// Unwrap returns the underlying error.
func (w WaitQueueTimeoutError) Unwrap() error {
return w.Wrapped
}

View File

@@ -0,0 +1,438 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"bytes"
"fmt"
"sync/atomic"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
)
var (
// MinSupportedMongoDBVersion is the version string for the lowest MongoDB version supported by the driver.
MinSupportedMongoDBVersion = "3.6"
// SupportedWireVersions is the range of wire versions supported by the driver.
SupportedWireVersions = description.NewVersionRange(6, 17)
)
type fsm struct {
description.Topology
maxElectionID primitive.ObjectID
maxSetVersion uint32
compatible atomic.Value
compatibilityErr error
}
func newFSM() *fsm {
f := fsm{}
f.compatible.Store(true)
return &f
}
// apply takes a new server description and modifies the FSM's topology description based on it. It returns the
// updated topology description as well as a server description. The returned server description is either the same
// one that was passed in, or a new one in the case that it had to be changed.
//
// apply should operation on immutable descriptions so we don't have to lock for the entire time we're applying the
// server description.
func (f *fsm) apply(s description.Server) (description.Topology, description.Server) {
newServers := make([]description.Server, len(f.Servers))
copy(newServers, f.Servers)
oldMinutes := f.SessionTimeoutMinutes
f.Topology = description.Topology{
Kind: f.Kind,
Servers: newServers,
SetName: f.SetName,
}
// For data bearing servers, set SessionTimeoutMinutes to the lowest among them
if oldMinutes == 0 {
// If timeout currently 0, check all servers to see if any still don't have a timeout
// If they all have timeout, pick the lowest.
timeout := s.SessionTimeoutMinutes
for _, server := range f.Servers {
if server.DataBearing() && server.SessionTimeoutMinutes < timeout {
timeout = server.SessionTimeoutMinutes
}
}
f.SessionTimeoutMinutes = timeout
} else {
if s.DataBearing() && oldMinutes > s.SessionTimeoutMinutes {
f.SessionTimeoutMinutes = s.SessionTimeoutMinutes
} else {
f.SessionTimeoutMinutes = oldMinutes
}
}
if _, ok := f.findServer(s.Addr); !ok {
return f.Topology, s
}
updatedDesc := s
switch f.Kind {
case description.Unknown:
updatedDesc = f.applyToUnknown(s)
case description.Sharded:
updatedDesc = f.applyToSharded(s)
case description.ReplicaSetNoPrimary:
updatedDesc = f.applyToReplicaSetNoPrimary(s)
case description.ReplicaSetWithPrimary:
updatedDesc = f.applyToReplicaSetWithPrimary(s)
case description.Single:
updatedDesc = f.applyToSingle(s)
}
for _, server := range f.Servers {
if server.WireVersion != nil {
if server.WireVersion.Max < SupportedWireVersions.Min {
f.compatible.Store(false)
f.compatibilityErr = fmt.Errorf(
"server at %s reports wire version %d, but this version of the Go driver requires "+
"at least %d (MongoDB %s)",
server.Addr.String(),
server.WireVersion.Max,
SupportedWireVersions.Min,
MinSupportedMongoDBVersion,
)
f.Topology.CompatibilityErr = f.compatibilityErr
return f.Topology, s
}
if server.WireVersion.Min > SupportedWireVersions.Max {
f.compatible.Store(false)
f.compatibilityErr = fmt.Errorf(
"server at %s requires wire version %d, but this version of the Go driver only supports up to %d",
server.Addr.String(),
server.WireVersion.Min,
SupportedWireVersions.Max,
)
f.Topology.CompatibilityErr = f.compatibilityErr
return f.Topology, s
}
}
}
f.compatible.Store(true)
f.compatibilityErr = nil
return f.Topology, updatedDesc
}
func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) description.Server {
switch s.Kind {
case description.Standalone, description.Mongos:
f.removeServerByAddr(s.Addr)
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.updateRSWithoutPrimary(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
}
return s
}
func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) description.Server {
switch s.Kind {
case description.Standalone, description.Mongos:
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.updateRSWithPrimaryFromMember(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
f.checkIfHasPrimary()
}
return s
}
func (f *fsm) applyToSharded(s description.Server) description.Server {
switch s.Kind {
case description.Mongos, description.Unknown:
f.replaceServer(s)
case description.Standalone, description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
f.removeServerByAddr(s.Addr)
}
return s
}
func (f *fsm) applyToSingle(s description.Server) description.Server {
switch s.Kind {
case description.Unknown:
f.replaceServer(s)
case description.Standalone, description.Mongos:
if f.SetName != "" {
f.removeServerByAddr(s.Addr)
return s
}
f.replaceServer(s)
case description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
// A replica set name can be provided when creating a direct connection. In this case, if the set name returned
// by the hello response doesn't match up with the one provided during configuration, the server description
// is replaced with a default Unknown description.
//
// We create a new server description rather than doing s.Kind = description.Unknown because the other fields,
// such as RTT, need to be cleared for Unknown descriptions as well.
if f.SetName != "" && f.SetName != s.SetName {
s = description.Server{
Addr: s.Addr,
Kind: description.Unknown,
}
}
f.replaceServer(s)
}
return s
}
func (f *fsm) applyToUnknown(s description.Server) description.Server {
switch s.Kind {
case description.Mongos:
f.setKind(description.Sharded)
f.replaceServer(s)
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.setKind(description.ReplicaSetNoPrimary)
f.updateRSWithoutPrimary(s)
case description.Standalone:
f.updateUnknownWithStandalone(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
}
return s
}
func (f *fsm) checkIfHasPrimary() {
if _, ok := f.findPrimary(); ok {
f.setKind(description.ReplicaSetWithPrimary)
} else {
f.setKind(description.ReplicaSetNoPrimary)
}
}
// hasStalePrimary returns true if the topology has a primary that is "stale".
func hasStalePrimary(fsm fsm, srv description.Server) bool {
// Compare the election ID values of the server and the topology lexicographically.
compRes := bytes.Compare(srv.ElectionID[:], fsm.maxElectionID[:])
if wireVersion := srv.WireVersion; wireVersion != nil && wireVersion.Max >= 17 {
// In the Post-6.0 case, a primary is considered "stale" if the server's election ID is greather than the
// topology's max election ID. In these versions, the primary is also considered "stale" if the server's
// election ID is LTE to the topologies election ID and the server's "setVersion" is less than the topology's
// max "setVersion".
return compRes == -1 || (compRes != 1 && srv.SetVersion < fsm.maxSetVersion)
}
// If the server's election ID is less than the topology's max election ID, the primary is considered
// "stale". Similarly, if the server's "setVersion" is less than the topology's max "setVersion", the
// primary is considered stale.
return compRes == -1 || fsm.maxSetVersion > srv.SetVersion
}
// transferEVTuple will transfer the ("ElectionID", "SetVersion") tuple from the description server to the topology.
// If the primary is stale, the tuple will not be transferred, the topology will update it's "Kind" value, and this
// routine will return "false".
func transferEVTuple(srv description.Server, fsm *fsm) bool {
stalePrimary := hasStalePrimary(*fsm, srv)
if wireVersion := srv.WireVersion; wireVersion != nil && wireVersion.Max >= 17 {
if stalePrimary {
fsm.checkIfHasPrimary()
return false
}
fsm.maxElectionID = srv.ElectionID
fsm.maxSetVersion = srv.SetVersion
return true
}
if srv.SetVersion != 0 && !srv.ElectionID.IsZero() {
if stalePrimary {
fsm.replaceServer(description.Server{
Addr: srv.Addr,
LastError: fmt.Errorf(
"was a primary, but its set version or election id is stale"),
})
fsm.checkIfHasPrimary()
return false
}
fsm.maxElectionID = srv.ElectionID
}
if srv.SetVersion > fsm.maxSetVersion {
fsm.maxSetVersion = srv.SetVersion
}
return true
}
func (f *fsm) updateRSFromPrimary(srv description.Server) {
if f.SetName == "" {
f.SetName = srv.SetName
} else if f.SetName != srv.SetName {
f.removeServerByAddr(srv.Addr)
f.checkIfHasPrimary()
return
}
if ok := transferEVTuple(srv, f); !ok {
return
}
if j, ok := f.findPrimary(); ok {
f.setServer(j, description.Server{
Addr: f.Servers[j].Addr,
LastError: fmt.Errorf("was a primary, but a new primary was discovered"),
})
}
f.replaceServer(srv)
for j := len(f.Servers) - 1; j >= 0; j-- {
found := false
for _, member := range srv.Members {
if member == f.Servers[j].Addr {
found = true
break
}
}
if !found {
f.removeServer(j)
}
}
for _, member := range srv.Members {
if _, ok := f.findServer(member); !ok {
f.addServer(member)
}
}
f.checkIfHasPrimary()
}
func (f *fsm) updateRSWithPrimaryFromMember(s description.Server) {
if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
if s.Addr != s.CanonicalAddr {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
f.replaceServer(s)
if _, ok := f.findPrimary(); !ok {
f.setKind(description.ReplicaSetNoPrimary)
}
}
func (f *fsm) updateRSWithoutPrimary(s description.Server) {
if f.SetName == "" {
f.SetName = s.SetName
} else if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
return
}
for _, member := range s.Members {
if _, ok := f.findServer(member); !ok {
f.addServer(member)
}
}
if s.Addr != s.CanonicalAddr {
f.removeServerByAddr(s.Addr)
return
}
f.replaceServer(s)
}
func (f *fsm) updateUnknownWithStandalone(s description.Server) {
if len(f.Servers) > 1 {
f.removeServerByAddr(s.Addr)
return
}
f.setKind(description.Single)
f.replaceServer(s)
}
func (f *fsm) addServer(addr address.Address) {
f.Servers = append(f.Servers, description.Server{
Addr: addr.Canonicalize(),
})
}
func (f *fsm) findPrimary() (int, bool) {
for i, s := range f.Servers {
if s.Kind == description.RSPrimary {
return i, true
}
}
return 0, false
}
func (f *fsm) findServer(addr address.Address) (int, bool) {
canon := addr.Canonicalize()
for i, s := range f.Servers {
if canon == s.Addr {
return i, true
}
}
return 0, false
}
func (f *fsm) removeServer(i int) {
f.Servers = append(f.Servers[:i], f.Servers[i+1:]...)
}
func (f *fsm) removeServerByAddr(addr address.Address) {
if i, ok := f.findServer(addr); ok {
f.removeServer(i)
}
}
func (f *fsm) replaceServer(s description.Server) {
if i, ok := f.findServer(s.Addr); ok {
f.setServer(i, s)
}
}
func (f *fsm) setServer(i int, s description.Server) {
f.Servers[i] = s
}
func (f *fsm) setKind(k description.TopologyKind) {
f.Kind = k
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,152 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"sync"
"sync/atomic"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Pool generation state constants.
const (
generationDisconnected int64 = iota
generationConnected
)
// generationStats represents the version of a pool. It tracks the generation number as well as the number of
// connections that have been created in the generation.
type generationStats struct {
generation uint64
numConns uint64
}
// poolGenerationMap tracks the version for each service ID present in a pool. For deployments that are not behind a
// load balancer, there is only one service ID: primitive.NilObjectID. For load-balanced deployments, each server behind
// the load balancer will have a unique service ID.
type poolGenerationMap struct {
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
generationMap map[primitive.ObjectID]*generationStats
sync.Mutex
}
func newPoolGenerationMap() *poolGenerationMap {
pgm := &poolGenerationMap{
generationMap: make(map[primitive.ObjectID]*generationStats),
}
pgm.generationMap[primitive.NilObjectID] = &generationStats{}
return pgm
}
func (p *poolGenerationMap) connect() {
atomic.StoreInt64(&p.state, generationConnected)
}
func (p *poolGenerationMap) disconnect() {
atomic.StoreInt64(&p.state, generationDisconnected)
}
// addConnection increments the connection count for the generation associated with the given service ID and returns the
// generation number for the connection.
func (p *poolGenerationMap) addConnection(serviceIDPtr *primitive.ObjectID) uint64 {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
stats, ok := p.generationMap[serviceID]
if ok {
// If the serviceID is already being tracked, we only need to increment the connection count.
stats.numConns++
return stats.generation
}
// If the serviceID is untracked, create a new entry with a starting generation number of 0.
stats = &generationStats{
numConns: 1,
}
p.generationMap[serviceID] = stats
return 0
}
func (p *poolGenerationMap) removeConnection(serviceIDPtr *primitive.ObjectID) {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
stats, ok := p.generationMap[serviceID]
if !ok {
return
}
// If the serviceID is being tracked, decrement the connection count and delete this serviceID to prevent the map
// from growing unboundedly. This case would happen if a server behind a load-balancer was permanently removed
// and its connections were pruned after a network error or idle timeout.
stats.numConns--
if stats.numConns == 0 {
delete(p.generationMap, serviceID)
}
}
func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
stats.generation++
}
}
func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool {
// If the map has been disconnected, all connections should be considered stale to ensure that they're closed.
if atomic.LoadInt64(&p.state) == generationDisconnected {
return true
}
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
return knownGeneration < stats.generation
}
return false
}
func (p *poolGenerationMap) getGeneration(serviceIDPtr *primitive.ObjectID) uint64 {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
return stats.generation
}
return 0
}
func (p *poolGenerationMap) getNumConns(serviceIDPtr *primitive.ObjectID) uint64 {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
defer p.Unlock()
if stats, ok := p.generationMap[serviceID]; ok {
return stats.numConns
}
return 0
}
func getServiceID(oid *primitive.ObjectID) primitive.ObjectID {
if oid == nil {
return primitive.NilObjectID
}
return *oid
}

View File

@@ -0,0 +1,307 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"fmt"
"math"
"sync"
"time"
"github.com/montanaflynn/stats"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
const (
rttAlphaValue = 0.2
minSamples = 10
maxSamples = 500
)
type rttConfig struct {
// The minimum interval between RTT measurements. The actual interval may be greater if running
// the operation takes longer than the interval.
interval time.Duration
// The timeout applied to running the "hello" operation. If the timeout is reached while running
// the operation, the RTT sample is discarded. The default is 1 minute.
timeout time.Duration
minRTTWindow time.Duration
createConnectionFn func() *connection
createOperationFn func(driver.Connection) *operation.Hello
}
type rttMonitor struct {
mu sync.RWMutex // mu guards samples, offset, minRTT, averageRTT, and averageRTTSet
samples []time.Duration
offset int
minRTT time.Duration
rtt90 time.Duration
averageRTT time.Duration
averageRTTSet bool
closeWg sync.WaitGroup
cfg *rttConfig
ctx context.Context
cancelFn context.CancelFunc
}
var _ driver.RTTMonitor = &rttMonitor{}
func newRTTMonitor(cfg *rttConfig) *rttMonitor {
if cfg.interval <= 0 {
panic("RTT monitor interval must be greater than 0")
}
ctx, cancel := context.WithCancel(context.Background())
// Determine the number of samples we need to keep to store the minWindow of RTT durations. The
// number of samples must be between [10, 500].
numSamples := int(math.Max(minSamples, math.Min(maxSamples, float64((cfg.minRTTWindow)/cfg.interval))))
return &rttMonitor{
samples: make([]time.Duration, numSamples),
cfg: cfg,
ctx: ctx,
cancelFn: cancel,
}
}
func (r *rttMonitor) connect() {
r.closeWg.Add(1)
go r.start()
}
func (r *rttMonitor) disconnect() {
// Signal for the routine to stop.
r.cancelFn()
r.closeWg.Wait()
}
func (r *rttMonitor) start() {
defer r.closeWg.Done()
var conn *connection
defer func() {
if conn != nil {
// If the connection exists, we need to wait for it to be connected because
// conn.connect() and conn.close() cannot be called concurrently. If the connection
// wasn't successfully opened, its state was set back to disconnected, so calling
// conn.close() will be a no-op.
conn.closeConnectContext()
conn.wait()
_ = conn.close()
}
}()
ticker := time.NewTicker(r.cfg.interval)
defer ticker.Stop()
for {
conn := r.cfg.createConnectionFn()
err := conn.connect(r.ctx)
// Add an RTT sample from the new connection handshake and start a runHellos() loop if we
// successfully established the new connection. Otherwise, close the connection and try to
// create another new connection.
if err == nil {
r.addSample(conn.helloRTT)
r.runHellos(conn)
}
// Close any connection here because we're either about to try to create another new
// connection or we're about to exit the loop.
_ = conn.close()
// If a connection error happens quickly, always wait for the monitoring interval to try
// to create a new connection to prevent creating connections too quickly.
select {
case <-ticker.C:
case <-r.ctx.Done():
return
}
}
}
// runHellos runs "hello" operations in a loop using the provided connection, measuring and
// recording the operation durations as RTT samples. If it encounters any errors, it returns.
func (r *rttMonitor) runHellos(conn *connection) {
ticker := time.NewTicker(r.cfg.interval)
defer ticker.Stop()
for {
// Assume that the connection establishment recorded the first RTT sample, so wait for the
// first tick before trying to record another RTT sample.
select {
case <-ticker.C:
case <-r.ctx.Done():
return
}
// Create a Context with the operation timeout specified in the RTT monitor config. If a
// timeout is not set in the RTT monitor config, default to the connection's
// "connectTimeoutMS". The purpose of the timeout is to allow the RTT monitor to continue
// monitoring server RTTs after an operation gets stuck. An operation can get stuck if the
// server or a proxy stops responding to requests on the RTT connection but does not close
// the TCP socket, effectively creating an operation that will never complete. We expect
// that "connectTimeoutMS" provides at least enough time for a single round trip.
timeout := r.cfg.timeout
if timeout <= 0 {
timeout = conn.config.connectTimeout
}
ctx, cancel := context.WithTimeout(r.ctx, timeout)
start := time.Now()
err := r.cfg.createOperationFn(initConnection{conn}).Execute(ctx)
cancel()
if err != nil {
return
}
// Only record a sample if the "hello" operation was successful. If it was not successful,
// the operation may not have actually performed a complete round trip, so the duration may
// be artificially short.
r.addSample(time.Since(start))
}
}
// reset sets the average and min RTT to 0. This should only be called from the server monitor when an error
// occurs during a server check. Errors in the RTT monitor should not reset the RTTs.
func (r *rttMonitor) reset() {
r.mu.Lock()
defer r.mu.Unlock()
for i := range r.samples {
r.samples[i] = 0
}
r.offset = 0
r.minRTT = 0
r.rtt90 = 0
r.averageRTT = 0
r.averageRTTSet = false
}
func (r *rttMonitor) addSample(rtt time.Duration) {
// Lock for the duration of this method. We're doing compuationally inexpensive work very infrequently, so lock
// contention isn't expected.
r.mu.Lock()
defer r.mu.Unlock()
r.samples[r.offset] = rtt
r.offset = (r.offset + 1) % len(r.samples)
// Set the minRTT and 90th percentile RTT of all collected samples. Require at least 10 samples before
// setting these to prevent noisy samples on startup from artificially increasing RTT and to allow the
// calculation of a 90th percentile.
r.minRTT = min(r.samples, minSamples)
r.rtt90 = percentile(90.0, r.samples, minSamples)
if !r.averageRTTSet {
r.averageRTT = rtt
r.averageRTTSet = true
return
}
r.averageRTT = time.Duration(rttAlphaValue*float64(rtt) + (1-rttAlphaValue)*float64(r.averageRTT))
}
// min returns the minimum value of the slice of duration samples. Zero values are not considered
// samples and are ignored. If no samples or fewer than minSamples are found in the slice, min
// returns 0.
func min(samples []time.Duration, minSamples int) time.Duration {
count := 0
min := time.Duration(math.MaxInt64)
for _, d := range samples {
if d > 0 {
count++
}
if d > 0 && d < min {
min = d
}
}
if count == 0 || count < minSamples {
return 0
}
return min
}
// percentile returns the specified percentile value of the slice of duration samples. Zero values
// are not considered samples and are ignored. If no samples or fewer than minSamples are found
// in the slice, percentile returns 0.
func percentile(perc float64, samples []time.Duration, minSamples int) time.Duration {
// Convert Durations to float64s.
floatSamples := make([]float64, 0, len(samples))
for _, sample := range samples {
if sample > 0 {
floatSamples = append(floatSamples, float64(sample))
}
}
if len(floatSamples) == 0 || len(floatSamples) < minSamples {
return 0
}
p, err := stats.Percentile(floatSamples, perc)
if err != nil {
panic(fmt.Errorf("x/mongo/driver/topology: error calculating %f percentile RTT: %v for samples:\n%v", perc, err, floatSamples))
}
return time.Duration(p)
}
// EWMA returns the exponentially weighted moving average observed round-trip time.
func (r *rttMonitor) EWMA() time.Duration {
r.mu.RLock()
defer r.mu.RUnlock()
return r.averageRTT
}
// Min returns the minimum observed round-trip time over the window period.
func (r *rttMonitor) Min() time.Duration {
r.mu.RLock()
defer r.mu.RUnlock()
return r.minRTT
}
// P90 returns the 90th percentile observed round-trip time over the window period.
func (r *rttMonitor) P90() time.Duration {
r.mu.RLock()
defer r.mu.RUnlock()
return r.rtt90
}
// Stats returns stringified stats of the current state of the monitor.
func (r *rttMonitor) Stats() string {
r.mu.RLock()
defer r.mu.RUnlock()
// Calculate standard deviation and average (non-EWMA) of samples.
var sum float64
floatSamples := make([]float64, 0, len(r.samples))
for _, sample := range r.samples {
if sample > 0 {
floatSamples = append(floatSamples, float64(sample))
sum += float64(sample)
}
}
var avg, stdDev float64
if len(floatSamples) > 0 {
avg = sum / float64(len(floatSamples))
var err error
stdDev, err = stats.StandardDeviation(floatSamples)
if err != nil {
panic(fmt.Errorf("x/mongo/driver/topology: error calculating standard deviation RTT: %v for samples:\n%v", err, floatSamples))
}
}
return fmt.Sprintf(`Round-trip-time monitor statistics:`+"\n"+
`average RTT: %v, minimum RTT: %v, 90th percentile RTT: %v, standard dev: %v`+"\n",
time.Duration(avg), r.minRTT, r.rtt90, time.Duration(stdDev))
}

View File

@@ -0,0 +1,957 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
const minHeartbeatInterval = 500 * time.Millisecond
// Server state constants.
const (
serverDisconnected int64 = iota
serverDisconnecting
serverConnected
)
func serverStateString(state int64) string {
switch state {
case serverDisconnected:
return "Disconnected"
case serverDisconnecting:
return "Disconnecting"
case serverConnected:
return "Connected"
}
return ""
}
var (
// ErrServerClosed occurs when an attempt to Get a connection is made after
// the server has been closed.
ErrServerClosed = errors.New("server is closed")
// ErrServerConnected occurs when at attempt to Connect is made after a server
// has already been connected.
ErrServerConnected = errors.New("server is connected")
errCheckCancelled = errors.New("server check cancelled")
emptyDescription = description.NewDefaultServer("")
)
// SelectedServer represents a specific server that was selected during server selection.
// It contains the kind of the topology it was selected from.
type SelectedServer struct {
*Server
Kind description.TopologyKind
}
// Description returns a description of the server as of the last heartbeat.
func (ss *SelectedServer) Description() description.SelectedServer {
sdesc := ss.Server.Description()
return description.SelectedServer{
Server: sdesc,
Kind: ss.Kind,
}
}
// Server is a single server within a topology.
type Server struct {
// The following integer fields must be accessed using the atomic package and should be at the
// beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
operationCount int64
cfg *serverConfig
address address.Address
// connection related fields
pool *pool
// goroutine management fields
done chan struct{}
checkNow chan struct{}
disconnecting chan struct{}
closewg sync.WaitGroup
// description related fields
desc atomic.Value // holds a description.Server
updateTopologyCallback atomic.Value
topologyID primitive.ObjectID
// subscriber related fields
subLock sync.Mutex
subscribers map[uint64]chan description.Server
currentSubscriberID uint64
subscriptionsClosed bool
// heartbeat and cancellation related fields
// globalCtx should be created in NewServer and cancelled in Disconnect to signal that the server is shutting down.
// heartbeatCtx should be used for individual heartbeats and should be a child of globalCtx so that it will be
// cancelled automatically during shutdown.
heartbeatLock sync.Mutex
conn *connection
globalCtx context.Context
globalCtxCancel context.CancelFunc
heartbeatCtx context.Context
heartbeatCtxCancel context.CancelFunc
processErrorLock sync.Mutex
rttMonitor *rttMonitor
}
// updateTopologyCallback is a callback used to create a server that should be called when the parent Topology instance
// should be updated based on a new server description. The callback must return the server description that should be
// stored by the server.
type updateTopologyCallback func(description.Server) description.Server
// ConnectServer creates a new Server and then initializes it using the
// Connect method.
func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) {
srvr := NewServer(addr, topologyID, opts...)
err := srvr.Connect(updateCallback)
if err != nil {
return nil, err
}
return srvr, nil
}
// NewServer creates a new server. The mongodb server at the address will be monitored
// on an internal monitoring goroutine.
func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...ServerOption) *Server {
cfg := newServerConfig(opts...)
globalCtx, globalCtxCancel := context.WithCancel(context.Background())
s := &Server{
state: serverDisconnected,
cfg: cfg,
address: addr,
done: make(chan struct{}),
checkNow: make(chan struct{}, 1),
disconnecting: make(chan struct{}),
topologyID: topologyID,
subscribers: make(map[uint64]chan description.Server),
globalCtx: globalCtx,
globalCtxCancel: globalCtxCancel,
}
s.desc.Store(description.NewDefaultServer(addr))
rttCfg := &rttConfig{
interval: cfg.heartbeatInterval,
minRTTWindow: 5 * time.Minute,
createConnectionFn: s.createConnection,
createOperationFn: s.createBaseOperation,
}
s.rttMonitor = newRTTMonitor(rttCfg)
pc := poolConfig{
Address: addr,
MinPoolSize: cfg.minConns,
MaxPoolSize: cfg.maxConns,
MaxConnecting: cfg.maxConnecting,
MaxIdleTime: cfg.poolMaxIdleTime,
MaintainInterval: cfg.poolMaintainInterval,
PoolMonitor: cfg.poolMonitor,
handshakeErrFn: s.ProcessHandshakeError,
}
connectionOpts := copyConnectionOpts(cfg.connectionOpts)
s.pool = newPool(pc, connectionOpts...)
s.publishServerOpeningEvent(s.address)
return s
}
// Connect initializes the Server by starting background monitoring goroutines.
// This method must be called before a Server can be used.
func (s *Server) Connect(updateCallback updateTopologyCallback) error {
if !atomic.CompareAndSwapInt64(&s.state, serverDisconnected, serverConnected) {
return ErrServerConnected
}
desc := description.NewDefaultServer(s.address)
if s.cfg.loadBalanced {
// LBs automatically start off with kind LoadBalancer because there is no monitoring routine for state changes.
desc.Kind = description.LoadBalancer
}
s.desc.Store(desc)
s.updateTopologyCallback.Store(updateCallback)
if !s.cfg.monitoringDisabled && !s.cfg.loadBalanced {
s.rttMonitor.connect()
s.closewg.Add(1)
go s.update()
}
// The CMAP spec describes that pools should only be marked "ready" when the server description
// is updated to something other than "Unknown". However, we maintain the previous Server
// behavior here and immediately mark the pool as ready during Connect() to simplify and speed
// up the Client startup behavior. The risk of marking a pool as ready proactively during
// Connect() is that we could attempt to create connections to a server that was configured
// erroneously until the first server check or checkOut() failure occurs, when the SDAM error
// handler would transition the Server back to "Unknown" and set the pool to "paused".
return s.pool.ready()
}
// Disconnect closes sockets to the server referenced by this Server.
// Subscriptions to this Server will be closed. Disconnect will shutdown
// any monitoring goroutines, closeConnection the idle connection pool, and will
// wait until all the in use connections have been returned to the connection
// pool and are closed before returning. If the context expires via
// cancellation, deadline, or timeout before the in use connections have been
// returned, the in use connections will be closed, resulting in the failure of
// any in flight read or write operations. If this method returns with no
// errors, all connections associated with this Server have been closed.
func (s *Server) Disconnect(ctx context.Context) error {
if !atomic.CompareAndSwapInt64(&s.state, serverConnected, serverDisconnecting) {
return ErrServerClosed
}
s.updateTopologyCallback.Store((updateTopologyCallback)(nil))
// Cancel the global context so any new contexts created from it will be automatically cancelled. Close the done
// channel so the update() routine will know that it can stop. Cancel any in-progress monitoring checks at the end.
// The done channel is closed before cancelling the check so the update routine() will immediately detect that it
// can stop rather than trying to create new connections until the read from done succeeds.
s.globalCtxCancel()
close(s.done)
s.cancelCheck()
s.rttMonitor.disconnect()
s.pool.close(ctx)
s.closewg.Wait()
atomic.StoreInt64(&s.state, serverDisconnected)
return nil
}
// Connection gets a connection to the server.
func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
if atomic.LoadInt64(&s.state) != serverConnected {
return nil, ErrServerClosed
}
// Increment the operation count before calling checkOut to make sure that all connection
// requests are included in the operation count, including those in the wait queue. If we got an
// error instead of a connection, immediately decrement the operation count.
atomic.AddInt64(&s.operationCount, 1)
conn, err := s.pool.checkOut(ctx)
if err != nil {
atomic.AddInt64(&s.operationCount, -1)
return nil, err
}
return &Connection{
connection: conn,
cleanupServerFn: func() {
// Decrement the operation count whenever the caller is done with the connection. Note
// that cleanupServerFn() is not called while the connection is pinned to a cursor or
// transaction, so the operation count is not decremented until the cursor is closed or
// the transaction is committed or aborted. Use an int64 instead of a uint64 to mitigate
// the impact of any possible bugs that could cause the uint64 to underflow, which would
// make the server much less selectable.
atomic.AddInt64(&s.operationCount, -1)
},
}, nil
}
// ProcessHandshakeError implements SDAM error handling for errors that occur before a connection
// finishes handshaking.
func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) {
// Ignore the error if the server is behind a load balancer but the service ID is unknown. This indicates that the
// error happened when dialing the connection or during the MongoDB handshake, so we don't know the service ID to
// use for clearing the pool.
if err == nil || s.cfg.loadBalanced && serviceID == nil {
return
}
// Ignore the error if the connection is stale.
if startingGenerationNumber < s.pool.generation.getGeneration(serviceID) {
return
}
wrappedConnErr := unwrapConnectionError(err)
if wrappedConnErr == nil {
return
}
// Must hold the processErrorLock while updating the server description and clearing the pool.
// Not holding the lock leads to possible out-of-order processing of pool.clear() and
// pool.ready() calls from concurrent server description updates.
s.processErrorLock.Lock()
defer s.processErrorLock.Unlock()
// Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set
// the description.Server appropriately. The description should not have a TopologyVersion because the staleness
// checking logic above has already determined that this description is not stale.
s.updateDescription(description.NewServerFromError(s.address, wrappedConnErr, nil))
s.pool.clear(err, serviceID)
s.cancelCheck()
}
// Description returns a description of the server as of the last heartbeat.
func (s *Server) Description() description.Server {
return s.desc.Load().(description.Server)
}
// SelectedDescription returns a description.SelectedServer with a Kind of
// Single. This can be used when performing tasks like monitoring a batch
// of servers and you want to run one off commands against those servers.
func (s *Server) SelectedDescription() description.SelectedServer {
sdesc := s.Description()
return description.SelectedServer{
Server: sdesc,
Kind: description.Single,
}
}
// Subscribe returns a ServerSubscription which has a channel on which all
// updated server descriptions will be sent. The channel will have a buffer
// size of one, and will be pre-populated with the current description.
func (s *Server) Subscribe() (*ServerSubscription, error) {
if atomic.LoadInt64(&s.state) != serverConnected {
return nil, ErrSubscribeAfterClosed
}
ch := make(chan description.Server, 1)
ch <- s.desc.Load().(description.Server)
s.subLock.Lock()
defer s.subLock.Unlock()
if s.subscriptionsClosed {
return nil, ErrSubscribeAfterClosed
}
id := s.currentSubscriberID
s.subscribers[id] = ch
s.currentSubscriberID++
ss := &ServerSubscription{
C: ch,
s: s,
id: id,
}
return ss, nil
}
// RequestImmediateCheck will cause the server to send a heartbeat immediately
// instead of waiting for the heartbeat timeout.
func (s *Server) RequestImmediateCheck() {
select {
case s.checkNow <- struct{}{}:
default:
}
}
// getWriteConcernErrorForProcessing extracts a driver.WriteConcernError from the provided error. This function returns
// (error, true) if the error is a WriteConcernError and the falls under the requirements for SDAM error
// handling and (nil, false) otherwise.
func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bool) {
writeCmdErr, ok := err.(driver.WriteCommandError)
if !ok {
return nil, false
}
wcerr := writeCmdErr.WriteConcernError
if wcerr != nil && (wcerr.NodeIsRecovering() || wcerr.NotPrimary()) {
return wcerr, true
}
return nil, false
}
// ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult {
// ignore nil error
if err == nil {
return driver.NoChange
}
// Must hold the processErrorLock while updating the server description and clearing the pool.
// Not holding the lock leads to possible out-of-order processing of pool.clear() and
// pool.ready() calls from concurrent server description updates.
s.processErrorLock.Lock()
defer s.processErrorLock.Unlock()
// ignore stale error
if conn.Stale() {
return driver.NoChange
}
// Invalidate server description if not primary or node recovering error occurs.
// These errors can be reported as a command error or a write concern error.
desc := conn.Description()
if cerr, ok := err.(driver.Error); ok && (cerr.NodeIsRecovering() || cerr.NotPrimary()) {
// ignore stale error
if desc.TopologyVersion.CompareToIncoming(cerr.TopologyVersion) >= 0 {
return driver.NoChange
}
// updates description to unknown
s.updateDescription(description.NewServerFromError(s.address, err, cerr.TopologyVersion))
s.RequestImmediateCheck()
res := driver.ServerMarkedUnknown
// If the node is shutting down or is older than 4.2, we synchronously clear the pool
if cerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 {
res = driver.ConnectionPoolCleared
s.pool.clear(err, desc.ServiceID)
}
return res
}
if wcerr, ok := getWriteConcernErrorForProcessing(err); ok {
// ignore stale error
if desc.TopologyVersion.CompareToIncoming(wcerr.TopologyVersion) >= 0 {
return driver.NoChange
}
// updates description to unknown
s.updateDescription(description.NewServerFromError(s.address, err, wcerr.TopologyVersion))
s.RequestImmediateCheck()
res := driver.ServerMarkedUnknown
// If the node is shutting down or is older than 4.2, we synchronously clear the pool
if wcerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 {
res = driver.ConnectionPoolCleared
s.pool.clear(err, desc.ServiceID)
}
return res
}
wrappedConnErr := unwrapConnectionError(err)
if wrappedConnErr == nil {
return driver.NoChange
}
// Ignore transient timeout errors.
if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() {
return driver.NoChange
}
if wrappedConnErr == context.Canceled || wrappedConnErr == context.DeadlineExceeded {
return driver.NoChange
}
// For a non-timeout network error, we clear the pool, set the description to Unknown, and cancel the in-progress
// monitoring check. The check is cancelled last to avoid a post-cancellation reconnect racing with
// updateDescription.
s.updateDescription(description.NewServerFromError(s.address, err, nil))
s.pool.clear(err, desc.ServiceID)
s.cancelCheck()
return driver.ConnectionPoolCleared
}
// update handles performing heartbeats and updating any subscribers of the
// newest description.Server retrieved.
func (s *Server) update() {
defer s.closewg.Done()
heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
rateLimiter := time.NewTicker(minHeartbeatInterval)
defer heartbeatTicker.Stop()
defer rateLimiter.Stop()
checkNow := s.checkNow
done := s.done
var doneOnce bool
defer func() {
if r := recover(); r != nil {
if doneOnce {
return
}
// We keep this goroutine alive attempting to read from the done channel.
<-done
}
}()
closeServer := func() {
doneOnce = true
s.subLock.Lock()
for id, c := range s.subscribers {
close(c)
delete(s.subscribers, id)
}
s.subscriptionsClosed = true
s.subLock.Unlock()
// We don't need to take s.heartbeatLock here because closeServer is called synchronously when the select checks
// below detect that the server is being closed, so we can be sure that the connection isn't being used.
if s.conn != nil {
_ = s.conn.close()
}
}
waitUntilNextCheck := func() {
// Wait until heartbeatFrequency elapses, an application operation requests an immediate check, or the server
// is disconnecting.
select {
case <-heartbeatTicker.C:
case <-checkNow:
case <-done:
// Return because the next update iteration will check the done channel again and clean up.
return
}
// Ensure we only return if minHeartbeatFrequency has elapsed or the server is disconnecting.
select {
case <-rateLimiter.C:
case <-done:
return
}
}
for {
// Check if the server is disconnecting. Even if waitForNextCheck has already read from the done channel, we
// can safely read from it again because Disconnect closes the channel.
select {
case <-done:
closeServer()
return
default:
}
previousDescription := s.Description()
// Perform the next check.
desc, err := s.check()
if err == errCheckCancelled {
if atomic.LoadInt64(&s.state) != serverConnected {
continue
}
// If the server is not disconnecting, the check was cancelled by an application operation after an error.
// Wait before running the next check.
waitUntilNextCheck()
continue
}
// Must hold the processErrorLock while updating the server description and clearing the
// pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and
// pool.ready() calls from concurrent server description updates.
s.processErrorLock.Lock()
s.updateDescription(desc)
if err := desc.LastError; err != nil {
// Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear
// because the monitoring routine only runs for non-load balanced deployments in which servers don't return
// IDs.
s.pool.clear(err, nil)
}
s.processErrorLock.Unlock()
// If the server supports streaming or we're already streaming, we want to move to streaming the next response
// without waiting. If the server has transitioned to Unknown from a network error, we want to do another
// check without waiting in case it was a transient error and the server isn't actually down.
serverSupportsStreaming := desc.Kind != description.Unknown && desc.TopologyVersion != nil
connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming()
transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil &&
previousDescription.Kind != description.Unknown
if serverSupportsStreaming || connectionIsStreaming || transitionedFromNetworkError {
continue
}
// The server either does not support the streamable protocol or is not in a healthy state, so we wait until
// the next check.
waitUntilNextCheck()
}
}
// updateDescription handles updating the description on the Server, notifying
// subscribers, and potentially draining the connection pool. The initial
// parameter is used to determine if this is the first description from the
// server.
func (s *Server) updateDescription(desc description.Server) {
if s.cfg.loadBalanced {
// In load balanced mode, there are no updates from the monitoring routine. For errors encountered in pooled
// connections, the server should not be marked Unknown to ensure that the LB remains selectable.
return
}
defer func() {
// ¯\_(ツ)_/¯
_ = recover()
}()
// Anytime we update the server description to something other than "unknown", set the pool to
// "ready". Do this before updating the description so that connections can be checked out as
// soon as the server is selectable. If the pool is already ready, this operation is a no-op.
// Note that this behavior is roughly consistent with the current Go driver behavior (connects
// to all servers, even non-data-bearing nodes) but deviates slightly from CMAP spec, which
// specifies a more restricted set of server descriptions and topologies that should mark the
// pool ready. We don't have access to the topology here, so prefer the current Go driver
// behavior for simplicity.
if desc.Kind != description.Unknown {
_ = s.pool.ready()
}
// Use the updateTopologyCallback to update the parent Topology and get the description that should be stored.
callback, ok := s.updateTopologyCallback.Load().(updateTopologyCallback)
if ok && callback != nil {
desc = callback(desc)
}
s.desc.Store(desc)
s.subLock.Lock()
for _, c := range s.subscribers {
select {
// drain the channel if it isn't empty
case <-c:
default:
}
c <- desc
}
s.subLock.Unlock()
}
// createConnection creates a new connection instance but does not call connect on it. The caller must call connect
// before the connection can be used for network operations.
func (s *Server) createConnection() *connection {
opts := copyConnectionOpts(s.cfg.connectionOpts)
opts = append(opts,
WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
// We override whatever handshaker is currently attached to the options with a basic
// one because need to make sure we don't do auth.
WithHandshaker(func(h Handshaker) Handshaker {
return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts).
ServerAPI(s.cfg.serverAPI)
}),
// Override any monitors specified in options with nil to avoid monitoring heartbeats.
WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }),
)
return newConnection(s.address, opts...)
}
func copyConnectionOpts(opts []ConnectionOption) []ConnectionOption {
optsCopy := make([]ConnectionOption, len(opts))
copy(optsCopy, opts)
return optsCopy
}
func (s *Server) setupHeartbeatConnection() error {
conn := s.createConnection()
// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
s.heartbeatLock.Lock()
if s.heartbeatCtxCancel != nil {
// Ensure the previous context is cancelled to avoid a leak.
s.heartbeatCtxCancel()
}
s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
s.conn = conn
s.heartbeatLock.Unlock()
return s.conn.connect(s.heartbeatCtx)
}
// cancelCheck cancels in-progress connection dials and reads. It does not set any fields on the server.
func (s *Server) cancelCheck() {
var conn *connection
// Take heartbeatLock for mutual exclusion with the checks in the update function.
s.heartbeatLock.Lock()
if s.heartbeatCtx != nil {
s.heartbeatCtxCancel()
}
conn = s.conn
s.heartbeatLock.Unlock()
if conn == nil {
return
}
// If the connection exists, we need to wait for it to be connected because conn.connect() and
// conn.close() cannot be called concurrently. If the connection wasn't successfully opened, its
// state was set back to disconnected, so calling conn.close() will be a no-op.
conn.closeConnectContext()
conn.wait()
_ = conn.close()
}
func (s *Server) checkWasCancelled() bool {
return s.heartbeatCtx.Err() != nil
}
func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello {
return operation.
NewHello().
ClusterClock(s.cfg.clock).
Deployment(driver.SingleConnectionDeployment{conn}).
ServerAPI(s.cfg.serverAPI)
}
func (s *Server) check() (description.Server, error) {
var descPtr *description.Server
var err error
var durationNanos int64
// Create a new connection if this is the first check, the connection was closed after an error during the previous
// check, or the previous check was cancelled.
if s.conn == nil || s.conn.closed() || s.checkWasCancelled() {
// Create a new connection and add it's handshake RTT as a sample.
err = s.setupHeartbeatConnection()
if err == nil {
// Use the description from the connection handshake as the value for this check.
s.rttMonitor.addSample(s.conn.helloRTT)
descPtr = &s.conn.desc
}
}
if descPtr == nil && err == nil {
// An existing connection is being used. Use the server description properties to execute the right heartbeat.
// Wrap conn in a type that implements driver.StreamerConnection.
heartbeatConn := initConnection{s.conn}
baseOperation := s.createBaseOperation(heartbeatConn)
previousDescription := s.Description()
streamable := previousDescription.TopologyVersion != nil
s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable)
start := time.Now()
switch {
case s.conn.getCurrentlyStreaming():
// The connection is already in a streaming state, so we stream the next response.
err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn)
case streamable:
// The server supports the streamable protocol. Set the socket timeout to
// connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable hello request. Set conn.canStream so
// the wire message will advertise streaming support to the server.
// Calculation for maxAwaitTimeMS is taken from time.Duration.Milliseconds (added in Go 1.13).
maxAwaitTimeMS := int64(s.cfg.heartbeatInterval) / 1e6
// If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS +
// heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS
// server-side.
socketTimeout := s.cfg.heartbeatTimeout
if socketTimeout != 0 {
socketTimeout += s.cfg.heartbeatInterval
}
s.conn.setSocketTimeout(socketTimeout)
baseOperation = baseOperation.TopologyVersion(previousDescription.TopologyVersion).
MaxAwaitTimeMS(maxAwaitTimeMS)
s.conn.setCanStream(true)
err = baseOperation.Execute(s.heartbeatCtx)
default:
// The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and
// execute a regular heartbeat without any additional parameters.
s.conn.setSocketTimeout(s.cfg.heartbeatTimeout)
err = baseOperation.Execute(s.heartbeatCtx)
}
durationNanos = time.Since(start).Nanoseconds()
if err == nil {
tempDesc := baseOperation.Result(s.address)
descPtr = &tempDesc
s.publishServerHeartbeatSucceededEvent(s.conn.ID(), durationNanos, tempDesc, s.conn.getCurrentlyStreaming() || streamable)
} else {
// Close the connection here rather than below so we ensure we're not closing a connection that wasn't
// successfully created.
if s.conn != nil {
_ = s.conn.close()
}
s.publishServerHeartbeatFailedEvent(s.conn.ID(), durationNanos, err, s.conn.getCurrentlyStreaming() || streamable)
}
}
if descPtr != nil {
// The check was successful. Set the average RTT and the 90th percentile RTT and return.
desc := *descPtr
desc = desc.SetAverageRTT(s.rttMonitor.EWMA())
desc.HeartbeatInterval = s.cfg.heartbeatInterval
return desc, nil
}
if s.checkWasCancelled() {
// If the previous check was cancelled, we don't want to clear the pool. Return a sentinel error so the caller
// will know that an actual error didn't occur.
return emptyDescription, errCheckCancelled
}
// An error occurred. We reset the RTT monitor for all errors and return an Unknown description. The pool must also
// be cleared, but only after the description has already been updated, so that is handled by the caller.
topologyVersion := extractTopologyVersion(err)
s.rttMonitor.reset()
return description.NewServerFromError(s.address, err, topologyVersion), nil
}
func extractTopologyVersion(err error) *description.TopologyVersion {
if ce, ok := err.(ConnectionError); ok {
err = ce.Wrapped
}
switch converted := err.(type) {
case driver.Error:
return converted.TopologyVersion
case driver.WriteCommandError:
if converted.WriteConcernError != nil {
return converted.WriteConcernError.TopologyVersion
}
}
return nil
}
// RTTMonitor returns this server's round-trip-time monitor.
func (s *Server) RTTMonitor() driver.RTTMonitor {
return s.rttMonitor
}
// OperationCount returns the current number of in-progress operations for this server.
func (s *Server) OperationCount() int64 {
return atomic.LoadInt64(&s.operationCount)
}
// String implements the Stringer interface.
func (s *Server) String() string {
desc := s.Description()
state := atomic.LoadInt64(&s.state)
str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
s.address, desc.Kind, serverStateString(state))
if len(desc.Tags) != 0 {
str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
}
if state == serverConnected {
str += fmt.Sprintf(", Average RTT: %s, Min RTT: %s", desc.AverageRTT, s.RTTMonitor().Min())
}
if desc.LastError != nil {
str += fmt.Sprintf(", Last error: %s", desc.LastError)
}
return str
}
// ServerSubscription represents a subscription to the description.Server updates for
// a specific server.
type ServerSubscription struct {
C <-chan description.Server
s *Server
id uint64
}
// Unsubscribe unsubscribes this ServerSubscription from updates and closes the
// subscription channel.
func (ss *ServerSubscription) Unsubscribe() error {
ss.s.subLock.Lock()
defer ss.s.subLock.Unlock()
if ss.s.subscriptionsClosed {
return nil
}
ch, ok := ss.s.subscribers[ss.id]
if !ok {
return nil
}
close(ch)
delete(ss.s.subscribers, ss.id)
return nil
}
// publishes a ServerOpeningEvent to indicate the server is being initialized
func (s *Server) publishServerOpeningEvent(addr address.Address) {
if s == nil {
return
}
serverOpening := &event.ServerOpeningEvent{
Address: addr,
TopologyID: s.topologyID,
}
if s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerOpening != nil {
s.cfg.serverMonitor.ServerOpening(serverOpening)
}
}
// publishes a ServerHeartbeatStartedEvent to indicate a hello command has started
func (s *Server) publishServerHeartbeatStartedEvent(connectionID string, await bool) {
serverHeartbeatStarted := &event.ServerHeartbeatStartedEvent{
ConnectionID: connectionID,
Awaited: await,
}
if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatStarted != nil {
s.cfg.serverMonitor.ServerHeartbeatStarted(serverHeartbeatStarted)
}
}
// publishes a ServerHeartbeatSucceededEvent to indicate hello has succeeded
func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string,
durationNanos int64,
desc description.Server,
await bool) {
serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{
DurationNanos: durationNanos,
Reply: desc,
ConnectionID: connectionID,
Awaited: await,
}
if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatSucceeded != nil {
s.cfg.serverMonitor.ServerHeartbeatSucceeded(serverHeartbeatSucceeded)
}
}
// publishes a ServerHeartbeatFailedEvent to indicate hello has failed
func (s *Server) publishServerHeartbeatFailedEvent(connectionID string,
durationNanos int64,
err error,
await bool) {
serverHeartbeatFailed := &event.ServerHeartbeatFailedEvent{
DurationNanos: durationNanos,
Failure: err,
ConnectionID: connectionID,
Awaited: await,
}
if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatFailed != nil {
s.cfg.serverMonitor.ServerHeartbeatFailed(serverHeartbeatFailed)
}
}
// unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error.
func unwrapConnectionError(err error) error {
// This is essentially an implementation of errors.As to unwrap this error until we get a ConnectionError and then
// return ConnectionError.Wrapped.
connErr, ok := err.(ConnectionError)
if ok {
return connErr.Wrapped
}
driverErr, ok := err.(driver.Error)
if !ok || !driverErr.NetworkError() {
return nil
}
connErr, ok = driverErr.Wrapped.(ConnectionError)
if ok {
return connErr.Wrapped
}
return nil
}

View File

@@ -0,0 +1,195 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
var defaultRegistry = bson.NewRegistryBuilder().Build()
type serverConfig struct {
clock *session.ClusterClock
compressionOpts []string
connectionOpts []ConnectionOption
appname string
heartbeatInterval time.Duration
heartbeatTimeout time.Duration
serverMonitor *event.ServerMonitor
registry *bsoncodec.Registry
monitoringDisabled bool
serverAPI *driver.ServerAPIOptions
loadBalanced bool
// Connection pool options.
maxConns uint64
minConns uint64
maxConnecting uint64
poolMonitor *event.PoolMonitor
poolMaxIdleTime time.Duration
poolMaintainInterval time.Duration
}
func newServerConfig(opts ...ServerOption) *serverConfig {
cfg := &serverConfig{
heartbeatInterval: 10 * time.Second,
heartbeatTimeout: 10 * time.Second,
registry: defaultRegistry,
}
for _, opt := range opts {
if opt == nil {
continue
}
opt(cfg)
}
return cfg
}
// ServerOption configures a server.
type ServerOption func(*serverConfig)
// ServerAPIFromServerOptions will return the server API options if they have been functionally set on the ServerOption
// slice.
func ServerAPIFromServerOptions(opts []ServerOption) *driver.ServerAPIOptions {
return newServerConfig(opts...).serverAPI
}
func withMonitoringDisabled(fn func(bool) bool) ServerOption {
return func(cfg *serverConfig) {
cfg.monitoringDisabled = fn(cfg.monitoringDisabled)
}
}
// WithConnectionOptions configures the server's connections.
func WithConnectionOptions(fn func(...ConnectionOption) []ConnectionOption) ServerOption {
return func(cfg *serverConfig) {
cfg.connectionOpts = fn(cfg.connectionOpts...)
}
}
// WithCompressionOptions configures the server's compressors.
func WithCompressionOptions(fn func(...string) []string) ServerOption {
return func(cfg *serverConfig) {
cfg.compressionOpts = fn(cfg.compressionOpts...)
}
}
// WithServerAppName configures the server's application name.
func WithServerAppName(fn func(string) string) ServerOption {
return func(cfg *serverConfig) {
cfg.appname = fn(cfg.appname)
}
}
// WithHeartbeatInterval configures a server's heartbeat interval.
func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption {
return func(cfg *serverConfig) {
cfg.heartbeatInterval = fn(cfg.heartbeatInterval)
}
}
// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to
// connection.
func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption {
return func(cfg *serverConfig) {
cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout)
}
}
// WithMaxConnections configures the maximum number of connections to allow for
// a given server. If max is 0, then maximum connection pool size is not limited.
func WithMaxConnections(fn func(uint64) uint64) ServerOption {
return func(cfg *serverConfig) {
cfg.maxConns = fn(cfg.maxConns)
}
}
// WithMinConnections configures the minimum number of connections to allow for
// a given server. If min is 0, then there is no lower limit to the number of
// connections.
func WithMinConnections(fn func(uint64) uint64) ServerOption {
return func(cfg *serverConfig) {
cfg.minConns = fn(cfg.minConns)
}
}
// WithMaxConnecting configures the maximum number of connections a connection
// pool may establish simultaneously. If maxConnecting is 0, the default value
// of 2 is used.
func WithMaxConnecting(fn func(uint64) uint64) ServerOption {
return func(cfg *serverConfig) {
cfg.maxConnecting = fn(cfg.maxConnecting)
}
}
// WithConnectionPoolMaxIdleTime configures the maximum time that a connection can remain idle in the connection pool
// before being removed. If connectionPoolMaxIdleTime is 0, then no idle time is set and connections will not be removed
// because of their age
func WithConnectionPoolMaxIdleTime(fn func(time.Duration) time.Duration) ServerOption {
return func(cfg *serverConfig) {
cfg.poolMaxIdleTime = fn(cfg.poolMaxIdleTime)
}
}
// WithConnectionPoolMaintainInterval configures the interval that the background connection pool
// maintenance goroutine runs.
func WithConnectionPoolMaintainInterval(fn func(time.Duration) time.Duration) ServerOption {
return func(cfg *serverConfig) {
cfg.poolMaintainInterval = fn(cfg.poolMaintainInterval)
}
}
// WithConnectionPoolMonitor configures the monitor for all connection pool actions
func WithConnectionPoolMonitor(fn func(*event.PoolMonitor) *event.PoolMonitor) ServerOption {
return func(cfg *serverConfig) {
cfg.poolMonitor = fn(cfg.poolMonitor)
}
}
// WithServerMonitor configures the monitor for all SDAM events for a server
func WithServerMonitor(fn func(*event.ServerMonitor) *event.ServerMonitor) ServerOption {
return func(cfg *serverConfig) {
cfg.serverMonitor = fn(cfg.serverMonitor)
}
}
// WithClock configures the ClusterClock for the server to use.
func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) ServerOption {
return func(cfg *serverConfig) {
cfg.clock = fn(cfg.clock)
}
}
// WithRegistry configures the registry for the server to use when creating
// cursors.
func WithRegistry(fn func(*bsoncodec.Registry) *bsoncodec.Registry) ServerOption {
return func(cfg *serverConfig) {
cfg.registry = fn(cfg.registry)
}
}
// WithServerAPI configures the server API options for the server to use.
func WithServerAPI(fn func(serverAPI *driver.ServerAPIOptions) *driver.ServerAPIOptions) ServerOption {
return func(cfg *serverConfig) {
cfg.serverAPI = fn(cfg.serverAPI)
}
}
// WithServerLoadBalanced specifies whether or not the server is behind a load balancer.
func WithServerLoadBalanced(fn func(bool) bool) ServerOption {
return func(cfg *serverConfig) {
cfg.loadBalanced = fn(cfg.loadBalanced)
}
}

View File

@@ -0,0 +1,58 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !go1.17
// +build !go1.17
package topology
import (
"context"
"crypto/tls"
"net"
)
type tlsConn interface {
net.Conn
// Only require Handshake on the interface for Go 1.16 and less.
Handshake() error
ConnectionState() tls.ConnectionState
}
var _ tlsConn = (*tls.Conn)(nil)
type tlsConnectionSource interface {
Client(net.Conn, *tls.Config) tlsConn
}
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn
var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
return t(nc, cfg)
}
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
return tls.Client(nc, cfg)
}
// clientHandshake will perform a handshake with a goroutine and wait for its completion on Go 1.16 and less
// when HandshakeContext is not available.
func clientHandshake(ctx context.Context, client tlsConn) error {
errChan := make(chan error, 1)
go func() {
errChan <- client.Handshake()
}()
select {
case err := <-errChan:
return err
case <-ctx.Done():
return ctx.Err()
}
}

View File

@@ -0,0 +1,47 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build go1.17
// +build go1.17
package topology
import (
"context"
"crypto/tls"
"net"
)
type tlsConn interface {
net.Conn
// Require HandshakeContext on the interface for Go 1.17 and higher.
HandshakeContext(ctx context.Context) error
ConnectionState() tls.ConnectionState
}
var _ tlsConn = (*tls.Conn)(nil)
type tlsConnectionSource interface {
Client(net.Conn, *tls.Config) tlsConn
}
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn
var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
return t(nc, cfg)
}
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
return tls.Client(nc, cfg)
}
// clientHandshake will perform a handshake on Go 1.17 and higher with HandshakeContext.
func clientHandshake(ctx context.Context, client tlsConn) error {
return client.HandshakeContext(ctx)
}

View File

@@ -0,0 +1,851 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package topology contains types that handles the discovery, monitoring, and selection
// of servers. This package is designed to expose enough inner workings of service discovery
// and monitoring to allow low level applications to have fine grained control, while hiding
// most of the detailed implementation of the algorithms.
package topology // import "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/randutil"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
)
// Topology state constants.
const (
topologyDisconnected int64 = iota
topologyDisconnecting
topologyConnected
topologyConnecting
)
// ErrSubscribeAfterClosed is returned when a user attempts to subscribe to a
// closed Server or Topology.
var ErrSubscribeAfterClosed = errors.New("cannot subscribe after closeConnection")
// ErrTopologyClosed is returned when a user attempts to call a method on a
// closed Topology.
var ErrTopologyClosed = errors.New("topology is closed")
// ErrTopologyConnected is returned whena user attempts to Connect to an
// already connected Topology.
var ErrTopologyConnected = errors.New("topology is connected or connecting")
// ErrServerSelectionTimeout is returned from server selection when the server
// selection process took longer than allowed by the timeout.
var ErrServerSelectionTimeout = errors.New("server selection timeout")
// MonitorMode represents the way in which a server is monitored.
type MonitorMode uint8
// random is a package-global pseudo-random number generator.
var random = randutil.NewLockedRand()
// These constants are the available monitoring modes.
const (
AutomaticMode MonitorMode = iota
SingleMode
)
// Topology represents a MongoDB deployment.
type Topology struct {
state int64
cfg *Config
desc atomic.Value // holds a description.Topology
dnsResolver *dns.Resolver
done chan struct{}
pollingRequired bool
pollingDone chan struct{}
pollingwg sync.WaitGroup
rescanSRVInterval time.Duration
pollHeartbeatTime atomic.Value // holds a bool
updateCallback updateTopologyCallback
fsm *fsm
// This should really be encapsulated into it's own type. This will likely
// require a redesign so we can share a minimum of data between the
// subscribers and the topology.
subscribers map[uint64]chan description.Topology
currentSubscriberID uint64
subscriptionsClosed bool
subLock sync.Mutex
// We should redesign how we Connect and handle individal servers. This is
// too difficult to maintain and it's rather easy to accidentally access
// the servers without acquiring the lock or checking if the servers are
// closed. This lock should also be an RWMutex.
serversLock sync.Mutex
serversClosed bool
servers map[address.Address]*Server
id primitive.ObjectID
}
var _ driver.Deployment = &Topology{}
var _ driver.Subscriber = &Topology{}
type serverSelectionState struct {
selector description.ServerSelector
timeoutChan <-chan time.Time
}
func newServerSelectionState(selector description.ServerSelector, timeoutChan <-chan time.Time) serverSelectionState {
return serverSelectionState{
selector: selector,
timeoutChan: timeoutChan,
}
}
// New creates a new topology. A "nil" config is interpreted as the default configuration.
func New(cfg *Config) (*Topology, error) {
if cfg == nil {
var err error
cfg, err = NewConfig(options.Client(), nil)
if err != nil {
return nil, err
}
}
t := &Topology{
cfg: cfg,
done: make(chan struct{}),
pollingDone: make(chan struct{}),
rescanSRVInterval: 60 * time.Second,
fsm: newFSM(),
subscribers: make(map[uint64]chan description.Topology),
servers: make(map[address.Address]*Server),
dnsResolver: dns.DefaultResolver,
id: primitive.NewObjectID(),
}
t.desc.Store(description.Topology{})
t.updateCallback = func(desc description.Server) description.Server {
return t.apply(context.TODO(), desc)
}
if t.cfg.URI != "" {
t.pollingRequired = strings.HasPrefix(t.cfg.URI, "mongodb+srv://") && !t.cfg.LoadBalanced
}
t.publishTopologyOpeningEvent()
return t, nil
}
// Connect initializes a Topology and starts the monitoring process. This function
// must be called to properly monitor the topology.
func (t *Topology) Connect() error {
if !atomic.CompareAndSwapInt64(&t.state, topologyDisconnected, topologyConnecting) {
return ErrTopologyConnected
}
t.desc.Store(description.Topology{})
var err error
t.serversLock.Lock()
// A replica set name sets the initial topology type to ReplicaSetNoPrimary unless a direct connection is also
// specified, in which case the initial type is Single.
if t.cfg.ReplicaSetName != "" {
t.fsm.SetName = t.cfg.ReplicaSetName
t.fsm.Kind = description.ReplicaSetNoPrimary
}
// A direct connection unconditionally sets the topology type to Single.
if t.cfg.Mode == SingleMode {
t.fsm.Kind = description.Single
}
for _, a := range t.cfg.SeedList {
addr := address.Address(a).Canonicalize()
t.fsm.Servers = append(t.fsm.Servers, description.NewDefaultServer(addr))
}
switch {
case t.cfg.LoadBalanced:
// In LoadBalanced mode, we mock a series of events: TopologyDescriptionChanged from Unknown to LoadBalanced,
// ServerDescriptionChanged from Unknown to LoadBalancer, and then TopologyDescriptionChanged to reflect the
// previous ServerDescriptionChanged event. We publish all of these events here because we don't start server
// monitoring routines in this mode, so we have to mock state changes.
// Transition from Unknown with no servers to LoadBalanced with a single Unknown server.
t.fsm.Kind = description.LoadBalanced
t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology)
addr := address.Address(t.cfg.SeedList[0]).Canonicalize()
if err := t.addServer(addr); err != nil {
t.serversLock.Unlock()
return err
}
// Transition the server from Unknown to LoadBalancer.
newServerDesc := t.servers[addr].Description()
t.publishServerDescriptionChangedEvent(t.fsm.Servers[0], newServerDesc)
// Transition from LoadBalanced with an Unknown server to LoadBalanced with a LoadBalancer.
oldDesc := t.fsm.Topology
t.fsm.Servers = []description.Server{newServerDesc}
t.desc.Store(t.fsm.Topology)
t.publishTopologyDescriptionChangedEvent(oldDesc, t.fsm.Topology)
default:
// In non-LB mode, we only publish an initial TopologyDescriptionChanged event from Unknown with no servers to
// the current state (e.g. Unknown with one or more servers if we're discovering or Single with one server if
// we're connecting directly). Other events are published when state changes occur due to responses in the
// server monitoring goroutines.
newDesc := description.Topology{
Kind: t.fsm.Kind,
Servers: t.fsm.Servers,
SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
}
t.desc.Store(newDesc)
t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology)
for _, a := range t.cfg.SeedList {
addr := address.Address(a).Canonicalize()
err = t.addServer(addr)
if err != nil {
t.serversLock.Unlock()
return err
}
}
}
t.serversLock.Unlock()
if t.pollingRequired {
uri, err := url.Parse(t.cfg.URI)
if err != nil {
return err
}
// sanity check before passing the hostname to resolver
if parsedHosts := strings.Split(uri.Host, ","); len(parsedHosts) != 1 {
return fmt.Errorf("URI with SRV must include one and only one hostname")
}
_, _, err = net.SplitHostPort(uri.Host)
if err == nil {
// we were able to successfully extract a port from the host,
// but should not be able to when using SRV
return fmt.Errorf("URI with srv must not include a port number")
}
go t.pollSRVRecords(uri.Host)
t.pollingwg.Add(1)
}
t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected
atomic.StoreInt64(&t.state, topologyConnected)
return nil
}
// Disconnect closes the topology. It stops the monitoring thread and
// closes all open subscriptions.
func (t *Topology) Disconnect(ctx context.Context) error {
if !atomic.CompareAndSwapInt64(&t.state, topologyConnected, topologyDisconnecting) {
return ErrTopologyClosed
}
servers := make(map[address.Address]*Server)
t.serversLock.Lock()
t.serversClosed = true
for addr, server := range t.servers {
servers[addr] = server
}
t.serversLock.Unlock()
for _, server := range servers {
_ = server.Disconnect(ctx)
t.publishServerClosedEvent(server.address)
}
t.subLock.Lock()
for id, ch := range t.subscribers {
close(ch)
delete(t.subscribers, id)
}
t.subscriptionsClosed = true
t.subLock.Unlock()
if t.pollingRequired {
t.pollingDone <- struct{}{}
t.pollingwg.Wait()
}
t.desc.Store(description.Topology{})
atomic.StoreInt64(&t.state, topologyDisconnected)
t.publishTopologyClosedEvent()
return nil
}
// Description returns a description of the topology.
func (t *Topology) Description() description.Topology {
td, ok := t.desc.Load().(description.Topology)
if !ok {
td = description.Topology{}
}
return td
}
// Kind returns the topology kind of this Topology.
func (t *Topology) Kind() description.TopologyKind { return t.Description().Kind }
// Subscribe returns a Subscription on which all updated description.Topologys
// will be sent. The channel of the subscription will have a buffer size of one,
// and will be pre-populated with the current description.Topology.
// Subscribe implements the driver.Subscriber interface.
func (t *Topology) Subscribe() (*driver.Subscription, error) {
if atomic.LoadInt64(&t.state) != topologyConnected {
return nil, errors.New("cannot subscribe to Topology that is not connected")
}
ch := make(chan description.Topology, 1)
td, ok := t.desc.Load().(description.Topology)
if !ok {
td = description.Topology{}
}
ch <- td
t.subLock.Lock()
defer t.subLock.Unlock()
if t.subscriptionsClosed {
return nil, ErrSubscribeAfterClosed
}
id := t.currentSubscriberID
t.subscribers[id] = ch
t.currentSubscriberID++
return &driver.Subscription{
Updates: ch,
ID: id,
}, nil
}
// Unsubscribe unsubscribes the given subscription from the topology and closes the subscription channel.
// Unsubscribe implements the driver.Subscriber interface.
func (t *Topology) Unsubscribe(sub *driver.Subscription) error {
t.subLock.Lock()
defer t.subLock.Unlock()
if t.subscriptionsClosed {
return nil
}
ch, ok := t.subscribers[sub.ID]
if !ok {
return nil
}
close(ch)
delete(t.subscribers, sub.ID)
return nil
}
// RequestImmediateCheck will send heartbeats to all the servers in the
// topology right away, instead of waiting for the heartbeat timeout.
func (t *Topology) RequestImmediateCheck() {
if atomic.LoadInt64(&t.state) != topologyConnected {
return
}
t.serversLock.Lock()
for _, server := range t.servers {
server.RequestImmediateCheck()
}
t.serversLock.Unlock()
}
// SelectServer selects a server with given a selector. SelectServer complies with the
// server selection spec, and will time out after serverSelectionTimeout or when the
// parent context is done.
func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) {
if atomic.LoadInt64(&t.state) != topologyConnected {
return nil, ErrTopologyClosed
}
var ssTimeoutCh <-chan time.Time
if t.cfg.ServerSelectionTimeout > 0 {
ssTimeout := time.NewTimer(t.cfg.ServerSelectionTimeout)
ssTimeoutCh = ssTimeout.C
defer ssTimeout.Stop()
}
var doneOnce bool
var sub *driver.Subscription
selectionState := newServerSelectionState(ss, ssTimeoutCh)
for {
var suitable []description.Server
var selectErr error
if !doneOnce {
// for the first pass, select a server from the current description.
// this improves selection speed for up-to-date topology descriptions.
suitable, selectErr = t.selectServerFromDescription(t.Description(), selectionState)
doneOnce = true
} else {
// if the first pass didn't select a server, the previous description did not contain a suitable server, so
// we subscribe to the topology and attempt to obtain a server from that subscription
if sub == nil {
var err error
sub, err = t.Subscribe()
if err != nil {
return nil, err
}
defer t.Unsubscribe(sub)
}
suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, selectionState)
}
if selectErr != nil {
return nil, selectErr
}
if len(suitable) == 0 {
// try again if there are no servers available
continue
}
// If there's only one suitable server description, try to find the associated server and
// return it. This is an optimization primarily for standalone and load-balanced deployments.
if len(suitable) == 1 {
server, err := t.FindServer(suitable[0])
if err != nil {
return nil, err
}
if server == nil {
continue
}
return server, nil
}
// Randomly select 2 suitable server descriptions and find servers for them. We select two
// so we can pick the one with the one with fewer in-progress operations below.
desc1, desc2 := pick2(suitable)
server1, err := t.FindServer(desc1)
if err != nil {
return nil, err
}
server2, err := t.FindServer(desc2)
if err != nil {
return nil, err
}
// If we don't have an actual server for one or both of the provided descriptions, either
// return the one server we have, or try again if they're both nil. This could happen for a
// number of reasons, including that the server has since stopped being a part of this
// topology.
if server1 == nil || server2 == nil {
if server1 == nil && server2 == nil {
continue
}
if server1 != nil {
return server1, nil
}
return server2, nil
}
// Of the two randomly selected suitable servers, pick the one with fewer in-use connections.
// We use in-use connections as an analog for in-progress operations because they are almost
// always the same value for a given server.
if server1.OperationCount() < server2.OperationCount() {
return server1, nil
}
return server2, nil
}
}
// pick2 returns 2 random server descriptions from the input slice of server descriptions,
// guaranteeing that the same element from the slice is not picked twice. The order of server
// descriptions in the input slice may be modified. If fewer than 2 server descriptions are
// provided, pick2 will panic.
func pick2(ds []description.Server) (description.Server, description.Server) {
// Select a random index from the input slice and keep the server description from that index.
idx := random.Intn(len(ds))
s1 := ds[idx]
// Swap the selected index to the end and reslice to remove it so we don't pick the same server
// description twice.
ds[idx], ds[len(ds)-1] = ds[len(ds)-1], ds[idx]
ds = ds[:len(ds)-1]
// Select another random index from the input slice and return both selected server descriptions.
return s1, ds[random.Intn(len(ds))]
}
// FindServer will attempt to find a server that fits the given server description.
// This method will return nil, nil if a matching server could not be found.
func (t *Topology) FindServer(selected description.Server) (*SelectedServer, error) {
if atomic.LoadInt64(&t.state) != topologyConnected {
return nil, ErrTopologyClosed
}
t.serversLock.Lock()
defer t.serversLock.Unlock()
server, ok := t.servers[selected.Addr]
if !ok {
return nil, nil
}
desc := t.Description()
return &SelectedServer{
Server: server,
Kind: desc.Kind,
}, nil
}
// selectServerFromSubscription loops until a topology description is available for server selection. It returns
// when the given context expires, server selection timeout is reached, or a description containing a selectable
// server is available.
func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptionCh <-chan description.Topology,
selectionState serverSelectionState) ([]description.Server, error) {
current := t.Description()
for {
select {
case <-ctx.Done():
return nil, ServerSelectionError{Wrapped: ctx.Err(), Desc: current}
case <-selectionState.timeoutChan:
return nil, ServerSelectionError{Wrapped: ErrServerSelectionTimeout, Desc: current}
case current = <-subscriptionCh:
}
suitable, err := t.selectServerFromDescription(current, selectionState)
if err != nil {
return nil, err
}
if len(suitable) > 0 {
return suitable, nil
}
t.RequestImmediateCheck()
}
}
// selectServerFromDescription process the given topology description and returns a slice of suitable servers.
func (t *Topology) selectServerFromDescription(desc description.Topology,
selectionState serverSelectionState) ([]description.Server, error) {
// Unlike selectServerFromSubscription, this code path does not check ctx.Done or selectionState.timeoutChan because
// selecting a server from a description is not a blocking operation.
if desc.CompatibilityErr != nil {
return nil, desc.CompatibilityErr
}
// If the topology kind is LoadBalanced, the LB is the only server and it is always considered selectable. The
// selectors exported by the driver should already return the LB as a candidate, so this but this check ensures that
// the LB is always selectable even if a user of the low-level driver provides a custom selector.
if desc.Kind == description.LoadBalanced {
return desc.Servers, nil
}
var allowed []description.Server
for _, s := range desc.Servers {
if s.Kind != description.Unknown {
allowed = append(allowed, s)
}
}
suitable, err := selectionState.selector.SelectServer(desc, allowed)
if err != nil {
return nil, ServerSelectionError{Wrapped: err, Desc: desc}
}
return suitable, nil
}
func (t *Topology) pollSRVRecords(hosts string) {
defer t.pollingwg.Done()
serverConfig := newServerConfig(t.cfg.ServerOpts...)
heartbeatInterval := serverConfig.heartbeatInterval
pollTicker := time.NewTicker(t.rescanSRVInterval)
defer pollTicker.Stop()
t.pollHeartbeatTime.Store(false)
var doneOnce bool
defer func() {
// ¯\_(ツ)_/¯
if r := recover(); r != nil && !doneOnce {
<-t.pollingDone
}
}()
for {
select {
case <-pollTicker.C:
case <-t.pollingDone:
doneOnce = true
return
}
topoKind := t.Description().Kind
if !(topoKind == description.Unknown || topoKind == description.Sharded) {
break
}
parsedHosts, err := t.dnsResolver.ParseHosts(hosts, t.cfg.SRVServiceName, false)
// DNS problem or no verified hosts returned
if err != nil || len(parsedHosts) == 0 {
if !t.pollHeartbeatTime.Load().(bool) {
pollTicker.Stop()
pollTicker = time.NewTicker(heartbeatInterval)
t.pollHeartbeatTime.Store(true)
}
continue
}
if t.pollHeartbeatTime.Load().(bool) {
pollTicker.Stop()
pollTicker = time.NewTicker(t.rescanSRVInterval)
t.pollHeartbeatTime.Store(false)
}
cont := t.processSRVResults(parsedHosts)
if !cont {
break
}
}
<-t.pollingDone
doneOnce = true
}
func (t *Topology) processSRVResults(parsedHosts []string) bool {
t.serversLock.Lock()
defer t.serversLock.Unlock()
if t.serversClosed {
return false
}
prev := t.fsm.Topology
diff := diffHostList(t.fsm.Topology, parsedHosts)
if len(diff.Added) == 0 && len(diff.Removed) == 0 {
return true
}
for _, r := range diff.Removed {
addr := address.Address(r).Canonicalize()
s, ok := t.servers[addr]
if !ok {
continue
}
go func() {
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
_ = s.Disconnect(cancelCtx)
}()
delete(t.servers, addr)
t.fsm.removeServerByAddr(addr)
t.publishServerClosedEvent(s.address)
}
// Now that we've removed all the hosts that disappeared from the SRV record, we need to add any
// new hosts added to the SRV record. If adding all of the new hosts would increase the number
// of servers past srvMaxHosts, shuffle the list of added hosts.
if t.cfg.SRVMaxHosts > 0 && len(t.servers)+len(diff.Added) > t.cfg.SRVMaxHosts {
random.Shuffle(len(diff.Added), func(i, j int) {
diff.Added[i], diff.Added[j] = diff.Added[j], diff.Added[i]
})
}
// Add all added hosts until the number of servers reaches srvMaxHosts.
for _, a := range diff.Added {
if t.cfg.SRVMaxHosts > 0 && len(t.servers) >= t.cfg.SRVMaxHosts {
break
}
addr := address.Address(a).Canonicalize()
_ = t.addServer(addr)
t.fsm.addServer(addr)
}
//store new description
newDesc := description.Topology{
Kind: t.fsm.Kind,
Servers: t.fsm.Servers,
SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
}
t.desc.Store(newDesc)
if !prev.Equal(newDesc) {
t.publishTopologyDescriptionChangedEvent(prev, newDesc)
}
t.subLock.Lock()
for _, ch := range t.subscribers {
// We drain the description if there's one in the channel
select {
case <-ch:
default:
}
ch <- newDesc
}
t.subLock.Unlock()
return true
}
// apply updates the Topology and its underlying FSM based on the provided server description and returns the server
// description that should be stored.
func (t *Topology) apply(ctx context.Context, desc description.Server) description.Server {
t.serversLock.Lock()
defer t.serversLock.Unlock()
ind, ok := t.fsm.findServer(desc.Addr)
if t.serversClosed || !ok {
return desc
}
prev := t.fsm.Topology
oldDesc := t.fsm.Servers[ind]
if oldDesc.TopologyVersion.CompareToIncoming(desc.TopologyVersion) > 0 {
return oldDesc
}
var current description.Topology
current, desc = t.fsm.apply(desc)
if !oldDesc.Equal(desc) {
t.publishServerDescriptionChangedEvent(oldDesc, desc)
}
diff := diffTopology(prev, current)
for _, removed := range diff.Removed {
if s, ok := t.servers[removed.Addr]; ok {
go func() {
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
_ = s.Disconnect(cancelCtx)
}()
delete(t.servers, removed.Addr)
t.publishServerClosedEvent(s.address)
}
}
for _, added := range diff.Added {
_ = t.addServer(added.Addr)
}
t.desc.Store(current)
if !prev.Equal(current) {
t.publishTopologyDescriptionChangedEvent(prev, current)
}
t.subLock.Lock()
for _, ch := range t.subscribers {
// We drain the description if there's one in the channel
select {
case <-ch:
default:
}
ch <- current
}
t.subLock.Unlock()
return desc
}
func (t *Topology) addServer(addr address.Address) error {
if _, ok := t.servers[addr]; ok {
return nil
}
svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.ServerOpts...)
if err != nil {
return err
}
t.servers[addr] = svr
return nil
}
// String implements the Stringer interface
func (t *Topology) String() string {
desc := t.Description()
serversStr := ""
t.serversLock.Lock()
defer t.serversLock.Unlock()
for _, s := range t.servers {
serversStr += "{ " + s.String() + " }, "
}
return fmt.Sprintf("Type: %s, Servers: [%s]", desc.Kind, serversStr)
}
// publishes a ServerDescriptionChangedEvent to indicate the server description has changed
func (t *Topology) publishServerDescriptionChangedEvent(prev description.Server, current description.Server) {
serverDescriptionChanged := &event.ServerDescriptionChangedEvent{
Address: current.Addr,
TopologyID: t.id,
PreviousDescription: prev,
NewDescription: current,
}
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.ServerDescriptionChanged != nil {
t.cfg.ServerMonitor.ServerDescriptionChanged(serverDescriptionChanged)
}
}
// publishes a ServerClosedEvent to indicate the server has closed
func (t *Topology) publishServerClosedEvent(addr address.Address) {
serverClosed := &event.ServerClosedEvent{
Address: addr,
TopologyID: t.id,
}
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.ServerClosed != nil {
t.cfg.ServerMonitor.ServerClosed(serverClosed)
}
}
// publishes a TopologyDescriptionChangedEvent to indicate the topology description has changed
func (t *Topology) publishTopologyDescriptionChangedEvent(prev description.Topology, current description.Topology) {
topologyDescriptionChanged := &event.TopologyDescriptionChangedEvent{
TopologyID: t.id,
PreviousDescription: prev,
NewDescription: current,
}
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyDescriptionChanged != nil {
t.cfg.ServerMonitor.TopologyDescriptionChanged(topologyDescriptionChanged)
}
}
// publishes a TopologyOpeningEvent to indicate the topology is being initialized
func (t *Topology) publishTopologyOpeningEvent() {
topologyOpening := &event.TopologyOpeningEvent{
TopologyID: t.id,
}
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyOpening != nil {
t.cfg.ServerMonitor.TopologyOpening(topologyOpening)
}
}
// publishes a TopologyClosedEvent to indicate the topology has been closed
func (t *Topology) publishTopologyClosedEvent() {
topologyClosed := &event.TopologyClosedEvent{
TopologyID: t.id,
}
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyClosed != nil {
t.cfg.ServerMonitor.TopologyClosed(topologyClosed)
}
}

View File

@@ -0,0 +1,344 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"crypto/tls"
"net/http"
"strings"
"time"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
const defaultServerSelectionTimeout = 30 * time.Second
// Config is used to construct a topology.
type Config struct {
Mode MonitorMode
ReplicaSetName string
SeedList []string
ServerOpts []ServerOption
URI string
ServerSelectionTimeout time.Duration
ServerMonitor *event.ServerMonitor
SRVMaxHosts int
SRVServiceName string
LoadBalanced bool
}
// ConvertToDriverAPIOptions converts a options.ServerAPIOptions instance to a driver.ServerAPIOptions.
func ConvertToDriverAPIOptions(s *options.ServerAPIOptions) *driver.ServerAPIOptions {
driverOpts := driver.NewServerAPIOptions(string(s.ServerAPIVersion))
if s.Strict != nil {
driverOpts.SetStrict(*s.Strict)
}
if s.DeprecationErrors != nil {
driverOpts.SetDeprecationErrors(*s.DeprecationErrors)
}
return driverOpts
}
// NewConfig will translate data from client options into a topology config for building non-default deployments.
// Server and topoplogy options are not honored if a custom deployment is used.
func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) {
var serverAPI *driver.ServerAPIOptions
if err := co.Validate(); err != nil {
return nil, err
}
var connOpts []ConnectionOption
var serverOpts []ServerOption
cfgp := new(Config)
// Set the default "ServerSelectionTimeout" to 30 seconds.
cfgp.ServerSelectionTimeout = defaultServerSelectionTimeout
// Set the default "SeedList" to localhost.
cfgp.SeedList = []string{"localhost:27017"}
// TODO(GODRIVER-814): Add tests for topology, server, and connection related options.
// ServerAPIOptions need to be handled early as other client and server options below reference
// c.serverAPI and serverOpts.serverAPI.
if co.ServerAPIOptions != nil {
serverAPI = ConvertToDriverAPIOptions(co.ServerAPIOptions)
serverOpts = append(serverOpts, WithServerAPI(func(*driver.ServerAPIOptions) *driver.ServerAPIOptions {
return serverAPI
}))
}
cfgp.URI = co.GetURI()
if co.SRVServiceName != nil {
cfgp.SRVServiceName = *co.SRVServiceName
}
if co.SRVMaxHosts != nil {
cfgp.SRVMaxHosts = *co.SRVMaxHosts
}
// AppName
var appName string
if co.AppName != nil {
appName = *co.AppName
serverOpts = append(serverOpts, WithServerAppName(func(string) string {
return appName
}))
}
// Compressors & ZlibLevel
var comps []string
if len(co.Compressors) > 0 {
comps = co.Compressors
connOpts = append(connOpts, WithCompressors(
func(compressors []string) []string {
return append(compressors, comps...)
},
))
for _, comp := range comps {
switch comp {
case "zlib":
connOpts = append(connOpts, WithZlibLevel(func(level *int) *int {
return co.ZlibLevel
}))
case "zstd":
connOpts = append(connOpts, WithZstdLevel(func(level *int) *int {
return co.ZstdLevel
}))
}
}
serverOpts = append(serverOpts, WithCompressionOptions(
func(opts ...string) []string { return append(opts, comps...) },
))
}
var loadBalanced bool
if co.LoadBalanced != nil {
loadBalanced = *co.LoadBalanced
}
// Handshaker
var handshaker = func(driver.Handshaker) driver.Handshaker {
return operation.NewHello().AppName(appName).Compressors(comps).ClusterClock(clock).
ServerAPI(serverAPI).LoadBalanced(loadBalanced)
}
// Auth & Database & Password & Username
if co.Auth != nil {
cred := &auth.Cred{
Username: co.Auth.Username,
Password: co.Auth.Password,
PasswordSet: co.Auth.PasswordSet,
Props: co.Auth.AuthMechanismProperties,
Source: co.Auth.AuthSource,
}
mechanism := co.Auth.AuthMechanism
if len(cred.Source) == 0 {
switch strings.ToUpper(mechanism) {
case auth.MongoDBX509, auth.GSSAPI, auth.PLAIN:
cred.Source = "$external"
default:
cred.Source = "admin"
}
}
authenticator, err := auth.CreateAuthenticator(mechanism, cred)
if err != nil {
return nil, err
}
handshakeOpts := &auth.HandshakeOptions{
AppName: appName,
Authenticator: authenticator,
Compressors: comps,
ServerAPI: serverAPI,
LoadBalanced: loadBalanced,
ClusterClock: clock,
HTTPClient: co.HTTPClient,
}
if mechanism == "" {
// Required for SASL mechanism negotiation during handshake
handshakeOpts.DBUser = cred.Source + "." + cred.Username
}
if co.AuthenticateToAnything != nil && *co.AuthenticateToAnything {
// Authenticate arbiters
handshakeOpts.PerformAuthentication = func(serv description.Server) bool {
return true
}
}
handshaker = func(driver.Handshaker) driver.Handshaker {
return auth.Handshaker(nil, handshakeOpts)
}
}
connOpts = append(connOpts, WithHandshaker(handshaker))
// ConnectTimeout
if co.ConnectTimeout != nil {
serverOpts = append(serverOpts, WithHeartbeatTimeout(
func(time.Duration) time.Duration { return *co.ConnectTimeout },
))
connOpts = append(connOpts, WithConnectTimeout(
func(time.Duration) time.Duration { return *co.ConnectTimeout },
))
}
// Dialer
if co.Dialer != nil {
connOpts = append(connOpts, WithDialer(
func(Dialer) Dialer { return co.Dialer },
))
}
// Direct
if co.Direct != nil && *co.Direct {
cfgp.Mode = SingleMode
}
// HeartbeatInterval
if co.HeartbeatInterval != nil {
serverOpts = append(serverOpts, WithHeartbeatInterval(
func(time.Duration) time.Duration { return *co.HeartbeatInterval },
))
}
// Hosts
cfgp.SeedList = []string{"localhost:27017"} // default host
if len(co.Hosts) > 0 {
cfgp.SeedList = co.Hosts
}
// MaxConIdleTime
if co.MaxConnIdleTime != nil {
connOpts = append(connOpts, WithIdleTimeout(
func(time.Duration) time.Duration { return *co.MaxConnIdleTime },
))
}
// MaxPoolSize
if co.MaxPoolSize != nil {
serverOpts = append(
serverOpts,
WithMaxConnections(func(uint64) uint64 { return *co.MaxPoolSize }),
)
}
// MinPoolSize
if co.MinPoolSize != nil {
serverOpts = append(
serverOpts,
WithMinConnections(func(uint64) uint64 { return *co.MinPoolSize }),
)
}
// MaxConnecting
if co.MaxConnecting != nil {
serverOpts = append(
serverOpts,
WithMaxConnecting(func(uint64) uint64 { return *co.MaxConnecting }),
)
}
// PoolMonitor
if co.PoolMonitor != nil {
serverOpts = append(
serverOpts,
WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return co.PoolMonitor }),
)
}
// Monitor
if co.Monitor != nil {
connOpts = append(connOpts, WithMonitor(
func(*event.CommandMonitor) *event.CommandMonitor { return co.Monitor },
))
}
// ServerMonitor
if co.ServerMonitor != nil {
serverOpts = append(
serverOpts,
WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return co.ServerMonitor }),
)
cfgp.ServerMonitor = co.ServerMonitor
}
// ReplicaSet
if co.ReplicaSet != nil {
cfgp.ReplicaSetName = *co.ReplicaSet
}
// ServerSelectionTimeout
if co.ServerSelectionTimeout != nil {
cfgp.ServerSelectionTimeout = *co.ServerSelectionTimeout
}
// SocketTimeout
if co.SocketTimeout != nil {
connOpts = append(
connOpts,
WithReadTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }),
WithWriteTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }),
)
}
// TLSConfig
if co.TLSConfig != nil {
connOpts = append(connOpts, WithTLSConfig(
func(*tls.Config) *tls.Config {
return co.TLSConfig
},
))
}
// HTTP Client
if co.HTTPClient != nil {
connOpts = append(connOpts, WithHTTPClient(
func(*http.Client) *http.Client {
return co.HTTPClient
},
))
}
// OCSP cache
ocspCache := ocsp.NewCache()
connOpts = append(
connOpts,
WithOCSPCache(func(ocsp.Cache) ocsp.Cache { return ocspCache }),
)
// Disable communication with external OCSP responders.
if co.DisableOCSPEndpointCheck != nil {
connOpts = append(
connOpts,
WithDisableOCSPEndpointCheck(func(bool) bool { return *co.DisableOCSPEndpointCheck }),
)
}
// LoadBalanced
if co.LoadBalanced != nil {
cfgp.LoadBalanced = *co.LoadBalanced
serverOpts = append(
serverOpts,
WithServerLoadBalanced(func(bool) bool { return *co.LoadBalanced }),
)
connOpts = append(
connOpts,
WithConnectionLoadBalanced(func(bool) bool { return *co.LoadBalanced }),
)
}
serverOpts = append(
serverOpts,
WithClock(func(*session.ClusterClock) *session.ClusterClock { return clock }),
WithConnectionOptions(func(...ConnectionOption) []ConnectionOption { return connOpts }))
cfgp.ServerOpts = serverOpts
return cfgp, nil
}