fix: prevent cross-site scripting for the device flow
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
		| @@ -11,6 +11,8 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/html" | ||||||
|  |  | ||||||
| 	"github.com/dexidp/dex/pkg/log" | 	"github.com/dexidp/dex/pkg/log" | ||||||
| 	"github.com/dexidp/dex/storage" | 	"github.com/dexidp/dex/storage" | ||||||
| ) | ) | ||||||
| @@ -247,7 +249,9 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { | |||||||
|  |  | ||||||
| 		// Authorization redirect callback from OAuth2 auth flow. | 		// Authorization redirect callback from OAuth2 auth flow. | ||||||
| 		if errMsg := r.FormValue("error"); errMsg != "" { | 		if errMsg := r.FormValue("error"); errMsg != "" { | ||||||
| 			http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) | 			// escape the message to prevent cross-site scripting | ||||||
|  | 			msg := html.EscapeString(errMsg + ": " + r.FormValue("error_description")) | ||||||
|  | 			http.Error(w, msg, http.StatusBadRequest) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -160,12 +160,13 @@ func TestDeviceCallback(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| 		testName             string | 		testName               string | ||||||
| 		expectedResponseCode int | 		expectedResponseCode   int | ||||||
| 		values               formValues | 		expectedServerResponse string | ||||||
| 		testAuthCode         storage.AuthCode | 		values                 formValues | ||||||
| 		testDeviceRequest    storage.DeviceRequest | 		testAuthCode           storage.AuthCode | ||||||
| 		testDeviceToken      storage.DeviceToken | 		testDeviceRequest      storage.DeviceRequest | ||||||
|  | 		testDeviceToken        storage.DeviceToken | ||||||
| 	}{ | 	}{ | ||||||
| 		{ | 		{ | ||||||
| 			testName: "Missing State", | 			testName: "Missing State", | ||||||
| @@ -192,7 +193,8 @@ func TestDeviceCallback(t *testing.T) { | |||||||
| 				code:  "somecode", | 				code:  "somecode", | ||||||
| 				error: "Error Condition", | 				error: "Error Condition", | ||||||
| 			}, | 			}, | ||||||
| 			expectedResponseCode: http.StatusBadRequest, | 			expectedResponseCode:   http.StatusBadRequest, | ||||||
|  | 			expectedServerResponse: "Error Condition: \n", | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			testName: "Expired Auth Code", | 			testName: "Expired Auth Code", | ||||||
| @@ -314,6 +316,16 @@ func TestDeviceCallback(t *testing.T) { | |||||||
| 			testDeviceToken:      baseDeviceToken, | 			testDeviceToken:      baseDeviceToken, | ||||||
| 			expectedResponseCode: http.StatusOK, | 			expectedResponseCode: http.StatusOK, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			testName: "Prevent cross-site scripting", | ||||||
|  | 			values: formValues{ | ||||||
|  | 				state: "XXXX-XXXX", | ||||||
|  | 				code:  "somecode", | ||||||
|  | 				error: "<script>console.log(window);</script>", | ||||||
|  | 			}, | ||||||
|  | 			expectedResponseCode:   http.StatusBadRequest, | ||||||
|  | 			expectedServerResponse: "<script>console.log(window);</script>: \n", | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| 	for _, tc := range tests { | 	for _, tc := range tests { | ||||||
| 		t.Run(tc.testName, func(t *testing.T) { | 		t.Run(tc.testName, func(t *testing.T) { | ||||||
| @@ -366,6 +378,13 @@ func TestDeviceCallback(t *testing.T) { | |||||||
| 			if rr.Code != tc.expectedResponseCode { | 			if rr.Code != tc.expectedResponseCode { | ||||||
| 				t.Errorf("%s: Unexpected Response Type.  Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) | 				t.Errorf("%s: Unexpected Response Type.  Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			if len(tc.expectedServerResponse) > 0 { | ||||||
|  | 				result, _ := io.ReadAll(rr.Body) | ||||||
|  | 				if string(result) != tc.expectedServerResponse { | ||||||
|  | 					t.Errorf("%s: Unexpected Response.  Expected %q got %q", tc.testName, tc.expectedServerResponse, result) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user