238 lines
6.0 KiB
Go
Raw Normal View History

package tfa
2019-01-30 16:52:47 +00:00
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
/**
* Setup
*/
func init() {
config.LogLevel = "panic"
log = NewDefaultLogger()
2019-01-30 16:52:47 +00:00
}
/**
* Tests
*/
func TestServerAuthHandler(t *testing.T) {
config, _ = NewConfig([]string{})
2019-01-30 16:52:47 +00:00
// Should redirect vanilla request to login url
req := newHttpRequest("/foo")
res, _ := httpRequest(req, nil)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 307 {
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "https" || fwd.Host != "accounts.google.com" || fwd.Path != "/o/oauth2/auth" {
2019-01-30 16:52:47 +00:00
t.Error("Vanilla request should be redirected to login url, got:", fwd)
}
// Should catch invalid cookie
req = newHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
2019-01-30 16:52:47 +00:00
parts := strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = httpRequest(req, c)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
}
// Should validate email
req = newHttpRequest("/foo")
c = MakeCookie(req, "test@example.com")
config.Domains = []string{"test.com"}
2019-01-30 16:52:47 +00:00
res, _ = httpRequest(req, c)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 401 {
t.Error("Request with invalid email shound't be authorised", res.StatusCode)
2019-01-30 16:52:47 +00:00
}
// Should allow valid request email
req = newHttpRequest("/foo")
c = MakeCookie(req, "test@example.com")
config.Domains = []string{}
2019-01-30 16:52:47 +00:00
res, _ = httpRequest(req, c)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 200 {
t.Error("Valid request should be allowed, got:", res.StatusCode)
}
// Should pass through user
users := res.Header["X-Forwarded-User"]
if len(users) != 1 {
t.Error("Valid request missing X-Forwarded-User header")
} else if users[0] != "test@example.com" {
t.Error("X-Forwarded-User should match user, got: ", users)
}
}
func TestServerAuthCallback(t *testing.T) {
config, _ = NewConfig([]string{})
2019-01-30 16:52:47 +00:00
// Setup token server
tokenServerHandler := &TokenServerHandler{}
tokenServer := httptest.NewServer(tokenServerHandler)
defer tokenServer.Close()
tokenUrl, _ := url.Parse(tokenServer.URL)
config.Providers.Google.TokenURL = tokenUrl
// Setup user server
userServerHandler := &UserServerHandler{}
userServer := httptest.NewServer(userServerHandler)
defer userServer.Close()
userUrl, _ := url.Parse(userServer.URL)
config.Providers.Google.UserURL = userUrl
// Should pass auth response request to callback
req := newHttpRequest("/_oauth")
res, _ := httpRequest(req, nil)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 401 {
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
}
// Should catch invalid csrf cookie
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
c := MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = httpRequest(req, c)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 401 {
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
}
// Should redirect valid request
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = httpRequest(req, c)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 307 {
t.Error("Valid callback should be allowed, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
t.Error("Valid request should be redirected to return url, got:", fwd)
}
}
func TestServerDefaultAction(t *testing.T) {
config, _ = NewConfig([]string{})
req := newHttpRequest("/random")
res, _ := httpRequest(req, nil)
if res.StatusCode != 307 {
t.Error("Request should require auth with auth default handler, got:", res.StatusCode)
}
config.DefaultAction = "allow"
req = newHttpRequest("/random")
res, _ = httpRequest(req, nil)
if res.StatusCode != 200 {
t.Error("Request should be allowed with allow default handler, got:", res.StatusCode)
}
}
func TestServerRoutePathPrefix(t *testing.T) {
config, _ = NewConfig([]string{})
config.Rules = map[string]*Rule{
"web1": {
Action: "allow",
Rule: "PathPrefix(`/api`)",
2019-01-30 16:52:47 +00:00
},
}
2019-04-17 11:29:35 +01:00
// Should block any request
req := newHttpRequest("/random")
res, _ := httpRequest(req, nil)
2019-04-17 11:29:35 +01:00
if res.StatusCode != 307 {
t.Error("Request not matching any rule should require auth, got:", res.StatusCode)
}
2019-01-30 16:52:47 +00:00
// Should allow /api request
2019-04-17 11:29:35 +01:00
req = newHttpRequest("/api")
res, _ = httpRequest(req, nil)
2019-01-30 16:52:47 +00:00
if res.StatusCode != 200 {
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
}
}
/**
* Utilities
*/
type TokenServerHandler struct{}
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"access_token":"123456789"}`)
}
type UserServerHandler struct{}
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{
"id":"1",
"email":"example@example.com",
"verified_email":true,
"hd":"example.com"
}`)
}
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
w := httptest.NewRecorder()
// Set cookies on recorder
if c != nil {
http.SetCookie(w, c)
}
// Copy into request
for _, c := range w.HeaderMap["Set-Cookie"] {
r.Header.Add("Cookie", c)
}
NewServer().RootHandler(w, r)
res := w.Result()
body, _ := ioutil.ReadAll(res.Body)
// if res.StatusCode > 300 && res.StatusCode < 400 {
// fmt.Printf("%#v", res.Header)
// }
return res, string(body)
}
func newHttpRequest(uri string) *http.Request {
r := httptest.NewRequest("", "http://example.com/", nil)
r.Header.Add("X-Forwarded-Uri", uri)
return r
}
func qsDiff(t *testing.T, one, two url.Values) []string {
errs := make([]string, 0)
for k := range one {
if two.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
}
if one.Get(k) != two.Get(k) {
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
}
}
for k := range two {
if one.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
}
}
return errs
}