Skip to content

Commit

Permalink
Merge pull request #1887 from vmware-tanzu/cli_callback_cors_get
Browse files Browse the repository at this point in the history
CLI's localhost listener handles CORS preflight requests for GETs
  • Loading branch information
cfryanr authored Mar 8, 2024
2 parents f881bbb + d49b011 commit 61835e9
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 110 deletions.
109 changes: 52 additions & 57 deletions pkg/oidcclient/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,77 +941,72 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
}
}()

// Calculate the allowed origin for CORS.
issuerURL, err := url.Parse(h.issuer)
if err != nil {
// This shouldn't happen in practice because the URL is normally validated before this function is called.
// Avoid using httperr.Wrap because that would hide the details of err from the browser output.
return httperr.Newf(http.StatusInternalServerError, "invalid issuer url: %s", err.Error())
}
allowOrigin := issuerURL.Scheme + "://" + issuerURL.Host

var params url.Values
if h.useFormPost { //nolint:nestif
// Return HTTP 405 for anything that's not a POST or an OPTIONS request.
if r.Method != http.MethodPost && r.Method != http.MethodOptions {
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Got unexpected request on callback listener", "method", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)

switch r.Method {
case http.MethodOptions:
// Google Chrome decided that it should do CORS preflight checks for this Javascript form submission POST request.
// See https://developer.chrome.com/blog/private-network-access-preflight/
// It seems like Chrome will likely soon also add CORS preflight checks for GET requests on redirects.
// See https://chromestatus.com/feature/4869685172764672
origin := r.Header.Get("Origin")
if origin == "" {
// The CORS preflight request should have an origin.
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Got OPTIONS request without origin header")
w.WriteHeader(http.StatusBadRequest)
return nil // keep listening for more requests
}

// For POST and OPTIONS requests, calculate the allowed origin for CORS.
issuerURL, err := url.Parse(h.issuer)
if err != nil {
// Avoid using httperr.Wrap because that would hide the details of err from the browser output.
return httperr.Newf(http.StatusInternalServerError, "invalid issuer url: %s", err.Error())
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Got CORS preflight request from browser", "origin", origin)
// To tell the browser that it is okay to make the real POST or GET request, return the following response.
w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
w.Header().Set("Vary", "*") // supposed to use Vary when Access-Control-Allow-Origin is a specific host
w.Header().Set("Access-Control-Allow-Credentials", "false")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Private-Network", "true")
// If the browser would like to send some headers on the real request, allow them. Chrome doesn't
// currently send this header at the moment. This is in case some browser in the future decides to
// request to be allowed to send specific headers by using Access-Control-Request-Headers.
requestedHeaders := r.Header.Get("Access-Control-Request-Headers")
if requestedHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", requestedHeaders)
}
allowOrigin := issuerURL.Scheme + "://" + issuerURL.Host

if r.Method == http.MethodOptions {
// Google Chrome decided that it should do CORS preflight checks for this Javascript form submission POST request.
// See https://developer.chrome.com/blog/private-network-access-preflight/
origin := r.Header.Get("Origin")
if origin == "" {
// The CORS preflight request should have an origin.
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Got OPTIONS request without origin header")
w.WriteHeader(http.StatusBadRequest)
return nil // keep listening for more requests
}
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Got CORS preflight request from browser", "origin", origin)
// To tell the browser that it is okay to make the real POST request, return the following response.
w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
w.Header().Set("Vary", "*") // supposed to use Vary when Access-Control-Allow-Origin is a specific host
w.Header().Set("Access-Control-Allow-Credentials", "false")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Private-Network", "true")
// If the browser would like to send some headers on the real request, allow them. Chrome doesn't
// currently send this header at the moment. This is in case some browser in the future decides to
// request to be allowed to send specific headers by using Access-Control-Request-Headers.
requestedHeaders := r.Header.Get("Access-Control-Request-Headers")
if requestedHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", requestedHeaders)
}
w.WriteHeader(http.StatusNoContent)
return nil // keep listening for more requests
} // Otherwise, this is a POST request...
w.WriteHeader(http.StatusNoContent)
return nil // keep listening for more requests

case http.MethodPost:
// Parse and pull the response parameters from an application/x-www-form-urlencoded request body.
if err = r.ParseForm(); err != nil {
// Avoid using httperr.Wrap because that would hide the details of err from the browser output.
return httperr.Newf(http.StatusBadRequest, "invalid form: %s", err.Error())
}
params = r.Form

// Allow CORS requests for POST so in the future our Javascript code can be updated to use the fetch API's
// mode "cors", and still be compatible with older CLI versions starting with those that have this code
// for CORS headers. Updating to use CORS would allow our Javascript code (form_post.js) to see the true
// http response status from this endpoint. Note that the POST response does not need to set as many CORS
// headers as the OPTIONS preflight response.
w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
w.Header().Set("Vary", "*") // supposed to use Vary when Access-Control-Allow-Origin is a specific host
} else {
// When we are not using form_post, then return HTTP 405 for anything that's not a GET.
if r.Method != http.MethodGet {
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Got unexpected request on callback listener", "method", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return nil // keep listening for more requests
}
params = r.Form // grab the params and continue handling this request below

case http.MethodGet:
// Pull response parameters from the URL query string.
params = r.URL.Query()
params = r.URL.Query() // grab the params and continue handling this request below

default:
// Return HTTP 405 for anything that's not a POST, GET, or an OPTIONS request.
h.logger.V(plog.KlogLevelDebug).Info("Pinniped: Got unexpected request on callback listener", "method", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return nil // keep listening for more requests
}

// Allow CORS requests for POST so our Javascript code can use the fetch API's mode "cors" (see form_post.js)
// to allow the JS code see the true http response status from this endpoint. Note that the POST response
// does not need to set as many CORS headers as the OPTIONS preflight response.
w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
w.Header().Set("Vary", "*") // supposed to use Vary when Access-Control-Allow-Origin is a specific host

// At this point, it doesn't matter if we got the params from a form_post POST request or a regular GET request.
// Next, validate the params, and if we got an authcode then try to use it to complete the login.

Expand Down
117 changes: 66 additions & 51 deletions pkg/oidcclient/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2673,12 +2673,6 @@ func TestHandlePasteCallback(t *testing.T) {
func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"

withFormPostMode := func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
return nil
}
}
tests := []struct {
name string
method string
Expand All @@ -2694,17 +2688,8 @@ func TestHandleAuthCodeCallback(t *testing.T) {
}{
{
name: "wrong method returns an error but keeps listening",
method: http.MethodPost,
query: "",
wantNoCallbacks: true,
wantHeaders: map[string][]string{},
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "wrong method for form_post returns an error but keeps listening",
method: http.MethodGet,
method: http.MethodHead,
query: "",
opt: withFormPostMode,
wantNoCallbacks: true,
wantHeaders: map[string][]string{},
wantHTTPStatus: http.StatusMethodNotAllowed,
Expand All @@ -2715,34 +2700,45 @@ func TestHandleAuthCodeCallback(t *testing.T) {
query: "",
headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}},
body: []byte(`%`),
opt: withFormPostMode,
wantErr: `invalid form: invalid URL escape "%"`,
wantHeaders: map[string][]string{},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid state",
query: "state=invalid",
wantErr: "missing or invalid state parameter",
wantHeaders: map[string][]string{},
name: "invalid state",
method: http.MethodGet,
query: "state=invalid",
wantErr: "missing or invalid state parameter",
wantHeaders: map[string][]string{
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Vary": {"*"},
},
wantHTTPStatus: http.StatusForbidden,
},
{
name: "error code from provider",
query: "state=test-state&error=some_error",
wantErr: `login failed with code "some_error"`,
wantHeaders: map[string][]string{},
name: "error code from provider",
method: http.MethodGet,
query: "state=test-state&error=some_error",
wantErr: `login failed with code "some_error"`,
wantHeaders: map[string][]string{
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Vary": {"*"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "error code with a description from provider",
query: "state=test-state&error=some_error&error_description=optional%20error%20description",
wantErr: `login failed with code "some_error": optional error description`,
wantHeaders: map[string][]string{},
name: "error code with a description from provider",
method: http.MethodGet,
query: "state=test-state&error=some_error&error_description=optional%20error%20description",
wantErr: `login failed with code "some_error": optional error description`,
wantHeaders: map[string][]string{
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Vary": {"*"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "in form post mode, invalid issuer url config during CORS preflight request returns an error",
name: "invalid issuer url config during CORS preflight request returns an error",
method: http.MethodOptions,
query: "",
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
Expand All @@ -2751,14 +2747,13 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantHTTPStatus: http.StatusInternalServerError,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "://bad-url"
return nil
}
},
},
{
name: "in form post mode, invalid issuer url config during POST request returns an error",
name: "invalid issuer url config during POST request returns an error",
method: http.MethodPost,
query: "",
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
Expand All @@ -2767,45 +2762,57 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantHTTPStatus: http.StatusInternalServerError,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "://bad-url"
return nil
}
},
},
{
name: "in form post mode, options request is missing origin header results in 400 and keeps listener running",
name: "invalid issuer url config during GET request returns an error",
method: http.MethodGet,
query: "code=foo",
headers: map[string][]string{},
wantErr: `invalid issuer url: parse "://bad-url": missing protocol scheme`,
wantHeaders: map[string][]string{},
wantHTTPStatus: http.StatusInternalServerError,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.issuer = "://bad-url"
return nil
}
},
},
{
name: "options request is missing origin header results in 400 and keeps listener running",
method: http.MethodOptions,
query: "",
opt: withFormPostMode,
wantNoCallbacks: true,
wantHeaders: map[string][]string{},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "in form post mode, valid CORS request responds with 402 and CORS headers and keeps listener running",
name: "valid CORS request responds with 402 and CORS headers and keeps listener running",
method: http.MethodOptions,
query: "",
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
wantNoCallbacks: true,
wantHTTPStatus: http.StatusNoContent,
wantHeaders: map[string][]string{
"Access-Control-Allow-Credentials": {"false"},
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Vary": {"*"},
"Access-Control-Allow-Private-Network": {"true"},
},
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "https://valid-issuer.com/with/some/path"
return nil
}
},
},
{
name: "in form post mode, valid CORS request with Access-Control-Request-Headers responds with 402 and CORS headers including Access-Control-Allow-Headers and keeps listener running",
name: "valid CORS request with Access-Control-Request-Headers responds with 402 and CORS headers including Access-Control-Allow-Headers and keeps listener running",
method: http.MethodOptions,
query: "",
headers: map[string][]string{
Expand All @@ -2816,25 +2823,28 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantHTTPStatus: http.StatusNoContent,
wantHeaders: map[string][]string{
"Access-Control-Allow-Credentials": {"false"},
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Vary": {"*"},
"Access-Control-Allow-Private-Network": {"true"},
"Access-Control-Allow-Headers": {"header1, header2, header3"},
},
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "https://valid-issuer.com/with/some/path"
return nil
}
},
},
{
name: "invalid code",
query: "state=test-state&code=invalid",
wantErr: "could not complete authorization code exchange: some exchange error",
wantHeaders: map[string][]string{},
name: "invalid code",
method: http.MethodGet,
query: "state=test-state&code=invalid",
wantErr: "could not complete authorization code exchange: some exchange error",
wantHeaders: map[string][]string{
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Vary": {"*"},
},
wantHTTPStatus: http.StatusBadRequest,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
Expand All @@ -2852,9 +2862,14 @@ func TestHandleAuthCodeCallback(t *testing.T) {
},
{
name: "valid",
method: http.MethodGet,
query: "state=test-state&code=valid",
wantHTTPStatus: http.StatusOK,
wantHeaders: map[string][]string{"Content-Type": {"text/plain; charset=utf-8"}},
wantHeaders: map[string][]string{
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Vary": {"*"},
"Content-Type": {"text/plain; charset=utf-8"},
},
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
Expand Down Expand Up @@ -2882,7 +2897,6 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantHTTPStatus: http.StatusOK,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
Expand Down Expand Up @@ -2911,7 +2925,6 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantHTTPStatus: http.StatusOK,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
Expand All @@ -2925,6 +2938,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
},
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -2948,10 +2962,11 @@ func TestHandleAuthCodeCallback(t *testing.T) {
resp := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", bytes.NewBuffer(tt.body))
require.NoError(t, err)

require.NotEmptyf(t, tt.method, "test author mistake: method is required on the test table entry")
req.Method = tt.method

req.URL.RawQuery = tt.query
if tt.method != "" {
req.Method = tt.method
}
if tt.headers != nil {
req.Header = tt.headers
}
Expand Down
Loading

0 comments on commit 61835e9

Please sign in to comment.