2016-07-25 20:00:28 +00:00
|
|
|
package server
|
2016-10-05 15:01:35 +00:00
|
|
|
|
|
|
|
import (
|
2019-07-26 01:13:37 +00:00
|
|
|
"bytes"
|
2017-03-08 18:33:19 +00:00
|
|
|
"context"
|
2019-07-26 01:13:37 +00:00
|
|
|
"encoding/json"
|
2019-02-04 17:45:13 +00:00
|
|
|
"errors"
|
2016-10-05 15:01:35 +00:00
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
|
|
|
"testing"
|
2019-02-04 17:45:13 +00:00
|
|
|
|
|
|
|
"github.com/dexidp/dex/storage"
|
2016-10-05 15:01:35 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
func TestHandleHealth(t *testing.T) {
|
2016-10-13 01:51:32 +00:00
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
2016-10-14 01:15:20 +00:00
|
|
|
httpServer, server := newTestServer(ctx, t, nil)
|
2016-10-05 15:01:35 +00:00
|
|
|
defer httpServer.Close()
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder()
|
2019-02-04 17:45:13 +00:00
|
|
|
server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil))
|
2016-10-05 15:01:35 +00:00
|
|
|
if rr.Code != http.StatusOK {
|
|
|
|
t.Errorf("expected 200 got %d", rr.Code)
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2019-02-04 17:45:13 +00:00
|
|
|
|
|
|
|
type badStorage struct {
|
|
|
|
storage.Storage
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *badStorage) CreateAuthRequest(r storage.AuthRequest) error {
|
|
|
|
return errors.New("storage unavailable")
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestHandleHealthFailure(t *testing.T) {
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
httpServer, server := newTestServer(ctx, t, func(c *Config) {
|
|
|
|
c.Storage = &badStorage{c.Storage}
|
|
|
|
})
|
|
|
|
defer httpServer.Close()
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil))
|
|
|
|
if rr.Code != http.StatusInternalServerError {
|
|
|
|
t.Errorf("expected 500 got %d", rr.Code)
|
|
|
|
}
|
|
|
|
}
|
2019-07-26 01:13:37 +00:00
|
|
|
|
|
|
|
type emptyStorage struct {
|
|
|
|
storage.Storage
|
|
|
|
}
|
|
|
|
|
|
|
|
func (*emptyStorage) GetAuthRequest(string) (storage.AuthRequest, error) {
|
|
|
|
return storage.AuthRequest{}, storage.ErrNotFound
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestHandleInvalidOAuth2Callbacks(t *testing.T) {
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
httpServer, server := newTestServer(ctx, t, func(c *Config) {
|
|
|
|
c.Storage = &emptyStorage{c.Storage}
|
|
|
|
})
|
|
|
|
defer httpServer.Close()
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
TargetURI string
|
|
|
|
ExpectedCode int
|
|
|
|
}{
|
|
|
|
{"/callback", http.StatusBadRequest},
|
|
|
|
{"/callback?code=&state=", http.StatusBadRequest},
|
|
|
|
{"/callback?code=AAAAAAA&state=BBBBBBB", http.StatusBadRequest},
|
|
|
|
}
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
|
|
|
|
for i, r := range tests {
|
|
|
|
server.ServeHTTP(rr, httptest.NewRequest("GET", r.TargetURI, nil))
|
|
|
|
if rr.Code != r.ExpectedCode {
|
|
|
|
t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestHandleInvalidSAMLCallbacks(t *testing.T) {
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
httpServer, server := newTestServer(ctx, t, func(c *Config) {
|
|
|
|
c.Storage = &emptyStorage{c.Storage}
|
|
|
|
})
|
|
|
|
defer httpServer.Close()
|
|
|
|
|
|
|
|
type requestForm struct {
|
|
|
|
RelayState string
|
|
|
|
}
|
|
|
|
tests := []struct {
|
|
|
|
RequestForm requestForm
|
|
|
|
ExpectedCode int
|
|
|
|
}{
|
|
|
|
{requestForm{}, http.StatusBadRequest},
|
|
|
|
{requestForm{RelayState: "AAAAAAA"}, http.StatusBadRequest},
|
|
|
|
}
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
|
|
|
|
for i, r := range tests {
|
|
|
|
jsonValue, err := json.Marshal(r.RequestForm)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err.Error())
|
|
|
|
}
|
|
|
|
server.ServeHTTP(rr, httptest.NewRequest("POST", "/callback", bytes.NewBuffer(jsonValue)))
|
|
|
|
if rr.Code != r.ExpectedCode {
|
|
|
|
t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|