diff --git a/server/handlers.go b/server/handlers.go index 5bdf39f0..d961dbef 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/gorilla/mux" @@ -20,31 +22,85 @@ import ( "github.com/dexidp/dex/storage" ) -func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { - start := s.now() - err := func() error { - // Instead of trying to introspect health, just try to use the underlying storage. - a := storage.AuthRequest{ - ID: storage.NewID(), - ClientID: storage.NewID(), +// newHealthChecker returns the healthz handler. The handler runs until the +// provided context is canceled. +func (s *Server) newHealthChecker(ctx context.Context) http.Handler { + h := &healthChecker{s: s} - // Set a short expiry so if the delete fails this will be cleaned up quickly by garbage collection. - Expiry: s.now().Add(time.Minute), - } + // Perform one health check synchronously so the returned handler returns + // valid data immediately. + h.runHealthCheck() - if err := s.storage.CreateAuthRequest(a); err != nil { - return fmt.Errorf("create auth request: %v", err) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Second * 15): + } + h.runHealthCheck() } - if err := s.storage.DeleteAuthRequest(a.ID); err != nil { - return fmt.Errorf("delete auth request: %v", err) - } - return nil }() + return h +} - t := s.now().Sub(start) +// healthChecker periodically performs health checks on server dependenices. +// Currently, it only checks that the storage layer is avialable. +type healthChecker struct { + s *Server + + // Result of the last health check: any error and the amount of time it took + // to query the storage. + mu sync.RWMutex + // Guarded by the mutex + err error + passed time.Duration +} + +// runHealthCheck performs a single health check and makes the result available +// for any clients performing and HTTP request against the healthChecker. +func (h *healthChecker) runHealthCheck() { + t := h.s.now() + err := checkStorageHealth(h.s.storage, h.s.now) + passed := h.s.now().Sub(t) if err != nil { - s.logger.Errorf("Storage health check failed: %v", err) - s.renderError(w, http.StatusInternalServerError, "Health check failed.") + h.s.logger.Errorf("Storage health check failed: %v", err) + } + + // Make sure to only hold the mutex to access the fields, and not while + // we're querying the storage object. + h.mu.Lock() + h.err = err + h.passed = passed + h.mu.Unlock() +} + +func checkStorageHealth(s storage.Storage, now func() time.Time) error { + a := storage.AuthRequest{ + ID: storage.NewID(), + ClientID: storage.NewID(), + + // Set a short expiry so if the delete fails this will be cleaned up quickly by garbage collection. + Expiry: now().Add(time.Minute), + } + + if err := s.CreateAuthRequest(a); err != nil { + return fmt.Errorf("create auth request: %v", err) + } + if err := s.DeleteAuthRequest(a.ID); err != nil { + return fmt.Errorf("delete auth request: %v", err) + } + return nil +} + +func (h *healthChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.mu.RLock() + err := h.err + t := h.passed + h.mu.RUnlock() + + if err != nil { + h.s.renderError(w, http.StatusInternalServerError, "Health check failed.") return } fmt.Fprintf(w, "Health check passed in %s", t) diff --git a/server/handlers_test.go b/server/handlers_test.go index 4c410b8e..3e0b1e81 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -2,9 +2,12 @@ package server import ( "context" + "errors" "net/http" "net/http/httptest" "testing" + + "github.com/dexidp/dex/storage" ) func TestHandleHealth(t *testing.T) { @@ -15,9 +18,33 @@ func TestHandleHealth(t *testing.T) { defer httpServer.Close() rr := httptest.NewRecorder() - server.handleHealth(rr, httptest.NewRequest("GET", "/healthz", nil)) + server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil)) if rr.Code != http.StatusOK { t.Errorf("expected 200 got %d", rr.Code) } } + +type badStorage struct { + storage.Storage +} + +func (b *badStorage) CreateAuthRequest(r storage.AuthRequest) error { + return errors.New("storage unavailable") +} + +func TestHandleHealthFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, server := newTestServer(ctx, t, func(c *Config) { + c.Storage = &badStorage{c.Storage} + }) + defer httpServer.Close() + + rr := httptest.NewRecorder() + server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil)) + if rr.Code != http.StatusInternalServerError { + t.Errorf("expected 500 got %d", rr.Code) + } +} diff --git a/server/server.go b/server/server.go index ee3355b5..58968a1b 100644 --- a/server/server.go +++ b/server/server.go @@ -242,8 +242,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } r := mux.NewRouter() + handle := func(p string, h http.Handler) { + r.Handle(path.Join(issuerURL.Path, p), instrumentHandlerCounter(p, h)) + } handleFunc := func(p string, h http.HandlerFunc) { - r.HandleFunc(path.Join(issuerURL.Path, p), instrumentHandlerCounter(p, h)) + handle(p, h) } handlePrefix := func(p string, h http.Handler) { prefix := path.Join(issuerURL.Path, p) @@ -284,7 +287,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) // "authproxy" connector. handleFunc("/callback/{connector}", s.handleConnectorCallback) handleFunc("/approval", s.handleApproval) - handleFunc("/healthz", s.handleHealth) + handle("/healthz", s.newHealthChecker(ctx)) handlePrefix("/static", static) handlePrefix("/theme", theme) s.mux = r