|
|
|
@ -31,6 +31,37 @@ func init() {
|
|
|
|
|
* Tests
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
func TestServerRootHandler(t *testing.T) {
|
|
|
|
|
assert := assert.New(t)
|
|
|
|
|
config = newDefaultConfig()
|
|
|
|
|
|
|
|
|
|
// X-Forwarded headers should be read into request
|
|
|
|
|
req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil)
|
|
|
|
|
req.Header.Add("X-Forwarded-Method", "GET")
|
|
|
|
|
req.Header.Add("X-Forwarded-Proto", "https")
|
|
|
|
|
req.Header.Add("X-Forwarded-Host", "example.com")
|
|
|
|
|
req.Header.Add("X-Forwarded-Uri", "/foo?q=bar")
|
|
|
|
|
NewServer().RootHandler(httptest.NewRecorder(), req)
|
|
|
|
|
|
|
|
|
|
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
|
|
|
|
|
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
|
|
|
|
|
assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request")
|
|
|
|
|
assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request")
|
|
|
|
|
|
|
|
|
|
// Other X-Forwarded headers should be read in into request and original URL
|
|
|
|
|
// should be preserved if X-Forwarded-Uri not present
|
|
|
|
|
req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil)
|
|
|
|
|
req.Header.Add("X-Forwarded-Method", "GET")
|
|
|
|
|
req.Header.Add("X-Forwarded-Proto", "https")
|
|
|
|
|
req.Header.Add("X-Forwarded-Host", "example.com")
|
|
|
|
|
NewServer().RootHandler(httptest.NewRecorder(), req)
|
|
|
|
|
|
|
|
|
|
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
|
|
|
|
|
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
|
|
|
|
|
assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present")
|
|
|
|
|
assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestServerAuthHandlerInvalid(t *testing.T) {
|
|
|
|
|
assert := assert.New(t)
|
|
|
|
|
config = newDefaultConfig()
|
|
|
|
@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) {
|
|
|
|
|
config.Domains = []string{"test.com"}
|
|
|
|
|
|
|
|
|
|
// Should redirect expired cookie
|
|
|
|
|
req := newDefaultHttpRequest("/foo")
|
|
|
|
|
req := newHTTPRequest("GET", "http://example.com/foo")
|
|
|
|
|
c := MakeCookie(req, "test@example.com")
|
|
|
|
|
res, _ := doHttpRequest(req, c)
|
|
|
|
|
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")
|
|
|
|
|
require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected")
|
|
|
|
|
|
|
|
|
|
// Check for CSRF cookie
|
|
|
|
|
var cookie *http.Cookie
|
|
|
|
@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
|
|
|
|
|
config = newDefaultConfig()
|
|
|
|
|
|
|
|
|
|
// Should allow valid request email
|
|
|
|
|
req := newDefaultHttpRequest("/foo")
|
|
|
|
|
req := newHTTPRequest("GET", "http://example.com/foo")
|
|
|
|
|
c := MakeCookie(req, "test@example.com")
|
|
|
|
|
config.Domains = []string{}
|
|
|
|
|
|
|
|
|
@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
|
|
|
|
|
|
|
|
|
|
func TestServerAuthCallback(t *testing.T) {
|
|
|
|
|
assert := assert.New(t)
|
|
|
|
|
require := require.New(t)
|
|
|
|
|
config = newDefaultConfig()
|
|
|
|
|
|
|
|
|
|
// Setup OAuth server
|
|
|
|
@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Should pass auth response request to callback
|
|
|
|
|
req := newDefaultHttpRequest("/_oauth")
|
|
|
|
|
req := newHTTPRequest("GET", "http://example.com/_oauth")
|
|
|
|
|
res, _ := doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
|
|
|
|
|
|
|
|
|
|
// Should catch invalid csrf cookie
|
|
|
|
|
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
|
|
|
|
nonce := "12345678901234567890123456789012"
|
|
|
|
|
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect")
|
|
|
|
|
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
|
|
|
|
res, _ = doHttpRequest(req, c)
|
|
|
|
|
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
|
|
|
|
|
|
|
|
|
|
// Should catch invalid provider cookie
|
|
|
|
|
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect")
|
|
|
|
|
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
|
|
|
|
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect")
|
|
|
|
|
c = MakeCSRFCookie(req, nonce)
|
|
|
|
|
res, _ = doHttpRequest(req, c)
|
|
|
|
|
assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised")
|
|
|
|
|
|
|
|
|
|
// Should redirect valid request
|
|
|
|
|
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
|
|
|
|
|
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
|
|
|
|
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect")
|
|
|
|
|
c = MakeCSRFCookie(req, nonce)
|
|
|
|
|
res, _ = doHttpRequest(req, c)
|
|
|
|
|
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
|
|
|
|
require.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
|
|
|
|
|
|
|
|
|
fwd, _ := res.Location()
|
|
|
|
|
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
|
|
|
|
@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Should block any request
|
|
|
|
|
req := newHttpRequest("GET", "https://example.com/", "/")
|
|
|
|
|
req := newHTTPRequest("GET", "https://example.com/")
|
|
|
|
|
res, _ := doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
|
|
|
|
|
|
|
|
|
// Should allow matching request
|
|
|
|
|
req = newHttpRequest("GET", "https://api.example.com/", "/")
|
|
|
|
|
req = newHTTPRequest("GET", "https://api.example.com/")
|
|
|
|
|
res, _ = doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
|
|
|
|
|
|
|
|
|
// Should allow matching request
|
|
|
|
|
req = newHttpRequest("GET", "https://sub8.example.com/", "/")
|
|
|
|
|
req = newHTTPRequest("GET", "https://sub8.example.com/")
|
|
|
|
|
res, _ = doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
|
|
|
|
}
|
|
|
|
@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Should block any request
|
|
|
|
|
req := newHttpRequest("GET", "https://example.com/", "/")
|
|
|
|
|
req := newHTTPRequest("GET", "https://example.com/")
|
|
|
|
|
res, _ := doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
|
|
|
|
|
|
|
|
|
// Should allow matching request
|
|
|
|
|
req = newHttpRequest("PUT", "https://example.com/", "/")
|
|
|
|
|
req = newHTTPRequest("PUT", "https://example.com/")
|
|
|
|
|
res, _ = doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
|
|
|
|
}
|
|
|
|
@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Should block any request
|
|
|
|
|
req := newHttpRequest("GET", "https://example.com/", "/?q=no")
|
|
|
|
|
req := newHTTPRequest("GET", "https://example.com/?q=no")
|
|
|
|
|
res, _ := doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
|
|
|
|
|
|
|
|
|
// Should allow matching request
|
|
|
|
|
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123")
|
|
|
|
|
req = newHTTPRequest("GET", "https://api.example.com/?q=test123")
|
|
|
|
|
res, _ = doHttpRequest(req, nil)
|
|
|
|
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
|
|
|
|
}
|
|
|
|
@ -531,16 +564,17 @@ func newDefaultConfig() *Config {
|
|
|
|
|
return config
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
|
|
|
|
|
func newDefaultHttpRequest(uri string) *http.Request {
|
|
|
|
|
return newHttpRequest("", "http://example.com/", uri)
|
|
|
|
|
return newHTTPRequest("GET", "http://example.com"+uri)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newHttpRequest(method, dest, uri string) *http.Request {
|
|
|
|
|
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
|
|
|
|
|
p, _ := url.Parse(dest)
|
|
|
|
|
func newHTTPRequest(method, target string) *http.Request {
|
|
|
|
|
u, _ := url.Parse(target)
|
|
|
|
|
r := httptest.NewRequest(method, target, nil)
|
|
|
|
|
r.Header.Add("X-Forwarded-Method", method)
|
|
|
|
|
r.Header.Add("X-Forwarded-Proto", p.Scheme)
|
|
|
|
|
r.Header.Add("X-Forwarded-Host", p.Host)
|
|
|
|
|
r.Header.Add("X-Forwarded-Uri", uri)
|
|
|
|
|
r.Header.Add("X-Forwarded-Proto", u.Scheme)
|
|
|
|
|
r.Header.Add("X-Forwarded-Host", u.Host)
|
|
|
|
|
r.Header.Add("X-Forwarded-Uri", u.RequestURI())
|
|
|
|
|
return r
|
|
|
|
|
}
|
|
|
|
|