From d413870f6e877864a1d1f74453e42a4f4992b5da Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Fri, 7 May 2021 02:10:11 +0400 Subject: [PATCH] feat: Update token periodically if Dex is running in Kubernetes cluster Signed-off-by: m.nabokikh --- storage/kubernetes/client.go | 40 +--------- storage/kubernetes/client_test.go | 79 +++++++++++++++++++- storage/kubernetes/storage.go | 2 +- storage/kubernetes/transport.go | 119 ++++++++++++++++++++++++++++++ 4 files changed, 197 insertions(+), 43 deletions(-) create mode 100644 storage/kubernetes/transport.go diff --git a/storage/kubernetes/client.go b/storage/kubernetes/client.go index 07ed0182..1769bf49 100644 --- a/storage/kubernetes/client.go +++ b/storage/kubernetes/client.go @@ -303,7 +303,7 @@ func defaultTLSConfig() *tls.Config { } } -func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger) (*client, error) { +func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger, inCluster bool) (*client, error) { tlsConfig := defaultTLSConfig() data := func(b string, file string) ([]byte, error) { if b != "" { @@ -359,25 +359,7 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l if err := http2.ConfigureTransport(httpTransport); err != nil { return nil, err } - t = httpTransport - - if user.Token != "" { - t = transport{ - updateReq: func(r *http.Request) { - r.Header.Set("Authorization", "Bearer "+user.Token) - }, - base: t, - } - } - - if user.Username != "" && user.Password != "" { - t = transport{ - updateReq: func(r *http.Request) { - r.SetBasicAuth(user.Username, user.Password) - }, - base: t, - } - } + t = wrapRoundTripper(httpTransport, user, inCluster) apiVersion := "dex.coreos.com/v1" @@ -396,24 +378,6 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l }, nil } -type transport struct { - updateReq func(r *http.Request) - base http.RoundTripper -} - -func (t transport) RoundTrip(r *http.Request) (*http.Response, error) { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) - } - t.updateReq(r2) - return t.base.RoundTrip(r2) -} - func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) { data, err := ioutil.ReadFile(kubeConfigPath) if err != nil { diff --git a/storage/kubernetes/client_test.go b/storage/kubernetes/client_test.go index 0f7b0789..77fe8d48 100644 --- a/storage/kubernetes/client_test.go +++ b/storage/kubernetes/client_test.go @@ -3,12 +3,8 @@ package kubernetes import ( "hash" "hash/fnv" - "io/ioutil" - "os" "sync" "testing" - - "github.com/stretchr/testify/require" ) // This test does not have an explicit error condition but is used @@ -46,6 +42,81 @@ func TestOfflineTokenName(t *testing.T) { } } +func TestInClusterTransport(t *testing.T) { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + user := k8sapi.AuthInfo{Token: "abc"} + cli, err := newClient( + k8sapi.Cluster{}, + user, + "test", + logger, + true, + ) + require.NoError(t, err) + + fpath := filepath.Join(os.TempDir(), "test.in_cluster") + defer os.RemoveAll(fpath) + + err = ioutil.WriteFile(fpath, []byte("def"), 0644) + require.NoError(t, err) + + tests := []struct { + name string + time func() time.Time + expected string + }{ + { + name: "Stale token", + time: func() time.Time { + return time.Now().Add(-24 * time.Hour) + }, + expected: "def", + }, + { + name: "Normal token", + time: func() time.Time { + return time.Time{} + }, + expected: "abc", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + helper := newInClusterTransportHelper(user) + helper.now = tc.time + helper.tokenLocation = fpath + + cli.client.Transport = transport{ + updateReq: func(r *http.Request) { + helper.UpdateToken() + r.Header.Set("Authorization", "Bearer "+helper.GetToken()) + }, + base: cli.client.Transport, + } + + _ = cli.isCRDReady("test") + require.Equal(t, tc.expected, helper.info.Token) + }) + } +} + +func TestNamespaceFromServiceAccountJWT(t *testing.T) { + namespace, err := namespaceFromServiceAccountJWT(serviceAccountToken) + if err != nil { + t.Fatal(err) + } + wantNamespace := "dex-test-namespace" + if namespace != wantNamespace { + t.Errorf("expected namespace %q got %q", wantNamespace, namespace) + } +} + func TestGetClusterConfigNamespace(t *testing.T) { const namespaceENVVariableName = "TEST_GET_CLUSTER_CONFIG_NAMESPACE" { diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index d6349793..13549ef5 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -83,7 +83,7 @@ func (c *Config) open(logger log.Logger, waitForResources bool) (*client, error) return nil, err } - cli, err := newClient(cluster, user, namespace, logger) + cli, err := newClient(cluster, user, namespace, logger, c.InCluster) if err != nil { return nil, fmt.Errorf("create client: %v", err) } diff --git a/storage/kubernetes/transport.go b/storage/kubernetes/transport.go new file mode 100644 index 00000000..084d15db --- /dev/null +++ b/storage/kubernetes/transport.go @@ -0,0 +1,119 @@ +package kubernetes + +import ( + "io/ioutil" + "net/http" + "sync" + "time" + + "github.com/dexidp/dex/storage/kubernetes/k8sapi" +) + +// transport is a simple http.Transport wrapper +type transport struct { + updateReq func(r *http.Request) + base http.RoundTripper +} + +func (t transport) RoundTrip(r *http.Request) (*http.Response, error) { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + t.updateReq(r2) + return t.base.RoundTrip(r2) +} + +func wrapRoundTripper(base http.RoundTripper, user k8sapi.AuthInfo, inCluster bool) http.RoundTripper { + if inCluster { + inClusterTransportHelper := newInClusterTransportHelper(user) + return transport{ + updateReq: func(r *http.Request) { + inClusterTransportHelper.UpdateToken() + r.Header.Set("Authorization", "Bearer "+inClusterTransportHelper.GetToken()) + }, + base: base, + } + } + + if user.Token != "" { + return transport{ + updateReq: func(r *http.Request) { + r.Header.Set("Authorization", "Bearer "+user.Token) + }, + base: base, + } + } + + if user.Username != "" && user.Password != "" { + return transport{ + updateReq: func(r *http.Request) { + r.SetBasicAuth(user.Username, user.Password) + }, + base: base, + } + } + + return base +} + +// renewTokenPeriod is the interval after which dex will read the token from a well-known file. +// By Kubernetes documentation, this interval should be at least one minute long. +// Kubernetes client-go v0.15+ uses 10 seconds long interval. +// Dex uses the reasonable value between these two. +const renewTokenPeriod = 30 * time.Second + +// inClusterTransportHelper is capable of safely updating the user token. +// BoundServiceAccountTokenVolume feature is enabled in Kubernetes >=1.21 by default. +// With this feature, the service account token in the pod becomes periodically updated. +// Therefore, Dex needs to re-read the token from the disk after some time to be sure that it uses the valid token. +type inClusterTransportHelper struct { + mu sync.RWMutex + info k8sapi.AuthInfo + + expiry time.Time + now func() time.Time + + tokenLocation string +} + +func newInClusterTransportHelper(info k8sapi.AuthInfo) *inClusterTransportHelper { + user := inClusterTransportHelper{ + info: info, + now: time.Now, + tokenLocation: "/var/run/secrets/kubernetes.io/serviceaccount/token", + } + user.UpdateToken() + return &user +} + +func (c *inClusterTransportHelper) UpdateToken() { + c.mu.RLock() + exp := c.expiry + c.mu.RUnlock() + + if !c.now().After(exp) { + // Do not need to update token yet + return + } + + token, err := ioutil.ReadFile(c.tokenLocation) + if err != nil { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + c.info.Token = string(token) + c.expiry = c.now().Add(renewTokenPeriod) +} + +func (c *inClusterTransportHelper) GetToken() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.info.Token +}