feat: Update token periodically if Dex is running in Kubernetes cluster
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
		| @@ -303,7 +303,7 @@ func defaultTLSConfig() *tls.Config { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger) (*client, error) { | func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger, inCluster bool) (*client, error) { | ||||||
| 	tlsConfig := defaultTLSConfig() | 	tlsConfig := defaultTLSConfig() | ||||||
| 	data := func(b string, file string) ([]byte, error) { | 	data := func(b string, file string) ([]byte, error) { | ||||||
| 		if b != "" { | 		if b != "" { | ||||||
| @@ -359,25 +359,7 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l | |||||||
| 	if err := http2.ConfigureTransport(httpTransport); err != nil { | 	if err := http2.ConfigureTransport(httpTransport); err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	t = httpTransport | 	t = wrapRoundTripper(httpTransport, user, inCluster) | ||||||
|  |  | ||||||
| 	if user.Token != "" { |  | ||||||
| 		t = transport{ |  | ||||||
| 			updateReq: func(r *http.Request) { |  | ||||||
| 				r.Header.Set("Authorization", "Bearer "+user.Token) |  | ||||||
| 			}, |  | ||||||
| 			base: t, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if user.Username != "" && user.Password != "" { |  | ||||||
| 		t = transport{ |  | ||||||
| 			updateReq: func(r *http.Request) { |  | ||||||
| 				r.SetBasicAuth(user.Username, user.Password) |  | ||||||
| 			}, |  | ||||||
| 			base: t, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	apiVersion := "dex.coreos.com/v1" | 	apiVersion := "dex.coreos.com/v1" | ||||||
|  |  | ||||||
| @@ -396,24 +378,6 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l | |||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type transport struct { |  | ||||||
| 	updateReq func(r *http.Request) |  | ||||||
| 	base      http.RoundTripper |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t transport) RoundTrip(r *http.Request) (*http.Response, error) { |  | ||||||
| 	// shallow copy of the struct |  | ||||||
| 	r2 := new(http.Request) |  | ||||||
| 	*r2 = *r |  | ||||||
| 	// deep copy of the Header |  | ||||||
| 	r2.Header = make(http.Header, len(r.Header)) |  | ||||||
| 	for k, s := range r.Header { |  | ||||||
| 		r2.Header[k] = append([]string(nil), s...) |  | ||||||
| 	} |  | ||||||
| 	t.updateReq(r2) |  | ||||||
| 	return t.base.RoundTrip(r2) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) { | func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) { | ||||||
| 	data, err := ioutil.ReadFile(kubeConfigPath) | 	data, err := ioutil.ReadFile(kubeConfigPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -3,12 +3,8 @@ package kubernetes | |||||||
| import ( | import ( | ||||||
| 	"hash" | 	"hash" | ||||||
| 	"hash/fnv" | 	"hash/fnv" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"os" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/stretchr/testify/require" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // This test does not have an explicit error condition but is used | // This test does not have an explicit error condition but is used | ||||||
| @@ -46,6 +42,81 @@ func TestOfflineTokenName(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestInClusterTransport(t *testing.T) { | ||||||
|  | 	logger := &logrus.Logger{ | ||||||
|  | 		Out:       os.Stderr, | ||||||
|  | 		Formatter: &logrus.TextFormatter{DisableColors: true}, | ||||||
|  | 		Level:     logrus.DebugLevel, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	user := k8sapi.AuthInfo{Token: "abc"} | ||||||
|  | 	cli, err := newClient( | ||||||
|  | 		k8sapi.Cluster{}, | ||||||
|  | 		user, | ||||||
|  | 		"test", | ||||||
|  | 		logger, | ||||||
|  | 		true, | ||||||
|  | 	) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	fpath := filepath.Join(os.TempDir(), "test.in_cluster") | ||||||
|  | 	defer os.RemoveAll(fpath) | ||||||
|  |  | ||||||
|  | 	err = ioutil.WriteFile(fpath, []byte("def"), 0644) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name     string | ||||||
|  | 		time     func() time.Time | ||||||
|  | 		expected string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "Stale token", | ||||||
|  | 			time: func() time.Time { | ||||||
|  | 				return time.Now().Add(-24 * time.Hour) | ||||||
|  | 			}, | ||||||
|  | 			expected: "def", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Normal token", | ||||||
|  | 			time: func() time.Time { | ||||||
|  | 				return time.Time{} | ||||||
|  | 			}, | ||||||
|  | 			expected: "abc", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			helper := newInClusterTransportHelper(user) | ||||||
|  | 			helper.now = tc.time | ||||||
|  | 			helper.tokenLocation = fpath | ||||||
|  |  | ||||||
|  | 			cli.client.Transport = transport{ | ||||||
|  | 				updateReq: func(r *http.Request) { | ||||||
|  | 					helper.UpdateToken() | ||||||
|  | 					r.Header.Set("Authorization", "Bearer "+helper.GetToken()) | ||||||
|  | 				}, | ||||||
|  | 				base: cli.client.Transport, | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			_ = cli.isCRDReady("test") | ||||||
|  | 			require.Equal(t, tc.expected, helper.info.Token) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNamespaceFromServiceAccountJWT(t *testing.T) { | ||||||
|  | 	namespace, err := namespaceFromServiceAccountJWT(serviceAccountToken) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	wantNamespace := "dex-test-namespace" | ||||||
|  | 	if namespace != wantNamespace { | ||||||
|  | 		t.Errorf("expected namespace %q got %q", wantNamespace, namespace) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestGetClusterConfigNamespace(t *testing.T) { | func TestGetClusterConfigNamespace(t *testing.T) { | ||||||
| 	const namespaceENVVariableName = "TEST_GET_CLUSTER_CONFIG_NAMESPACE" | 	const namespaceENVVariableName = "TEST_GET_CLUSTER_CONFIG_NAMESPACE" | ||||||
| 	{ | 	{ | ||||||
|   | |||||||
| @@ -83,7 +83,7 @@ func (c *Config) open(logger log.Logger, waitForResources bool) (*client, error) | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cli, err := newClient(cluster, user, namespace, logger) | 	cli, err := newClient(cluster, user, namespace, logger, c.InCluster) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("create client: %v", err) | 		return nil, fmt.Errorf("create client: %v", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										119
									
								
								storage/kubernetes/transport.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								storage/kubernetes/transport.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,119 @@ | |||||||
|  | package kubernetes | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/dexidp/dex/storage/kubernetes/k8sapi" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // transport is a simple http.Transport wrapper | ||||||
|  | type transport struct { | ||||||
|  | 	updateReq func(r *http.Request) | ||||||
|  | 	base      http.RoundTripper | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t transport) RoundTrip(r *http.Request) (*http.Response, error) { | ||||||
|  | 	// shallow copy of the struct | ||||||
|  | 	r2 := new(http.Request) | ||||||
|  | 	*r2 = *r | ||||||
|  | 	// deep copy of the Header | ||||||
|  | 	r2.Header = make(http.Header, len(r.Header)) | ||||||
|  | 	for k, s := range r.Header { | ||||||
|  | 		r2.Header[k] = append([]string(nil), s...) | ||||||
|  | 	} | ||||||
|  | 	t.updateReq(r2) | ||||||
|  | 	return t.base.RoundTrip(r2) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func wrapRoundTripper(base http.RoundTripper, user k8sapi.AuthInfo, inCluster bool) http.RoundTripper { | ||||||
|  | 	if inCluster { | ||||||
|  | 		inClusterTransportHelper := newInClusterTransportHelper(user) | ||||||
|  | 		return transport{ | ||||||
|  | 			updateReq: func(r *http.Request) { | ||||||
|  | 				inClusterTransportHelper.UpdateToken() | ||||||
|  | 				r.Header.Set("Authorization", "Bearer "+inClusterTransportHelper.GetToken()) | ||||||
|  | 			}, | ||||||
|  | 			base: base, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if user.Token != "" { | ||||||
|  | 		return transport{ | ||||||
|  | 			updateReq: func(r *http.Request) { | ||||||
|  | 				r.Header.Set("Authorization", "Bearer "+user.Token) | ||||||
|  | 			}, | ||||||
|  | 			base: base, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if user.Username != "" && user.Password != "" { | ||||||
|  | 		return transport{ | ||||||
|  | 			updateReq: func(r *http.Request) { | ||||||
|  | 				r.SetBasicAuth(user.Username, user.Password) | ||||||
|  | 			}, | ||||||
|  | 			base: base, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return base | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // renewTokenPeriod is the interval after which dex will read the token from a well-known file. | ||||||
|  | //   By Kubernetes documentation, this interval should be at least one minute long. | ||||||
|  | //   Kubernetes client-go v0.15+ uses 10 seconds long interval. | ||||||
|  | //   Dex uses the reasonable value between these two. | ||||||
|  | const renewTokenPeriod = 30 * time.Second | ||||||
|  |  | ||||||
|  | // inClusterTransportHelper is capable of safely updating the user token. | ||||||
|  | //   BoundServiceAccountTokenVolume feature is enabled in Kubernetes >=1.21 by default. | ||||||
|  | //   With this feature, the service account token in the pod becomes periodically updated. | ||||||
|  | //   Therefore, Dex needs to re-read the token from the disk after some time to be sure that it uses the valid token. | ||||||
|  | type inClusterTransportHelper struct { | ||||||
|  | 	mu   sync.RWMutex | ||||||
|  | 	info k8sapi.AuthInfo | ||||||
|  |  | ||||||
|  | 	expiry time.Time | ||||||
|  | 	now    func() time.Time | ||||||
|  |  | ||||||
|  | 	tokenLocation string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newInClusterTransportHelper(info k8sapi.AuthInfo) *inClusterTransportHelper { | ||||||
|  | 	user := inClusterTransportHelper{ | ||||||
|  | 		info:          info, | ||||||
|  | 		now:           time.Now, | ||||||
|  | 		tokenLocation: "/var/run/secrets/kubernetes.io/serviceaccount/token", | ||||||
|  | 	} | ||||||
|  | 	user.UpdateToken() | ||||||
|  | 	return &user | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *inClusterTransportHelper) UpdateToken() { | ||||||
|  | 	c.mu.RLock() | ||||||
|  | 	exp := c.expiry | ||||||
|  | 	c.mu.RUnlock() | ||||||
|  |  | ||||||
|  | 	if !c.now().After(exp) { | ||||||
|  | 		// Do not need to update token yet | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	token, err := ioutil.ReadFile(c.tokenLocation) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.mu.Lock() | ||||||
|  | 	defer c.mu.Unlock() | ||||||
|  | 	c.info.Token = string(token) | ||||||
|  | 	c.expiry = c.now().Add(renewTokenPeriod) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *inClusterTransportHelper) GetToken() string { | ||||||
|  | 	c.mu.RLock() | ||||||
|  | 	defer c.mu.RUnlock() | ||||||
|  | 	return c.info.Token | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user