fix: prevent cross-site scripting for the device flow

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2022-04-11 14:49:47 +04:00
parent 0270536a2e
commit 3d5a3befb4
2 changed files with 31 additions and 8 deletions

View File

@ -11,6 +11,8 @@ import (
"strings"
"time"
"golang.org/x/net/html"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage"
)
@ -247,7 +249,9 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
// Authorization redirect callback from OAuth2 auth flow.
if errMsg := r.FormValue("error"); errMsg != "" {
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
// escape the message to prevent cross-site scripting
msg := html.EscapeString(errMsg + ": " + r.FormValue("error_description"))
http.Error(w, msg, http.StatusBadRequest)
return
}

View File

@ -162,6 +162,7 @@ func TestDeviceCallback(t *testing.T) {
tests := []struct {
testName string
expectedResponseCode int
expectedServerResponse string
values formValues
testAuthCode storage.AuthCode
testDeviceRequest storage.DeviceRequest
@ -193,6 +194,7 @@ func TestDeviceCallback(t *testing.T) {
error: "Error Condition",
},
expectedResponseCode: http.StatusBadRequest,
expectedServerResponse: "Error Condition: \n",
},
{
testName: "Expired Auth Code",
@ -314,6 +316,16 @@ func TestDeviceCallback(t *testing.T) {
testDeviceToken: baseDeviceToken,
expectedResponseCode: http.StatusOK,
},
{
testName: "Prevent cross-site scripting",
values: formValues{
state: "XXXX-XXXX",
code: "somecode",
error: "<script>console.log(window);</script>",
},
expectedResponseCode: http.StatusBadRequest,
expectedServerResponse: "&lt;script&gt;console.log(window);&lt;/script&gt;: \n",
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
@ -366,6 +378,13 @@ func TestDeviceCallback(t *testing.T) {
if rr.Code != tc.expectedResponseCode {
t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code)
}
if len(tc.expectedServerResponse) > 0 {
result, _ := io.ReadAll(rr.Body)
if string(result) != tc.expectedServerResponse {
t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result)
}
}
})
}
}