Multiple provider support + OIDC provider
This commit is contained in:
@ -11,15 +11,16 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// TODO:
|
||||
|
||||
/**
|
||||
* Setup
|
||||
*/
|
||||
|
||||
func init() {
|
||||
config = newDefaultConfig()
|
||||
config.LogLevel = "panic"
|
||||
log = NewDefaultLogger()
|
||||
}
|
||||
@ -30,7 +31,7 @@ func init() {
|
||||
|
||||
func TestServerAuthHandlerInvalid(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Should redirect vanilla request to login url
|
||||
req := newDefaultHttpRequest("/foo")
|
||||
@ -42,10 +43,20 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
|
||||
assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
|
||||
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
|
||||
|
||||
// Check state string
|
||||
qs := fwd.Query()
|
||||
state, exists := qs["state"]
|
||||
require.True(t, exists)
|
||||
require.Len(t, state, 1)
|
||||
parts := strings.SplitN(state[0], ":", 3)
|
||||
require.Len(t, parts, 3)
|
||||
assert.Equal("google", parts[1])
|
||||
assert.Equal("http://example.com/foo", parts[2])
|
||||
|
||||
// Should catch invalid cookie
|
||||
req = newDefaultHttpRequest("/foo")
|
||||
c := MakeCookie(req, "test@example.com")
|
||||
parts := strings.Split(c.Value, "|")
|
||||
parts = strings.Split(c.Value, "|")
|
||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||
|
||||
res, _ = doHttpRequest(req, c)
|
||||
@ -62,7 +73,7 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
|
||||
|
||||
func TestServerAuthHandlerExpired(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Lifetime = time.Second * time.Duration(-1)
|
||||
config.Domains = []string{"test.com"}
|
||||
|
||||
@ -90,7 +101,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
|
||||
|
||||
func TestServerAuthHandlerValid(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Should allow valid request email
|
||||
req := newDefaultHttpRequest("/foo")
|
||||
@ -108,7 +119,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
|
||||
|
||||
func TestServerAuthCallback(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Setup token server
|
||||
tokenServerHandler := &TokenServerHandler{}
|
||||
@ -136,7 +147,7 @@ func TestServerAuthCallback(t *testing.T) {
|
||||
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
|
||||
|
||||
// Should redirect valid request
|
||||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
|
||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
||||
@ -149,7 +160,7 @@ func TestServerAuthCallback(t *testing.T) {
|
||||
|
||||
func TestServerDefaultAction(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
req := newDefaultHttpRequest("/random")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
@ -161,9 +172,36 @@ func TestServerDefaultAction(t *testing.T) {
|
||||
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
|
||||
}
|
||||
|
||||
func TestServerDefaultProvider(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Should use "google" as default provider when not specified
|
||||
req := newDefaultHttpRequest("/random")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
fwd, _ := res.Location()
|
||||
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
|
||||
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
|
||||
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
|
||||
|
||||
// Should use alternative default provider when set
|
||||
config.DefaultProvider = "oidc"
|
||||
config.Providers.OIDC.OAuthProvider.Config = &oauth2.Config{
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oidc.com/oidcauth",
|
||||
},
|
||||
}
|
||||
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
fwd, _ = res.Location()
|
||||
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to oidc")
|
||||
assert.Equal("oidc.com", fwd.Host, "request with expired cookie should be redirected to oidc")
|
||||
assert.Equal("/oidcauth", fwd.Path, "request with expired cookie should be redirected to oidc")
|
||||
}
|
||||
|
||||
func TestServerRouteHeaders(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -196,7 +234,7 @@ func TestServerRouteHeaders(t *testing.T) {
|
||||
|
||||
func TestServerRouteHost(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -226,7 +264,7 @@ func TestServerRouteHost(t *testing.T) {
|
||||
|
||||
func TestServerRouteMethod(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -247,7 +285,7 @@ func TestServerRouteMethod(t *testing.T) {
|
||||
|
||||
func TestServerRoutePath(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -281,7 +319,7 @@ func TestServerRoutePath(t *testing.T) {
|
||||
|
||||
func TestServerRouteQuery(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -346,6 +384,18 @@ func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
return res, string(body)
|
||||
}
|
||||
|
||||
func newDefaultConfig() *Config {
|
||||
config, _ = NewConfig([]string{
|
||||
"--providers.google.client-id=id",
|
||||
"--providers.google.client-secret=secret",
|
||||
})
|
||||
|
||||
// Setup the google providers without running all the config validation
|
||||
config.Providers.Google.Setup()
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func newDefaultHttpRequest(uri string) *http.Request {
|
||||
return newHttpRequest("", "http://example.com/", uri)
|
||||
}
|
||||
@ -354,25 +404,8 @@ func newHttpRequest(method, dest, uri string) *http.Request {
|
||||
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
|
||||
p, _ := url.Parse(dest)
|
||||
r.Header.Add("X-Forwarded-Method", method)
|
||||
r.Header.Add("X-Forwarded-Proto", p.Scheme)
|
||||
r.Header.Add("X-Forwarded-Host", p.Host)
|
||||
r.Header.Add("X-Forwarded-Uri", uri)
|
||||
return r
|
||||
}
|
||||
|
||||
func qsDiff(t *testing.T, one, two url.Values) []string {
|
||||
errs := make([]string, 0)
|
||||
for k := range one {
|
||||
if two.Get(k) == "" {
|
||||
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
|
||||
}
|
||||
if one.Get(k) != two.Get(k) {
|
||||
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
|
||||
}
|
||||
}
|
||||
for k := range two {
|
||||
if one.Get(k) == "" {
|
||||
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
Reference in New Issue
Block a user