Prior to this change, the request URI was only ever read from the X-Forwarded-Uri header which was only set when the container was accessed via the forwardauth middleware. As such, it was necessary to apply the treafik-forward-auth middleware to the treafik-forward-auth container when running auth host mode. This is a quirk, unnecessary complexity and is a frequent source of configuration issues.
		
			
				
	
	
		
			581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package tfa
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"net/http"
 | 
						|
	"net/http/httptest"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/sirupsen/logrus"
 | 
						|
	"github.com/sirupsen/logrus/hooks/test"
 | 
						|
	"github.com/stretchr/testify/assert"
 | 
						|
	"github.com/stretchr/testify/require"
 | 
						|
	"golang.org/x/oauth2"
 | 
						|
)
 | 
						|
 | 
						|
/**
 | 
						|
 * Setup
 | 
						|
 */
 | 
						|
 | 
						|
func init() {
 | 
						|
	config = newDefaultConfig()
 | 
						|
	config.LogLevel = "panic"
 | 
						|
	log = NewDefaultLogger()
 | 
						|
}
 | 
						|
 | 
						|
/**
 | 
						|
 * Tests
 | 
						|
 */
 | 
						|
 | 
						|
func TestServerRootHandler(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	// X-Forwarded headers should be read into request
 | 
						|
	req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil)
 | 
						|
	req.Header.Add("X-Forwarded-Method", "GET")
 | 
						|
	req.Header.Add("X-Forwarded-Proto", "https")
 | 
						|
	req.Header.Add("X-Forwarded-Host", "example.com")
 | 
						|
	req.Header.Add("X-Forwarded-Uri", "/foo?q=bar")
 | 
						|
	NewServer().RootHandler(httptest.NewRecorder(), req)
 | 
						|
 | 
						|
	assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
 | 
						|
	assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
 | 
						|
	assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request")
 | 
						|
	assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request")
 | 
						|
 | 
						|
	// Other X-Forwarded headers should be read in into request and original URL
 | 
						|
	// should be preserved if X-Forwarded-Uri not present
 | 
						|
	req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil)
 | 
						|
	req.Header.Add("X-Forwarded-Method", "GET")
 | 
						|
	req.Header.Add("X-Forwarded-Proto", "https")
 | 
						|
	req.Header.Add("X-Forwarded-Host", "example.com")
 | 
						|
	NewServer().RootHandler(httptest.NewRecorder(), req)
 | 
						|
 | 
						|
	assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
 | 
						|
	assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
 | 
						|
	assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present")
 | 
						|
	assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerAuthHandlerInvalid(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
	var hook *test.Hook
 | 
						|
	log, hook = test.NewNullLogger()
 | 
						|
 | 
						|
	// Should redirect vanilla request to login url
 | 
						|
	req := newDefaultHttpRequest("/foo")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(307, res.StatusCode, "vanilla request should be redirected")
 | 
						|
 | 
						|
	fwd, _ := res.Location()
 | 
						|
	assert.Equal("https", fwd.Scheme, "vanilla request should be redirected to google")
 | 
						|
	assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
 | 
						|
	assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
 | 
						|
 | 
						|
	// Check state string
 | 
						|
	qs := fwd.Query()
 | 
						|
	state, exists := qs["state"]
 | 
						|
	require.True(t, exists)
 | 
						|
	require.Len(t, state, 1)
 | 
						|
	parts := strings.SplitN(state[0], ":", 3)
 | 
						|
	require.Len(t, parts, 3)
 | 
						|
	assert.Equal("google", parts[1])
 | 
						|
	assert.Equal("http://example.com/foo", parts[2])
 | 
						|
 | 
						|
	// Should warn as using http without insecure cookie
 | 
						|
	logs := hook.AllEntries()
 | 
						|
	assert.Len(logs, 1)
 | 
						|
	assert.Equal("You are using \"secure\" cookies for a request that was not "+
 | 
						|
		"received via https. You should either redirect to https or pass the "+
 | 
						|
		"\"insecure-cookie\" config option to permit cookies via http.", logs[0].Message)
 | 
						|
	assert.Equal(logrus.WarnLevel, logs[0].Level)
 | 
						|
 | 
						|
	// Should catch invalid cookie
 | 
						|
	req = newDefaultHttpRequest("/foo")
 | 
						|
	c := MakeCookie(req, "test@example.com")
 | 
						|
	parts = strings.Split(c.Value, "|")
 | 
						|
	c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
 | 
						|
 | 
						|
	res, _ = doHttpRequest(req, c)
 | 
						|
	assert.Equal(401, res.StatusCode, "invalid cookie should not be authorised")
 | 
						|
 | 
						|
	// Should validate email
 | 
						|
	req = newDefaultHttpRequest("/foo")
 | 
						|
	c = MakeCookie(req, "test@example.com")
 | 
						|
	config.Domains = []string{"test.com"}
 | 
						|
 | 
						|
	res, _ = doHttpRequest(req, c)
 | 
						|
	assert.Equal(401, res.StatusCode, "invalid email should not be authorised")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerAuthHandlerExpired(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
	config.Lifetime = time.Second * time.Duration(-1)
 | 
						|
	config.Domains = []string{"test.com"}
 | 
						|
 | 
						|
	// Should redirect expired cookie
 | 
						|
	req := newHTTPRequest("GET", "http://example.com/foo")
 | 
						|
	c := MakeCookie(req, "test@example.com")
 | 
						|
	res, _ := doHttpRequest(req, c)
 | 
						|
	require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected")
 | 
						|
 | 
						|
	// Check for CSRF cookie
 | 
						|
	var cookie *http.Cookie
 | 
						|
	for _, c := range res.Cookies() {
 | 
						|
		if strings.HasPrefix(c.Name, config.CSRFCookieName) {
 | 
						|
			cookie = c
 | 
						|
		}
 | 
						|
	}
 | 
						|
	assert.NotNil(cookie)
 | 
						|
 | 
						|
	// Check redirection location
 | 
						|
	fwd, _ := res.Location()
 | 
						|
	assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
 | 
						|
	assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
 | 
						|
	assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerAuthHandlerValid(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	// Should allow valid request email
 | 
						|
	req := newHTTPRequest("GET", "http://example.com/foo")
 | 
						|
	c := MakeCookie(req, "test@example.com")
 | 
						|
	config.Domains = []string{}
 | 
						|
 | 
						|
	res, _ := doHttpRequest(req, c)
 | 
						|
	assert.Equal(200, res.StatusCode, "valid request should be allowed")
 | 
						|
 | 
						|
	// Should pass through user
 | 
						|
	users := res.Header["X-Forwarded-User"]
 | 
						|
	assert.Len(users, 1, "valid request should have X-Forwarded-User header")
 | 
						|
	assert.Equal([]string{"test@example.com"}, users, "X-Forwarded-User header should match user")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerAuthCallback(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	require := require.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	// Setup OAuth server
 | 
						|
	server, serverURL := NewOAuthServer(t)
 | 
						|
	defer server.Close()
 | 
						|
	config.Providers.Google.TokenURL = &url.URL{
 | 
						|
		Scheme: serverURL.Scheme,
 | 
						|
		Host:   serverURL.Host,
 | 
						|
		Path:   "/token",
 | 
						|
	}
 | 
						|
	config.Providers.Google.UserURL = &url.URL{
 | 
						|
		Scheme: serverURL.Scheme,
 | 
						|
		Host:   serverURL.Host,
 | 
						|
		Path:   "/userinfo",
 | 
						|
	}
 | 
						|
 | 
						|
	// Should pass auth response request to callback
 | 
						|
	req := newHTTPRequest("GET", "http://example.com/_oauth")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
 | 
						|
 | 
						|
	// Should catch invalid csrf cookie
 | 
						|
	nonce := "12345678901234567890123456789012"
 | 
						|
	req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect")
 | 
						|
	c := MakeCSRFCookie(req, "nononononononononononononononono")
 | 
						|
	res, _ = doHttpRequest(req, c)
 | 
						|
	assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
 | 
						|
 | 
						|
	// Should catch invalid provider cookie
 | 
						|
	req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect")
 | 
						|
	c = MakeCSRFCookie(req, nonce)
 | 
						|
	res, _ = doHttpRequest(req, c)
 | 
						|
	assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised")
 | 
						|
 | 
						|
	// Should redirect valid request
 | 
						|
	req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect")
 | 
						|
	c = MakeCSRFCookie(req, nonce)
 | 
						|
	res, _ = doHttpRequest(req, c)
 | 
						|
	require.Equal(307, res.StatusCode, "valid auth callback should be allowed")
 | 
						|
 | 
						|
	fwd, _ := res.Location()
 | 
						|
	assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
 | 
						|
	assert.Equal("redirect", fwd.Host, "valid request should be redirected to return url")
 | 
						|
	assert.Equal("", fwd.Path, "valid request should be redirected to return url")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerAuthCallbackExchangeFailure(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	// Setup OAuth server
 | 
						|
	server, serverURL := NewFailingOAuthServer(t)
 | 
						|
	defer server.Close()
 | 
						|
	config.Providers.Google.TokenURL = &url.URL{
 | 
						|
		Scheme: serverURL.Scheme,
 | 
						|
		Host:   serverURL.Host,
 | 
						|
		Path:   "/token",
 | 
						|
	}
 | 
						|
	config.Providers.Google.UserURL = &url.URL{
 | 
						|
		Scheme: serverURL.Scheme,
 | 
						|
		Host:   serverURL.Host,
 | 
						|
		Path:   "/userinfo",
 | 
						|
	}
 | 
						|
 | 
						|
	// Should handle failed code exchange
 | 
						|
	req := newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
 | 
						|
	c := MakeCSRFCookie(req, "12345678901234567890123456789012")
 | 
						|
	res, _ := doHttpRequest(req, c)
 | 
						|
	assert.Equal(503, res.StatusCode, "auth callback should handle failed code exchange")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerAuthCallbackUserFailure(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	// Setup OAuth server
 | 
						|
	server, serverURL := NewOAuthServer(t)
 | 
						|
	defer server.Close()
 | 
						|
	config.Providers.Google.TokenURL = &url.URL{
 | 
						|
		Scheme: serverURL.Scheme,
 | 
						|
		Host:   serverURL.Host,
 | 
						|
		Path:   "/token",
 | 
						|
	}
 | 
						|
	serverFail, serverFailURL := NewFailingOAuthServer(t)
 | 
						|
	defer serverFail.Close()
 | 
						|
	config.Providers.Google.UserURL = &url.URL{
 | 
						|
		Scheme: serverFailURL.Scheme,
 | 
						|
		Host:   serverFailURL.Host,
 | 
						|
		Path:   "/userinfo",
 | 
						|
	}
 | 
						|
 | 
						|
	// Should handle failed user request
 | 
						|
	req := newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
 | 
						|
	c := MakeCSRFCookie(req, "12345678901234567890123456789012")
 | 
						|
	res, _ := doHttpRequest(req, c)
 | 
						|
	assert.Equal(503, res.StatusCode, "auth callback should handle failed user request")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerLogout(t *testing.T) {
 | 
						|
	require := require.New(t)
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	req := newDefaultHttpRequest("/_oauth/logout")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	require.Equal(401, res.StatusCode, "should return a 401")
 | 
						|
 | 
						|
	// Check for cookie
 | 
						|
	var cookie *http.Cookie
 | 
						|
	for _, c := range res.Cookies() {
 | 
						|
		if c.Name == config.CookieName {
 | 
						|
			cookie = c
 | 
						|
		}
 | 
						|
	}
 | 
						|
	require.NotNil(cookie)
 | 
						|
	require.Less(cookie.Expires.Local().Unix(), time.Now().Local().Unix()-50, "cookie should have expired")
 | 
						|
 | 
						|
	// Test with redirect
 | 
						|
	config.LogoutRedirect = "http://redirect/path"
 | 
						|
	req = newDefaultHttpRequest("/_oauth/logout")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	require.Equal(307, res.StatusCode, "should return a 307")
 | 
						|
 | 
						|
	// Check for cookie
 | 
						|
	cookie = nil
 | 
						|
	for _, c := range res.Cookies() {
 | 
						|
		if c.Name == config.CookieName {
 | 
						|
			cookie = c
 | 
						|
		}
 | 
						|
	}
 | 
						|
	require.NotNil(cookie)
 | 
						|
	require.Less(cookie.Expires.Local().Unix(), time.Now().Local().Unix()-50, "cookie should have expired")
 | 
						|
 | 
						|
	fwd, _ := res.Location()
 | 
						|
	require.NotNil(fwd)
 | 
						|
	assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
 | 
						|
	assert.Equal("redirect", fwd.Host, "valid request should be redirected to return url")
 | 
						|
	assert.Equal("/path", fwd.Path, "valid request should be redirected to return url")
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func TestServerDefaultAction(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	req := newDefaultHttpRequest("/random")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(307, res.StatusCode, "request should require auth with auth default handler")
 | 
						|
 | 
						|
	config.DefaultAction = "allow"
 | 
						|
	req = newDefaultHttpRequest("/random")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerDefaultProvider(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
 | 
						|
	// Should use "google" as default provider when not specified
 | 
						|
	req := newDefaultHttpRequest("/random")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	fwd, _ := res.Location()
 | 
						|
	assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
 | 
						|
	assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
 | 
						|
	assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
 | 
						|
 | 
						|
	// Should use alternative default provider when set
 | 
						|
	config.DefaultProvider = "oidc"
 | 
						|
	config.Providers.OIDC.OAuthProvider.Config = &oauth2.Config{
 | 
						|
		Endpoint: oauth2.Endpoint{
 | 
						|
			AuthURL: "https://oidc.com/oidcauth",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	fwd, _ = res.Location()
 | 
						|
	assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to oidc")
 | 
						|
	assert.Equal("oidc.com", fwd.Host, "request with expired cookie should be redirected to oidc")
 | 
						|
	assert.Equal("/oidcauth", fwd.Path, "request with expired cookie should be redirected to oidc")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerRouteHeaders(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
	config.Rules = map[string]*Rule{
 | 
						|
		"1": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "Headers(`X-Test`, `test123`)",
 | 
						|
		},
 | 
						|
		"2": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "HeadersRegexp(`X-Test`, `test(456|789)`)",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	// Should block any request
 | 
						|
	req := newDefaultHttpRequest("/random")
 | 
						|
	req.Header.Add("X-Random", "hello")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
 | 
						|
 | 
						|
	// Should allow matching
 | 
						|
	req = newDefaultHttpRequest("/api")
 | 
						|
	req.Header.Add("X-Test", "test123")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
 | 
						|
	// Should allow matching
 | 
						|
	req = newDefaultHttpRequest("/api")
 | 
						|
	req.Header.Add("X-Test", "test789")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerRouteHost(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
	config.Rules = map[string]*Rule{
 | 
						|
		"1": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "Host(`api.example.com`)",
 | 
						|
		},
 | 
						|
		"2": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "HostRegexp(`sub{num:[0-9]}.example.com`)",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	// Should block any request
 | 
						|
	req := newHTTPRequest("GET", "https://example.com/")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
 | 
						|
 | 
						|
	// Should allow matching request
 | 
						|
	req = newHTTPRequest("GET", "https://api.example.com/")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
 | 
						|
	// Should allow matching request
 | 
						|
	req = newHTTPRequest("GET", "https://sub8.example.com/")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerRouteMethod(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
	config.Rules = map[string]*Rule{
 | 
						|
		"1": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "Method(`PUT`)",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	// Should block any request
 | 
						|
	req := newHTTPRequest("GET", "https://example.com/")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
 | 
						|
 | 
						|
	// Should allow matching request
 | 
						|
	req = newHTTPRequest("PUT", "https://example.com/")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerRoutePath(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
	config.Rules = map[string]*Rule{
 | 
						|
		"1": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "Path(`/api`)",
 | 
						|
		},
 | 
						|
		"2": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "PathPrefix(`/private`)",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	// Should block any request
 | 
						|
	req := newDefaultHttpRequest("/random")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
 | 
						|
 | 
						|
	// Should allow /api request
 | 
						|
	req = newDefaultHttpRequest("/api")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
 | 
						|
	// Should allow /private request
 | 
						|
	req = newDefaultHttpRequest("/private")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
 | 
						|
	req = newDefaultHttpRequest("/private/path")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
}
 | 
						|
 | 
						|
func TestServerRouteQuery(t *testing.T) {
 | 
						|
	assert := assert.New(t)
 | 
						|
	config = newDefaultConfig()
 | 
						|
	config.Rules = map[string]*Rule{
 | 
						|
		"1": {
 | 
						|
			Action: "allow",
 | 
						|
			Rule:   "Query(`q=test123`)",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	// Should block any request
 | 
						|
	req := newHTTPRequest("GET", "https://example.com/?q=no")
 | 
						|
	res, _ := doHttpRequest(req, nil)
 | 
						|
	assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
 | 
						|
 | 
						|
	// Should allow matching request
 | 
						|
	req = newHTTPRequest("GET", "https://api.example.com/?q=test123")
 | 
						|
	res, _ = doHttpRequest(req, nil)
 | 
						|
	assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
 | 
						|
}
 | 
						|
 | 
						|
/**
 | 
						|
 * Utilities
 | 
						|
 */
 | 
						|
 | 
						|
type OAuthServer struct {
 | 
						|
	t    *testing.T
 | 
						|
	fail bool
 | 
						|
}
 | 
						|
 | 
						|
func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
						|
	if s.fail {
 | 
						|
		http.Error(w, "Service unavailable", 500)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if r.URL.Path == "/token" {
 | 
						|
		fmt.Fprintf(w, `{"access_token":"123456789"}`)
 | 
						|
	} else if r.URL.Path == "/userinfo" {
 | 
						|
		fmt.Fprint(w, `{
 | 
						|
			"id":"1",
 | 
						|
			"email":"example@example.com",
 | 
						|
			"verified_email":true,
 | 
						|
			"hd":"example.com"
 | 
						|
		}`)
 | 
						|
	} else {
 | 
						|
		s.t.Fatal("Unrecognised request: ", r.Method, r.URL)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func NewOAuthServer(t *testing.T) (*httptest.Server, *url.URL) {
 | 
						|
	handler := &OAuthServer{}
 | 
						|
	server := httptest.NewServer(handler)
 | 
						|
	serverURL, _ := url.Parse(server.URL)
 | 
						|
	return server, serverURL
 | 
						|
}
 | 
						|
 | 
						|
func NewFailingOAuthServer(t *testing.T) (*httptest.Server, *url.URL) {
 | 
						|
	handler := &OAuthServer{fail: true}
 | 
						|
	server := httptest.NewServer(handler)
 | 
						|
	serverURL, _ := url.Parse(server.URL)
 | 
						|
	return server, serverURL
 | 
						|
}
 | 
						|
 | 
						|
func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
 | 
						|
	w := httptest.NewRecorder()
 | 
						|
 | 
						|
	// Set cookies on recorder
 | 
						|
	if c != nil {
 | 
						|
		http.SetCookie(w, c)
 | 
						|
	}
 | 
						|
 | 
						|
	// Copy into request
 | 
						|
	for _, c := range w.HeaderMap["Set-Cookie"] {
 | 
						|
		r.Header.Add("Cookie", c)
 | 
						|
	}
 | 
						|
 | 
						|
	NewServer().RootHandler(w, r)
 | 
						|
 | 
						|
	res := w.Result()
 | 
						|
	body, _ := ioutil.ReadAll(res.Body)
 | 
						|
 | 
						|
	// if res.StatusCode > 300 && res.StatusCode < 400 {
 | 
						|
	// 	fmt.Printf("%#v", res.Header)
 | 
						|
	// }
 | 
						|
 | 
						|
	return res, string(body)
 | 
						|
}
 | 
						|
 | 
						|
func newDefaultConfig() *Config {
 | 
						|
	config, _ = NewConfig([]string{
 | 
						|
		"--providers.google.client-id=id",
 | 
						|
		"--providers.google.client-secret=secret",
 | 
						|
	})
 | 
						|
 | 
						|
	// Setup the google providers without running all the config validation
 | 
						|
	config.Providers.Google.Setup()
 | 
						|
 | 
						|
	return config
 | 
						|
}
 | 
						|
 | 
						|
// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
 | 
						|
func newDefaultHttpRequest(uri string) *http.Request {
 | 
						|
	return newHTTPRequest("GET", "http://example.com"+uri)
 | 
						|
}
 | 
						|
 | 
						|
func newHTTPRequest(method, target string) *http.Request {
 | 
						|
	u, _ := url.Parse(target)
 | 
						|
	r := httptest.NewRequest(method, target, nil)
 | 
						|
	r.Header.Add("X-Forwarded-Method", method)
 | 
						|
	r.Header.Add("X-Forwarded-Proto", u.Scheme)
 | 
						|
	r.Header.Add("X-Forwarded-Host", u.Host)
 | 
						|
	r.Header.Add("X-Forwarded-Uri", u.RequestURI())
 | 
						|
	return r
 | 
						|
}
 |