fix: prevent cross-site scripting for the device flow
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
		| @@ -11,6 +11,8 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"golang.org/x/net/html" | ||||
|  | ||||
| 	"github.com/dexidp/dex/pkg/log" | ||||
| 	"github.com/dexidp/dex/storage" | ||||
| ) | ||||
| @@ -247,7 +249,9 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { | ||||
|  | ||||
| 		// Authorization redirect callback from OAuth2 auth flow. | ||||
| 		if errMsg := r.FormValue("error"); errMsg != "" { | ||||
| 			http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) | ||||
| 			// escape the message to prevent cross-site scripting | ||||
| 			msg := html.EscapeString(errMsg + ": " + r.FormValue("error_description")) | ||||
| 			http.Error(w, msg, http.StatusBadRequest) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
|   | ||||
| @@ -160,12 +160,13 @@ func TestDeviceCallback(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		testName             string | ||||
| 		expectedResponseCode int | ||||
| 		values               formValues | ||||
| 		testAuthCode         storage.AuthCode | ||||
| 		testDeviceRequest    storage.DeviceRequest | ||||
| 		testDeviceToken      storage.DeviceToken | ||||
| 		testName               string | ||||
| 		expectedResponseCode   int | ||||
| 		expectedServerResponse string | ||||
| 		values                 formValues | ||||
| 		testAuthCode           storage.AuthCode | ||||
| 		testDeviceRequest      storage.DeviceRequest | ||||
| 		testDeviceToken        storage.DeviceToken | ||||
| 	}{ | ||||
| 		{ | ||||
| 			testName: "Missing State", | ||||
| @@ -192,7 +193,8 @@ func TestDeviceCallback(t *testing.T) { | ||||
| 				code:  "somecode", | ||||
| 				error: "Error Condition", | ||||
| 			}, | ||||
| 			expectedResponseCode: http.StatusBadRequest, | ||||
| 			expectedResponseCode:   http.StatusBadRequest, | ||||
| 			expectedServerResponse: "Error Condition: \n", | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName: "Expired Auth Code", | ||||
| @@ -314,6 +316,16 @@ func TestDeviceCallback(t *testing.T) { | ||||
| 			testDeviceToken:      baseDeviceToken, | ||||
| 			expectedResponseCode: http.StatusOK, | ||||
| 		}, | ||||
| 		{ | ||||
| 			testName: "Prevent cross-site scripting", | ||||
| 			values: formValues{ | ||||
| 				state: "XXXX-XXXX", | ||||
| 				code:  "somecode", | ||||
| 				error: "<script>console.log(window);</script>", | ||||
| 			}, | ||||
| 			expectedResponseCode:   http.StatusBadRequest, | ||||
| 			expectedServerResponse: "<script>console.log(window);</script>: \n", | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tc := range tests { | ||||
| 		t.Run(tc.testName, func(t *testing.T) { | ||||
| @@ -366,6 +378,13 @@ func TestDeviceCallback(t *testing.T) { | ||||
| 			if rr.Code != tc.expectedResponseCode { | ||||
| 				t.Errorf("%s: Unexpected Response Type.  Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) | ||||
| 			} | ||||
|  | ||||
| 			if len(tc.expectedServerResponse) > 0 { | ||||
| 				result, _ := io.ReadAll(rr.Body) | ||||
| 				if string(result) != tc.expectedServerResponse { | ||||
| 					t.Errorf("%s: Unexpected Response.  Expected %q got %q", tc.testName, tc.expectedServerResponse, result) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user