Allow to be run without middleware + improve request reading consistency (#217)

Prior to this change, the request URI was only ever read from the
X-Forwarded-Uri header which was only set when the container was
accessed via the forwardauth middleware. As such, it was necessary
to apply the treafik-forward-auth middleware to the treafik-forward-auth
container when running auth host mode.
This is a quirk, unnecessary complexity and is a frequent source of
configuration issues.
This commit is contained in:
Thom Seddon
2021-06-24 21:45:28 +01:00
committed by GitHub
parent 4ffb6593d5
commit c4317b7503
6 changed files with 74 additions and 54 deletions

View File

@ -31,6 +31,37 @@ func init() {
* Tests
*/
func TestServerRootHandler(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
// X-Forwarded headers should be read into request
req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil)
req.Header.Add("X-Forwarded-Method", "GET")
req.Header.Add("X-Forwarded-Proto", "https")
req.Header.Add("X-Forwarded-Host", "example.com")
req.Header.Add("X-Forwarded-Uri", "/foo?q=bar")
NewServer().RootHandler(httptest.NewRecorder(), req)
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request")
assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request")
// Other X-Forwarded headers should be read in into request and original URL
// should be preserved if X-Forwarded-Uri not present
req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil)
req.Header.Add("X-Forwarded-Method", "GET")
req.Header.Add("X-Forwarded-Proto", "https")
req.Header.Add("X-Forwarded-Host", "example.com")
NewServer().RootHandler(httptest.NewRecorder(), req)
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present")
assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present")
}
func TestServerAuthHandlerInvalid(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) {
config.Domains = []string{"test.com"}
// Should redirect expired cookie
req := newDefaultHttpRequest("/foo")
req := newHTTPRequest("GET", "http://example.com/foo")
c := MakeCookie(req, "test@example.com")
res, _ := doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")
require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected")
// Check for CSRF cookie
var cookie *http.Cookie
@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
config = newDefaultConfig()
// Should allow valid request email
req := newDefaultHttpRequest("/foo")
req := newHTTPRequest("GET", "http://example.com/foo")
c := MakeCookie(req, "test@example.com")
config.Domains = []string{}
@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
func TestServerAuthCallback(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
config = newDefaultConfig()
// Setup OAuth server
@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) {
}
// Should pass auth response request to callback
req := newDefaultHttpRequest("/_oauth")
req := newHTTPRequest("GET", "http://example.com/_oauth")
res, _ := doHttpRequest(req, nil)
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
// Should catch invalid csrf cookie
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
nonce := "12345678901234567890123456789012"
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect")
c := MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
// Should catch invalid provider cookie
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect")
c = MakeCSRFCookie(req, nonce)
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised")
// Should redirect valid request
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect")
c = MakeCSRFCookie(req, nonce)
res, _ = doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
require.Equal(307, res.StatusCode, "valid auth callback should be allowed")
fwd, _ := res.Location()
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) {
}
// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/")
req := newHTTPRequest("GET", "https://example.com/")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
// Should allow matching request
req = newHttpRequest("GET", "https://api.example.com/", "/")
req = newHTTPRequest("GET", "https://api.example.com/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
// Should allow matching request
req = newHttpRequest("GET", "https://sub8.example.com/", "/")
req = newHTTPRequest("GET", "https://sub8.example.com/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) {
}
// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/")
req := newHTTPRequest("GET", "https://example.com/")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
// Should allow matching request
req = newHttpRequest("PUT", "https://example.com/", "/")
req = newHTTPRequest("PUT", "https://example.com/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) {
}
// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/?q=no")
req := newHTTPRequest("GET", "https://example.com/?q=no")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
// Should allow matching request
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123")
req = newHTTPRequest("GET", "https://api.example.com/?q=test123")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
@ -531,16 +564,17 @@ func newDefaultConfig() *Config {
return config
}
// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
func newDefaultHttpRequest(uri string) *http.Request {
return newHttpRequest("", "http://example.com/", uri)
return newHTTPRequest("GET", "http://example.com"+uri)
}
func newHttpRequest(method, dest, uri string) *http.Request {
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
p, _ := url.Parse(dest)
func newHTTPRequest(method, target string) *http.Request {
u, _ := url.Parse(target)
r := httptest.NewRequest(method, target, nil)
r.Header.Add("X-Forwarded-Method", method)
r.Header.Add("X-Forwarded-Proto", p.Scheme)
r.Header.Add("X-Forwarded-Host", p.Host)
r.Header.Add("X-Forwarded-Uri", uri)
r.Header.Add("X-Forwarded-Proto", u.Scheme)
r.Header.Add("X-Forwarded-Host", u.Host)
r.Header.Add("X-Forwarded-Uri", u.RequestURI())
return r
}