Multiple provider support + OIDC provider

This commit is contained in:
Thom Seddon
2019-09-18 17:55:52 +01:00
parent 5dfd4f2878
commit 5a9c6adedf
16 changed files with 1043 additions and 278 deletions

View File

@ -1,11 +1,13 @@
package tfa
import (
"net/url"
// "fmt"
"os"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -28,34 +30,11 @@ func TestConfigDefaults(t *testing.T) {
assert.Equal("_forward_auth", c.CookieName)
assert.Equal("_forward_auth_csrf", c.CSRFCookieName)
assert.Equal("auth", c.DefaultAction)
assert.Equal("google", c.DefaultProvider)
assert.Len(c.Domains, 0)
assert.Equal(time.Second*time.Duration(43200), c.Lifetime)
assert.Equal("/_oauth", c.Path)
assert.Len(c.Whitelist, 0)
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", c.Providers.Google.Scope)
assert.Equal("", c.Providers.Google.Prompt)
loginURL := &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
assert.Equal(loginURL, c.Providers.Google.LoginURL)
tokenURL := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
assert.Equal(tokenURL, c.Providers.Google.TokenURL)
userURL := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
assert.Equal(userURL, c.Providers.Google.UserURL)
}
func TestConfigParseArgs(t *testing.T) {
@ -63,6 +42,7 @@ func TestConfigParseArgs(t *testing.T) {
c, err := NewConfig([]string{
"--cookie-name=cookiename",
"--csrf-cookie-name", "\"csrfcookiename\"",
"--default-provider", "\"oidc\"",
"--rule.1.action=allow",
"--rule.1.rule=PathPrefix(`/one`)",
"--rule.two.action=auth",
@ -73,18 +53,19 @@ func TestConfigParseArgs(t *testing.T) {
// Check normal flags
assert.Equal("cookiename", c.CookieName)
assert.Equal("csrfcookiename", c.CSRFCookieName)
assert.Equal("oidc", c.DefaultProvider)
// Check rules
assert.Equal(map[string]*Rule{
"1": {
Action: "allow",
Rule: "PathPrefix(`/one`)",
Provider: "google",
Provider: "oidc",
},
"two": {
Action: "auth",
Rule: "Host(`two.com`) && Path(`/two`)",
Provider: "google",
Provider: "oidc",
},
}, c.Rules)
}
@ -157,7 +138,7 @@ func TestConfigFlagBackwardsCompatability(t *testing.T) {
// Google provider params used to be top level
assert.Equal("clientid", c.ClientIdLegacy)
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id")
assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
assert.Equal("verysecret", c.ClientSecretLegacy)
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
assert.Equal("prompt", c.PromptLegacy)
@ -220,7 +201,7 @@ func TestConfigParseEnvironment(t *testing.T) {
assert.Nil(err)
assert.Equal("env_cookie_name", c.CookieName, "variable should be read from environment")
assert.Equal("env_client_id", c.Providers.Google.ClientId, "namespace variable should be read from environment")
assert.Equal("env_client_id", c.Providers.Google.ClientID, "namespace variable should be read from environment")
os.Unsetenv("COOKIE_NAME")
os.Unsetenv("PROVIDERS_GOOGLE_CLIENT_ID")
@ -265,7 +246,7 @@ func TestConfigParseEnvironmentBackwardsCompatability(t *testing.T) {
// Google provider params used to be top level
assert.Equal("clientid", c.ClientIdLegacy)
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id")
assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
assert.Equal("verysecret", c.ClientSecretLegacy)
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
assert.Equal("prompt", c.PromptLegacy)
@ -305,6 +286,92 @@ func TestConfigTransformation(t *testing.T) {
assert.Equal(time.Second*time.Duration(200), c.Lifetime, "lifetime should be read and converted to duration")
}
func TestConfigValidate(t *testing.T) {
assert := assert.New(t)
// Install new logger + hook
var hook *test.Hook
log, hook = test.NewNullLogger()
log.ExitFunc = func(code int) {}
// Validate defualt config + rule error
c, _ := NewConfig([]string{
"--rule.1.action=bad",
})
c.Validate()
logs := hook.AllEntries()
assert.Len(logs, 3)
// Should have fatal error requiring secret
assert.Equal("\"secret\" option must be set", logs[0].Message)
assert.Equal(logrus.FatalLevel, logs[0].Level)
// Should also have default provider (google) error
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", logs[1].Message)
assert.Equal(logrus.FatalLevel, logs[1].Level)
// Should validate rule
assert.Equal("invalid rule action, must be \"auth\" or \"allow\"", logs[2].Message)
assert.Equal(logrus.FatalLevel, logs[2].Level)
hook.Reset()
// Validate with invalid providers
c, _ = NewConfig([]string{
"--secret=veryverysecret",
"--providers.google.client-id=id",
"--providers.google.client-secret=secret",
"--rule.1.action=auth",
"--rule.1.provider=bad2",
})
c.Validate()
logs = hook.AllEntries()
assert.Len(logs, 1)
// Should have error for rule provider
assert.Equal("Unknown provider: bad2", logs[0].Message)
assert.Equal(logrus.FatalLevel, logs[0].Level)
}
func TestConfigGetProvider(t *testing.T) {
assert := assert.New(t)
c, _ := NewConfig([]string{})
// Should be able to get "google" provider
p, err := c.GetProvider("google")
assert.Nil(err)
assert.Equal(&c.Providers.Google, p)
// Should be able to get "oidc" provider
p, err = c.GetProvider("oidc")
assert.Nil(err)
assert.Equal(&c.Providers.OIDC, p)
// Should catch unknown provider
p, err = c.GetProvider("bad")
if assert.Error(err) {
assert.Equal("Unknown provider: bad", err.Error())
}
}
func TestConfigGetConfiguredProvider(t *testing.T) {
assert := assert.New(t)
c, _ := NewConfig([]string{})
// Should be able to get "google" default provider
p, err := c.GetConfiguredProvider("google")
assert.Nil(err)
assert.Equal(&c.Providers.Google, p)
// Should fail to get valid "oidc" provider as it's not configured
p, err = c.GetConfiguredProvider("oidc")
if assert.Error(err) {
assert.Equal("Unconfigured provider: oidc", err.Error())
}
}
func TestConfigCommaSeparatedList(t *testing.T) {
assert := assert.New(t)
list := CommaSeparatedList{}