Multiple provider support + OIDC provider

This commit is contained in:
Thom Seddon
2019-09-18 17:55:52 +01:00
parent 5dfd4f2878
commit 5a9c6adedf
16 changed files with 1043 additions and 278 deletions

View File

@ -11,15 +11,16 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
// TODO:
/**
* Setup
*/
func init() {
config = newDefaultConfig()
config.LogLevel = "panic"
log = NewDefaultLogger()
}
@ -30,7 +31,7 @@ func init() {
func TestServerAuthHandlerInvalid(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
// Should redirect vanilla request to login url
req := newDefaultHttpRequest("/foo")
@ -42,10 +43,20 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
// Check state string
qs := fwd.Query()
state, exists := qs["state"]
require.True(t, exists)
require.Len(t, state, 1)
parts := strings.SplitN(state[0], ":", 3)
require.Len(t, parts, 3)
assert.Equal("google", parts[1])
assert.Equal("http://example.com/foo", parts[2])
// Should catch invalid cookie
req = newDefaultHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|")
parts = strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = doHttpRequest(req, c)
@ -62,7 +73,7 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
func TestServerAuthHandlerExpired(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Lifetime = time.Second * time.Duration(-1)
config.Domains = []string{"test.com"}
@ -90,7 +101,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
func TestServerAuthHandlerValid(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
// Should allow valid request email
req := newDefaultHttpRequest("/foo")
@ -108,7 +119,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
func TestServerAuthCallback(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
// Setup token server
tokenServerHandler := &TokenServerHandler{}
@ -136,7 +147,7 @@ func TestServerAuthCallback(t *testing.T) {
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
// Should redirect valid request
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
@ -149,7 +160,7 @@ func TestServerAuthCallback(t *testing.T) {
func TestServerDefaultAction(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
req := newDefaultHttpRequest("/random")
res, _ := doHttpRequest(req, nil)
@ -161,9 +172,36 @@ func TestServerDefaultAction(t *testing.T) {
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
}
func TestServerDefaultProvider(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
// Should use "google" as default provider when not specified
req := newDefaultHttpRequest("/random")
res, _ := doHttpRequest(req, nil)
fwd, _ := res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
// Should use alternative default provider when set
config.DefaultProvider = "oidc"
config.Providers.OIDC.OAuthProvider.Config = &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: "https://oidc.com/oidcauth",
},
}
res, _ = doHttpRequest(req, nil)
fwd, _ = res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to oidc")
assert.Equal("oidc.com", fwd.Host, "request with expired cookie should be redirected to oidc")
assert.Equal("/oidcauth", fwd.Path, "request with expired cookie should be redirected to oidc")
}
func TestServerRouteHeaders(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -196,7 +234,7 @@ func TestServerRouteHeaders(t *testing.T) {
func TestServerRouteHost(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -226,7 +264,7 @@ func TestServerRouteHost(t *testing.T) {
func TestServerRouteMethod(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -247,7 +285,7 @@ func TestServerRouteMethod(t *testing.T) {
func TestServerRoutePath(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -281,7 +319,7 @@ func TestServerRoutePath(t *testing.T) {
func TestServerRouteQuery(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -346,6 +384,18 @@ func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
return res, string(body)
}
func newDefaultConfig() *Config {
config, _ = NewConfig([]string{
"--providers.google.client-id=id",
"--providers.google.client-secret=secret",
})
// Setup the google providers without running all the config validation
config.Providers.Google.Setup()
return config
}
func newDefaultHttpRequest(uri string) *http.Request {
return newHttpRequest("", "http://example.com/", uri)
}
@ -354,25 +404,8 @@ func newHttpRequest(method, dest, uri string) *http.Request {
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
p, _ := url.Parse(dest)
r.Header.Add("X-Forwarded-Method", method)
r.Header.Add("X-Forwarded-Proto", p.Scheme)
r.Header.Add("X-Forwarded-Host", p.Host)
r.Header.Add("X-Forwarded-Uri", uri)
return r
}
func qsDiff(t *testing.T, one, two url.Values) []string {
errs := make([]string, 0)
for k := range one {
if two.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
}
if one.Get(k) != two.Get(k) {
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
}
}
for k := range two {
if one.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
}
}
return errs
}