go mod vendor
+ move k8s.io/apimachinery fork from go.work to go.mod (and include it in vendor)
This commit is contained in:
40
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/DESIGN.md
generated
vendored
Normal file
40
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/DESIGN.md
generated
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
# Topology Package Design
|
||||
This document outlines the design for this package.
|
||||
|
||||
## Topology
|
||||
The `Topology` type handles monitoring the state of a MongoDB deployment and selecting servers.
|
||||
Updating the description is handled by finite state machine which implements the server discovery
|
||||
and monitoring specification. A `Topology` can be connected and fully disconnected, which enables
|
||||
saving resources. The `Topology` type also handles server selection following the server selection
|
||||
specification.
|
||||
|
||||
## Server
|
||||
The `Server` type handles heartbeating a MongoDB server and holds a pool of connections.
|
||||
|
||||
## Connection
|
||||
Connections are handled by two main types and an auxiliary type. The two main types are `connection`
|
||||
and `Connection`. The first holds most of the logic required to actually read and write wire
|
||||
messages. Instances can be created with the `newConnection` method. Inside the `newConnection`
|
||||
method the auxiliary type, `initConnection` is used to perform the connection handshake. This is
|
||||
required because the `connection` type does not fully implement `driver.Connection` which is
|
||||
required during handshaking. The `Connection` type is what is actually returned to a consumer of the
|
||||
`topology` package. This type does implement the `driver.Connection` type, holds a reference to a
|
||||
`connection` instance, and exists mainly to prevent accidental continued usage of a connection after
|
||||
closing it.
|
||||
|
||||
The connection implementations in this package are conduits for wire messages but they have no
|
||||
ability to encode, decode, or validate wire messages. That must be handled by consumers.
|
||||
|
||||
## Pool
|
||||
The `pool` type implements a connection pool. It handles caching idle connections and dialing
|
||||
new ones, but it does not track a maximum number of connections. That is the responsibility of a
|
||||
wrapping type, like `Server`.
|
||||
|
||||
The `pool` type has no concept of closing, instead it has concepts of connecting and disconnecting.
|
||||
This allows a `Topology` to be disconnected,but keeping the memory around to be reconnected later.
|
||||
There is a `close` method, but this is used to close a connection.
|
||||
|
||||
There are three methods related to getting and putting connections: `get`, `close`, and `put`. The
|
||||
`get` method will either retrieve a connection from the cache or it will dial a new `connection`.
|
||||
The `close` method will close the underlying socket of a `connection`. The `put` method will put a
|
||||
connection into the pool, placing it in the cahce if there is space, otherwise it will close it.
|
14
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/cancellation_listener.go
generated
vendored
Normal file
14
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/cancellation_listener.go
generated
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import "context"
|
||||
|
||||
type cancellationListener interface {
|
||||
Listen(context.Context, func())
|
||||
StopListening() bool
|
||||
}
|
825
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
generated
vendored
Normal file
825
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
generated
vendored
Normal file
@@ -0,0 +1,825 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
||||
)
|
||||
|
||||
// Connection state constants.
|
||||
const (
|
||||
connDisconnected int64 = iota
|
||||
connConnected
|
||||
connInitialized
|
||||
)
|
||||
|
||||
var globalConnectionID uint64 = 1
|
||||
|
||||
var (
|
||||
defaultMaxMessageSize uint32 = 48000000
|
||||
errResponseTooLarge = errors.New("length of read message too large")
|
||||
errLoadBalancedStateMismatch = errors.New("driver attempted to initialize in load balancing mode, but the server does not support this mode")
|
||||
)
|
||||
|
||||
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
|
||||
|
||||
type connection struct {
|
||||
// state must be accessed using the atomic package and should be at the beginning of the struct.
|
||||
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
|
||||
// - suggested layout: https://go101.org/article/memory-layout.html
|
||||
state int64
|
||||
|
||||
id string
|
||||
nc net.Conn // When nil, the connection is closed.
|
||||
addr address.Address
|
||||
idleTimeout time.Duration
|
||||
idleDeadline atomic.Value // Stores a time.Time
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
desc description.Server
|
||||
helloRTT time.Duration
|
||||
compressor wiremessage.CompressorID
|
||||
zliblevel int
|
||||
zstdLevel int
|
||||
connectDone chan struct{}
|
||||
config *connectionConfig
|
||||
cancelConnectContext context.CancelFunc
|
||||
connectContextMade chan struct{}
|
||||
canStream bool
|
||||
currentlyStreaming bool
|
||||
connectContextMutex sync.Mutex
|
||||
cancellationListener cancellationListener
|
||||
serverConnectionID *int32 // the server's ID for this client's connection
|
||||
|
||||
// pool related fields
|
||||
pool *pool
|
||||
poolID uint64
|
||||
generation uint64
|
||||
}
|
||||
|
||||
// newConnection handles the creation of a connection. It does not connect the connection.
|
||||
func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
|
||||
cfg := newConnectionConfig(opts...)
|
||||
|
||||
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
|
||||
|
||||
c := &connection{
|
||||
id: id,
|
||||
addr: addr,
|
||||
idleTimeout: cfg.idleTimeout,
|
||||
readTimeout: cfg.readTimeout,
|
||||
writeTimeout: cfg.writeTimeout,
|
||||
connectDone: make(chan struct{}),
|
||||
config: cfg,
|
||||
connectContextMade: make(chan struct{}),
|
||||
cancellationListener: internal.NewCancellationListener(),
|
||||
}
|
||||
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
|
||||
// at any point during connection establishment can be processed without the connection being considered stale.
|
||||
if !c.config.loadBalanced {
|
||||
c.setGenerationNumber()
|
||||
}
|
||||
atomic.StoreInt64(&c.state, connInitialized)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
|
||||
// configuration.
|
||||
func (c *connection) setGenerationNumber() {
|
||||
if c.config.getGenerationFn != nil {
|
||||
c.generation = c.config.getGenerationFn(c.desc.ServiceID)
|
||||
}
|
||||
}
|
||||
|
||||
// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
|
||||
// generationNumberFn provided via the connection options has been called exactly once.
|
||||
func (c *connection) hasGenerationNumber() bool {
|
||||
if !c.config.loadBalanced {
|
||||
// The generation is known for all non-LB clusters once the connection object has been created.
|
||||
return true
|
||||
}
|
||||
|
||||
// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
|
||||
// description has been updated to reflect that it's behind an LB.
|
||||
return c.desc.LoadBalanced()
|
||||
}
|
||||
|
||||
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
|
||||
// handshakes. All errors returned by connect are considered "before the handshake completes" and
|
||||
// must be handled by calling the appropriate SDAM handshake error handler.
|
||||
func (c *connection) connect(ctx context.Context) (err error) {
|
||||
if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer close(c.connectDone)
|
||||
|
||||
// If connect returns an error, set the connection status as disconnected and close the
|
||||
// underlying net.Conn if it was created.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
atomic.StoreInt64(&c.state, connDisconnected)
|
||||
|
||||
if c.nc != nil {
|
||||
_ = c.nc.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
|
||||
//
|
||||
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
|
||||
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
|
||||
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
|
||||
//
|
||||
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
|
||||
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
|
||||
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
|
||||
// holding the lock longer than necessary.
|
||||
c.connectContextMutex.Lock()
|
||||
var handshakeCtx context.Context
|
||||
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
|
||||
c.connectContextMutex.Unlock()
|
||||
|
||||
dialCtx := handshakeCtx
|
||||
var dialCancel context.CancelFunc
|
||||
if c.config.connectTimeout != 0 {
|
||||
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
|
||||
defer dialCancel()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
var cancelFn context.CancelFunc
|
||||
|
||||
c.connectContextMutex.Lock()
|
||||
cancelFn = c.cancelConnectContext
|
||||
c.cancelConnectContext = nil
|
||||
c.connectContextMutex.Unlock()
|
||||
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}()
|
||||
|
||||
close(c.connectContextMade)
|
||||
|
||||
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
|
||||
tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
|
||||
if err != nil {
|
||||
return ConnectionError{Wrapped: err, init: true}
|
||||
}
|
||||
c.nc = tempNc
|
||||
|
||||
if c.config.tlsConfig != nil {
|
||||
tlsConfig := c.config.tlsConfig.Clone()
|
||||
|
||||
// store the result of configureTLS in a separate variable than c.nc to avoid overwriting c.nc with nil in
|
||||
// error cases.
|
||||
ocspOpts := &ocsp.VerifyOptions{
|
||||
Cache: c.config.ocspCache,
|
||||
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
|
||||
HTTPClient: c.config.httpClient,
|
||||
}
|
||||
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
|
||||
if err != nil {
|
||||
return ConnectionError{Wrapped: err, init: true}
|
||||
}
|
||||
c.nc = tlsNc
|
||||
}
|
||||
|
||||
// running hello and authentication is handled by a handshaker on the configuration instance.
|
||||
handshaker := c.config.handshaker
|
||||
if handshaker == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var handshakeInfo driver.HandshakeInformation
|
||||
handshakeStartTime := time.Now()
|
||||
handshakeConn := initConnection{c}
|
||||
handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
|
||||
if err == nil {
|
||||
// We only need to retain the Description field as the connection's description. The authentication-related
|
||||
// fields in handshakeInfo are tracked by the handshaker if necessary.
|
||||
c.desc = handshakeInfo.Description
|
||||
c.serverConnectionID = handshakeInfo.ServerConnectionID
|
||||
c.helloRTT = time.Since(handshakeStartTime)
|
||||
|
||||
// If the application has indicated that the cluster is load balanced, ensure the server has included serviceId
|
||||
// in its handshake response to signal that it knows it's behind an LB as well.
|
||||
if c.config.loadBalanced && c.desc.ServiceID == nil {
|
||||
err = errLoadBalancedStateMismatch
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
// For load-balanced connections, the generation number depends on the service ID, which isn't known until the
|
||||
// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
|
||||
// number unless GetHandshakeInformation succeeds.
|
||||
if c.config.loadBalanced {
|
||||
c.setGenerationNumber()
|
||||
}
|
||||
|
||||
// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
|
||||
// the handshake.
|
||||
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
|
||||
}
|
||||
|
||||
// We have a failed handshake here
|
||||
if err != nil {
|
||||
return ConnectionError{Wrapped: err, init: true}
|
||||
}
|
||||
|
||||
if len(c.desc.Compression) > 0 {
|
||||
clientMethodLoop:
|
||||
for _, method := range c.config.compressors {
|
||||
for _, serverMethod := range c.desc.Compression {
|
||||
if method != serverMethod {
|
||||
continue
|
||||
}
|
||||
|
||||
switch strings.ToLower(method) {
|
||||
case "snappy":
|
||||
c.compressor = wiremessage.CompressorSnappy
|
||||
case "zlib":
|
||||
c.compressor = wiremessage.CompressorZLib
|
||||
c.zliblevel = wiremessage.DefaultZlibLevel
|
||||
if c.config.zlibLevel != nil {
|
||||
c.zliblevel = *c.config.zlibLevel
|
||||
}
|
||||
case "zstd":
|
||||
c.compressor = wiremessage.CompressorZstd
|
||||
c.zstdLevel = wiremessage.DefaultZstdLevel
|
||||
if c.config.zstdLevel != nil {
|
||||
c.zstdLevel = *c.config.zstdLevel
|
||||
}
|
||||
}
|
||||
break clientMethodLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connection) wait() {
|
||||
if c.connectDone != nil {
|
||||
<-c.connectDone
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connection) closeConnectContext() {
|
||||
<-c.connectContextMade
|
||||
var cancelFn context.CancelFunc
|
||||
|
||||
c.connectContextMutex.Lock()
|
||||
cancelFn = c.cancelConnectContext
|
||||
c.cancelConnectContext = nil
|
||||
c.connectContextMutex.Unlock()
|
||||
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
|
||||
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
|
||||
if originalError == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If there was an error and the context was cancelled, we assume it happened due to the cancellation.
|
||||
if ctx.Err() == context.Canceled {
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
// If there was a timeout error and the context deadline was used, we convert the error into
|
||||
// context.DeadlineExceeded.
|
||||
if !contextDeadlineUsed {
|
||||
return originalError
|
||||
}
|
||||
if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
return originalError
|
||||
}
|
||||
|
||||
func (c *connection) cancellationListenerCallback() {
|
||||
_ = c.close()
|
||||
}
|
||||
|
||||
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
|
||||
var err error
|
||||
if atomic.LoadInt64(&c.state) != connConnected {
|
||||
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
|
||||
}
|
||||
|
||||
var deadline time.Time
|
||||
if c.writeTimeout != 0 {
|
||||
deadline = time.Now().Add(c.writeTimeout)
|
||||
}
|
||||
|
||||
var contextDeadlineUsed bool
|
||||
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
|
||||
contextDeadlineUsed = true
|
||||
deadline = dl
|
||||
}
|
||||
|
||||
if err := c.nc.SetWriteDeadline(deadline); err != nil {
|
||||
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
|
||||
}
|
||||
|
||||
err = c.write(ctx, wm)
|
||||
if err != nil {
|
||||
c.close()
|
||||
return ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
|
||||
message: "unable to write wire message to network",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connection) write(ctx context.Context, wm []byte) (err error) {
|
||||
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
|
||||
defer func() {
|
||||
// There is a race condition between Write and StopListening. If the context is cancelled after c.nc.Write
|
||||
// succeeds, the cancellation listener could fire and close the connection. In this case, the connection has
|
||||
// been invalidated but the error is nil. To account for this, overwrite the error to context.Cancelled if
|
||||
// the abortedForCancellation flag was set.
|
||||
|
||||
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
|
||||
err = context.Canceled
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = c.nc.Write(wm)
|
||||
return err
|
||||
}
|
||||
|
||||
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
|
||||
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
|
||||
if atomic.LoadInt64(&c.state) != connConnected {
|
||||
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
|
||||
}
|
||||
|
||||
var deadline time.Time
|
||||
if c.readTimeout != 0 {
|
||||
deadline = time.Now().Add(c.readTimeout)
|
||||
}
|
||||
|
||||
var contextDeadlineUsed bool
|
||||
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
|
||||
contextDeadlineUsed = true
|
||||
deadline = dl
|
||||
}
|
||||
|
||||
if err := c.nc.SetReadDeadline(deadline); err != nil {
|
||||
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
|
||||
}
|
||||
|
||||
dst, errMsg, err := c.read(ctx, dst)
|
||||
if err != nil {
|
||||
// We closeConnection the connection because we don't know if there are other bytes left to read.
|
||||
c.close()
|
||||
message := errMsg
|
||||
if err == io.EOF {
|
||||
message = "socket was unexpectedly closed"
|
||||
}
|
||||
return dst, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
|
||||
message: message,
|
||||
}
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func (c *connection) read(ctx context.Context, dst []byte) (bytesRead []byte, errMsg string, err error) {
|
||||
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
|
||||
defer func() {
|
||||
// If the context is cancelled after we finish reading the server response, the cancellation listener could fire
|
||||
// even though the socket reads succeed. To account for this, we overwrite err to be context.Canceled if the
|
||||
// abortedForCancellation flag is set.
|
||||
|
||||
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
|
||||
errMsg = "unable to read server response"
|
||||
err = context.Canceled
|
||||
}
|
||||
}()
|
||||
|
||||
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
|
||||
// reslice dst once instead of twice.
|
||||
var sizeBuf [4]byte
|
||||
|
||||
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
|
||||
// because there might be more than one wire message waiting to be read, for example when
|
||||
// reading messages from an exhaust cursor.
|
||||
_, err = io.ReadFull(c.nc, sizeBuf[:])
|
||||
if err != nil {
|
||||
return dst, "incomplete read of message header", err
|
||||
}
|
||||
|
||||
// read the length as an int32
|
||||
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
|
||||
|
||||
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
|
||||
// defaultMaxMessageSize instead.
|
||||
maxMessageSize := c.desc.MaxMessageSize
|
||||
if maxMessageSize == 0 {
|
||||
maxMessageSize = defaultMaxMessageSize
|
||||
}
|
||||
if uint32(size) > maxMessageSize {
|
||||
return dst, errResponseTooLarge.Error(), errResponseTooLarge
|
||||
}
|
||||
|
||||
if int(size) > cap(dst) {
|
||||
// Since we can't grow this slice without allocating, just allocate an entirely new slice.
|
||||
dst = make([]byte, 0, size)
|
||||
}
|
||||
// We need to ensure we don't accidentally read into a subsequent wire message, so we set the
|
||||
// size to read exactly this wire message.
|
||||
dst = dst[:size]
|
||||
copy(dst, sizeBuf[:])
|
||||
|
||||
_, err = io.ReadFull(c.nc, dst[4:])
|
||||
if err != nil {
|
||||
return dst, "incomplete read of full message", err
|
||||
}
|
||||
|
||||
return dst, "", nil
|
||||
}
|
||||
|
||||
func (c *connection) close() error {
|
||||
// Overwrite the connection state as the first step so only the first close call will execute.
|
||||
if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if c.nc != nil {
|
||||
err = c.nc.Close()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *connection) closed() bool {
|
||||
return atomic.LoadInt64(&c.state) == connDisconnected
|
||||
}
|
||||
|
||||
func (c *connection) idleTimeoutExpired() bool {
|
||||
now := time.Now()
|
||||
if c.idleTimeout > 0 {
|
||||
idleDeadline, ok := c.idleDeadline.Load().(time.Time)
|
||||
if ok && now.After(idleDeadline) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *connection) bumpIdleDeadline() {
|
||||
if c.idleTimeout > 0 {
|
||||
c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connection) setCanStream(canStream bool) {
|
||||
c.canStream = canStream
|
||||
}
|
||||
|
||||
func (c initConnection) supportsStreaming() bool {
|
||||
return c.canStream
|
||||
}
|
||||
|
||||
func (c *connection) setStreaming(streaming bool) {
|
||||
c.currentlyStreaming = streaming
|
||||
}
|
||||
|
||||
func (c *connection) getCurrentlyStreaming() bool {
|
||||
return c.currentlyStreaming
|
||||
}
|
||||
|
||||
func (c *connection) setSocketTimeout(timeout time.Duration) {
|
||||
c.readTimeout = timeout
|
||||
c.writeTimeout = timeout
|
||||
}
|
||||
|
||||
func (c *connection) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *connection) ServerConnectionID() *int32 {
|
||||
return c.serverConnectionID
|
||||
}
|
||||
|
||||
// initConnection is an adapter used during connection initialization. It has the minimum
|
||||
// functionality necessary to implement the driver.Connection interface, which is required to pass a
|
||||
// *connection to a Handshaker.
|
||||
type initConnection struct{ *connection }
|
||||
|
||||
var _ driver.Connection = initConnection{}
|
||||
var _ driver.StreamerConnection = initConnection{}
|
||||
|
||||
func (c initConnection) Description() description.Server {
|
||||
if c.connection == nil {
|
||||
return description.Server{}
|
||||
}
|
||||
return c.connection.desc
|
||||
}
|
||||
func (c initConnection) Close() error { return nil }
|
||||
func (c initConnection) ID() string { return c.id }
|
||||
func (c initConnection) Address() address.Address { return c.addr }
|
||||
func (c initConnection) Stale() bool { return false }
|
||||
func (c initConnection) LocalAddress() address.Address {
|
||||
if c.connection == nil || c.nc == nil {
|
||||
return address.Address("0.0.0.0")
|
||||
}
|
||||
return address.Address(c.nc.LocalAddr().String())
|
||||
}
|
||||
func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
|
||||
return c.writeWireMessage(ctx, wm)
|
||||
}
|
||||
func (c initConnection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
|
||||
return c.readWireMessage(ctx, dst)
|
||||
}
|
||||
func (c initConnection) SetStreaming(streaming bool) {
|
||||
c.setStreaming(streaming)
|
||||
}
|
||||
func (c initConnection) CurrentlyStreaming() bool {
|
||||
return c.getCurrentlyStreaming()
|
||||
}
|
||||
func (c initConnection) SupportsStreaming() bool {
|
||||
return c.supportsStreaming()
|
||||
}
|
||||
|
||||
// Connection implements the driver.Connection interface to allow reading and writing wire
|
||||
// messages and the driver.Expirable interface to allow expiring.
|
||||
type Connection struct {
|
||||
*connection
|
||||
refCount int
|
||||
cleanupPoolFn func()
|
||||
|
||||
// cleanupServerFn resets the server state when a connection is returned to the connection pool
|
||||
// via Close() or expired via Expire().
|
||||
cleanupServerFn func()
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var _ driver.Connection = (*Connection)(nil)
|
||||
var _ driver.Expirable = (*Connection)(nil)
|
||||
var _ driver.PinnedConnection = (*Connection)(nil)
|
||||
|
||||
// WriteWireMessage handles writing a wire message to the underlying connection.
|
||||
func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
return c.writeWireMessage(ctx, wm)
|
||||
}
|
||||
|
||||
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
|
||||
// will be overwritten with the new wire message.
|
||||
func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return dst, ErrConnectionClosed
|
||||
}
|
||||
return c.readWireMessage(ctx, dst)
|
||||
}
|
||||
|
||||
// CompressWireMessage handles compressing the provided wire message using the underlying
|
||||
// connection's compressor. The dst parameter will be overwritten with the new wire message. If
|
||||
// there is no compressor set on the underlying connection, then no compression will be performed.
|
||||
func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return dst, ErrConnectionClosed
|
||||
}
|
||||
if c.connection.compressor == wiremessage.CompressorNoOp {
|
||||
return append(dst, src...), nil
|
||||
}
|
||||
_, reqid, respto, origcode, rem, ok := wiremessage.ReadHeader(src)
|
||||
if !ok {
|
||||
return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
|
||||
}
|
||||
idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
|
||||
dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode)
|
||||
dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem)))
|
||||
dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor)
|
||||
opts := driver.CompressionOpts{
|
||||
Compressor: c.connection.compressor,
|
||||
ZlibLevel: c.connection.zliblevel,
|
||||
ZstdLevel: c.connection.zstdLevel,
|
||||
}
|
||||
compressed, err := driver.CompressPayload(rem, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed)
|
||||
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
|
||||
}
|
||||
|
||||
// Description returns the server description of the server this connection is connected to.
|
||||
func (c *Connection) Description() description.Server {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return description.Server{}
|
||||
}
|
||||
return c.desc
|
||||
}
|
||||
|
||||
// Close returns this connection to the connection pool. This method may not closeConnection the underlying
|
||||
// socket.
|
||||
func (c *Connection) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connection == nil || c.refCount > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.cleanupReferences()
|
||||
}
|
||||
|
||||
// Expire closes this connection and will closeConnection the underlying socket.
|
||||
func (c *Connection) Expire() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connection == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = c.close()
|
||||
return c.cleanupReferences()
|
||||
}
|
||||
|
||||
func (c *Connection) cleanupReferences() error {
|
||||
err := c.pool.checkIn(c.connection)
|
||||
if c.cleanupPoolFn != nil {
|
||||
c.cleanupPoolFn()
|
||||
c.cleanupPoolFn = nil
|
||||
}
|
||||
if c.cleanupServerFn != nil {
|
||||
c.cleanupServerFn()
|
||||
c.cleanupServerFn = nil
|
||||
}
|
||||
c.connection = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Alive returns if the connection is still alive.
|
||||
func (c *Connection) Alive() bool {
|
||||
return c.connection != nil
|
||||
}
|
||||
|
||||
// ID returns the ID of this connection.
|
||||
func (c *Connection) ID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return "<closed>"
|
||||
}
|
||||
return c.id
|
||||
}
|
||||
|
||||
// Stale returns if the connection is stale.
|
||||
func (c *Connection) Stale() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.pool.stale(c.connection)
|
||||
}
|
||||
|
||||
// Address returns the address of this connection.
|
||||
func (c *Connection) Address() address.Address {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return address.Address("0.0.0.0")
|
||||
}
|
||||
return c.addr
|
||||
}
|
||||
|
||||
// LocalAddress returns the local address of the connection
|
||||
func (c *Connection) LocalAddress() address.Address {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil || c.nc == nil {
|
||||
return address.Address("0.0.0.0")
|
||||
}
|
||||
return address.Address(c.nc.LocalAddr().String())
|
||||
}
|
||||
|
||||
// PinToCursor updates this connection to reflect that it is pinned to a cursor.
|
||||
func (c *Connection) PinToCursor() error {
|
||||
return c.pin("cursor", c.pool.pinConnectionToCursor, c.pool.unpinConnectionFromCursor)
|
||||
}
|
||||
|
||||
// PinToTransaction updates this connection to reflect that it is pinned to a transaction.
|
||||
func (c *Connection) PinToTransaction() error {
|
||||
return c.pin("transaction", c.pool.pinConnectionToTransaction, c.pool.unpinConnectionFromTransaction)
|
||||
}
|
||||
|
||||
func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connection == nil {
|
||||
return fmt.Errorf("attempted to pin a connection for a %s, but the connection has already been returned to the pool", reason)
|
||||
}
|
||||
|
||||
// Only use the provided callbacks for the first reference to avoid double-counting pinned connection statistics
|
||||
// in the pool.
|
||||
if c.refCount == 0 {
|
||||
updatePoolFn()
|
||||
c.cleanupPoolFn = cleanupPoolFn
|
||||
}
|
||||
c.refCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnpinFromCursor updates this connection to reflect that it is no longer pinned to a cursor.
|
||||
func (c *Connection) UnpinFromCursor() error {
|
||||
return c.unpin("cursor")
|
||||
}
|
||||
|
||||
// UnpinFromTransaction updates this connection to reflect that it is no longer pinned to a transaction.
|
||||
func (c *Connection) UnpinFromTransaction() error {
|
||||
return c.unpin("transaction")
|
||||
}
|
||||
|
||||
func (c *Connection) unpin(reason string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connection == nil {
|
||||
// We don't error here because the resource could have been forcefully closed via Expire.
|
||||
return nil
|
||||
}
|
||||
if c.refCount == 0 {
|
||||
return fmt.Errorf("attempted to unpin a connection from a %s, but the connection is not pinned by any resources", reason)
|
||||
}
|
||||
|
||||
c.refCount--
|
||||
return nil
|
||||
}
|
||||
|
||||
func configureTLS(ctx context.Context,
|
||||
tlsConnSource tlsConnectionSource,
|
||||
nc net.Conn,
|
||||
addr address.Address,
|
||||
config *tls.Config,
|
||||
ocspOpts *ocsp.VerifyOptions,
|
||||
) (net.Conn, error) {
|
||||
// Ensure config.ServerName is always set for SNI.
|
||||
if config.ServerName == "" {
|
||||
hostname := addr.String()
|
||||
colonPos := strings.LastIndex(hostname, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(hostname)
|
||||
}
|
||||
|
||||
hostname = hostname[:colonPos]
|
||||
config.ServerName = hostname
|
||||
}
|
||||
|
||||
client := tlsConnSource.Client(nc, config)
|
||||
if err := clientHandshake(ctx, client); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only do OCSP verification if TLS verification is requested.
|
||||
if !config.InsecureSkipVerify {
|
||||
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
|
||||
return nil, ocspErr
|
||||
}
|
||||
}
|
||||
return client, nil
|
||||
}
|
7
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_legacy.go
generated
vendored
Normal file
7
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_legacy.go
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
214
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go
generated
vendored
Normal file
214
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go
generated
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
|
||||
)
|
||||
|
||||
// Dialer is used to make network connections.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// DialerFunc is a type implemented by functions that can be used as a Dialer.
|
||||
type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// DialContext implements the Dialer interface.
|
||||
func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return df(ctx, network, address)
|
||||
}
|
||||
|
||||
// DefaultDialer is the Dialer implementation that is used by this package. Changing this
|
||||
// will also change the Dialer used for this package. This should only be changed why all
|
||||
// of the connections being made need to use a different Dialer. Most of the time, using a
|
||||
// WithDialer option is more appropriate than changing this variable.
|
||||
var DefaultDialer Dialer = &net.Dialer{}
|
||||
|
||||
// Handshaker is the interface implemented by types that can perform a MongoDB
|
||||
// handshake over a provided driver.Connection. This is used during connection
|
||||
// initialization. Implementations must be goroutine safe.
|
||||
type Handshaker = driver.Handshaker
|
||||
|
||||
// generationNumberFn is a callback type used by a connection to fetch its generation number given its service ID.
|
||||
type generationNumberFn func(serviceID *primitive.ObjectID) uint64
|
||||
|
||||
type connectionConfig struct {
|
||||
connectTimeout time.Duration
|
||||
dialer Dialer
|
||||
handshaker Handshaker
|
||||
idleTimeout time.Duration
|
||||
cmdMonitor *event.CommandMonitor
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
tlsConfig *tls.Config
|
||||
httpClient *http.Client
|
||||
compressors []string
|
||||
zlibLevel *int
|
||||
zstdLevel *int
|
||||
ocspCache ocsp.Cache
|
||||
disableOCSPEndpointCheck bool
|
||||
tlsConnectionSource tlsConnectionSource
|
||||
loadBalanced bool
|
||||
getGenerationFn generationNumberFn
|
||||
}
|
||||
|
||||
func newConnectionConfig(opts ...ConnectionOption) *connectionConfig {
|
||||
cfg := &connectionConfig{
|
||||
connectTimeout: 30 * time.Second,
|
||||
dialer: nil,
|
||||
tlsConnectionSource: defaultTLSConnectionSource,
|
||||
httpClient: internal.DefaultHTTPClient,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
if cfg.dialer == nil {
|
||||
cfg.dialer = &net.Dialer{}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ConnectionOption is used to configure a connection.
|
||||
type ConnectionOption func(*connectionConfig)
|
||||
|
||||
func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.tlsConnectionSource = fn(c.tlsConnectionSource)
|
||||
}
|
||||
}
|
||||
|
||||
// WithCompressors sets the compressors that can be used for communication.
|
||||
func WithCompressors(fn func([]string) []string) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.compressors = fn(c.compressors)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectTimeout configures the maximum amount of time a dial will wait for a
|
||||
// Connect to complete. The default is 30 seconds.
|
||||
func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.connectTimeout = fn(c.connectTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithDialer configures the Dialer to use when making a new connection to MongoDB.
|
||||
func WithDialer(fn func(Dialer) Dialer) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.dialer = fn(c.dialer)
|
||||
}
|
||||
}
|
||||
|
||||
// WithHandshaker configures the Handshaker that wll be used to initialize newly
|
||||
// dialed connections.
|
||||
func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.handshaker = fn(c.handshaker)
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdleTimeout configures the maximum idle time to allow for a connection.
|
||||
func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.idleTimeout = fn(c.idleTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithReadTimeout configures the maximum read time for a connection.
|
||||
func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.readTimeout = fn(c.readTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithWriteTimeout configures the maximum write time for a connection.
|
||||
func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.writeTimeout = fn(c.writeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig configures the TLS options for a connection.
|
||||
func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.tlsConfig = fn(c.tlsConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient configures the HTTP client for a connection.
|
||||
func WithHTTPClient(fn func(*http.Client) *http.Client) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.httpClient = fn(c.httpClient)
|
||||
}
|
||||
}
|
||||
|
||||
// WithMonitor configures a event for command monitoring.
|
||||
func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.cmdMonitor = fn(c.cmdMonitor)
|
||||
}
|
||||
}
|
||||
|
||||
// WithZlibLevel sets the zLib compression level.
|
||||
func WithZlibLevel(fn func(*int) *int) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.zlibLevel = fn(c.zlibLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// WithZstdLevel sets the zstd compression level.
|
||||
func WithZstdLevel(fn func(*int) *int) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.zstdLevel = fn(c.zstdLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// WithOCSPCache specifies a cache to use for OCSP verification.
|
||||
func WithOCSPCache(fn func(ocsp.Cache) ocsp.Cache) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.ocspCache = fn(c.ocspCache)
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisableOCSPEndpointCheck specifies whether or the driver should perform non-stapled OCSP verification. If set
|
||||
// to true, the driver will only check stapled responses and will continue the connection without reaching out to
|
||||
// OCSP responders.
|
||||
func WithDisableOCSPEndpointCheck(fn func(bool) bool) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.disableOCSPEndpointCheck = fn(c.disableOCSPEndpointCheck)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionLoadBalanced specifies whether or not the connection is to a server behind a load balancer.
|
||||
func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.loadBalanced = fn(c.loadBalanced)
|
||||
}
|
||||
}
|
||||
|
||||
func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption {
|
||||
return func(c *connectionConfig) {
|
||||
c.getGenerationFn = fn(c.getGenerationFn)
|
||||
}
|
||||
}
|
73
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/diff.go
generated
vendored
Normal file
73
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/diff.go
generated
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import "go.mongodb.org/mongo-driver/mongo/description"
|
||||
|
||||
// hostlistDiff is the difference between a topology and a host list.
|
||||
type hostlistDiff struct {
|
||||
Added []string
|
||||
Removed []string
|
||||
}
|
||||
|
||||
// diffHostList compares the topology description and host list and returns the difference.
|
||||
func diffHostList(t description.Topology, hostlist []string) hostlistDiff {
|
||||
var diff hostlistDiff
|
||||
|
||||
oldServers := make(map[string]bool)
|
||||
for _, s := range t.Servers {
|
||||
oldServers[s.Addr.String()] = true
|
||||
}
|
||||
|
||||
for _, addr := range hostlist {
|
||||
if oldServers[addr] {
|
||||
delete(oldServers, addr)
|
||||
} else {
|
||||
diff.Added = append(diff.Added, addr)
|
||||
}
|
||||
}
|
||||
|
||||
for addr := range oldServers {
|
||||
diff.Removed = append(diff.Removed, addr)
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// topologyDiff is the difference between two different topology descriptions.
|
||||
type topologyDiff struct {
|
||||
Added []description.Server
|
||||
Removed []description.Server
|
||||
}
|
||||
|
||||
// diffTopology compares the two topology descriptions and returns the difference.
|
||||
func diffTopology(old, new description.Topology) topologyDiff {
|
||||
var diff topologyDiff
|
||||
|
||||
oldServers := make(map[string]bool)
|
||||
for _, s := range old.Servers {
|
||||
oldServers[s.Addr.String()] = true
|
||||
}
|
||||
|
||||
for _, s := range new.Servers {
|
||||
addr := s.Addr.String()
|
||||
if oldServers[addr] {
|
||||
delete(oldServers, addr)
|
||||
} else {
|
||||
diff.Added = append(diff.Added, s)
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range old.Servers {
|
||||
addr := s.Addr.String()
|
||||
if oldServers[addr] {
|
||||
diff.Removed = append(diff.Removed, s)
|
||||
}
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
111
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go
generated
vendored
Normal file
111
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go
generated
vendored
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
)
|
||||
|
||||
// ConnectionError represents a connection error.
|
||||
type ConnectionError struct {
|
||||
ConnectionID string
|
||||
Wrapped error
|
||||
|
||||
// init will be set to true if this error occurred during connection initialization or
|
||||
// during a connection handshake.
|
||||
init bool
|
||||
message string
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ConnectionError) Error() string {
|
||||
message := e.message
|
||||
if e.init {
|
||||
fullMsg := "error occurred during connection handshake"
|
||||
if message != "" {
|
||||
fullMsg = fmt.Sprintf("%s: %s", fullMsg, message)
|
||||
}
|
||||
message = fullMsg
|
||||
}
|
||||
if e.Wrapped != nil && message != "" {
|
||||
return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, message, e.Wrapped.Error())
|
||||
}
|
||||
if e.Wrapped != nil {
|
||||
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.Wrapped.Error())
|
||||
}
|
||||
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, message)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (e ConnectionError) Unwrap() error {
|
||||
return e.Wrapped
|
||||
}
|
||||
|
||||
// ServerSelectionError represents a Server Selection error.
|
||||
type ServerSelectionError struct {
|
||||
Desc description.Topology
|
||||
Wrapped error
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ServerSelectionError) Error() string {
|
||||
if e.Wrapped != nil {
|
||||
return fmt.Sprintf("server selection error: %s, current topology: { %s }", e.Wrapped.Error(), e.Desc.String())
|
||||
}
|
||||
return fmt.Sprintf("server selection error: current topology: { %s }", e.Desc.String())
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (e ServerSelectionError) Unwrap() error {
|
||||
return e.Wrapped
|
||||
}
|
||||
|
||||
// WaitQueueTimeoutError represents a timeout when requesting a connection from the pool
|
||||
type WaitQueueTimeoutError struct {
|
||||
Wrapped error
|
||||
PinnedCursorConnections uint64
|
||||
PinnedTransactionConnections uint64
|
||||
maxPoolSize uint64
|
||||
totalConnectionCount int
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (w WaitQueueTimeoutError) Error() string {
|
||||
errorMsg := "timed out while checking out a connection from connection pool"
|
||||
switch w.Wrapped {
|
||||
case nil:
|
||||
case context.Canceled:
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s: %s",
|
||||
"canceled while checking out a connection from connection pool",
|
||||
w.Wrapped.Error(),
|
||||
)
|
||||
default:
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s: %s",
|
||||
errorMsg,
|
||||
w.Wrapped.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s; maxPoolSize: %d, connections in use by cursors: %d"+
|
||||
", connections in use by transactions: %d, connections in use by other operations: %d",
|
||||
errorMsg,
|
||||
w.maxPoolSize,
|
||||
w.PinnedCursorConnections,
|
||||
w.PinnedTransactionConnections,
|
||||
uint64(w.totalConnectionCount)-w.PinnedCursorConnections-w.PinnedTransactionConnections)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (w WaitQueueTimeoutError) Unwrap() error {
|
||||
return w.Wrapped
|
||||
}
|
438
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go
generated
vendored
Normal file
438
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go
generated
vendored
Normal file
@@ -0,0 +1,438 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
)
|
||||
|
||||
var (
|
||||
// MinSupportedMongoDBVersion is the version string for the lowest MongoDB version supported by the driver.
|
||||
MinSupportedMongoDBVersion = "3.6"
|
||||
|
||||
// SupportedWireVersions is the range of wire versions supported by the driver.
|
||||
SupportedWireVersions = description.NewVersionRange(6, 17)
|
||||
)
|
||||
|
||||
type fsm struct {
|
||||
description.Topology
|
||||
maxElectionID primitive.ObjectID
|
||||
maxSetVersion uint32
|
||||
compatible atomic.Value
|
||||
compatibilityErr error
|
||||
}
|
||||
|
||||
func newFSM() *fsm {
|
||||
f := fsm{}
|
||||
f.compatible.Store(true)
|
||||
return &f
|
||||
}
|
||||
|
||||
// apply takes a new server description and modifies the FSM's topology description based on it. It returns the
|
||||
// updated topology description as well as a server description. The returned server description is either the same
|
||||
// one that was passed in, or a new one in the case that it had to be changed.
|
||||
//
|
||||
// apply should operation on immutable descriptions so we don't have to lock for the entire time we're applying the
|
||||
// server description.
|
||||
func (f *fsm) apply(s description.Server) (description.Topology, description.Server) {
|
||||
newServers := make([]description.Server, len(f.Servers))
|
||||
copy(newServers, f.Servers)
|
||||
|
||||
oldMinutes := f.SessionTimeoutMinutes
|
||||
f.Topology = description.Topology{
|
||||
Kind: f.Kind,
|
||||
Servers: newServers,
|
||||
SetName: f.SetName,
|
||||
}
|
||||
|
||||
// For data bearing servers, set SessionTimeoutMinutes to the lowest among them
|
||||
if oldMinutes == 0 {
|
||||
// If timeout currently 0, check all servers to see if any still don't have a timeout
|
||||
// If they all have timeout, pick the lowest.
|
||||
timeout := s.SessionTimeoutMinutes
|
||||
for _, server := range f.Servers {
|
||||
if server.DataBearing() && server.SessionTimeoutMinutes < timeout {
|
||||
timeout = server.SessionTimeoutMinutes
|
||||
}
|
||||
}
|
||||
f.SessionTimeoutMinutes = timeout
|
||||
} else {
|
||||
if s.DataBearing() && oldMinutes > s.SessionTimeoutMinutes {
|
||||
f.SessionTimeoutMinutes = s.SessionTimeoutMinutes
|
||||
} else {
|
||||
f.SessionTimeoutMinutes = oldMinutes
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := f.findServer(s.Addr); !ok {
|
||||
return f.Topology, s
|
||||
}
|
||||
|
||||
updatedDesc := s
|
||||
switch f.Kind {
|
||||
case description.Unknown:
|
||||
updatedDesc = f.applyToUnknown(s)
|
||||
case description.Sharded:
|
||||
updatedDesc = f.applyToSharded(s)
|
||||
case description.ReplicaSetNoPrimary:
|
||||
updatedDesc = f.applyToReplicaSetNoPrimary(s)
|
||||
case description.ReplicaSetWithPrimary:
|
||||
updatedDesc = f.applyToReplicaSetWithPrimary(s)
|
||||
case description.Single:
|
||||
updatedDesc = f.applyToSingle(s)
|
||||
}
|
||||
|
||||
for _, server := range f.Servers {
|
||||
if server.WireVersion != nil {
|
||||
if server.WireVersion.Max < SupportedWireVersions.Min {
|
||||
f.compatible.Store(false)
|
||||
f.compatibilityErr = fmt.Errorf(
|
||||
"server at %s reports wire version %d, but this version of the Go driver requires "+
|
||||
"at least %d (MongoDB %s)",
|
||||
server.Addr.String(),
|
||||
server.WireVersion.Max,
|
||||
SupportedWireVersions.Min,
|
||||
MinSupportedMongoDBVersion,
|
||||
)
|
||||
f.Topology.CompatibilityErr = f.compatibilityErr
|
||||
return f.Topology, s
|
||||
}
|
||||
|
||||
if server.WireVersion.Min > SupportedWireVersions.Max {
|
||||
f.compatible.Store(false)
|
||||
f.compatibilityErr = fmt.Errorf(
|
||||
"server at %s requires wire version %d, but this version of the Go driver only supports up to %d",
|
||||
server.Addr.String(),
|
||||
server.WireVersion.Min,
|
||||
SupportedWireVersions.Max,
|
||||
)
|
||||
f.Topology.CompatibilityErr = f.compatibilityErr
|
||||
return f.Topology, s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.compatible.Store(true)
|
||||
f.compatibilityErr = nil
|
||||
return f.Topology, updatedDesc
|
||||
}
|
||||
|
||||
func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) description.Server {
|
||||
switch s.Kind {
|
||||
case description.Standalone, description.Mongos:
|
||||
f.removeServerByAddr(s.Addr)
|
||||
case description.RSPrimary:
|
||||
f.updateRSFromPrimary(s)
|
||||
case description.RSSecondary, description.RSArbiter, description.RSMember:
|
||||
f.updateRSWithoutPrimary(s)
|
||||
case description.Unknown, description.RSGhost:
|
||||
f.replaceServer(s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) description.Server {
|
||||
switch s.Kind {
|
||||
case description.Standalone, description.Mongos:
|
||||
f.removeServerByAddr(s.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
case description.RSPrimary:
|
||||
f.updateRSFromPrimary(s)
|
||||
case description.RSSecondary, description.RSArbiter, description.RSMember:
|
||||
f.updateRSWithPrimaryFromMember(s)
|
||||
case description.Unknown, description.RSGhost:
|
||||
f.replaceServer(s)
|
||||
f.checkIfHasPrimary()
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (f *fsm) applyToSharded(s description.Server) description.Server {
|
||||
switch s.Kind {
|
||||
case description.Mongos, description.Unknown:
|
||||
f.replaceServer(s)
|
||||
case description.Standalone, description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
|
||||
f.removeServerByAddr(s.Addr)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (f *fsm) applyToSingle(s description.Server) description.Server {
|
||||
switch s.Kind {
|
||||
case description.Unknown:
|
||||
f.replaceServer(s)
|
||||
case description.Standalone, description.Mongos:
|
||||
if f.SetName != "" {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return s
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
case description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
|
||||
// A replica set name can be provided when creating a direct connection. In this case, if the set name returned
|
||||
// by the hello response doesn't match up with the one provided during configuration, the server description
|
||||
// is replaced with a default Unknown description.
|
||||
//
|
||||
// We create a new server description rather than doing s.Kind = description.Unknown because the other fields,
|
||||
// such as RTT, need to be cleared for Unknown descriptions as well.
|
||||
if f.SetName != "" && f.SetName != s.SetName {
|
||||
s = description.Server{
|
||||
Addr: s.Addr,
|
||||
Kind: description.Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (f *fsm) applyToUnknown(s description.Server) description.Server {
|
||||
switch s.Kind {
|
||||
case description.Mongos:
|
||||
f.setKind(description.Sharded)
|
||||
f.replaceServer(s)
|
||||
case description.RSPrimary:
|
||||
f.updateRSFromPrimary(s)
|
||||
case description.RSSecondary, description.RSArbiter, description.RSMember:
|
||||
f.setKind(description.ReplicaSetNoPrimary)
|
||||
f.updateRSWithoutPrimary(s)
|
||||
case description.Standalone:
|
||||
f.updateUnknownWithStandalone(s)
|
||||
case description.Unknown, description.RSGhost:
|
||||
f.replaceServer(s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (f *fsm) checkIfHasPrimary() {
|
||||
if _, ok := f.findPrimary(); ok {
|
||||
f.setKind(description.ReplicaSetWithPrimary)
|
||||
} else {
|
||||
f.setKind(description.ReplicaSetNoPrimary)
|
||||
}
|
||||
}
|
||||
|
||||
// hasStalePrimary returns true if the topology has a primary that is "stale".
|
||||
func hasStalePrimary(fsm fsm, srv description.Server) bool {
|
||||
// Compare the election ID values of the server and the topology lexicographically.
|
||||
compRes := bytes.Compare(srv.ElectionID[:], fsm.maxElectionID[:])
|
||||
|
||||
if wireVersion := srv.WireVersion; wireVersion != nil && wireVersion.Max >= 17 {
|
||||
// In the Post-6.0 case, a primary is considered "stale" if the server's election ID is greather than the
|
||||
// topology's max election ID. In these versions, the primary is also considered "stale" if the server's
|
||||
// election ID is LTE to the topologies election ID and the server's "setVersion" is less than the topology's
|
||||
// max "setVersion".
|
||||
return compRes == -1 || (compRes != 1 && srv.SetVersion < fsm.maxSetVersion)
|
||||
}
|
||||
|
||||
// If the server's election ID is less than the topology's max election ID, the primary is considered
|
||||
// "stale". Similarly, if the server's "setVersion" is less than the topology's max "setVersion", the
|
||||
// primary is considered stale.
|
||||
return compRes == -1 || fsm.maxSetVersion > srv.SetVersion
|
||||
}
|
||||
|
||||
// transferEVTuple will transfer the ("ElectionID", "SetVersion") tuple from the description server to the topology.
|
||||
// If the primary is stale, the tuple will not be transferred, the topology will update it's "Kind" value, and this
|
||||
// routine will return "false".
|
||||
func transferEVTuple(srv description.Server, fsm *fsm) bool {
|
||||
stalePrimary := hasStalePrimary(*fsm, srv)
|
||||
|
||||
if wireVersion := srv.WireVersion; wireVersion != nil && wireVersion.Max >= 17 {
|
||||
if stalePrimary {
|
||||
fsm.checkIfHasPrimary()
|
||||
return false
|
||||
}
|
||||
|
||||
fsm.maxElectionID = srv.ElectionID
|
||||
fsm.maxSetVersion = srv.SetVersion
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if srv.SetVersion != 0 && !srv.ElectionID.IsZero() {
|
||||
if stalePrimary {
|
||||
fsm.replaceServer(description.Server{
|
||||
Addr: srv.Addr,
|
||||
LastError: fmt.Errorf(
|
||||
"was a primary, but its set version or election id is stale"),
|
||||
})
|
||||
|
||||
fsm.checkIfHasPrimary()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
fsm.maxElectionID = srv.ElectionID
|
||||
}
|
||||
|
||||
if srv.SetVersion > fsm.maxSetVersion {
|
||||
fsm.maxSetVersion = srv.SetVersion
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *fsm) updateRSFromPrimary(srv description.Server) {
|
||||
if f.SetName == "" {
|
||||
f.SetName = srv.SetName
|
||||
} else if f.SetName != srv.SetName {
|
||||
f.removeServerByAddr(srv.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if ok := transferEVTuple(srv, f); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if j, ok := f.findPrimary(); ok {
|
||||
f.setServer(j, description.Server{
|
||||
Addr: f.Servers[j].Addr,
|
||||
LastError: fmt.Errorf("was a primary, but a new primary was discovered"),
|
||||
})
|
||||
}
|
||||
|
||||
f.replaceServer(srv)
|
||||
|
||||
for j := len(f.Servers) - 1; j >= 0; j-- {
|
||||
found := false
|
||||
for _, member := range srv.Members {
|
||||
if member == f.Servers[j].Addr {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
f.removeServer(j)
|
||||
}
|
||||
}
|
||||
|
||||
for _, member := range srv.Members {
|
||||
if _, ok := f.findServer(member); !ok {
|
||||
f.addServer(member)
|
||||
}
|
||||
}
|
||||
|
||||
f.checkIfHasPrimary()
|
||||
}
|
||||
|
||||
func (f *fsm) updateRSWithPrimaryFromMember(s description.Server) {
|
||||
if f.SetName != s.SetName {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
return
|
||||
}
|
||||
|
||||
if s.Addr != s.CanonicalAddr {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
return
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
|
||||
if _, ok := f.findPrimary(); !ok {
|
||||
f.setKind(description.ReplicaSetNoPrimary)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) updateRSWithoutPrimary(s description.Server) {
|
||||
if f.SetName == "" {
|
||||
f.SetName = s.SetName
|
||||
} else if f.SetName != s.SetName {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
for _, member := range s.Members {
|
||||
if _, ok := f.findServer(member); !ok {
|
||||
f.addServer(member)
|
||||
}
|
||||
}
|
||||
|
||||
if s.Addr != s.CanonicalAddr {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
}
|
||||
|
||||
func (f *fsm) updateUnknownWithStandalone(s description.Server) {
|
||||
if len(f.Servers) > 1 {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
f.setKind(description.Single)
|
||||
f.replaceServer(s)
|
||||
}
|
||||
|
||||
func (f *fsm) addServer(addr address.Address) {
|
||||
f.Servers = append(f.Servers, description.Server{
|
||||
Addr: addr.Canonicalize(),
|
||||
})
|
||||
}
|
||||
|
||||
func (f *fsm) findPrimary() (int, bool) {
|
||||
for i, s := range f.Servers {
|
||||
if s.Kind == description.RSPrimary {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (f *fsm) findServer(addr address.Address) (int, bool) {
|
||||
canon := addr.Canonicalize()
|
||||
for i, s := range f.Servers {
|
||||
if canon == s.Addr {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (f *fsm) removeServer(i int) {
|
||||
f.Servers = append(f.Servers[:i], f.Servers[i+1:]...)
|
||||
}
|
||||
|
||||
func (f *fsm) removeServerByAddr(addr address.Address) {
|
||||
if i, ok := f.findServer(addr); ok {
|
||||
f.removeServer(i)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) replaceServer(s description.Server) {
|
||||
if i, ok := f.findServer(s.Addr); ok {
|
||||
f.setServer(i, s)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) setServer(i int, s description.Server) {
|
||||
f.Servers[i] = s
|
||||
}
|
||||
|
||||
func (f *fsm) setKind(k description.TopologyKind) {
|
||||
f.Kind = k
|
||||
}
|
1135
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go
generated
vendored
Normal file
1135
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
152
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool_generation_counter.go
generated
vendored
Normal file
152
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool_generation_counter.go
generated
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// Pool generation state constants.
|
||||
const (
|
||||
generationDisconnected int64 = iota
|
||||
generationConnected
|
||||
)
|
||||
|
||||
// generationStats represents the version of a pool. It tracks the generation number as well as the number of
|
||||
// connections that have been created in the generation.
|
||||
type generationStats struct {
|
||||
generation uint64
|
||||
numConns uint64
|
||||
}
|
||||
|
||||
// poolGenerationMap tracks the version for each service ID present in a pool. For deployments that are not behind a
|
||||
// load balancer, there is only one service ID: primitive.NilObjectID. For load-balanced deployments, each server behind
|
||||
// the load balancer will have a unique service ID.
|
||||
type poolGenerationMap struct {
|
||||
// state must be accessed using the atomic package and should be at the beginning of the struct.
|
||||
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
|
||||
// - suggested layout: https://go101.org/article/memory-layout.html
|
||||
state int64
|
||||
generationMap map[primitive.ObjectID]*generationStats
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func newPoolGenerationMap() *poolGenerationMap {
|
||||
pgm := &poolGenerationMap{
|
||||
generationMap: make(map[primitive.ObjectID]*generationStats),
|
||||
}
|
||||
pgm.generationMap[primitive.NilObjectID] = &generationStats{}
|
||||
return pgm
|
||||
}
|
||||
|
||||
func (p *poolGenerationMap) connect() {
|
||||
atomic.StoreInt64(&p.state, generationConnected)
|
||||
}
|
||||
|
||||
func (p *poolGenerationMap) disconnect() {
|
||||
atomic.StoreInt64(&p.state, generationDisconnected)
|
||||
}
|
||||
|
||||
// addConnection increments the connection count for the generation associated with the given service ID and returns the
|
||||
// generation number for the connection.
|
||||
func (p *poolGenerationMap) addConnection(serviceIDPtr *primitive.ObjectID) uint64 {
|
||||
serviceID := getServiceID(serviceIDPtr)
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
stats, ok := p.generationMap[serviceID]
|
||||
if ok {
|
||||
// If the serviceID is already being tracked, we only need to increment the connection count.
|
||||
stats.numConns++
|
||||
return stats.generation
|
||||
}
|
||||
|
||||
// If the serviceID is untracked, create a new entry with a starting generation number of 0.
|
||||
stats = &generationStats{
|
||||
numConns: 1,
|
||||
}
|
||||
p.generationMap[serviceID] = stats
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *poolGenerationMap) removeConnection(serviceIDPtr *primitive.ObjectID) {
|
||||
serviceID := getServiceID(serviceIDPtr)
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
stats, ok := p.generationMap[serviceID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// If the serviceID is being tracked, decrement the connection count and delete this serviceID to prevent the map
|
||||
// from growing unboundedly. This case would happen if a server behind a load-balancer was permanently removed
|
||||
// and its connections were pruned after a network error or idle timeout.
|
||||
stats.numConns--
|
||||
if stats.numConns == 0 {
|
||||
delete(p.generationMap, serviceID)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) {
|
||||
serviceID := getServiceID(serviceIDPtr)
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if stats, ok := p.generationMap[serviceID]; ok {
|
||||
stats.generation++
|
||||
}
|
||||
}
|
||||
|
||||
func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool {
|
||||
// If the map has been disconnected, all connections should be considered stale to ensure that they're closed.
|
||||
if atomic.LoadInt64(&p.state) == generationDisconnected {
|
||||
return true
|
||||
}
|
||||
|
||||
serviceID := getServiceID(serviceIDPtr)
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if stats, ok := p.generationMap[serviceID]; ok {
|
||||
return knownGeneration < stats.generation
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *poolGenerationMap) getGeneration(serviceIDPtr *primitive.ObjectID) uint64 {
|
||||
serviceID := getServiceID(serviceIDPtr)
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if stats, ok := p.generationMap[serviceID]; ok {
|
||||
return stats.generation
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *poolGenerationMap) getNumConns(serviceIDPtr *primitive.ObjectID) uint64 {
|
||||
serviceID := getServiceID(serviceIDPtr)
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if stats, ok := p.generationMap[serviceID]; ok {
|
||||
return stats.numConns
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getServiceID(oid *primitive.ObjectID) primitive.ObjectID {
|
||||
if oid == nil {
|
||||
return primitive.NilObjectID
|
||||
}
|
||||
return *oid
|
||||
}
|
307
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/rtt_monitor.go
generated
vendored
Normal file
307
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/rtt_monitor.go
generated
vendored
Normal file
@@ -0,0 +1,307 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/montanaflynn/stats"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
)
|
||||
|
||||
const (
|
||||
rttAlphaValue = 0.2
|
||||
minSamples = 10
|
||||
maxSamples = 500
|
||||
)
|
||||
|
||||
type rttConfig struct {
|
||||
// The minimum interval between RTT measurements. The actual interval may be greater if running
|
||||
// the operation takes longer than the interval.
|
||||
interval time.Duration
|
||||
|
||||
// The timeout applied to running the "hello" operation. If the timeout is reached while running
|
||||
// the operation, the RTT sample is discarded. The default is 1 minute.
|
||||
timeout time.Duration
|
||||
|
||||
minRTTWindow time.Duration
|
||||
createConnectionFn func() *connection
|
||||
createOperationFn func(driver.Connection) *operation.Hello
|
||||
}
|
||||
|
||||
type rttMonitor struct {
|
||||
mu sync.RWMutex // mu guards samples, offset, minRTT, averageRTT, and averageRTTSet
|
||||
samples []time.Duration
|
||||
offset int
|
||||
minRTT time.Duration
|
||||
rtt90 time.Duration
|
||||
averageRTT time.Duration
|
||||
averageRTTSet bool
|
||||
|
||||
closeWg sync.WaitGroup
|
||||
cfg *rttConfig
|
||||
ctx context.Context
|
||||
cancelFn context.CancelFunc
|
||||
}
|
||||
|
||||
var _ driver.RTTMonitor = &rttMonitor{}
|
||||
|
||||
func newRTTMonitor(cfg *rttConfig) *rttMonitor {
|
||||
if cfg.interval <= 0 {
|
||||
panic("RTT monitor interval must be greater than 0")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Determine the number of samples we need to keep to store the minWindow of RTT durations. The
|
||||
// number of samples must be between [10, 500].
|
||||
numSamples := int(math.Max(minSamples, math.Min(maxSamples, float64((cfg.minRTTWindow)/cfg.interval))))
|
||||
|
||||
return &rttMonitor{
|
||||
samples: make([]time.Duration, numSamples),
|
||||
cfg: cfg,
|
||||
ctx: ctx,
|
||||
cancelFn: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rttMonitor) connect() {
|
||||
r.closeWg.Add(1)
|
||||
go r.start()
|
||||
}
|
||||
|
||||
func (r *rttMonitor) disconnect() {
|
||||
// Signal for the routine to stop.
|
||||
r.cancelFn()
|
||||
r.closeWg.Wait()
|
||||
}
|
||||
|
||||
func (r *rttMonitor) start() {
|
||||
defer r.closeWg.Done()
|
||||
|
||||
var conn *connection
|
||||
defer func() {
|
||||
if conn != nil {
|
||||
// If the connection exists, we need to wait for it to be connected because
|
||||
// conn.connect() and conn.close() cannot be called concurrently. If the connection
|
||||
// wasn't successfully opened, its state was set back to disconnected, so calling
|
||||
// conn.close() will be a no-op.
|
||||
conn.closeConnectContext()
|
||||
conn.wait()
|
||||
_ = conn.close()
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(r.cfg.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
conn := r.cfg.createConnectionFn()
|
||||
err := conn.connect(r.ctx)
|
||||
|
||||
// Add an RTT sample from the new connection handshake and start a runHellos() loop if we
|
||||
// successfully established the new connection. Otherwise, close the connection and try to
|
||||
// create another new connection.
|
||||
if err == nil {
|
||||
r.addSample(conn.helloRTT)
|
||||
r.runHellos(conn)
|
||||
}
|
||||
|
||||
// Close any connection here because we're either about to try to create another new
|
||||
// connection or we're about to exit the loop.
|
||||
_ = conn.close()
|
||||
|
||||
// If a connection error happens quickly, always wait for the monitoring interval to try
|
||||
// to create a new connection to prevent creating connections too quickly.
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runHellos runs "hello" operations in a loop using the provided connection, measuring and
|
||||
// recording the operation durations as RTT samples. If it encounters any errors, it returns.
|
||||
func (r *rttMonitor) runHellos(conn *connection) {
|
||||
ticker := time.NewTicker(r.cfg.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
// Assume that the connection establishment recorded the first RTT sample, so wait for the
|
||||
// first tick before trying to record another RTT sample.
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
// Create a Context with the operation timeout specified in the RTT monitor config. If a
|
||||
// timeout is not set in the RTT monitor config, default to the connection's
|
||||
// "connectTimeoutMS". The purpose of the timeout is to allow the RTT monitor to continue
|
||||
// monitoring server RTTs after an operation gets stuck. An operation can get stuck if the
|
||||
// server or a proxy stops responding to requests on the RTT connection but does not close
|
||||
// the TCP socket, effectively creating an operation that will never complete. We expect
|
||||
// that "connectTimeoutMS" provides at least enough time for a single round trip.
|
||||
timeout := r.cfg.timeout
|
||||
if timeout <= 0 {
|
||||
timeout = conn.config.connectTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(r.ctx, timeout)
|
||||
|
||||
start := time.Now()
|
||||
err := r.cfg.createOperationFn(initConnection{conn}).Execute(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Only record a sample if the "hello" operation was successful. If it was not successful,
|
||||
// the operation may not have actually performed a complete round trip, so the duration may
|
||||
// be artificially short.
|
||||
r.addSample(time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
// reset sets the average and min RTT to 0. This should only be called from the server monitor when an error
|
||||
// occurs during a server check. Errors in the RTT monitor should not reset the RTTs.
|
||||
func (r *rttMonitor) reset() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for i := range r.samples {
|
||||
r.samples[i] = 0
|
||||
}
|
||||
r.offset = 0
|
||||
r.minRTT = 0
|
||||
r.rtt90 = 0
|
||||
r.averageRTT = 0
|
||||
r.averageRTTSet = false
|
||||
}
|
||||
|
||||
func (r *rttMonitor) addSample(rtt time.Duration) {
|
||||
// Lock for the duration of this method. We're doing compuationally inexpensive work very infrequently, so lock
|
||||
// contention isn't expected.
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.samples[r.offset] = rtt
|
||||
r.offset = (r.offset + 1) % len(r.samples)
|
||||
// Set the minRTT and 90th percentile RTT of all collected samples. Require at least 10 samples before
|
||||
// setting these to prevent noisy samples on startup from artificially increasing RTT and to allow the
|
||||
// calculation of a 90th percentile.
|
||||
r.minRTT = min(r.samples, minSamples)
|
||||
r.rtt90 = percentile(90.0, r.samples, minSamples)
|
||||
|
||||
if !r.averageRTTSet {
|
||||
r.averageRTT = rtt
|
||||
r.averageRTTSet = true
|
||||
return
|
||||
}
|
||||
|
||||
r.averageRTT = time.Duration(rttAlphaValue*float64(rtt) + (1-rttAlphaValue)*float64(r.averageRTT))
|
||||
}
|
||||
|
||||
// min returns the minimum value of the slice of duration samples. Zero values are not considered
|
||||
// samples and are ignored. If no samples or fewer than minSamples are found in the slice, min
|
||||
// returns 0.
|
||||
func min(samples []time.Duration, minSamples int) time.Duration {
|
||||
count := 0
|
||||
min := time.Duration(math.MaxInt64)
|
||||
for _, d := range samples {
|
||||
if d > 0 {
|
||||
count++
|
||||
}
|
||||
if d > 0 && d < min {
|
||||
min = d
|
||||
}
|
||||
}
|
||||
if count == 0 || count < minSamples {
|
||||
return 0
|
||||
}
|
||||
return min
|
||||
}
|
||||
|
||||
// percentile returns the specified percentile value of the slice of duration samples. Zero values
|
||||
// are not considered samples and are ignored. If no samples or fewer than minSamples are found
|
||||
// in the slice, percentile returns 0.
|
||||
func percentile(perc float64, samples []time.Duration, minSamples int) time.Duration {
|
||||
// Convert Durations to float64s.
|
||||
floatSamples := make([]float64, 0, len(samples))
|
||||
for _, sample := range samples {
|
||||
if sample > 0 {
|
||||
floatSamples = append(floatSamples, float64(sample))
|
||||
}
|
||||
}
|
||||
if len(floatSamples) == 0 || len(floatSamples) < minSamples {
|
||||
return 0
|
||||
}
|
||||
|
||||
p, err := stats.Percentile(floatSamples, perc)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("x/mongo/driver/topology: error calculating %f percentile RTT: %v for samples:\n%v", perc, err, floatSamples))
|
||||
}
|
||||
return time.Duration(p)
|
||||
}
|
||||
|
||||
// EWMA returns the exponentially weighted moving average observed round-trip time.
|
||||
func (r *rttMonitor) EWMA() time.Duration {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return r.averageRTT
|
||||
}
|
||||
|
||||
// Min returns the minimum observed round-trip time over the window period.
|
||||
func (r *rttMonitor) Min() time.Duration {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return r.minRTT
|
||||
}
|
||||
|
||||
// P90 returns the 90th percentile observed round-trip time over the window period.
|
||||
func (r *rttMonitor) P90() time.Duration {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return r.rtt90
|
||||
}
|
||||
|
||||
// Stats returns stringified stats of the current state of the monitor.
|
||||
func (r *rttMonitor) Stats() string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Calculate standard deviation and average (non-EWMA) of samples.
|
||||
var sum float64
|
||||
floatSamples := make([]float64, 0, len(r.samples))
|
||||
for _, sample := range r.samples {
|
||||
if sample > 0 {
|
||||
floatSamples = append(floatSamples, float64(sample))
|
||||
sum += float64(sample)
|
||||
}
|
||||
}
|
||||
|
||||
var avg, stdDev float64
|
||||
if len(floatSamples) > 0 {
|
||||
avg = sum / float64(len(floatSamples))
|
||||
|
||||
var err error
|
||||
stdDev, err = stats.StandardDeviation(floatSamples)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("x/mongo/driver/topology: error calculating standard deviation RTT: %v for samples:\n%v", err, floatSamples))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`Round-trip-time monitor statistics:`+"\n"+
|
||||
`average RTT: %v, minimum RTT: %v, 90th percentile RTT: %v, standard dev: %v`+"\n",
|
||||
time.Duration(avg), r.minRTT, r.rtt90, time.Duration(stdDev))
|
||||
}
|
957
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go
generated
vendored
Normal file
957
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go
generated
vendored
Normal file
@@ -0,0 +1,957 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
)
|
||||
|
||||
const minHeartbeatInterval = 500 * time.Millisecond
|
||||
|
||||
// Server state constants.
|
||||
const (
|
||||
serverDisconnected int64 = iota
|
||||
serverDisconnecting
|
||||
serverConnected
|
||||
)
|
||||
|
||||
func serverStateString(state int64) string {
|
||||
switch state {
|
||||
case serverDisconnected:
|
||||
return "Disconnected"
|
||||
case serverDisconnecting:
|
||||
return "Disconnecting"
|
||||
case serverConnected:
|
||||
return "Connected"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrServerClosed occurs when an attempt to Get a connection is made after
|
||||
// the server has been closed.
|
||||
ErrServerClosed = errors.New("server is closed")
|
||||
// ErrServerConnected occurs when at attempt to Connect is made after a server
|
||||
// has already been connected.
|
||||
ErrServerConnected = errors.New("server is connected")
|
||||
|
||||
errCheckCancelled = errors.New("server check cancelled")
|
||||
emptyDescription = description.NewDefaultServer("")
|
||||
)
|
||||
|
||||
// SelectedServer represents a specific server that was selected during server selection.
|
||||
// It contains the kind of the topology it was selected from.
|
||||
type SelectedServer struct {
|
||||
*Server
|
||||
|
||||
Kind description.TopologyKind
|
||||
}
|
||||
|
||||
// Description returns a description of the server as of the last heartbeat.
|
||||
func (ss *SelectedServer) Description() description.SelectedServer {
|
||||
sdesc := ss.Server.Description()
|
||||
return description.SelectedServer{
|
||||
Server: sdesc,
|
||||
Kind: ss.Kind,
|
||||
}
|
||||
}
|
||||
|
||||
// Server is a single server within a topology.
|
||||
type Server struct {
|
||||
// The following integer fields must be accessed using the atomic package and should be at the
|
||||
// beginning of the struct.
|
||||
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
|
||||
// - suggested layout: https://go101.org/article/memory-layout.html
|
||||
|
||||
state int64
|
||||
operationCount int64
|
||||
|
||||
cfg *serverConfig
|
||||
address address.Address
|
||||
|
||||
// connection related fields
|
||||
pool *pool
|
||||
|
||||
// goroutine management fields
|
||||
done chan struct{}
|
||||
checkNow chan struct{}
|
||||
disconnecting chan struct{}
|
||||
closewg sync.WaitGroup
|
||||
|
||||
// description related fields
|
||||
desc atomic.Value // holds a description.Server
|
||||
updateTopologyCallback atomic.Value
|
||||
topologyID primitive.ObjectID
|
||||
|
||||
// subscriber related fields
|
||||
subLock sync.Mutex
|
||||
subscribers map[uint64]chan description.Server
|
||||
currentSubscriberID uint64
|
||||
subscriptionsClosed bool
|
||||
|
||||
// heartbeat and cancellation related fields
|
||||
// globalCtx should be created in NewServer and cancelled in Disconnect to signal that the server is shutting down.
|
||||
// heartbeatCtx should be used for individual heartbeats and should be a child of globalCtx so that it will be
|
||||
// cancelled automatically during shutdown.
|
||||
heartbeatLock sync.Mutex
|
||||
conn *connection
|
||||
globalCtx context.Context
|
||||
globalCtxCancel context.CancelFunc
|
||||
heartbeatCtx context.Context
|
||||
heartbeatCtxCancel context.CancelFunc
|
||||
|
||||
processErrorLock sync.Mutex
|
||||
rttMonitor *rttMonitor
|
||||
}
|
||||
|
||||
// updateTopologyCallback is a callback used to create a server that should be called when the parent Topology instance
|
||||
// should be updated based on a new server description. The callback must return the server description that should be
|
||||
// stored by the server.
|
||||
type updateTopologyCallback func(description.Server) description.Server
|
||||
|
||||
// ConnectServer creates a new Server and then initializes it using the
|
||||
// Connect method.
|
||||
func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) {
|
||||
srvr := NewServer(addr, topologyID, opts...)
|
||||
err := srvr.Connect(updateCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srvr, nil
|
||||
}
|
||||
|
||||
// NewServer creates a new server. The mongodb server at the address will be monitored
|
||||
// on an internal monitoring goroutine.
|
||||
func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...ServerOption) *Server {
|
||||
cfg := newServerConfig(opts...)
|
||||
globalCtx, globalCtxCancel := context.WithCancel(context.Background())
|
||||
s := &Server{
|
||||
state: serverDisconnected,
|
||||
|
||||
cfg: cfg,
|
||||
address: addr,
|
||||
|
||||
done: make(chan struct{}),
|
||||
checkNow: make(chan struct{}, 1),
|
||||
disconnecting: make(chan struct{}),
|
||||
|
||||
topologyID: topologyID,
|
||||
|
||||
subscribers: make(map[uint64]chan description.Server),
|
||||
globalCtx: globalCtx,
|
||||
globalCtxCancel: globalCtxCancel,
|
||||
}
|
||||
s.desc.Store(description.NewDefaultServer(addr))
|
||||
rttCfg := &rttConfig{
|
||||
interval: cfg.heartbeatInterval,
|
||||
minRTTWindow: 5 * time.Minute,
|
||||
createConnectionFn: s.createConnection,
|
||||
createOperationFn: s.createBaseOperation,
|
||||
}
|
||||
s.rttMonitor = newRTTMonitor(rttCfg)
|
||||
|
||||
pc := poolConfig{
|
||||
Address: addr,
|
||||
MinPoolSize: cfg.minConns,
|
||||
MaxPoolSize: cfg.maxConns,
|
||||
MaxConnecting: cfg.maxConnecting,
|
||||
MaxIdleTime: cfg.poolMaxIdleTime,
|
||||
MaintainInterval: cfg.poolMaintainInterval,
|
||||
PoolMonitor: cfg.poolMonitor,
|
||||
handshakeErrFn: s.ProcessHandshakeError,
|
||||
}
|
||||
|
||||
connectionOpts := copyConnectionOpts(cfg.connectionOpts)
|
||||
s.pool = newPool(pc, connectionOpts...)
|
||||
s.publishServerOpeningEvent(s.address)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Connect initializes the Server by starting background monitoring goroutines.
|
||||
// This method must be called before a Server can be used.
|
||||
func (s *Server) Connect(updateCallback updateTopologyCallback) error {
|
||||
if !atomic.CompareAndSwapInt64(&s.state, serverDisconnected, serverConnected) {
|
||||
return ErrServerConnected
|
||||
}
|
||||
|
||||
desc := description.NewDefaultServer(s.address)
|
||||
if s.cfg.loadBalanced {
|
||||
// LBs automatically start off with kind LoadBalancer because there is no monitoring routine for state changes.
|
||||
desc.Kind = description.LoadBalancer
|
||||
}
|
||||
s.desc.Store(desc)
|
||||
s.updateTopologyCallback.Store(updateCallback)
|
||||
|
||||
if !s.cfg.monitoringDisabled && !s.cfg.loadBalanced {
|
||||
s.rttMonitor.connect()
|
||||
s.closewg.Add(1)
|
||||
go s.update()
|
||||
}
|
||||
|
||||
// The CMAP spec describes that pools should only be marked "ready" when the server description
|
||||
// is updated to something other than "Unknown". However, we maintain the previous Server
|
||||
// behavior here and immediately mark the pool as ready during Connect() to simplify and speed
|
||||
// up the Client startup behavior. The risk of marking a pool as ready proactively during
|
||||
// Connect() is that we could attempt to create connections to a server that was configured
|
||||
// erroneously until the first server check or checkOut() failure occurs, when the SDAM error
|
||||
// handler would transition the Server back to "Unknown" and set the pool to "paused".
|
||||
return s.pool.ready()
|
||||
}
|
||||
|
||||
// Disconnect closes sockets to the server referenced by this Server.
|
||||
// Subscriptions to this Server will be closed. Disconnect will shutdown
|
||||
// any monitoring goroutines, closeConnection the idle connection pool, and will
|
||||
// wait until all the in use connections have been returned to the connection
|
||||
// pool and are closed before returning. If the context expires via
|
||||
// cancellation, deadline, or timeout before the in use connections have been
|
||||
// returned, the in use connections will be closed, resulting in the failure of
|
||||
// any in flight read or write operations. If this method returns with no
|
||||
// errors, all connections associated with this Server have been closed.
|
||||
func (s *Server) Disconnect(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt64(&s.state, serverConnected, serverDisconnecting) {
|
||||
return ErrServerClosed
|
||||
}
|
||||
|
||||
s.updateTopologyCallback.Store((updateTopologyCallback)(nil))
|
||||
|
||||
// Cancel the global context so any new contexts created from it will be automatically cancelled. Close the done
|
||||
// channel so the update() routine will know that it can stop. Cancel any in-progress monitoring checks at the end.
|
||||
// The done channel is closed before cancelling the check so the update routine() will immediately detect that it
|
||||
// can stop rather than trying to create new connections until the read from done succeeds.
|
||||
s.globalCtxCancel()
|
||||
close(s.done)
|
||||
s.cancelCheck()
|
||||
|
||||
s.rttMonitor.disconnect()
|
||||
s.pool.close(ctx)
|
||||
|
||||
s.closewg.Wait()
|
||||
atomic.StoreInt64(&s.state, serverDisconnected)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connection gets a connection to the server.
|
||||
func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
|
||||
if atomic.LoadInt64(&s.state) != serverConnected {
|
||||
return nil, ErrServerClosed
|
||||
}
|
||||
|
||||
// Increment the operation count before calling checkOut to make sure that all connection
|
||||
// requests are included in the operation count, including those in the wait queue. If we got an
|
||||
// error instead of a connection, immediately decrement the operation count.
|
||||
atomic.AddInt64(&s.operationCount, 1)
|
||||
conn, err := s.pool.checkOut(ctx)
|
||||
if err != nil {
|
||||
atomic.AddInt64(&s.operationCount, -1)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Connection{
|
||||
connection: conn,
|
||||
cleanupServerFn: func() {
|
||||
// Decrement the operation count whenever the caller is done with the connection. Note
|
||||
// that cleanupServerFn() is not called while the connection is pinned to a cursor or
|
||||
// transaction, so the operation count is not decremented until the cursor is closed or
|
||||
// the transaction is committed or aborted. Use an int64 instead of a uint64 to mitigate
|
||||
// the impact of any possible bugs that could cause the uint64 to underflow, which would
|
||||
// make the server much less selectable.
|
||||
atomic.AddInt64(&s.operationCount, -1)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessHandshakeError implements SDAM error handling for errors that occur before a connection
|
||||
// finishes handshaking.
|
||||
func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) {
|
||||
// Ignore the error if the server is behind a load balancer but the service ID is unknown. This indicates that the
|
||||
// error happened when dialing the connection or during the MongoDB handshake, so we don't know the service ID to
|
||||
// use for clearing the pool.
|
||||
if err == nil || s.cfg.loadBalanced && serviceID == nil {
|
||||
return
|
||||
}
|
||||
// Ignore the error if the connection is stale.
|
||||
if startingGenerationNumber < s.pool.generation.getGeneration(serviceID) {
|
||||
return
|
||||
}
|
||||
|
||||
wrappedConnErr := unwrapConnectionError(err)
|
||||
if wrappedConnErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Must hold the processErrorLock while updating the server description and clearing the pool.
|
||||
// Not holding the lock leads to possible out-of-order processing of pool.clear() and
|
||||
// pool.ready() calls from concurrent server description updates.
|
||||
s.processErrorLock.Lock()
|
||||
defer s.processErrorLock.Unlock()
|
||||
|
||||
// Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set
|
||||
// the description.Server appropriately. The description should not have a TopologyVersion because the staleness
|
||||
// checking logic above has already determined that this description is not stale.
|
||||
s.updateDescription(description.NewServerFromError(s.address, wrappedConnErr, nil))
|
||||
s.pool.clear(err, serviceID)
|
||||
s.cancelCheck()
|
||||
}
|
||||
|
||||
// Description returns a description of the server as of the last heartbeat.
|
||||
func (s *Server) Description() description.Server {
|
||||
return s.desc.Load().(description.Server)
|
||||
}
|
||||
|
||||
// SelectedDescription returns a description.SelectedServer with a Kind of
|
||||
// Single. This can be used when performing tasks like monitoring a batch
|
||||
// of servers and you want to run one off commands against those servers.
|
||||
func (s *Server) SelectedDescription() description.SelectedServer {
|
||||
sdesc := s.Description()
|
||||
return description.SelectedServer{
|
||||
Server: sdesc,
|
||||
Kind: description.Single,
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe returns a ServerSubscription which has a channel on which all
|
||||
// updated server descriptions will be sent. The channel will have a buffer
|
||||
// size of one, and will be pre-populated with the current description.
|
||||
func (s *Server) Subscribe() (*ServerSubscription, error) {
|
||||
if atomic.LoadInt64(&s.state) != serverConnected {
|
||||
return nil, ErrSubscribeAfterClosed
|
||||
}
|
||||
ch := make(chan description.Server, 1)
|
||||
ch <- s.desc.Load().(description.Server)
|
||||
|
||||
s.subLock.Lock()
|
||||
defer s.subLock.Unlock()
|
||||
if s.subscriptionsClosed {
|
||||
return nil, ErrSubscribeAfterClosed
|
||||
}
|
||||
id := s.currentSubscriberID
|
||||
s.subscribers[id] = ch
|
||||
s.currentSubscriberID++
|
||||
|
||||
ss := &ServerSubscription{
|
||||
C: ch,
|
||||
s: s,
|
||||
id: id,
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// RequestImmediateCheck will cause the server to send a heartbeat immediately
|
||||
// instead of waiting for the heartbeat timeout.
|
||||
func (s *Server) RequestImmediateCheck() {
|
||||
select {
|
||||
case s.checkNow <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// getWriteConcernErrorForProcessing extracts a driver.WriteConcernError from the provided error. This function returns
|
||||
// (error, true) if the error is a WriteConcernError and the falls under the requirements for SDAM error
|
||||
// handling and (nil, false) otherwise.
|
||||
func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bool) {
|
||||
writeCmdErr, ok := err.(driver.WriteCommandError)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
wcerr := writeCmdErr.WriteConcernError
|
||||
if wcerr != nil && (wcerr.NodeIsRecovering() || wcerr.NotPrimary()) {
|
||||
return wcerr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
|
||||
func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult {
|
||||
// ignore nil error
|
||||
if err == nil {
|
||||
return driver.NoChange
|
||||
}
|
||||
|
||||
// Must hold the processErrorLock while updating the server description and clearing the pool.
|
||||
// Not holding the lock leads to possible out-of-order processing of pool.clear() and
|
||||
// pool.ready() calls from concurrent server description updates.
|
||||
s.processErrorLock.Lock()
|
||||
defer s.processErrorLock.Unlock()
|
||||
|
||||
// ignore stale error
|
||||
if conn.Stale() {
|
||||
return driver.NoChange
|
||||
}
|
||||
// Invalidate server description if not primary or node recovering error occurs.
|
||||
// These errors can be reported as a command error or a write concern error.
|
||||
desc := conn.Description()
|
||||
if cerr, ok := err.(driver.Error); ok && (cerr.NodeIsRecovering() || cerr.NotPrimary()) {
|
||||
// ignore stale error
|
||||
if desc.TopologyVersion.CompareToIncoming(cerr.TopologyVersion) >= 0 {
|
||||
return driver.NoChange
|
||||
}
|
||||
|
||||
// updates description to unknown
|
||||
s.updateDescription(description.NewServerFromError(s.address, err, cerr.TopologyVersion))
|
||||
s.RequestImmediateCheck()
|
||||
|
||||
res := driver.ServerMarkedUnknown
|
||||
// If the node is shutting down or is older than 4.2, we synchronously clear the pool
|
||||
if cerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 {
|
||||
res = driver.ConnectionPoolCleared
|
||||
s.pool.clear(err, desc.ServiceID)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
if wcerr, ok := getWriteConcernErrorForProcessing(err); ok {
|
||||
// ignore stale error
|
||||
if desc.TopologyVersion.CompareToIncoming(wcerr.TopologyVersion) >= 0 {
|
||||
return driver.NoChange
|
||||
}
|
||||
|
||||
// updates description to unknown
|
||||
s.updateDescription(description.NewServerFromError(s.address, err, wcerr.TopologyVersion))
|
||||
s.RequestImmediateCheck()
|
||||
|
||||
res := driver.ServerMarkedUnknown
|
||||
// If the node is shutting down or is older than 4.2, we synchronously clear the pool
|
||||
if wcerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 {
|
||||
res = driver.ConnectionPoolCleared
|
||||
s.pool.clear(err, desc.ServiceID)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
wrappedConnErr := unwrapConnectionError(err)
|
||||
if wrappedConnErr == nil {
|
||||
return driver.NoChange
|
||||
}
|
||||
|
||||
// Ignore transient timeout errors.
|
||||
if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() {
|
||||
return driver.NoChange
|
||||
}
|
||||
if wrappedConnErr == context.Canceled || wrappedConnErr == context.DeadlineExceeded {
|
||||
return driver.NoChange
|
||||
}
|
||||
|
||||
// For a non-timeout network error, we clear the pool, set the description to Unknown, and cancel the in-progress
|
||||
// monitoring check. The check is cancelled last to avoid a post-cancellation reconnect racing with
|
||||
// updateDescription.
|
||||
s.updateDescription(description.NewServerFromError(s.address, err, nil))
|
||||
s.pool.clear(err, desc.ServiceID)
|
||||
s.cancelCheck()
|
||||
return driver.ConnectionPoolCleared
|
||||
}
|
||||
|
||||
// update handles performing heartbeats and updating any subscribers of the
|
||||
// newest description.Server retrieved.
|
||||
func (s *Server) update() {
|
||||
defer s.closewg.Done()
|
||||
heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
|
||||
rateLimiter := time.NewTicker(minHeartbeatInterval)
|
||||
defer heartbeatTicker.Stop()
|
||||
defer rateLimiter.Stop()
|
||||
checkNow := s.checkNow
|
||||
done := s.done
|
||||
|
||||
var doneOnce bool
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if doneOnce {
|
||||
return
|
||||
}
|
||||
// We keep this goroutine alive attempting to read from the done channel.
|
||||
<-done
|
||||
}
|
||||
}()
|
||||
|
||||
closeServer := func() {
|
||||
doneOnce = true
|
||||
s.subLock.Lock()
|
||||
for id, c := range s.subscribers {
|
||||
close(c)
|
||||
delete(s.subscribers, id)
|
||||
}
|
||||
s.subscriptionsClosed = true
|
||||
s.subLock.Unlock()
|
||||
|
||||
// We don't need to take s.heartbeatLock here because closeServer is called synchronously when the select checks
|
||||
// below detect that the server is being closed, so we can be sure that the connection isn't being used.
|
||||
if s.conn != nil {
|
||||
_ = s.conn.close()
|
||||
}
|
||||
}
|
||||
|
||||
waitUntilNextCheck := func() {
|
||||
// Wait until heartbeatFrequency elapses, an application operation requests an immediate check, or the server
|
||||
// is disconnecting.
|
||||
select {
|
||||
case <-heartbeatTicker.C:
|
||||
case <-checkNow:
|
||||
case <-done:
|
||||
// Return because the next update iteration will check the done channel again and clean up.
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we only return if minHeartbeatFrequency has elapsed or the server is disconnecting.
|
||||
select {
|
||||
case <-rateLimiter.C:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
// Check if the server is disconnecting. Even if waitForNextCheck has already read from the done channel, we
|
||||
// can safely read from it again because Disconnect closes the channel.
|
||||
select {
|
||||
case <-done:
|
||||
closeServer()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
previousDescription := s.Description()
|
||||
|
||||
// Perform the next check.
|
||||
desc, err := s.check()
|
||||
if err == errCheckCancelled {
|
||||
if atomic.LoadInt64(&s.state) != serverConnected {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the server is not disconnecting, the check was cancelled by an application operation after an error.
|
||||
// Wait before running the next check.
|
||||
waitUntilNextCheck()
|
||||
continue
|
||||
}
|
||||
|
||||
// Must hold the processErrorLock while updating the server description and clearing the
|
||||
// pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and
|
||||
// pool.ready() calls from concurrent server description updates.
|
||||
s.processErrorLock.Lock()
|
||||
s.updateDescription(desc)
|
||||
if err := desc.LastError; err != nil {
|
||||
// Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear
|
||||
// because the monitoring routine only runs for non-load balanced deployments in which servers don't return
|
||||
// IDs.
|
||||
s.pool.clear(err, nil)
|
||||
}
|
||||
s.processErrorLock.Unlock()
|
||||
|
||||
// If the server supports streaming or we're already streaming, we want to move to streaming the next response
|
||||
// without waiting. If the server has transitioned to Unknown from a network error, we want to do another
|
||||
// check without waiting in case it was a transient error and the server isn't actually down.
|
||||
serverSupportsStreaming := desc.Kind != description.Unknown && desc.TopologyVersion != nil
|
||||
connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming()
|
||||
transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil &&
|
||||
previousDescription.Kind != description.Unknown
|
||||
|
||||
if serverSupportsStreaming || connectionIsStreaming || transitionedFromNetworkError {
|
||||
continue
|
||||
}
|
||||
|
||||
// The server either does not support the streamable protocol or is not in a healthy state, so we wait until
|
||||
// the next check.
|
||||
waitUntilNextCheck()
|
||||
}
|
||||
}
|
||||
|
||||
// updateDescription handles updating the description on the Server, notifying
|
||||
// subscribers, and potentially draining the connection pool. The initial
|
||||
// parameter is used to determine if this is the first description from the
|
||||
// server.
|
||||
func (s *Server) updateDescription(desc description.Server) {
|
||||
if s.cfg.loadBalanced {
|
||||
// In load balanced mode, there are no updates from the monitoring routine. For errors encountered in pooled
|
||||
// connections, the server should not be marked Unknown to ensure that the LB remains selectable.
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// ¯\_(ツ)_/¯
|
||||
_ = recover()
|
||||
}()
|
||||
|
||||
// Anytime we update the server description to something other than "unknown", set the pool to
|
||||
// "ready". Do this before updating the description so that connections can be checked out as
|
||||
// soon as the server is selectable. If the pool is already ready, this operation is a no-op.
|
||||
// Note that this behavior is roughly consistent with the current Go driver behavior (connects
|
||||
// to all servers, even non-data-bearing nodes) but deviates slightly from CMAP spec, which
|
||||
// specifies a more restricted set of server descriptions and topologies that should mark the
|
||||
// pool ready. We don't have access to the topology here, so prefer the current Go driver
|
||||
// behavior for simplicity.
|
||||
if desc.Kind != description.Unknown {
|
||||
_ = s.pool.ready()
|
||||
}
|
||||
|
||||
// Use the updateTopologyCallback to update the parent Topology and get the description that should be stored.
|
||||
callback, ok := s.updateTopologyCallback.Load().(updateTopologyCallback)
|
||||
if ok && callback != nil {
|
||||
desc = callback(desc)
|
||||
}
|
||||
s.desc.Store(desc)
|
||||
|
||||
s.subLock.Lock()
|
||||
for _, c := range s.subscribers {
|
||||
select {
|
||||
// drain the channel if it isn't empty
|
||||
case <-c:
|
||||
default:
|
||||
}
|
||||
c <- desc
|
||||
}
|
||||
s.subLock.Unlock()
|
||||
}
|
||||
|
||||
// createConnection creates a new connection instance but does not call connect on it. The caller must call connect
|
||||
// before the connection can be used for network operations.
|
||||
func (s *Server) createConnection() *connection {
|
||||
opts := copyConnectionOpts(s.cfg.connectionOpts)
|
||||
opts = append(opts,
|
||||
WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
|
||||
WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
|
||||
WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
|
||||
// We override whatever handshaker is currently attached to the options with a basic
|
||||
// one because need to make sure we don't do auth.
|
||||
WithHandshaker(func(h Handshaker) Handshaker {
|
||||
return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts).
|
||||
ServerAPI(s.cfg.serverAPI)
|
||||
}),
|
||||
// Override any monitors specified in options with nil to avoid monitoring heartbeats.
|
||||
WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }),
|
||||
)
|
||||
|
||||
return newConnection(s.address, opts...)
|
||||
}
|
||||
|
||||
func copyConnectionOpts(opts []ConnectionOption) []ConnectionOption {
|
||||
optsCopy := make([]ConnectionOption, len(opts))
|
||||
copy(optsCopy, opts)
|
||||
return optsCopy
|
||||
}
|
||||
|
||||
func (s *Server) setupHeartbeatConnection() error {
|
||||
conn := s.createConnection()
|
||||
|
||||
// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
|
||||
s.heartbeatLock.Lock()
|
||||
if s.heartbeatCtxCancel != nil {
|
||||
// Ensure the previous context is cancelled to avoid a leak.
|
||||
s.heartbeatCtxCancel()
|
||||
}
|
||||
s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
|
||||
s.conn = conn
|
||||
s.heartbeatLock.Unlock()
|
||||
|
||||
return s.conn.connect(s.heartbeatCtx)
|
||||
}
|
||||
|
||||
// cancelCheck cancels in-progress connection dials and reads. It does not set any fields on the server.
|
||||
func (s *Server) cancelCheck() {
|
||||
var conn *connection
|
||||
|
||||
// Take heartbeatLock for mutual exclusion with the checks in the update function.
|
||||
s.heartbeatLock.Lock()
|
||||
if s.heartbeatCtx != nil {
|
||||
s.heartbeatCtxCancel()
|
||||
}
|
||||
conn = s.conn
|
||||
s.heartbeatLock.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// If the connection exists, we need to wait for it to be connected because conn.connect() and
|
||||
// conn.close() cannot be called concurrently. If the connection wasn't successfully opened, its
|
||||
// state was set back to disconnected, so calling conn.close() will be a no-op.
|
||||
conn.closeConnectContext()
|
||||
conn.wait()
|
||||
_ = conn.close()
|
||||
}
|
||||
|
||||
func (s *Server) checkWasCancelled() bool {
|
||||
return s.heartbeatCtx.Err() != nil
|
||||
}
|
||||
|
||||
func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello {
|
||||
return operation.
|
||||
NewHello().
|
||||
ClusterClock(s.cfg.clock).
|
||||
Deployment(driver.SingleConnectionDeployment{conn}).
|
||||
ServerAPI(s.cfg.serverAPI)
|
||||
}
|
||||
|
||||
func (s *Server) check() (description.Server, error) {
|
||||
var descPtr *description.Server
|
||||
var err error
|
||||
var durationNanos int64
|
||||
|
||||
// Create a new connection if this is the first check, the connection was closed after an error during the previous
|
||||
// check, or the previous check was cancelled.
|
||||
if s.conn == nil || s.conn.closed() || s.checkWasCancelled() {
|
||||
// Create a new connection and add it's handshake RTT as a sample.
|
||||
err = s.setupHeartbeatConnection()
|
||||
if err == nil {
|
||||
// Use the description from the connection handshake as the value for this check.
|
||||
s.rttMonitor.addSample(s.conn.helloRTT)
|
||||
descPtr = &s.conn.desc
|
||||
}
|
||||
}
|
||||
|
||||
if descPtr == nil && err == nil {
|
||||
// An existing connection is being used. Use the server description properties to execute the right heartbeat.
|
||||
|
||||
// Wrap conn in a type that implements driver.StreamerConnection.
|
||||
heartbeatConn := initConnection{s.conn}
|
||||
baseOperation := s.createBaseOperation(heartbeatConn)
|
||||
previousDescription := s.Description()
|
||||
streamable := previousDescription.TopologyVersion != nil
|
||||
|
||||
s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable)
|
||||
start := time.Now()
|
||||
switch {
|
||||
case s.conn.getCurrentlyStreaming():
|
||||
// The connection is already in a streaming state, so we stream the next response.
|
||||
err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn)
|
||||
case streamable:
|
||||
// The server supports the streamable protocol. Set the socket timeout to
|
||||
// connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable hello request. Set conn.canStream so
|
||||
// the wire message will advertise streaming support to the server.
|
||||
|
||||
// Calculation for maxAwaitTimeMS is taken from time.Duration.Milliseconds (added in Go 1.13).
|
||||
maxAwaitTimeMS := int64(s.cfg.heartbeatInterval) / 1e6
|
||||
// If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS +
|
||||
// heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS
|
||||
// server-side.
|
||||
socketTimeout := s.cfg.heartbeatTimeout
|
||||
if socketTimeout != 0 {
|
||||
socketTimeout += s.cfg.heartbeatInterval
|
||||
}
|
||||
s.conn.setSocketTimeout(socketTimeout)
|
||||
baseOperation = baseOperation.TopologyVersion(previousDescription.TopologyVersion).
|
||||
MaxAwaitTimeMS(maxAwaitTimeMS)
|
||||
s.conn.setCanStream(true)
|
||||
err = baseOperation.Execute(s.heartbeatCtx)
|
||||
default:
|
||||
// The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and
|
||||
// execute a regular heartbeat without any additional parameters.
|
||||
|
||||
s.conn.setSocketTimeout(s.cfg.heartbeatTimeout)
|
||||
err = baseOperation.Execute(s.heartbeatCtx)
|
||||
}
|
||||
durationNanos = time.Since(start).Nanoseconds()
|
||||
|
||||
if err == nil {
|
||||
tempDesc := baseOperation.Result(s.address)
|
||||
descPtr = &tempDesc
|
||||
s.publishServerHeartbeatSucceededEvent(s.conn.ID(), durationNanos, tempDesc, s.conn.getCurrentlyStreaming() || streamable)
|
||||
} else {
|
||||
// Close the connection here rather than below so we ensure we're not closing a connection that wasn't
|
||||
// successfully created.
|
||||
if s.conn != nil {
|
||||
_ = s.conn.close()
|
||||
}
|
||||
s.publishServerHeartbeatFailedEvent(s.conn.ID(), durationNanos, err, s.conn.getCurrentlyStreaming() || streamable)
|
||||
}
|
||||
}
|
||||
|
||||
if descPtr != nil {
|
||||
// The check was successful. Set the average RTT and the 90th percentile RTT and return.
|
||||
desc := *descPtr
|
||||
desc = desc.SetAverageRTT(s.rttMonitor.EWMA())
|
||||
desc.HeartbeatInterval = s.cfg.heartbeatInterval
|
||||
return desc, nil
|
||||
}
|
||||
|
||||
if s.checkWasCancelled() {
|
||||
// If the previous check was cancelled, we don't want to clear the pool. Return a sentinel error so the caller
|
||||
// will know that an actual error didn't occur.
|
||||
return emptyDescription, errCheckCancelled
|
||||
}
|
||||
|
||||
// An error occurred. We reset the RTT monitor for all errors and return an Unknown description. The pool must also
|
||||
// be cleared, but only after the description has already been updated, so that is handled by the caller.
|
||||
topologyVersion := extractTopologyVersion(err)
|
||||
s.rttMonitor.reset()
|
||||
return description.NewServerFromError(s.address, err, topologyVersion), nil
|
||||
}
|
||||
|
||||
func extractTopologyVersion(err error) *description.TopologyVersion {
|
||||
if ce, ok := err.(ConnectionError); ok {
|
||||
err = ce.Wrapped
|
||||
}
|
||||
|
||||
switch converted := err.(type) {
|
||||
case driver.Error:
|
||||
return converted.TopologyVersion
|
||||
case driver.WriteCommandError:
|
||||
if converted.WriteConcernError != nil {
|
||||
return converted.WriteConcernError.TopologyVersion
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RTTMonitor returns this server's round-trip-time monitor.
|
||||
func (s *Server) RTTMonitor() driver.RTTMonitor {
|
||||
return s.rttMonitor
|
||||
}
|
||||
|
||||
// OperationCount returns the current number of in-progress operations for this server.
|
||||
func (s *Server) OperationCount() int64 {
|
||||
return atomic.LoadInt64(&s.operationCount)
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (s *Server) String() string {
|
||||
desc := s.Description()
|
||||
state := atomic.LoadInt64(&s.state)
|
||||
str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
|
||||
s.address, desc.Kind, serverStateString(state))
|
||||
if len(desc.Tags) != 0 {
|
||||
str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
|
||||
}
|
||||
if state == serverConnected {
|
||||
str += fmt.Sprintf(", Average RTT: %s, Min RTT: %s", desc.AverageRTT, s.RTTMonitor().Min())
|
||||
}
|
||||
if desc.LastError != nil {
|
||||
str += fmt.Sprintf(", Last error: %s", desc.LastError)
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// ServerSubscription represents a subscription to the description.Server updates for
|
||||
// a specific server.
|
||||
type ServerSubscription struct {
|
||||
C <-chan description.Server
|
||||
s *Server
|
||||
id uint64
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes this ServerSubscription from updates and closes the
|
||||
// subscription channel.
|
||||
func (ss *ServerSubscription) Unsubscribe() error {
|
||||
ss.s.subLock.Lock()
|
||||
defer ss.s.subLock.Unlock()
|
||||
if ss.s.subscriptionsClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, ok := ss.s.subscribers[ss.id]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(ss.s.subscribers, ss.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishes a ServerOpeningEvent to indicate the server is being initialized
|
||||
func (s *Server) publishServerOpeningEvent(addr address.Address) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
serverOpening := &event.ServerOpeningEvent{
|
||||
Address: addr,
|
||||
TopologyID: s.topologyID,
|
||||
}
|
||||
|
||||
if s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerOpening != nil {
|
||||
s.cfg.serverMonitor.ServerOpening(serverOpening)
|
||||
}
|
||||
}
|
||||
|
||||
// publishes a ServerHeartbeatStartedEvent to indicate a hello command has started
|
||||
func (s *Server) publishServerHeartbeatStartedEvent(connectionID string, await bool) {
|
||||
serverHeartbeatStarted := &event.ServerHeartbeatStartedEvent{
|
||||
ConnectionID: connectionID,
|
||||
Awaited: await,
|
||||
}
|
||||
|
||||
if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatStarted != nil {
|
||||
s.cfg.serverMonitor.ServerHeartbeatStarted(serverHeartbeatStarted)
|
||||
}
|
||||
}
|
||||
|
||||
// publishes a ServerHeartbeatSucceededEvent to indicate hello has succeeded
|
||||
func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string,
|
||||
durationNanos int64,
|
||||
desc description.Server,
|
||||
await bool) {
|
||||
serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{
|
||||
DurationNanos: durationNanos,
|
||||
Reply: desc,
|
||||
ConnectionID: connectionID,
|
||||
Awaited: await,
|
||||
}
|
||||
|
||||
if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatSucceeded != nil {
|
||||
s.cfg.serverMonitor.ServerHeartbeatSucceeded(serverHeartbeatSucceeded)
|
||||
}
|
||||
}
|
||||
|
||||
// publishes a ServerHeartbeatFailedEvent to indicate hello has failed
|
||||
func (s *Server) publishServerHeartbeatFailedEvent(connectionID string,
|
||||
durationNanos int64,
|
||||
err error,
|
||||
await bool) {
|
||||
serverHeartbeatFailed := &event.ServerHeartbeatFailedEvent{
|
||||
DurationNanos: durationNanos,
|
||||
Failure: err,
|
||||
ConnectionID: connectionID,
|
||||
Awaited: await,
|
||||
}
|
||||
|
||||
if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatFailed != nil {
|
||||
s.cfg.serverMonitor.ServerHeartbeatFailed(serverHeartbeatFailed)
|
||||
}
|
||||
}
|
||||
|
||||
// unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error.
|
||||
func unwrapConnectionError(err error) error {
|
||||
// This is essentially an implementation of errors.As to unwrap this error until we get a ConnectionError and then
|
||||
// return ConnectionError.Wrapped.
|
||||
|
||||
connErr, ok := err.(ConnectionError)
|
||||
if ok {
|
||||
return connErr.Wrapped
|
||||
}
|
||||
|
||||
driverErr, ok := err.(driver.Error)
|
||||
if !ok || !driverErr.NetworkError() {
|
||||
return nil
|
||||
}
|
||||
|
||||
connErr, ok = driverErr.Wrapped.(ConnectionError)
|
||||
if ok {
|
||||
return connErr.Wrapped
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
195
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go
generated
vendored
Normal file
195
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go
generated
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
var defaultRegistry = bson.NewRegistryBuilder().Build()
|
||||
|
||||
type serverConfig struct {
|
||||
clock *session.ClusterClock
|
||||
compressionOpts []string
|
||||
connectionOpts []ConnectionOption
|
||||
appname string
|
||||
heartbeatInterval time.Duration
|
||||
heartbeatTimeout time.Duration
|
||||
serverMonitor *event.ServerMonitor
|
||||
registry *bsoncodec.Registry
|
||||
monitoringDisabled bool
|
||||
serverAPI *driver.ServerAPIOptions
|
||||
loadBalanced bool
|
||||
|
||||
// Connection pool options.
|
||||
maxConns uint64
|
||||
minConns uint64
|
||||
maxConnecting uint64
|
||||
poolMonitor *event.PoolMonitor
|
||||
poolMaxIdleTime time.Duration
|
||||
poolMaintainInterval time.Duration
|
||||
}
|
||||
|
||||
func newServerConfig(opts ...ServerOption) *serverConfig {
|
||||
cfg := &serverConfig{
|
||||
heartbeatInterval: 10 * time.Second,
|
||||
heartbeatTimeout: 10 * time.Second,
|
||||
registry: defaultRegistry,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ServerOption configures a server.
|
||||
type ServerOption func(*serverConfig)
|
||||
|
||||
// ServerAPIFromServerOptions will return the server API options if they have been functionally set on the ServerOption
|
||||
// slice.
|
||||
func ServerAPIFromServerOptions(opts []ServerOption) *driver.ServerAPIOptions {
|
||||
return newServerConfig(opts...).serverAPI
|
||||
}
|
||||
|
||||
func withMonitoringDisabled(fn func(bool) bool) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.monitoringDisabled = fn(cfg.monitoringDisabled)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionOptions configures the server's connections.
|
||||
func WithConnectionOptions(fn func(...ConnectionOption) []ConnectionOption) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.connectionOpts = fn(cfg.connectionOpts...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithCompressionOptions configures the server's compressors.
|
||||
func WithCompressionOptions(fn func(...string) []string) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.compressionOpts = fn(cfg.compressionOpts...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerAppName configures the server's application name.
|
||||
func WithServerAppName(fn func(string) string) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.appname = fn(cfg.appname)
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeartbeatInterval configures a server's heartbeat interval.
|
||||
func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.heartbeatInterval = fn(cfg.heartbeatInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to
|
||||
// connection.
|
||||
func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxConnections configures the maximum number of connections to allow for
|
||||
// a given server. If max is 0, then maximum connection pool size is not limited.
|
||||
func WithMaxConnections(fn func(uint64) uint64) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.maxConns = fn(cfg.maxConns)
|
||||
}
|
||||
}
|
||||
|
||||
// WithMinConnections configures the minimum number of connections to allow for
|
||||
// a given server. If min is 0, then there is no lower limit to the number of
|
||||
// connections.
|
||||
func WithMinConnections(fn func(uint64) uint64) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.minConns = fn(cfg.minConns)
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxConnecting configures the maximum number of connections a connection
|
||||
// pool may establish simultaneously. If maxConnecting is 0, the default value
|
||||
// of 2 is used.
|
||||
func WithMaxConnecting(fn func(uint64) uint64) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.maxConnecting = fn(cfg.maxConnecting)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionPoolMaxIdleTime configures the maximum time that a connection can remain idle in the connection pool
|
||||
// before being removed. If connectionPoolMaxIdleTime is 0, then no idle time is set and connections will not be removed
|
||||
// because of their age
|
||||
func WithConnectionPoolMaxIdleTime(fn func(time.Duration) time.Duration) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.poolMaxIdleTime = fn(cfg.poolMaxIdleTime)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionPoolMaintainInterval configures the interval that the background connection pool
|
||||
// maintenance goroutine runs.
|
||||
func WithConnectionPoolMaintainInterval(fn func(time.Duration) time.Duration) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.poolMaintainInterval = fn(cfg.poolMaintainInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionPoolMonitor configures the monitor for all connection pool actions
|
||||
func WithConnectionPoolMonitor(fn func(*event.PoolMonitor) *event.PoolMonitor) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.poolMonitor = fn(cfg.poolMonitor)
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerMonitor configures the monitor for all SDAM events for a server
|
||||
func WithServerMonitor(fn func(*event.ServerMonitor) *event.ServerMonitor) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.serverMonitor = fn(cfg.serverMonitor)
|
||||
}
|
||||
}
|
||||
|
||||
// WithClock configures the ClusterClock for the server to use.
|
||||
func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.clock = fn(cfg.clock)
|
||||
}
|
||||
}
|
||||
|
||||
// WithRegistry configures the registry for the server to use when creating
|
||||
// cursors.
|
||||
func WithRegistry(fn func(*bsoncodec.Registry) *bsoncodec.Registry) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.registry = fn(cfg.registry)
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerAPI configures the server API options for the server to use.
|
||||
func WithServerAPI(fn func(serverAPI *driver.ServerAPIOptions) *driver.ServerAPIOptions) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.serverAPI = fn(cfg.serverAPI)
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerLoadBalanced specifies whether or not the server is behind a load balancer.
|
||||
func WithServerLoadBalanced(fn func(bool) bool) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.loadBalanced = fn(cfg.loadBalanced)
|
||||
}
|
||||
}
|
58
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source_1_16.go
generated
vendored
Normal file
58
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source_1_16.go
generated
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build !go1.17
|
||||
// +build !go1.17
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
type tlsConn interface {
|
||||
net.Conn
|
||||
|
||||
// Only require Handshake on the interface for Go 1.16 and less.
|
||||
Handshake() error
|
||||
ConnectionState() tls.ConnectionState
|
||||
}
|
||||
|
||||
var _ tlsConn = (*tls.Conn)(nil)
|
||||
|
||||
type tlsConnectionSource interface {
|
||||
Client(net.Conn, *tls.Config) tlsConn
|
||||
}
|
||||
|
||||
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn
|
||||
|
||||
var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)
|
||||
|
||||
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
|
||||
return t(nc, cfg)
|
||||
}
|
||||
|
||||
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
|
||||
return tls.Client(nc, cfg)
|
||||
}
|
||||
|
||||
// clientHandshake will perform a handshake with a goroutine and wait for its completion on Go 1.16 and less
|
||||
// when HandshakeContext is not available.
|
||||
func clientHandshake(ctx context.Context, client tlsConn) error {
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- client.Handshake()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
47
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source_1_17.go
generated
vendored
Normal file
47
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source_1_17.go
generated
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build go1.17
|
||||
// +build go1.17
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
type tlsConn interface {
|
||||
net.Conn
|
||||
|
||||
// Require HandshakeContext on the interface for Go 1.17 and higher.
|
||||
HandshakeContext(ctx context.Context) error
|
||||
ConnectionState() tls.ConnectionState
|
||||
}
|
||||
|
||||
var _ tlsConn = (*tls.Conn)(nil)
|
||||
|
||||
type tlsConnectionSource interface {
|
||||
Client(net.Conn, *tls.Config) tlsConn
|
||||
}
|
||||
|
||||
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn
|
||||
|
||||
var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)
|
||||
|
||||
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
|
||||
return t(nc, cfg)
|
||||
}
|
||||
|
||||
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
|
||||
return tls.Client(nc, cfg)
|
||||
}
|
||||
|
||||
// clientHandshake will perform a handshake on Go 1.17 and higher with HandshakeContext.
|
||||
func clientHandshake(ctx context.Context, client tlsConn) error {
|
||||
return client.HandshakeContext(ctx)
|
||||
}
|
851
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go
generated
vendored
Normal file
851
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go
generated
vendored
Normal file
@@ -0,0 +1,851 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Package topology contains types that handles the discovery, monitoring, and selection
|
||||
// of servers. This package is designed to expose enough inner workings of service discovery
|
||||
// and monitoring to allow low level applications to have fine grained control, while hiding
|
||||
// most of the detailed implementation of the algorithms.
|
||||
package topology // import "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/internal/randutil"
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
|
||||
)
|
||||
|
||||
// Topology state constants.
|
||||
const (
|
||||
topologyDisconnected int64 = iota
|
||||
topologyDisconnecting
|
||||
topologyConnected
|
||||
topologyConnecting
|
||||
)
|
||||
|
||||
// ErrSubscribeAfterClosed is returned when a user attempts to subscribe to a
|
||||
// closed Server or Topology.
|
||||
var ErrSubscribeAfterClosed = errors.New("cannot subscribe after closeConnection")
|
||||
|
||||
// ErrTopologyClosed is returned when a user attempts to call a method on a
|
||||
// closed Topology.
|
||||
var ErrTopologyClosed = errors.New("topology is closed")
|
||||
|
||||
// ErrTopologyConnected is returned whena user attempts to Connect to an
|
||||
// already connected Topology.
|
||||
var ErrTopologyConnected = errors.New("topology is connected or connecting")
|
||||
|
||||
// ErrServerSelectionTimeout is returned from server selection when the server
|
||||
// selection process took longer than allowed by the timeout.
|
||||
var ErrServerSelectionTimeout = errors.New("server selection timeout")
|
||||
|
||||
// MonitorMode represents the way in which a server is monitored.
|
||||
type MonitorMode uint8
|
||||
|
||||
// random is a package-global pseudo-random number generator.
|
||||
var random = randutil.NewLockedRand()
|
||||
|
||||
// These constants are the available monitoring modes.
|
||||
const (
|
||||
AutomaticMode MonitorMode = iota
|
||||
SingleMode
|
||||
)
|
||||
|
||||
// Topology represents a MongoDB deployment.
|
||||
type Topology struct {
|
||||
state int64
|
||||
|
||||
cfg *Config
|
||||
|
||||
desc atomic.Value // holds a description.Topology
|
||||
|
||||
dnsResolver *dns.Resolver
|
||||
|
||||
done chan struct{}
|
||||
|
||||
pollingRequired bool
|
||||
pollingDone chan struct{}
|
||||
pollingwg sync.WaitGroup
|
||||
rescanSRVInterval time.Duration
|
||||
pollHeartbeatTime atomic.Value // holds a bool
|
||||
|
||||
updateCallback updateTopologyCallback
|
||||
fsm *fsm
|
||||
|
||||
// This should really be encapsulated into it's own type. This will likely
|
||||
// require a redesign so we can share a minimum of data between the
|
||||
// subscribers and the topology.
|
||||
subscribers map[uint64]chan description.Topology
|
||||
currentSubscriberID uint64
|
||||
subscriptionsClosed bool
|
||||
subLock sync.Mutex
|
||||
|
||||
// We should redesign how we Connect and handle individal servers. This is
|
||||
// too difficult to maintain and it's rather easy to accidentally access
|
||||
// the servers without acquiring the lock or checking if the servers are
|
||||
// closed. This lock should also be an RWMutex.
|
||||
serversLock sync.Mutex
|
||||
serversClosed bool
|
||||
servers map[address.Address]*Server
|
||||
|
||||
id primitive.ObjectID
|
||||
}
|
||||
|
||||
var _ driver.Deployment = &Topology{}
|
||||
var _ driver.Subscriber = &Topology{}
|
||||
|
||||
type serverSelectionState struct {
|
||||
selector description.ServerSelector
|
||||
timeoutChan <-chan time.Time
|
||||
}
|
||||
|
||||
func newServerSelectionState(selector description.ServerSelector, timeoutChan <-chan time.Time) serverSelectionState {
|
||||
return serverSelectionState{
|
||||
selector: selector,
|
||||
timeoutChan: timeoutChan,
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new topology. A "nil" config is interpreted as the default configuration.
|
||||
func New(cfg *Config) (*Topology, error) {
|
||||
if cfg == nil {
|
||||
var err error
|
||||
cfg, err = NewConfig(options.Client(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t := &Topology{
|
||||
cfg: cfg,
|
||||
done: make(chan struct{}),
|
||||
pollingDone: make(chan struct{}),
|
||||
rescanSRVInterval: 60 * time.Second,
|
||||
fsm: newFSM(),
|
||||
subscribers: make(map[uint64]chan description.Topology),
|
||||
servers: make(map[address.Address]*Server),
|
||||
dnsResolver: dns.DefaultResolver,
|
||||
id: primitive.NewObjectID(),
|
||||
}
|
||||
t.desc.Store(description.Topology{})
|
||||
t.updateCallback = func(desc description.Server) description.Server {
|
||||
return t.apply(context.TODO(), desc)
|
||||
}
|
||||
|
||||
if t.cfg.URI != "" {
|
||||
t.pollingRequired = strings.HasPrefix(t.cfg.URI, "mongodb+srv://") && !t.cfg.LoadBalanced
|
||||
}
|
||||
|
||||
t.publishTopologyOpeningEvent()
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Connect initializes a Topology and starts the monitoring process. This function
|
||||
// must be called to properly monitor the topology.
|
||||
func (t *Topology) Connect() error {
|
||||
if !atomic.CompareAndSwapInt64(&t.state, topologyDisconnected, topologyConnecting) {
|
||||
return ErrTopologyConnected
|
||||
}
|
||||
|
||||
t.desc.Store(description.Topology{})
|
||||
var err error
|
||||
t.serversLock.Lock()
|
||||
|
||||
// A replica set name sets the initial topology type to ReplicaSetNoPrimary unless a direct connection is also
|
||||
// specified, in which case the initial type is Single.
|
||||
if t.cfg.ReplicaSetName != "" {
|
||||
t.fsm.SetName = t.cfg.ReplicaSetName
|
||||
t.fsm.Kind = description.ReplicaSetNoPrimary
|
||||
}
|
||||
|
||||
// A direct connection unconditionally sets the topology type to Single.
|
||||
if t.cfg.Mode == SingleMode {
|
||||
t.fsm.Kind = description.Single
|
||||
}
|
||||
|
||||
for _, a := range t.cfg.SeedList {
|
||||
addr := address.Address(a).Canonicalize()
|
||||
t.fsm.Servers = append(t.fsm.Servers, description.NewDefaultServer(addr))
|
||||
}
|
||||
|
||||
switch {
|
||||
case t.cfg.LoadBalanced:
|
||||
// In LoadBalanced mode, we mock a series of events: TopologyDescriptionChanged from Unknown to LoadBalanced,
|
||||
// ServerDescriptionChanged from Unknown to LoadBalancer, and then TopologyDescriptionChanged to reflect the
|
||||
// previous ServerDescriptionChanged event. We publish all of these events here because we don't start server
|
||||
// monitoring routines in this mode, so we have to mock state changes.
|
||||
|
||||
// Transition from Unknown with no servers to LoadBalanced with a single Unknown server.
|
||||
t.fsm.Kind = description.LoadBalanced
|
||||
t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology)
|
||||
|
||||
addr := address.Address(t.cfg.SeedList[0]).Canonicalize()
|
||||
if err := t.addServer(addr); err != nil {
|
||||
t.serversLock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Transition the server from Unknown to LoadBalancer.
|
||||
newServerDesc := t.servers[addr].Description()
|
||||
t.publishServerDescriptionChangedEvent(t.fsm.Servers[0], newServerDesc)
|
||||
|
||||
// Transition from LoadBalanced with an Unknown server to LoadBalanced with a LoadBalancer.
|
||||
oldDesc := t.fsm.Topology
|
||||
t.fsm.Servers = []description.Server{newServerDesc}
|
||||
t.desc.Store(t.fsm.Topology)
|
||||
t.publishTopologyDescriptionChangedEvent(oldDesc, t.fsm.Topology)
|
||||
default:
|
||||
// In non-LB mode, we only publish an initial TopologyDescriptionChanged event from Unknown with no servers to
|
||||
// the current state (e.g. Unknown with one or more servers if we're discovering or Single with one server if
|
||||
// we're connecting directly). Other events are published when state changes occur due to responses in the
|
||||
// server monitoring goroutines.
|
||||
|
||||
newDesc := description.Topology{
|
||||
Kind: t.fsm.Kind,
|
||||
Servers: t.fsm.Servers,
|
||||
SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
|
||||
}
|
||||
t.desc.Store(newDesc)
|
||||
t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology)
|
||||
for _, a := range t.cfg.SeedList {
|
||||
addr := address.Address(a).Canonicalize()
|
||||
err = t.addServer(addr)
|
||||
if err != nil {
|
||||
t.serversLock.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.serversLock.Unlock()
|
||||
if t.pollingRequired {
|
||||
uri, err := url.Parse(t.cfg.URI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// sanity check before passing the hostname to resolver
|
||||
if parsedHosts := strings.Split(uri.Host, ","); len(parsedHosts) != 1 {
|
||||
return fmt.Errorf("URI with SRV must include one and only one hostname")
|
||||
}
|
||||
_, _, err = net.SplitHostPort(uri.Host)
|
||||
if err == nil {
|
||||
// we were able to successfully extract a port from the host,
|
||||
// but should not be able to when using SRV
|
||||
return fmt.Errorf("URI with srv must not include a port number")
|
||||
}
|
||||
go t.pollSRVRecords(uri.Host)
|
||||
t.pollingwg.Add(1)
|
||||
}
|
||||
|
||||
t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected
|
||||
|
||||
atomic.StoreInt64(&t.state, topologyConnected)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect closes the topology. It stops the monitoring thread and
|
||||
// closes all open subscriptions.
|
||||
func (t *Topology) Disconnect(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt64(&t.state, topologyConnected, topologyDisconnecting) {
|
||||
return ErrTopologyClosed
|
||||
}
|
||||
|
||||
servers := make(map[address.Address]*Server)
|
||||
t.serversLock.Lock()
|
||||
t.serversClosed = true
|
||||
for addr, server := range t.servers {
|
||||
servers[addr] = server
|
||||
}
|
||||
t.serversLock.Unlock()
|
||||
|
||||
for _, server := range servers {
|
||||
_ = server.Disconnect(ctx)
|
||||
t.publishServerClosedEvent(server.address)
|
||||
}
|
||||
|
||||
t.subLock.Lock()
|
||||
for id, ch := range t.subscribers {
|
||||
close(ch)
|
||||
delete(t.subscribers, id)
|
||||
}
|
||||
t.subscriptionsClosed = true
|
||||
t.subLock.Unlock()
|
||||
|
||||
if t.pollingRequired {
|
||||
t.pollingDone <- struct{}{}
|
||||
t.pollingwg.Wait()
|
||||
}
|
||||
|
||||
t.desc.Store(description.Topology{})
|
||||
|
||||
atomic.StoreInt64(&t.state, topologyDisconnected)
|
||||
t.publishTopologyClosedEvent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Description returns a description of the topology.
|
||||
func (t *Topology) Description() description.Topology {
|
||||
td, ok := t.desc.Load().(description.Topology)
|
||||
if !ok {
|
||||
td = description.Topology{}
|
||||
}
|
||||
return td
|
||||
}
|
||||
|
||||
// Kind returns the topology kind of this Topology.
|
||||
func (t *Topology) Kind() description.TopologyKind { return t.Description().Kind }
|
||||
|
||||
// Subscribe returns a Subscription on which all updated description.Topologys
|
||||
// will be sent. The channel of the subscription will have a buffer size of one,
|
||||
// and will be pre-populated with the current description.Topology.
|
||||
// Subscribe implements the driver.Subscriber interface.
|
||||
func (t *Topology) Subscribe() (*driver.Subscription, error) {
|
||||
if atomic.LoadInt64(&t.state) != topologyConnected {
|
||||
return nil, errors.New("cannot subscribe to Topology that is not connected")
|
||||
}
|
||||
ch := make(chan description.Topology, 1)
|
||||
td, ok := t.desc.Load().(description.Topology)
|
||||
if !ok {
|
||||
td = description.Topology{}
|
||||
}
|
||||
ch <- td
|
||||
|
||||
t.subLock.Lock()
|
||||
defer t.subLock.Unlock()
|
||||
if t.subscriptionsClosed {
|
||||
return nil, ErrSubscribeAfterClosed
|
||||
}
|
||||
id := t.currentSubscriberID
|
||||
t.subscribers[id] = ch
|
||||
t.currentSubscriberID++
|
||||
|
||||
return &driver.Subscription{
|
||||
Updates: ch,
|
||||
ID: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes the given subscription from the topology and closes the subscription channel.
|
||||
// Unsubscribe implements the driver.Subscriber interface.
|
||||
func (t *Topology) Unsubscribe(sub *driver.Subscription) error {
|
||||
t.subLock.Lock()
|
||||
defer t.subLock.Unlock()
|
||||
|
||||
if t.subscriptionsClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, ok := t.subscribers[sub.ID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(t.subscribers, sub.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestImmediateCheck will send heartbeats to all the servers in the
|
||||
// topology right away, instead of waiting for the heartbeat timeout.
|
||||
func (t *Topology) RequestImmediateCheck() {
|
||||
if atomic.LoadInt64(&t.state) != topologyConnected {
|
||||
return
|
||||
}
|
||||
t.serversLock.Lock()
|
||||
for _, server := range t.servers {
|
||||
server.RequestImmediateCheck()
|
||||
}
|
||||
t.serversLock.Unlock()
|
||||
}
|
||||
|
||||
// SelectServer selects a server with given a selector. SelectServer complies with the
|
||||
// server selection spec, and will time out after serverSelectionTimeout or when the
|
||||
// parent context is done.
|
||||
func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) {
|
||||
if atomic.LoadInt64(&t.state) != topologyConnected {
|
||||
return nil, ErrTopologyClosed
|
||||
}
|
||||
var ssTimeoutCh <-chan time.Time
|
||||
|
||||
if t.cfg.ServerSelectionTimeout > 0 {
|
||||
ssTimeout := time.NewTimer(t.cfg.ServerSelectionTimeout)
|
||||
ssTimeoutCh = ssTimeout.C
|
||||
defer ssTimeout.Stop()
|
||||
}
|
||||
|
||||
var doneOnce bool
|
||||
var sub *driver.Subscription
|
||||
selectionState := newServerSelectionState(ss, ssTimeoutCh)
|
||||
for {
|
||||
var suitable []description.Server
|
||||
var selectErr error
|
||||
|
||||
if !doneOnce {
|
||||
// for the first pass, select a server from the current description.
|
||||
// this improves selection speed for up-to-date topology descriptions.
|
||||
suitable, selectErr = t.selectServerFromDescription(t.Description(), selectionState)
|
||||
doneOnce = true
|
||||
} else {
|
||||
// if the first pass didn't select a server, the previous description did not contain a suitable server, so
|
||||
// we subscribe to the topology and attempt to obtain a server from that subscription
|
||||
if sub == nil {
|
||||
var err error
|
||||
sub, err = t.Subscribe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer t.Unsubscribe(sub)
|
||||
}
|
||||
|
||||
suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, selectionState)
|
||||
}
|
||||
if selectErr != nil {
|
||||
return nil, selectErr
|
||||
}
|
||||
|
||||
if len(suitable) == 0 {
|
||||
// try again if there are no servers available
|
||||
continue
|
||||
}
|
||||
|
||||
// If there's only one suitable server description, try to find the associated server and
|
||||
// return it. This is an optimization primarily for standalone and load-balanced deployments.
|
||||
if len(suitable) == 1 {
|
||||
server, err := t.FindServer(suitable[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if server == nil {
|
||||
continue
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// Randomly select 2 suitable server descriptions and find servers for them. We select two
|
||||
// so we can pick the one with the one with fewer in-progress operations below.
|
||||
desc1, desc2 := pick2(suitable)
|
||||
server1, err := t.FindServer(desc1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server2, err := t.FindServer(desc2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If we don't have an actual server for one or both of the provided descriptions, either
|
||||
// return the one server we have, or try again if they're both nil. This could happen for a
|
||||
// number of reasons, including that the server has since stopped being a part of this
|
||||
// topology.
|
||||
if server1 == nil || server2 == nil {
|
||||
if server1 == nil && server2 == nil {
|
||||
continue
|
||||
}
|
||||
if server1 != nil {
|
||||
return server1, nil
|
||||
}
|
||||
return server2, nil
|
||||
}
|
||||
|
||||
// Of the two randomly selected suitable servers, pick the one with fewer in-use connections.
|
||||
// We use in-use connections as an analog for in-progress operations because they are almost
|
||||
// always the same value for a given server.
|
||||
if server1.OperationCount() < server2.OperationCount() {
|
||||
return server1, nil
|
||||
}
|
||||
return server2, nil
|
||||
}
|
||||
}
|
||||
|
||||
// pick2 returns 2 random server descriptions from the input slice of server descriptions,
|
||||
// guaranteeing that the same element from the slice is not picked twice. The order of server
|
||||
// descriptions in the input slice may be modified. If fewer than 2 server descriptions are
|
||||
// provided, pick2 will panic.
|
||||
func pick2(ds []description.Server) (description.Server, description.Server) {
|
||||
// Select a random index from the input slice and keep the server description from that index.
|
||||
idx := random.Intn(len(ds))
|
||||
s1 := ds[idx]
|
||||
|
||||
// Swap the selected index to the end and reslice to remove it so we don't pick the same server
|
||||
// description twice.
|
||||
ds[idx], ds[len(ds)-1] = ds[len(ds)-1], ds[idx]
|
||||
ds = ds[:len(ds)-1]
|
||||
|
||||
// Select another random index from the input slice and return both selected server descriptions.
|
||||
return s1, ds[random.Intn(len(ds))]
|
||||
}
|
||||
|
||||
// FindServer will attempt to find a server that fits the given server description.
|
||||
// This method will return nil, nil if a matching server could not be found.
|
||||
func (t *Topology) FindServer(selected description.Server) (*SelectedServer, error) {
|
||||
if atomic.LoadInt64(&t.state) != topologyConnected {
|
||||
return nil, ErrTopologyClosed
|
||||
}
|
||||
t.serversLock.Lock()
|
||||
defer t.serversLock.Unlock()
|
||||
server, ok := t.servers[selected.Addr]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
desc := t.Description()
|
||||
return &SelectedServer{
|
||||
Server: server,
|
||||
Kind: desc.Kind,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// selectServerFromSubscription loops until a topology description is available for server selection. It returns
|
||||
// when the given context expires, server selection timeout is reached, or a description containing a selectable
|
||||
// server is available.
|
||||
func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptionCh <-chan description.Topology,
|
||||
selectionState serverSelectionState) ([]description.Server, error) {
|
||||
|
||||
current := t.Description()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ServerSelectionError{Wrapped: ctx.Err(), Desc: current}
|
||||
case <-selectionState.timeoutChan:
|
||||
return nil, ServerSelectionError{Wrapped: ErrServerSelectionTimeout, Desc: current}
|
||||
case current = <-subscriptionCh:
|
||||
}
|
||||
|
||||
suitable, err := t.selectServerFromDescription(current, selectionState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(suitable) > 0 {
|
||||
return suitable, nil
|
||||
}
|
||||
t.RequestImmediateCheck()
|
||||
}
|
||||
}
|
||||
|
||||
// selectServerFromDescription process the given topology description and returns a slice of suitable servers.
|
||||
func (t *Topology) selectServerFromDescription(desc description.Topology,
|
||||
selectionState serverSelectionState) ([]description.Server, error) {
|
||||
|
||||
// Unlike selectServerFromSubscription, this code path does not check ctx.Done or selectionState.timeoutChan because
|
||||
// selecting a server from a description is not a blocking operation.
|
||||
|
||||
if desc.CompatibilityErr != nil {
|
||||
return nil, desc.CompatibilityErr
|
||||
}
|
||||
|
||||
// If the topology kind is LoadBalanced, the LB is the only server and it is always considered selectable. The
|
||||
// selectors exported by the driver should already return the LB as a candidate, so this but this check ensures that
|
||||
// the LB is always selectable even if a user of the low-level driver provides a custom selector.
|
||||
if desc.Kind == description.LoadBalanced {
|
||||
return desc.Servers, nil
|
||||
}
|
||||
|
||||
var allowed []description.Server
|
||||
for _, s := range desc.Servers {
|
||||
if s.Kind != description.Unknown {
|
||||
allowed = append(allowed, s)
|
||||
}
|
||||
}
|
||||
|
||||
suitable, err := selectionState.selector.SelectServer(desc, allowed)
|
||||
if err != nil {
|
||||
return nil, ServerSelectionError{Wrapped: err, Desc: desc}
|
||||
}
|
||||
return suitable, nil
|
||||
}
|
||||
|
||||
func (t *Topology) pollSRVRecords(hosts string) {
|
||||
defer t.pollingwg.Done()
|
||||
|
||||
serverConfig := newServerConfig(t.cfg.ServerOpts...)
|
||||
heartbeatInterval := serverConfig.heartbeatInterval
|
||||
|
||||
pollTicker := time.NewTicker(t.rescanSRVInterval)
|
||||
defer pollTicker.Stop()
|
||||
t.pollHeartbeatTime.Store(false)
|
||||
var doneOnce bool
|
||||
defer func() {
|
||||
// ¯\_(ツ)_/¯
|
||||
if r := recover(); r != nil && !doneOnce {
|
||||
<-t.pollingDone
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pollTicker.C:
|
||||
case <-t.pollingDone:
|
||||
doneOnce = true
|
||||
return
|
||||
}
|
||||
topoKind := t.Description().Kind
|
||||
if !(topoKind == description.Unknown || topoKind == description.Sharded) {
|
||||
break
|
||||
}
|
||||
|
||||
parsedHosts, err := t.dnsResolver.ParseHosts(hosts, t.cfg.SRVServiceName, false)
|
||||
// DNS problem or no verified hosts returned
|
||||
if err != nil || len(parsedHosts) == 0 {
|
||||
if !t.pollHeartbeatTime.Load().(bool) {
|
||||
pollTicker.Stop()
|
||||
pollTicker = time.NewTicker(heartbeatInterval)
|
||||
t.pollHeartbeatTime.Store(true)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if t.pollHeartbeatTime.Load().(bool) {
|
||||
pollTicker.Stop()
|
||||
pollTicker = time.NewTicker(t.rescanSRVInterval)
|
||||
t.pollHeartbeatTime.Store(false)
|
||||
}
|
||||
|
||||
cont := t.processSRVResults(parsedHosts)
|
||||
if !cont {
|
||||
break
|
||||
}
|
||||
}
|
||||
<-t.pollingDone
|
||||
doneOnce = true
|
||||
}
|
||||
|
||||
func (t *Topology) processSRVResults(parsedHosts []string) bool {
|
||||
t.serversLock.Lock()
|
||||
defer t.serversLock.Unlock()
|
||||
|
||||
if t.serversClosed {
|
||||
return false
|
||||
}
|
||||
prev := t.fsm.Topology
|
||||
diff := diffHostList(t.fsm.Topology, parsedHosts)
|
||||
|
||||
if len(diff.Added) == 0 && len(diff.Removed) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, r := range diff.Removed {
|
||||
addr := address.Address(r).Canonicalize()
|
||||
s, ok := t.servers[addr]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_ = s.Disconnect(cancelCtx)
|
||||
}()
|
||||
delete(t.servers, addr)
|
||||
t.fsm.removeServerByAddr(addr)
|
||||
t.publishServerClosedEvent(s.address)
|
||||
}
|
||||
|
||||
// Now that we've removed all the hosts that disappeared from the SRV record, we need to add any
|
||||
// new hosts added to the SRV record. If adding all of the new hosts would increase the number
|
||||
// of servers past srvMaxHosts, shuffle the list of added hosts.
|
||||
if t.cfg.SRVMaxHosts > 0 && len(t.servers)+len(diff.Added) > t.cfg.SRVMaxHosts {
|
||||
random.Shuffle(len(diff.Added), func(i, j int) {
|
||||
diff.Added[i], diff.Added[j] = diff.Added[j], diff.Added[i]
|
||||
})
|
||||
}
|
||||
// Add all added hosts until the number of servers reaches srvMaxHosts.
|
||||
for _, a := range diff.Added {
|
||||
if t.cfg.SRVMaxHosts > 0 && len(t.servers) >= t.cfg.SRVMaxHosts {
|
||||
break
|
||||
}
|
||||
addr := address.Address(a).Canonicalize()
|
||||
_ = t.addServer(addr)
|
||||
t.fsm.addServer(addr)
|
||||
}
|
||||
|
||||
//store new description
|
||||
newDesc := description.Topology{
|
||||
Kind: t.fsm.Kind,
|
||||
Servers: t.fsm.Servers,
|
||||
SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
|
||||
}
|
||||
t.desc.Store(newDesc)
|
||||
|
||||
if !prev.Equal(newDesc) {
|
||||
t.publishTopologyDescriptionChangedEvent(prev, newDesc)
|
||||
}
|
||||
|
||||
t.subLock.Lock()
|
||||
for _, ch := range t.subscribers {
|
||||
// We drain the description if there's one in the channel
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
}
|
||||
ch <- newDesc
|
||||
}
|
||||
t.subLock.Unlock()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// apply updates the Topology and its underlying FSM based on the provided server description and returns the server
|
||||
// description that should be stored.
|
||||
func (t *Topology) apply(ctx context.Context, desc description.Server) description.Server {
|
||||
t.serversLock.Lock()
|
||||
defer t.serversLock.Unlock()
|
||||
|
||||
ind, ok := t.fsm.findServer(desc.Addr)
|
||||
if t.serversClosed || !ok {
|
||||
return desc
|
||||
}
|
||||
|
||||
prev := t.fsm.Topology
|
||||
oldDesc := t.fsm.Servers[ind]
|
||||
if oldDesc.TopologyVersion.CompareToIncoming(desc.TopologyVersion) > 0 {
|
||||
return oldDesc
|
||||
}
|
||||
|
||||
var current description.Topology
|
||||
current, desc = t.fsm.apply(desc)
|
||||
|
||||
if !oldDesc.Equal(desc) {
|
||||
t.publishServerDescriptionChangedEvent(oldDesc, desc)
|
||||
}
|
||||
|
||||
diff := diffTopology(prev, current)
|
||||
|
||||
for _, removed := range diff.Removed {
|
||||
if s, ok := t.servers[removed.Addr]; ok {
|
||||
go func() {
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
_ = s.Disconnect(cancelCtx)
|
||||
}()
|
||||
delete(t.servers, removed.Addr)
|
||||
t.publishServerClosedEvent(s.address)
|
||||
}
|
||||
}
|
||||
|
||||
for _, added := range diff.Added {
|
||||
_ = t.addServer(added.Addr)
|
||||
}
|
||||
|
||||
t.desc.Store(current)
|
||||
if !prev.Equal(current) {
|
||||
t.publishTopologyDescriptionChangedEvent(prev, current)
|
||||
}
|
||||
|
||||
t.subLock.Lock()
|
||||
for _, ch := range t.subscribers {
|
||||
// We drain the description if there's one in the channel
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
}
|
||||
ch <- current
|
||||
}
|
||||
t.subLock.Unlock()
|
||||
|
||||
return desc
|
||||
}
|
||||
|
||||
func (t *Topology) addServer(addr address.Address) error {
|
||||
if _, ok := t.servers[addr]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.ServerOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.servers[addr] = svr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implements the Stringer interface
|
||||
func (t *Topology) String() string {
|
||||
desc := t.Description()
|
||||
|
||||
serversStr := ""
|
||||
t.serversLock.Lock()
|
||||
defer t.serversLock.Unlock()
|
||||
for _, s := range t.servers {
|
||||
serversStr += "{ " + s.String() + " }, "
|
||||
}
|
||||
return fmt.Sprintf("Type: %s, Servers: [%s]", desc.Kind, serversStr)
|
||||
}
|
||||
|
||||
// publishes a ServerDescriptionChangedEvent to indicate the server description has changed
|
||||
func (t *Topology) publishServerDescriptionChangedEvent(prev description.Server, current description.Server) {
|
||||
serverDescriptionChanged := &event.ServerDescriptionChangedEvent{
|
||||
Address: current.Addr,
|
||||
TopologyID: t.id,
|
||||
PreviousDescription: prev,
|
||||
NewDescription: current,
|
||||
}
|
||||
|
||||
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.ServerDescriptionChanged != nil {
|
||||
t.cfg.ServerMonitor.ServerDescriptionChanged(serverDescriptionChanged)
|
||||
}
|
||||
}
|
||||
|
||||
// publishes a ServerClosedEvent to indicate the server has closed
|
||||
func (t *Topology) publishServerClosedEvent(addr address.Address) {
|
||||
serverClosed := &event.ServerClosedEvent{
|
||||
Address: addr,
|
||||
TopologyID: t.id,
|
||||
}
|
||||
|
||||
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.ServerClosed != nil {
|
||||
t.cfg.ServerMonitor.ServerClosed(serverClosed)
|
||||
}
|
||||
}
|
||||
|
||||
// publishes a TopologyDescriptionChangedEvent to indicate the topology description has changed
|
||||
func (t *Topology) publishTopologyDescriptionChangedEvent(prev description.Topology, current description.Topology) {
|
||||
topologyDescriptionChanged := &event.TopologyDescriptionChangedEvent{
|
||||
TopologyID: t.id,
|
||||
PreviousDescription: prev,
|
||||
NewDescription: current,
|
||||
}
|
||||
|
||||
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyDescriptionChanged != nil {
|
||||
t.cfg.ServerMonitor.TopologyDescriptionChanged(topologyDescriptionChanged)
|
||||
}
|
||||
}
|
||||
|
||||
// publishes a TopologyOpeningEvent to indicate the topology is being initialized
|
||||
func (t *Topology) publishTopologyOpeningEvent() {
|
||||
topologyOpening := &event.TopologyOpeningEvent{
|
||||
TopologyID: t.id,
|
||||
}
|
||||
|
||||
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyOpening != nil {
|
||||
t.cfg.ServerMonitor.TopologyOpening(topologyOpening)
|
||||
}
|
||||
}
|
||||
|
||||
// publishes a TopologyClosedEvent to indicate the topology has been closed
|
||||
func (t *Topology) publishTopologyClosedEvent() {
|
||||
topologyClosed := &event.TopologyClosedEvent{
|
||||
TopologyID: t.id,
|
||||
}
|
||||
|
||||
if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyClosed != nil {
|
||||
t.cfg.ServerMonitor.TopologyClosed(topologyClosed)
|
||||
}
|
||||
}
|
344
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go
generated
vendored
Normal file
344
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go
generated
vendored
Normal file
@@ -0,0 +1,344 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
const defaultServerSelectionTimeout = 30 * time.Second
|
||||
|
||||
// Config is used to construct a topology.
|
||||
type Config struct {
|
||||
Mode MonitorMode
|
||||
ReplicaSetName string
|
||||
SeedList []string
|
||||
ServerOpts []ServerOption
|
||||
URI string
|
||||
ServerSelectionTimeout time.Duration
|
||||
ServerMonitor *event.ServerMonitor
|
||||
SRVMaxHosts int
|
||||
SRVServiceName string
|
||||
LoadBalanced bool
|
||||
}
|
||||
|
||||
// ConvertToDriverAPIOptions converts a options.ServerAPIOptions instance to a driver.ServerAPIOptions.
|
||||
func ConvertToDriverAPIOptions(s *options.ServerAPIOptions) *driver.ServerAPIOptions {
|
||||
driverOpts := driver.NewServerAPIOptions(string(s.ServerAPIVersion))
|
||||
if s.Strict != nil {
|
||||
driverOpts.SetStrict(*s.Strict)
|
||||
}
|
||||
if s.DeprecationErrors != nil {
|
||||
driverOpts.SetDeprecationErrors(*s.DeprecationErrors)
|
||||
}
|
||||
return driverOpts
|
||||
}
|
||||
|
||||
// NewConfig will translate data from client options into a topology config for building non-default deployments.
|
||||
// Server and topoplogy options are not honored if a custom deployment is used.
|
||||
func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) {
|
||||
var serverAPI *driver.ServerAPIOptions
|
||||
|
||||
if err := co.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var connOpts []ConnectionOption
|
||||
var serverOpts []ServerOption
|
||||
|
||||
cfgp := new(Config)
|
||||
|
||||
// Set the default "ServerSelectionTimeout" to 30 seconds.
|
||||
cfgp.ServerSelectionTimeout = defaultServerSelectionTimeout
|
||||
|
||||
// Set the default "SeedList" to localhost.
|
||||
cfgp.SeedList = []string{"localhost:27017"}
|
||||
|
||||
// TODO(GODRIVER-814): Add tests for topology, server, and connection related options.
|
||||
|
||||
// ServerAPIOptions need to be handled early as other client and server options below reference
|
||||
// c.serverAPI and serverOpts.serverAPI.
|
||||
if co.ServerAPIOptions != nil {
|
||||
serverAPI = ConvertToDriverAPIOptions(co.ServerAPIOptions)
|
||||
serverOpts = append(serverOpts, WithServerAPI(func(*driver.ServerAPIOptions) *driver.ServerAPIOptions {
|
||||
return serverAPI
|
||||
}))
|
||||
}
|
||||
|
||||
cfgp.URI = co.GetURI()
|
||||
|
||||
if co.SRVServiceName != nil {
|
||||
cfgp.SRVServiceName = *co.SRVServiceName
|
||||
}
|
||||
|
||||
if co.SRVMaxHosts != nil {
|
||||
cfgp.SRVMaxHosts = *co.SRVMaxHosts
|
||||
}
|
||||
|
||||
// AppName
|
||||
var appName string
|
||||
if co.AppName != nil {
|
||||
appName = *co.AppName
|
||||
|
||||
serverOpts = append(serverOpts, WithServerAppName(func(string) string {
|
||||
return appName
|
||||
}))
|
||||
}
|
||||
// Compressors & ZlibLevel
|
||||
var comps []string
|
||||
if len(co.Compressors) > 0 {
|
||||
comps = co.Compressors
|
||||
|
||||
connOpts = append(connOpts, WithCompressors(
|
||||
func(compressors []string) []string {
|
||||
return append(compressors, comps...)
|
||||
},
|
||||
))
|
||||
|
||||
for _, comp := range comps {
|
||||
switch comp {
|
||||
case "zlib":
|
||||
connOpts = append(connOpts, WithZlibLevel(func(level *int) *int {
|
||||
return co.ZlibLevel
|
||||
}))
|
||||
case "zstd":
|
||||
connOpts = append(connOpts, WithZstdLevel(func(level *int) *int {
|
||||
return co.ZstdLevel
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
serverOpts = append(serverOpts, WithCompressionOptions(
|
||||
func(opts ...string) []string { return append(opts, comps...) },
|
||||
))
|
||||
}
|
||||
|
||||
var loadBalanced bool
|
||||
if co.LoadBalanced != nil {
|
||||
loadBalanced = *co.LoadBalanced
|
||||
}
|
||||
|
||||
// Handshaker
|
||||
var handshaker = func(driver.Handshaker) driver.Handshaker {
|
||||
return operation.NewHello().AppName(appName).Compressors(comps).ClusterClock(clock).
|
||||
ServerAPI(serverAPI).LoadBalanced(loadBalanced)
|
||||
}
|
||||
// Auth & Database & Password & Username
|
||||
if co.Auth != nil {
|
||||
cred := &auth.Cred{
|
||||
Username: co.Auth.Username,
|
||||
Password: co.Auth.Password,
|
||||
PasswordSet: co.Auth.PasswordSet,
|
||||
Props: co.Auth.AuthMechanismProperties,
|
||||
Source: co.Auth.AuthSource,
|
||||
}
|
||||
mechanism := co.Auth.AuthMechanism
|
||||
|
||||
if len(cred.Source) == 0 {
|
||||
switch strings.ToUpper(mechanism) {
|
||||
case auth.MongoDBX509, auth.GSSAPI, auth.PLAIN:
|
||||
cred.Source = "$external"
|
||||
default:
|
||||
cred.Source = "admin"
|
||||
}
|
||||
}
|
||||
|
||||
authenticator, err := auth.CreateAuthenticator(mechanism, cred)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handshakeOpts := &auth.HandshakeOptions{
|
||||
AppName: appName,
|
||||
Authenticator: authenticator,
|
||||
Compressors: comps,
|
||||
ServerAPI: serverAPI,
|
||||
LoadBalanced: loadBalanced,
|
||||
ClusterClock: clock,
|
||||
HTTPClient: co.HTTPClient,
|
||||
}
|
||||
|
||||
if mechanism == "" {
|
||||
// Required for SASL mechanism negotiation during handshake
|
||||
handshakeOpts.DBUser = cred.Source + "." + cred.Username
|
||||
}
|
||||
if co.AuthenticateToAnything != nil && *co.AuthenticateToAnything {
|
||||
// Authenticate arbiters
|
||||
handshakeOpts.PerformAuthentication = func(serv description.Server) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
handshaker = func(driver.Handshaker) driver.Handshaker {
|
||||
return auth.Handshaker(nil, handshakeOpts)
|
||||
}
|
||||
}
|
||||
connOpts = append(connOpts, WithHandshaker(handshaker))
|
||||
// ConnectTimeout
|
||||
if co.ConnectTimeout != nil {
|
||||
serverOpts = append(serverOpts, WithHeartbeatTimeout(
|
||||
func(time.Duration) time.Duration { return *co.ConnectTimeout },
|
||||
))
|
||||
connOpts = append(connOpts, WithConnectTimeout(
|
||||
func(time.Duration) time.Duration { return *co.ConnectTimeout },
|
||||
))
|
||||
}
|
||||
// Dialer
|
||||
if co.Dialer != nil {
|
||||
connOpts = append(connOpts, WithDialer(
|
||||
func(Dialer) Dialer { return co.Dialer },
|
||||
))
|
||||
}
|
||||
// Direct
|
||||
if co.Direct != nil && *co.Direct {
|
||||
cfgp.Mode = SingleMode
|
||||
}
|
||||
|
||||
// HeartbeatInterval
|
||||
if co.HeartbeatInterval != nil {
|
||||
serverOpts = append(serverOpts, WithHeartbeatInterval(
|
||||
func(time.Duration) time.Duration { return *co.HeartbeatInterval },
|
||||
))
|
||||
}
|
||||
// Hosts
|
||||
cfgp.SeedList = []string{"localhost:27017"} // default host
|
||||
if len(co.Hosts) > 0 {
|
||||
cfgp.SeedList = co.Hosts
|
||||
}
|
||||
|
||||
// MaxConIdleTime
|
||||
if co.MaxConnIdleTime != nil {
|
||||
connOpts = append(connOpts, WithIdleTimeout(
|
||||
func(time.Duration) time.Duration { return *co.MaxConnIdleTime },
|
||||
))
|
||||
}
|
||||
// MaxPoolSize
|
||||
if co.MaxPoolSize != nil {
|
||||
serverOpts = append(
|
||||
serverOpts,
|
||||
WithMaxConnections(func(uint64) uint64 { return *co.MaxPoolSize }),
|
||||
)
|
||||
}
|
||||
// MinPoolSize
|
||||
if co.MinPoolSize != nil {
|
||||
serverOpts = append(
|
||||
serverOpts,
|
||||
WithMinConnections(func(uint64) uint64 { return *co.MinPoolSize }),
|
||||
)
|
||||
}
|
||||
// MaxConnecting
|
||||
if co.MaxConnecting != nil {
|
||||
serverOpts = append(
|
||||
serverOpts,
|
||||
WithMaxConnecting(func(uint64) uint64 { return *co.MaxConnecting }),
|
||||
)
|
||||
}
|
||||
// PoolMonitor
|
||||
if co.PoolMonitor != nil {
|
||||
serverOpts = append(
|
||||
serverOpts,
|
||||
WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return co.PoolMonitor }),
|
||||
)
|
||||
}
|
||||
// Monitor
|
||||
if co.Monitor != nil {
|
||||
connOpts = append(connOpts, WithMonitor(
|
||||
func(*event.CommandMonitor) *event.CommandMonitor { return co.Monitor },
|
||||
))
|
||||
}
|
||||
// ServerMonitor
|
||||
if co.ServerMonitor != nil {
|
||||
serverOpts = append(
|
||||
serverOpts,
|
||||
WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return co.ServerMonitor }),
|
||||
)
|
||||
cfgp.ServerMonitor = co.ServerMonitor
|
||||
}
|
||||
// ReplicaSet
|
||||
if co.ReplicaSet != nil {
|
||||
cfgp.ReplicaSetName = *co.ReplicaSet
|
||||
}
|
||||
// ServerSelectionTimeout
|
||||
if co.ServerSelectionTimeout != nil {
|
||||
cfgp.ServerSelectionTimeout = *co.ServerSelectionTimeout
|
||||
}
|
||||
// SocketTimeout
|
||||
if co.SocketTimeout != nil {
|
||||
connOpts = append(
|
||||
connOpts,
|
||||
WithReadTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }),
|
||||
WithWriteTimeout(func(time.Duration) time.Duration { return *co.SocketTimeout }),
|
||||
)
|
||||
}
|
||||
// TLSConfig
|
||||
if co.TLSConfig != nil {
|
||||
connOpts = append(connOpts, WithTLSConfig(
|
||||
func(*tls.Config) *tls.Config {
|
||||
return co.TLSConfig
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
// HTTP Client
|
||||
if co.HTTPClient != nil {
|
||||
connOpts = append(connOpts, WithHTTPClient(
|
||||
func(*http.Client) *http.Client {
|
||||
return co.HTTPClient
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
// OCSP cache
|
||||
ocspCache := ocsp.NewCache()
|
||||
connOpts = append(
|
||||
connOpts,
|
||||
WithOCSPCache(func(ocsp.Cache) ocsp.Cache { return ocspCache }),
|
||||
)
|
||||
|
||||
// Disable communication with external OCSP responders.
|
||||
if co.DisableOCSPEndpointCheck != nil {
|
||||
connOpts = append(
|
||||
connOpts,
|
||||
WithDisableOCSPEndpointCheck(func(bool) bool { return *co.DisableOCSPEndpointCheck }),
|
||||
)
|
||||
}
|
||||
|
||||
// LoadBalanced
|
||||
if co.LoadBalanced != nil {
|
||||
cfgp.LoadBalanced = *co.LoadBalanced
|
||||
|
||||
serverOpts = append(
|
||||
serverOpts,
|
||||
WithServerLoadBalanced(func(bool) bool { return *co.LoadBalanced }),
|
||||
)
|
||||
connOpts = append(
|
||||
connOpts,
|
||||
WithConnectionLoadBalanced(func(bool) bool { return *co.LoadBalanced }),
|
||||
)
|
||||
}
|
||||
|
||||
serverOpts = append(
|
||||
serverOpts,
|
||||
WithClock(func(*session.ClusterClock) *session.ClusterClock { return clock }),
|
||||
WithConnectionOptions(func(...ConnectionOption) []ConnectionOption { return connOpts }))
|
||||
|
||||
cfgp.ServerOpts = serverOpts
|
||||
|
||||
return cfgp, nil
|
||||
}
|
Reference in New Issue
Block a user