412 lines
12 KiB
Go
Raw Normal View History

package tfa
2019-01-30 16:52:47 +00:00
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
2019-04-23 17:49:16 +01:00
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
2019-01-30 16:52:47 +00:00
)
/**
* Setup
*/
func init() {
config = newDefaultConfig()
config.LogLevel = "panic"
log = NewDefaultLogger()
2019-01-30 16:52:47 +00:00
}
/**
* Tests
*/
func TestServerAuthHandlerInvalid(t *testing.T) {
2019-04-23 17:49:16 +01:00
assert := assert.New(t)
config = newDefaultConfig()
2019-01-30 16:52:47 +00:00
// Should redirect vanilla request to login url
2019-05-07 14:16:38 +01:00
req := newDefaultHttpRequest("/foo")
2019-04-23 17:49:16 +01:00
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "vanilla request should be redirected")
2019-01-30 16:52:47 +00:00
fwd, _ := res.Location()
2019-04-23 17:49:16 +01:00
assert.Equal("https", fwd.Scheme, "vanilla request should be redirected to google")
assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
2019-01-30 16:52:47 +00:00
// Check state string
qs := fwd.Query()
state, exists := qs["state"]
require.True(t, exists)
require.Len(t, state, 1)
parts := strings.SplitN(state[0], ":", 3)
require.Len(t, parts, 3)
assert.Equal("google", parts[1])
assert.Equal("http://example.com/foo", parts[2])
2019-01-30 16:52:47 +00:00
// Should catch invalid cookie
2019-05-07 14:16:38 +01:00
req = newDefaultHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
parts = strings.Split(c.Value, "|")
2019-01-30 16:52:47 +00:00
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
2019-04-23 17:49:16 +01:00
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "invalid cookie should not be authorised")
2019-01-30 16:52:47 +00:00
// Should validate email
2019-05-07 14:16:38 +01:00
req = newDefaultHttpRequest("/foo")
c = MakeCookie(req, "test@example.com")
config.Domains = []string{"test.com"}
2019-01-30 16:52:47 +00:00
2019-04-23 17:49:16 +01:00
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "invalid email should not be authorised")
}
func TestServerAuthHandlerExpired(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
config.Lifetime = time.Second * time.Duration(-1)
config.Domains = []string{"test.com"}
// Should redirect expired cookie
req := newDefaultHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
res, _ := doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")
// Check for CSRF cookie
var cookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == config.CSRFCookieName {
cookie = c
}
}
assert.NotNil(cookie)
// Check redirection location
fwd, _ := res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
}
func TestServerAuthHandlerValid(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
2019-01-30 16:52:47 +00:00
// Should allow valid request email
req := newDefaultHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
config.Domains = []string{}
2019-01-30 16:52:47 +00:00
res, _ := doHttpRequest(req, c)
2019-04-23 17:49:16 +01:00
assert.Equal(200, res.StatusCode, "valid request should be allowed")
2019-01-30 16:52:47 +00:00
// Should pass through user
users := res.Header["X-Forwarded-User"]
2019-04-23 17:49:16 +01:00
assert.Len(users, 1, "valid request should have X-Forwarded-User header")
assert.Equal([]string{"test@example.com"}, users, "X-Forwarded-User header should match user")
2019-01-30 16:52:47 +00:00
}
func TestServerAuthCallback(t *testing.T) {
2019-04-23 17:49:16 +01:00
assert := assert.New(t)
config = newDefaultConfig()
2019-01-30 16:52:47 +00:00
// Setup token server
tokenServerHandler := &TokenServerHandler{}
tokenServer := httptest.NewServer(tokenServerHandler)
defer tokenServer.Close()
tokenUrl, _ := url.Parse(tokenServer.URL)
config.Providers.Google.TokenURL = tokenUrl
// Setup user server
userServerHandler := &UserServerHandler{}
userServer := httptest.NewServer(userServerHandler)
defer userServer.Close()
userUrl, _ := url.Parse(userServer.URL)
config.Providers.Google.UserURL = userUrl
// Should pass auth response request to callback
2019-05-07 14:16:38 +01:00
req := newDefaultHttpRequest("/_oauth")
2019-04-23 17:49:16 +01:00
res, _ := doHttpRequest(req, nil)
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
2019-01-30 16:52:47 +00:00
// Should catch invalid csrf cookie
2019-05-07 14:16:38 +01:00
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
c := MakeCSRFCookie(req, "nononononononononononononononono")
2019-04-23 17:49:16 +01:00
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
2019-01-30 16:52:47 +00:00
// Should redirect valid request
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
2019-04-23 17:49:16 +01:00
res, _ = doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
2019-01-30 16:52:47 +00:00
fwd, _ := res.Location()
2019-04-23 17:49:16 +01:00
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
assert.Equal("redirect", fwd.Host, "valid request should be redirected to return url")
assert.Equal("", fwd.Path, "valid request should be redirected to return url")
2019-01-30 16:52:47 +00:00
}
func TestServerDefaultAction(t *testing.T) {
2019-04-23 17:49:16 +01:00
assert := assert.New(t)
config = newDefaultConfig()
2019-05-07 14:16:38 +01:00
req := newDefaultHttpRequest("/random")
2019-04-23 17:49:16 +01:00
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request should require auth with auth default handler")
config.DefaultAction = "allow"
2019-05-07 14:16:38 +01:00
req = newDefaultHttpRequest("/random")
2019-04-23 17:49:16 +01:00
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
}
func TestServerDefaultProvider(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
// Should use "google" as default provider when not specified
req := newDefaultHttpRequest("/random")
res, _ := doHttpRequest(req, nil)
fwd, _ := res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
// Should use alternative default provider when set
config.DefaultProvider = "oidc"
config.Providers.OIDC.OAuthProvider.Config = &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: "https://oidc.com/oidcauth",
},
}
res, _ = doHttpRequest(req, nil)
fwd, _ = res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to oidc")
assert.Equal("oidc.com", fwd.Host, "request with expired cookie should be redirected to oidc")
assert.Equal("/oidcauth", fwd.Path, "request with expired cookie should be redirected to oidc")
}
2019-05-07 14:16:38 +01:00
func TestServerRouteHeaders(t *testing.T) {
2019-04-23 17:49:16 +01:00
assert := assert.New(t)
config = newDefaultConfig()
config.Rules = map[string]*Rule{
2019-05-07 14:16:38 +01:00
"1": {
Action: "allow",
2019-05-07 14:16:38 +01:00
Rule: "Headers(`X-Test`, `test123`)",
},
"2": {
Action: "allow",
Rule: "HeadersRegexp(`X-Test`, `test(456|789)`)",
},
}
// Should block any request
req := newDefaultHttpRequest("/random")
req.Header.Add("X-Random", "hello")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
// Should allow matching
req = newDefaultHttpRequest("/api")
req.Header.Add("X-Test", "test123")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
// Should allow matching
req = newDefaultHttpRequest("/api")
req.Header.Add("X-Test", "test789")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
func TestServerRouteHost(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
2019-05-07 14:16:38 +01:00
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
Rule: "Host(`api.example.com`)",
},
"2": {
Action: "allow",
Rule: "HostRegexp(`sub{num:[0-9]}.example.com`)",
},
}
// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
// Should allow matching request
req = newHttpRequest("GET", "https://api.example.com/", "/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
// Should allow matching request
req = newHttpRequest("GET", "https://sub8.example.com/", "/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
func TestServerRouteMethod(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
2019-05-07 14:16:38 +01:00
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
Rule: "Method(`PUT`)",
2019-01-30 16:52:47 +00:00
},
}
2019-04-17 11:29:35 +01:00
// Should block any request
2019-05-07 14:16:38 +01:00
req := newHttpRequest("GET", "https://example.com/", "/")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
// Should allow matching request
req = newHttpRequest("PUT", "https://example.com/", "/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
func TestServerRoutePath(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
2019-05-07 14:16:38 +01:00
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
Rule: "Path(`/api`)",
},
"2": {
Action: "allow",
Rule: "PathPrefix(`/private`)",
},
}
// Should block any request
req := newDefaultHttpRequest("/random")
2019-04-23 17:49:16 +01:00
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
2019-01-30 16:52:47 +00:00
// Should allow /api request
2019-05-07 14:16:38 +01:00
req = newDefaultHttpRequest("/api")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
// Should allow /private request
req = newDefaultHttpRequest("/private")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
req = newDefaultHttpRequest("/private/path")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
func TestServerRouteQuery(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
2019-05-07 14:16:38 +01:00
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
Rule: "Query(`q=test123`)",
},
}
// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/?q=no")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
// Should allow matching request
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123")
2019-04-23 17:49:16 +01:00
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
2019-01-30 16:52:47 +00:00
}
/**
* Utilities
*/
type TokenServerHandler struct{}
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"access_token":"123456789"}`)
}
type UserServerHandler struct{}
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{
"id":"1",
"email":"example@example.com",
"verified_email":true,
"hd":"example.com"
}`)
}
2019-04-23 17:49:16 +01:00
func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
w := httptest.NewRecorder()
// Set cookies on recorder
if c != nil {
http.SetCookie(w, c)
}
// Copy into request
for _, c := range w.HeaderMap["Set-Cookie"] {
r.Header.Add("Cookie", c)
}
NewServer().RootHandler(w, r)
res := w.Result()
body, _ := ioutil.ReadAll(res.Body)
// if res.StatusCode > 300 && res.StatusCode < 400 {
// fmt.Printf("%#v", res.Header)
// }
return res, string(body)
}
func newDefaultConfig() *Config {
config, _ = NewConfig([]string{
"--providers.google.client-id=id",
"--providers.google.client-secret=secret",
})
// Setup the google providers without running all the config validation
config.Providers.Google.Setup()
return config
}
2019-05-07 14:16:38 +01:00
func newDefaultHttpRequest(uri string) *http.Request {
return newHttpRequest("", "http://example.com/", uri)
}
func newHttpRequest(method, dest, uri string) *http.Request {
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
p, _ := url.Parse(dest)
r.Header.Add("X-Forwarded-Method", method)
r.Header.Add("X-Forwarded-Proto", p.Scheme)
2019-05-07 14:16:38 +01:00
r.Header.Add("X-Forwarded-Host", p.Host)
r.Header.Add("X-Forwarded-Uri", uri)
return r
}