Overhaul testing to use testify
This commit is contained in:
parent
2074bc7727
commit
93912f4a6e
14
go.mod
14
go.mod
@ -9,23 +9,27 @@ require (
|
||||
github.com/containous/flaeg v1.4.1 // indirect
|
||||
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect
|
||||
github.com/containous/traefik v2.0.0-alpha2+incompatible
|
||||
github.com/go-acme/lego v2.4.0+incompatible // indirect
|
||||
github.com/go-acme/lego v2.5.0+incompatible // indirect
|
||||
github.com/go-kit/kit v0.8.0 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff // indirect
|
||||
github.com/jessevdk/go-flags v1.4.0
|
||||
github.com/jonboulle/clockwork v0.1.0 // indirect
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/kr/pty v1.1.4 // indirect
|
||||
github.com/miekg/dns v1.1.8 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pkg/errors v0.8.1 // indirect
|
||||
github.com/ryanuber/go-glob v1.0.0 // indirect
|
||||
github.com/sirupsen/logrus v1.4.1
|
||||
github.com/stretchr/testify v1.3.0 // indirect
|
||||
github.com/stretchr/objx v0.2.0 // indirect
|
||||
github.com/stretchr/testify v1.3.0
|
||||
github.com/vulcand/predicate v1.1.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a // indirect
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 // indirect
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
|
||||
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd // indirect
|
||||
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 // indirect
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 // indirect
|
||||
)
|
||||
|
12
go.sum
12
go.sum
@ -15,6 +15,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI=
|
||||
github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
|
||||
github.com/go-acme/lego v2.5.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
|
||||
github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
@ -27,9 +28,11 @@ github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0
|
||||
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/pty v1.1.4/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/miekg/dns v1.1.8 h1:1QYRAKU3lN5cRfLCkPU08hwvLJFhvjP6MqNMmQz6ZVI=
|
||||
@ -46,6 +49,7 @@ github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k
|
||||
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
@ -54,15 +58,23 @@ github.com/vulcand/predicate v1.1.0/go.mod h1:mlccC5IRBoc2cIFmCB8ZM62I3VDb6p2GXE
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd h1:sMHc2rZHuzQmrbVoSpt9HgerkXPyIeCSO6k0zUMGfFk=
|
||||
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 h1:HdqqaWmYAUI7/dmByKKEw+yxDksGSo+9GjkUc9Zp34E=
|
||||
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
@ -4,11 +4,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
)
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
*/
|
||||
|
||||
func TestAuthValidateCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
c := &http.Cookie{}
|
||||
@ -24,87 +25,85 @@ func TestAuthValidateCookie(t *testing.T) {
|
||||
// Should require 3 parts
|
||||
c.Value = ""
|
||||
valid, _, err := ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid cookie format", err.Error())
|
||||
}
|
||||
c.Value = "1|2"
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid cookie format", err.Error())
|
||||
}
|
||||
c.Value = "1|2|3|4"
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid cookie format", err.Error())
|
||||
}
|
||||
|
||||
// Should catch invalid mac
|
||||
c.Value = "MQ==|2|3"
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie mac" {
|
||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid cookie mac", err.Error())
|
||||
}
|
||||
|
||||
// Should catch expired
|
||||
config.Lifetime = time.Second * time.Duration(-1)
|
||||
c = MakeCookie(r, "test@test.com")
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Cookie has expired" {
|
||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Cookie has expired", err.Error())
|
||||
}
|
||||
|
||||
// Should accept valid cookie
|
||||
config.Lifetime = time.Second * time.Duration(10)
|
||||
c = MakeCookie(r, "test@test.com")
|
||||
valid, email, err := ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if email != "test@test.com" {
|
||||
t.Error("Valid request should return user email")
|
||||
}
|
||||
assert.True(valid, "valid request should return valid")
|
||||
assert.Nil(err, "valid request should not return an error")
|
||||
assert.Equal("test@test.com", email, "valid request should return user email")
|
||||
}
|
||||
|
||||
func TestAuthValidateEmail(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Should allow any
|
||||
if !ValidateEmail("test@test.com") || !ValidateEmail("one@two.com") {
|
||||
t.Error("Should allow any domain if email domain is not defined")
|
||||
}
|
||||
v := ValidateEmail("test@test.com")
|
||||
assert.True(v, "should allow any domain if email domain is not defined")
|
||||
v = ValidateEmail("one@two.com")
|
||||
assert.True(v, "should allow any domain if email domain is not defined")
|
||||
|
||||
// Should block non matching domain
|
||||
config.Domains = []string{"test.com"}
|
||||
if ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user from another domain")
|
||||
}
|
||||
v = ValidateEmail("one@two.com")
|
||||
assert.False(v, "should not allow user from another domain")
|
||||
|
||||
// Should allow matching domain
|
||||
config.Domains = []string{"test.com"}
|
||||
if !ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user from allowed domain")
|
||||
}
|
||||
v = ValidateEmail("test@test.com")
|
||||
assert.True(v, "should allow user from allowed domain")
|
||||
|
||||
// Should block non whitelisted email address
|
||||
config.Domains = []string{}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
if ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user not in whitelist.")
|
||||
}
|
||||
v = ValidateEmail("one@two.com")
|
||||
assert.False(v, "should not allow user not in whitelist")
|
||||
|
||||
// Should allow matching whitelisted email address
|
||||
config.Domains = []string{}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
if !ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user in whitelist.")
|
||||
}
|
||||
v = ValidateEmail("test@test.com")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
}
|
||||
|
||||
// TODO: Split google tests out
|
||||
func TestAuthGetLoginURL(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
google := provider.Google{
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
@ -127,18 +126,10 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
@ -150,11 +141,7 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
|
||||
//
|
||||
// With Auth URL but no matching cookie domain
|
||||
@ -166,18 +153,10 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
@ -189,11 +168,7 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
|
||||
//
|
||||
// With correct Auth URL + cookie domain
|
||||
@ -205,18 +180,10 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
@ -228,14 +195,7 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
}
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
|
||||
//
|
||||
// With Auth URL + cookie domain, but from different domain
|
||||
@ -248,18 +208,10 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
@ -271,14 +223,7 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
"state": []string{"nonce:http://another.com/hello"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
}
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
}
|
||||
|
||||
// TODO
|
||||
@ -290,68 +235,47 @@ func TestAuthGetLoginURL(t *testing.T) {
|
||||
// }
|
||||
|
||||
func TestAuthMakeCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
|
||||
c := MakeCookie(r, "test@example.com")
|
||||
if c.Name != "_forward_auth" {
|
||||
t.Error("Cookie name should be \"_forward_auth\", got:", c.Name)
|
||||
}
|
||||
assert.Equal("_forward_auth", c.Name)
|
||||
parts := strings.Split(c.Value, "|")
|
||||
if len(parts) != 3 {
|
||||
t.Error("Cookie should be in 3 parts, got:", c.Value)
|
||||
}
|
||||
assert.Len(parts, 3, "cookie should be 3 parts")
|
||||
valid, _, _ := ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Should generate valid cookie:", c.Value)
|
||||
}
|
||||
if c.Path != "/" {
|
||||
t.Error("Cookie path should be \"/\", got:", c.Path)
|
||||
}
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie domain should be \"app.example.com\", got:", c.Domain)
|
||||
}
|
||||
if c.Secure != true {
|
||||
t.Error("Cookie domain should be true, got:", c.Secure)
|
||||
}
|
||||
if !c.Expires.After(time.Now().Local()) {
|
||||
t.Error("Expires should be after now, got:", c.Expires)
|
||||
}
|
||||
if !c.Expires.Before(time.Now().Local().Add(config.Lifetime).Add(10 * time.Second)) {
|
||||
t.Error("Expires should be before lifetime + 10 seconds, got:", c.Expires)
|
||||
}
|
||||
assert.True(valid, "should generate valid cookie")
|
||||
assert.Equal("/", c.Path)
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
assert.True(c.Secure)
|
||||
|
||||
expires := time.Now().Local().Add(config.Lifetime)
|
||||
assert.WithinDuration(expires, c.Expires, 10*time.Second)
|
||||
|
||||
config.CookieName = "testname"
|
||||
config.InsecureCookie = true
|
||||
c = MakeCookie(r, "test@example.com")
|
||||
if c.Name != "testname" {
|
||||
t.Error("Cookie name should be \"testname\", got:", c.Name)
|
||||
}
|
||||
if c.Secure != false {
|
||||
t.Error("Cookie domain should be false, got:", c.Secure)
|
||||
}
|
||||
assert.Equal("testname", c.Name)
|
||||
assert.False(c.Secure)
|
||||
}
|
||||
|
||||
func TestAuthMakeCSRFCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
|
||||
// No cookie domain or auth url
|
||||
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain but no auth url
|
||||
config = Config{
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain and auth url
|
||||
config = Config{
|
||||
@ -359,9 +283,7 @@ func TestAuthMakeCSRFCookie(t *testing.T) {
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
assert.Equal("example.com", c.Domain)
|
||||
}
|
||||
|
||||
func TestAuthClearCSRFCookie(t *testing.T) {
|
||||
@ -375,6 +297,7 @@ func TestAuthClearCSRFCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
c := &http.Cookie{}
|
||||
|
||||
@ -388,103 +311,88 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||
r := newCsrfRequest("")
|
||||
c.Value = ""
|
||||
valid, _, err := ValidateCSRFCookie(r, c)
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, err = ValidateCSRFCookie(r, c)
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
}
|
||||
|
||||
// Should require valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, err = ValidateCSRFCookie(r, c)
|
||||
if valid || err.Error() != "Invalid CSRF state value" {
|
||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state value", err.Error())
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:99")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, state, err := ValidateCSRFCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if state != "99" {
|
||||
t.Error("Valid request should return correct state, got:", state)
|
||||
}
|
||||
assert.True(valid, "valid request should return valid")
|
||||
assert.Nil(err, "valid request should not return an error")
|
||||
assert.Equal("99", state, "valid request should return correct state")
|
||||
}
|
||||
|
||||
func TestAuthNonce(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
err, nonce1 := Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
assert.Nil(err, "error generating nonce")
|
||||
assert.Len(nonce1, 32, "length should be 32 chars")
|
||||
|
||||
err, nonce2 := Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
assert.Nil(err, "error generating nonce")
|
||||
assert.Len(nonce2, 32, "length should be 32 chars")
|
||||
|
||||
if len(nonce1) != 32 || len(nonce2) != 32 {
|
||||
t.Error("Nonce should be 32 chars")
|
||||
}
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("Nonce should not be equal")
|
||||
}
|
||||
assert.NotEqual(nonce1, nonce2, "nonce should not be equal")
|
||||
}
|
||||
|
||||
func TestAuthCookieDomainMatch(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
cd := NewCookieDomain("example.com")
|
||||
|
||||
// Exact should match
|
||||
if !cd.Match("example.com") {
|
||||
t.Error("Exact domain should match")
|
||||
}
|
||||
assert.True(cd.Match("example.com"), "exact domain should match")
|
||||
|
||||
// Subdomain should match
|
||||
if !cd.Match("test.example.com") {
|
||||
t.Error("Subdomain should match")
|
||||
}
|
||||
assert.True(cd.Match("test.example.com"), "subdomain should match")
|
||||
|
||||
// Derived domain should not match
|
||||
if cd.Match("testexample.com") {
|
||||
t.Error("Derived domain should not match")
|
||||
}
|
||||
assert.False(cd.Match("testexample.com"), "derived domain should not match")
|
||||
|
||||
// Other domain should not match
|
||||
if cd.Match("test.com") {
|
||||
t.Error("Other domain should not match")
|
||||
}
|
||||
assert.False(cd.Match("test.com"), "other domain should not match")
|
||||
}
|
||||
|
||||
func TestAuthCookieDomains(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
cds := CookieDomains{}
|
||||
|
||||
err := cds.UnmarshalFlag("one.com,two.org")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(cds) != 2 {
|
||||
t.Error("Expected UnmarshalFlag to provide 2 CookieDomains, got", cds)
|
||||
}
|
||||
if cds[0].Domain != "one.com" || cds[0].SubDomain != ".one.com" {
|
||||
t.Error("Expected UnmarshalFlag to provide one.com, got", cds[0])
|
||||
}
|
||||
if cds[1].Domain != "two.org" || cds[1].SubDomain != ".two.org" {
|
||||
t.Error("Expected UnmarshalFlag to provide two.org, got", cds[1])
|
||||
assert.Nil(err)
|
||||
expected := CookieDomains{
|
||||
CookieDomain{
|
||||
Domain: "one.com",
|
||||
DomainLen: 7,
|
||||
SubDomain: ".one.com",
|
||||
SubDomainLen: 8,
|
||||
},
|
||||
CookieDomain{
|
||||
Domain: "two.org",
|
||||
DomainLen: 7,
|
||||
SubDomain: ".two.org",
|
||||
SubDomainLen: 8,
|
||||
},
|
||||
}
|
||||
assert.Equal(expected, cds)
|
||||
|
||||
marshal, err := cds.MarshalFlag()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if marshal != "one.com,two.org" {
|
||||
t.Error("Expected MarshalFlag to provide \"one.com,two.org\", got", cds)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal("one.com,two.org", marshal)
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ type Config struct {
|
||||
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Comma separated list of email addresses to allow"`
|
||||
|
||||
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
|
||||
Rules map[string]*Rule `long:"rules.<name>.<param>" description:"Rule definitions, see docs, param can be: \"action\", \"rule\""`
|
||||
Rules map[string]*Rule `long:"rules.<name>.<param>" description:"Rule definitions, see docs, param can be: \"action\", \"rule\""`
|
||||
|
||||
// Filled during transformations
|
||||
Secret []byte
|
||||
|
@ -1,10 +1,11 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
/**
|
||||
@ -12,56 +13,29 @@ import (
|
||||
*/
|
||||
|
||||
func TestConfigDefaults(t *testing.T) {
|
||||
// Check defaults
|
||||
assert := assert.New(t)
|
||||
c, err := NewConfig([]string{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Nil(err)
|
||||
|
||||
if c.LogLevel != "warn" {
|
||||
t.Error("LogLevel default should be warn, got", c.LogLevel)
|
||||
}
|
||||
if c.LogFormat != "text" {
|
||||
t.Error("LogFormat default should be text, got", c.LogFormat)
|
||||
}
|
||||
assert.Equal("warn", c.LogLevel)
|
||||
assert.Equal("text", c.LogFormat)
|
||||
|
||||
if c.AuthHost != "" {
|
||||
t.Error("AuthHost default should be empty, got", c.AuthHost)
|
||||
}
|
||||
if len(c.CookieDomains) != 0 {
|
||||
t.Error("CookieDomains default should be empty, got", c.CookieDomains)
|
||||
}
|
||||
if c.InsecureCookie != false {
|
||||
t.Error("InsecureCookie default should be false, got", c.InsecureCookie)
|
||||
}
|
||||
if c.CookieName != "_forward_auth" {
|
||||
t.Error("CookieName default should be _forward_auth, got", c.CookieName)
|
||||
}
|
||||
if c.CSRFCookieName != "_forward_auth_csrf" {
|
||||
t.Error("CSRFCookieName default should be _forward_auth_csrf, got", c.CSRFCookieName)
|
||||
}
|
||||
if c.DefaultAction != "auth" {
|
||||
t.Error("DefaultAction default should be auth, got", c.DefaultAction)
|
||||
}
|
||||
if len(c.Domains) != 0 {
|
||||
t.Error("Domain default should be empty, got", c.Domains)
|
||||
}
|
||||
if c.Lifetime != time.Second*time.Duration(43200) {
|
||||
t.Error("Lifetime default should be 43200, got", c.Lifetime)
|
||||
}
|
||||
if c.Path != "/_oauth" {
|
||||
t.Error("Path default should be /_oauth, got", c.Path)
|
||||
}
|
||||
if len(c.Whitelist) != 0 {
|
||||
t.Error("Whitelist default should be empty, got", c.Whitelist)
|
||||
}
|
||||
assert.Equal("", c.AuthHost)
|
||||
assert.Len(c.CookieDomains, 0)
|
||||
assert.False(c.InsecureCookie)
|
||||
assert.Equal("_forward_auth", c.CookieName)
|
||||
assert.Equal("_forward_auth_csrf", c.CSRFCookieName)
|
||||
assert.Equal("auth", c.DefaultAction)
|
||||
assert.Len(c.Domains, 0)
|
||||
assert.Equal(time.Second*time.Duration(43200), c.Lifetime)
|
||||
assert.Equal("/_oauth", c.Path)
|
||||
assert.Len(c.Whitelist, 0)
|
||||
|
||||
if c.Providers.Google.Prompt != "" {
|
||||
t.Error("Providers.Google.Prompt default should be empty, got", c.Providers.Google.Prompt)
|
||||
}
|
||||
assert.Equal("", c.Providers.Google.Prompt)
|
||||
}
|
||||
|
||||
func TestConfigParseArgs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c, err := NewConfig([]string{
|
||||
"--cookie-name=cookiename",
|
||||
"--csrf-cookie-name", "\"csrfcookiename\"",
|
||||
@ -70,64 +44,38 @@ func TestConfigParseArgs(t *testing.T) {
|
||||
"--rule.two.action=auth",
|
||||
"--rule.two.rule=\"Host(`two.com`) && Path(`/two`)\"",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Nil(err)
|
||||
|
||||
// Check normal flags
|
||||
if c.CookieName != "cookiename" {
|
||||
t.Error("CookieName default should be cookiename, got", c.CookieName)
|
||||
}
|
||||
if c.CSRFCookieName != "csrfcookiename" {
|
||||
t.Error("CSRFCookieName default should be csrfcookiename, got", c.CSRFCookieName)
|
||||
}
|
||||
assert.Equal("cookiename", c.CookieName)
|
||||
assert.Equal("csrfcookiename", c.CSRFCookieName)
|
||||
|
||||
// Check rules
|
||||
if len(c.Rules) != 2 {
|
||||
t.Error("Should create 2 rules, got:", len(c.Rules))
|
||||
}
|
||||
|
||||
// First rule
|
||||
if rule, ok := c.Rules["1"]; !ok {
|
||||
t.Error("Could not find rule key '1'")
|
||||
} else {
|
||||
if rule.Action != "allow" {
|
||||
t.Error("First rule action should be allow, got:", rule.Action)
|
||||
}
|
||||
if rule.Rule != "PathPrefix(`/one`)" {
|
||||
t.Error("First rule rule should be PathPrefix(`/one`), got:", rule.Rule)
|
||||
}
|
||||
if rule.Provider != "google" {
|
||||
t.Error("First rule provider should be google, got:", rule.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Second rule
|
||||
if rule, ok := c.Rules["two"]; !ok {
|
||||
t.Error("Could not find rule key '1'")
|
||||
} else {
|
||||
if rule.Action != "auth" {
|
||||
t.Error("Second rule action should be auth, got:", rule.Action)
|
||||
}
|
||||
if rule.Rule != "Host(`two.com`) && Path(`/two`)" {
|
||||
t.Error("Second rule rule should be Host(`two.com`) && Path(`/two`), got:", rule.Rule)
|
||||
}
|
||||
if rule.Provider != "google" {
|
||||
t.Error("Second rule provider should be google, got:", rule.Provider)
|
||||
}
|
||||
}
|
||||
assert.Equal(map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
Rule: "PathPrefix(`/one`)",
|
||||
Provider: "google",
|
||||
},
|
||||
"two": {
|
||||
Action: "auth",
|
||||
Rule: "Host(`two.com`) && Path(`/two`)",
|
||||
Provider: "google",
|
||||
},
|
||||
}, c.Rules)
|
||||
}
|
||||
|
||||
func TestConfigParseUnknownFlags(t *testing.T) {
|
||||
_, err := NewConfig([]string{
|
||||
"--unknown=_oauthpath2",
|
||||
})
|
||||
if err.Error() != "unknown flag: unknown" {
|
||||
t.Error("Error should be \"unknown flag: unknown\", got:", err)
|
||||
if assert.Error(t, err) {
|
||||
assert.Equal(t, "unknown flag: unknown", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFlagBackwardsCompatability(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c, err := NewConfig([]string{
|
||||
"--client-id=clientid",
|
||||
"--client-secret=verysecret",
|
||||
@ -135,135 +83,87 @@ func TestConfigFlagBackwardsCompatability(t *testing.T) {
|
||||
"--lifetime=200",
|
||||
"--cookie-secure=false",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Nil(err)
|
||||
|
||||
if c.ClientIdLegacy != "clientid" {
|
||||
t.Error("ClientIdLegacy should be clientid, got", c.ClientIdLegacy)
|
||||
}
|
||||
if c.Providers.Google.ClientId != "clientid" {
|
||||
t.Error("Providers.Google.ClientId should be clientid, got", c.Providers.Google.ClientId)
|
||||
}
|
||||
if c.ClientSecretLegacy != "verysecret" {
|
||||
t.Error("ClientSecretLegacy should be verysecret, got", c.ClientSecretLegacy)
|
||||
}
|
||||
if c.Providers.Google.ClientSecret != "verysecret" {
|
||||
t.Error("Providers.Google.ClientSecret should be verysecret, got", c.Providers.Google.ClientSecret)
|
||||
}
|
||||
if c.PromptLegacy != "prompt" {
|
||||
t.Error("PromptLegacy should be prompt, got", c.PromptLegacy)
|
||||
}
|
||||
if c.Providers.Google.Prompt != "prompt" {
|
||||
t.Error("Providers.Google.Prompt should be prompt, got", c.Providers.Google.Prompt)
|
||||
}
|
||||
assert.Equal("clientid", c.ClientIdLegacy)
|
||||
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id")
|
||||
assert.Equal("verysecret", c.ClientSecretLegacy)
|
||||
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
|
||||
assert.Equal("prompt", c.PromptLegacy)
|
||||
assert.Equal("prompt", c.Providers.Google.Prompt, "--prompt should set providers.google.promot")
|
||||
|
||||
// "cookie-secure" used to be a standard go bool flag that could take
|
||||
// true, TRUE, 1, false, FALSE, 0 etc. values.
|
||||
// Here we're checking that format is still suppoted
|
||||
if c.CookieSecureLegacy != "false" || c.InsecureCookie != true {
|
||||
t.Error("Setting cookie-secure=false should set InsecureCookie true, got", c.InsecureCookie)
|
||||
}
|
||||
assert.Equal("false", c.CookieSecureLegacy)
|
||||
assert.True(c.InsecureCookie, "--cookie-secure=false should set insecure-cookie true")
|
||||
|
||||
c, err = NewConfig([]string{"--cookie-secure=TRUE"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if c.CookieSecureLegacy != "TRUE" || c.InsecureCookie != false {
|
||||
t.Error("Setting cookie-secure=TRUE should set InsecureCookie false, got", c.InsecureCookie)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal("TRUE", c.CookieSecureLegacy)
|
||||
assert.False(c.InsecureCookie, "--cookie-secure=TRUE should set insecure-cookie false")
|
||||
}
|
||||
|
||||
func TestConfigParseIni(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c, err := NewConfig([]string{
|
||||
"--config=../test/config0",
|
||||
"--config=../test/config1",
|
||||
"--csrf-cookie-name=csrfcookiename",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Nil(err)
|
||||
|
||||
if c.CookieName != "inicookiename" {
|
||||
t.Error("CookieName should be read as inicookiename from ini file, got", c.CookieName)
|
||||
}
|
||||
if c.CSRFCookieName != "csrfcookiename" {
|
||||
t.Error("CSRFCookieName argument should override ini file, got", c.CSRFCookieName)
|
||||
}
|
||||
if c.Path != "/two" {
|
||||
t.Error("Path in second ini file should override first ini file, got", c.Path)
|
||||
}
|
||||
assert.Equal("inicookiename", c.CookieName, "should be read from ini file")
|
||||
assert.Equal("csrfcookiename", c.CSRFCookieName, "should be read from ini file")
|
||||
assert.Equal("/two", c.Path, "variable in second ini file should override first ini file")
|
||||
}
|
||||
|
||||
func TestConfigFileBackwardsCompatability(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c, err := NewConfig([]string{
|
||||
"--config=../test/config-legacy",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Nil(err)
|
||||
|
||||
if c.Path != "/two" {
|
||||
t.Error("Path in legacy config file should be read, got", c.Path)
|
||||
}
|
||||
assert.Equal("/two", c.Path, "Variable in legacy config file should be read")
|
||||
}
|
||||
|
||||
func TestConfigParseEnvironment(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
os.Setenv("COOKIE_NAME", "env_cookie_name")
|
||||
c, err := NewConfig([]string{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Nil(err)
|
||||
|
||||
if c.CookieName != "env_cookie_name" {
|
||||
t.Error("CookieName should be read as env_cookie_name from environment, got", c.CookieName)
|
||||
}
|
||||
assert.Equal("env_cookie_name", c.CookieName, "variable should be read from environment")
|
||||
}
|
||||
|
||||
func TestConfigTransformation(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c, err := NewConfig([]string{
|
||||
"--url-path=_oauthpath",
|
||||
"--secret=verysecret",
|
||||
"--lifetime=200",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Nil(err)
|
||||
|
||||
if c.Path != "/_oauthpath" {
|
||||
t.Error("Path should add slash to front to get /_oauthpath, got:", c.Path)
|
||||
}
|
||||
assert.Equal("/_oauthpath", c.Path, "path should add slash to front")
|
||||
|
||||
if c.SecretString != "verysecret" {
|
||||
t.Error("SecretString should be verysecret, got:", c.SecretString)
|
||||
}
|
||||
if bytes.Compare(c.Secret, []byte("verysecret")) != 0 {
|
||||
t.Error("Secret should be []byte(verysecret), got:", string(c.Secret))
|
||||
}
|
||||
assert.Equal("verysecret", c.SecretString)
|
||||
assert.Equal([]byte("verysecret"), c.Secret, "secret should be converted to byte array")
|
||||
|
||||
if c.LifetimeString != 200 {
|
||||
t.Error("LifetimeString should be 200, got:", c.LifetimeString)
|
||||
}
|
||||
if c.Lifetime != time.Second*time.Duration(200) {
|
||||
t.Error("Lifetime default should be 200, got", c.Lifetime)
|
||||
}
|
||||
assert.Equal(200, c.LifetimeString)
|
||||
assert.Equal(time.Second*time.Duration(200), c.Lifetime, "lifetime should be read and converted to duration")
|
||||
}
|
||||
|
||||
func TestConfigCommaSeparatedList(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
list := CommaSeparatedList{}
|
||||
|
||||
err := list.UnmarshalFlag("one,two")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(list) != 2 || list[0] != "one" || list[1] != "two" {
|
||||
t.Error("Expected UnmarshalFlag to provide CommaSeparatedList{one,two}, got", list)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal(CommaSeparatedList{"one", "two"}, list, "should parse comma sepearated list")
|
||||
|
||||
marshal, err := list.MarshalFlag()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if marshal != "one,two" {
|
||||
t.Error("Expected MarshalFlag to provide \"one,two\", got", list)
|
||||
}
|
||||
assert.Nil(err)
|
||||
assert.Equal("one,two", marshal, "should marshal back to comma sepearated list")
|
||||
}
|
||||
|
@ -85,6 +85,7 @@ func (s *Server) AuthHandler() http.HandlerFunc {
|
||||
// Forward them on
|
||||
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||
|
||||
logger.Debug("Done")
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -8,8 +8,12 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TODO:
|
||||
|
||||
/**
|
||||
* Setup
|
||||
*/
|
||||
@ -24,19 +28,18 @@ func init() {
|
||||
*/
|
||||
|
||||
func TestServerAuthHandler(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Should redirect vanilla request to login url
|
||||
req := newHttpRequest("/foo")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(307, res.StatusCode, "vanilla request should be redirected")
|
||||
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "https" || fwd.Host != "accounts.google.com" || fwd.Path != "/o/oauth2/auth" {
|
||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||
}
|
||||
assert.Equal("https", fwd.Scheme, "vanilla request should be redirected to google")
|
||||
assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
|
||||
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
|
||||
|
||||
// Should catch invalid cookie
|
||||
req = newHttpRequest("/foo")
|
||||
@ -44,42 +47,33 @@ func TestServerAuthHandler(t *testing.T) {
|
||||
parts := strings.Split(c.Value, "|")
|
||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(401, res.StatusCode, "invalid cookie should not be authorised")
|
||||
|
||||
// Should validate email
|
||||
req = newHttpRequest("/foo")
|
||||
c = MakeCookie(req, "test@example.com")
|
||||
config.Domains = []string{"test.com"}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid email shound't be authorised", res.StatusCode)
|
||||
}
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(401, res.StatusCode, "invalid email should not be authorised")
|
||||
|
||||
// Should allow valid request email
|
||||
req = newHttpRequest("/foo")
|
||||
|
||||
c = MakeCookie(req, "test@example.com")
|
||||
config.Domains = []string{}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(200, res.StatusCode, "valid request should be allowed")
|
||||
|
||||
// Should pass through user
|
||||
users := res.Header["X-Forwarded-User"]
|
||||
if len(users) != 1 {
|
||||
t.Error("Valid request missing X-Forwarded-User header")
|
||||
} else if users[0] != "test@example.com" {
|
||||
t.Error("X-Forwarded-User should match user, got: ", users)
|
||||
}
|
||||
assert.Len(users, 1, "valid request should have X-Forwarded-User header")
|
||||
assert.Equal([]string{"test@example.com"}, users, "X-Forwarded-User header should match user")
|
||||
}
|
||||
|
||||
func TestServerAuthCallback(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Setup token server
|
||||
@ -98,50 +92,43 @@ func TestServerAuthCallback(t *testing.T) {
|
||||
|
||||
// Should pass auth response request to callback
|
||||
req := newHttpRequest("/_oauth")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
|
||||
|
||||
// Should catch invalid csrf cookie
|
||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
|
||||
|
||||
// Should redirect valid request
|
||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
||||
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||
}
|
||||
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
|
||||
assert.Equal("redirect", fwd.Host, "valid request should be redirected to return url")
|
||||
assert.Equal("", fwd.Path, "valid request should be redirected to return url")
|
||||
}
|
||||
|
||||
func TestServerDefaultAction(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
req := newHttpRequest("/random")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Request should require auth with auth default handler, got:", res.StatusCode)
|
||||
}
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(307, res.StatusCode, "request should require auth with auth default handler")
|
||||
|
||||
config.DefaultAction = "allow"
|
||||
req = newHttpRequest("/random")
|
||||
res, _ = httpRequest(req, nil)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Request should be allowed with allow default handler, got:", res.StatusCode)
|
||||
}
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
|
||||
}
|
||||
|
||||
func TestServerRoutePathPrefix(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config.Rules = map[string]*Rule{
|
||||
"web1": {
|
||||
@ -152,17 +139,13 @@ func TestServerRoutePathPrefix(t *testing.T) {
|
||||
|
||||
// Should block any request
|
||||
req := newHttpRequest("/random")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Request not matching any rule should require auth, got:", res.StatusCode)
|
||||
}
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||
|
||||
// Should allow /api request
|
||||
req = newHttpRequest("/api")
|
||||
res, _ = httpRequest(req, nil)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||
}
|
||||
|
||||
/**
|
||||
@ -186,7 +169,7 @@ func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}`)
|
||||
}
|
||||
|
||||
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Set cookies on recorder
|
||||
@ -199,7 +182,6 @@ func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
r.Header.Add("Cookie", c)
|
||||
}
|
||||
|
||||
|
||||
NewServer().RootHandler(w, r)
|
||||
|
||||
res := w.Result()
|
||||
|
Loading…
x
Reference in New Issue
Block a user