
Prior to this change, the request URI was only ever read from the X-Forwarded-Uri header which was only set when the container was accessed via the forwardauth middleware. As such, it was necessary to apply the treafik-forward-auth middleware to the treafik-forward-auth container when running auth host mode. This is a quirk, unnecessary complexity and is a frequent source of configuration issues.
455 lines
14 KiB
Go
455 lines
14 KiB
Go
package tfa
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
|
)
|
|
|
|
/**
|
|
* Tests
|
|
*/
|
|
|
|
func TestAuthValidateCookie(t *testing.T) {
|
|
assert := assert.New(t)
|
|
config, _ = NewConfig([]string{})
|
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
c := &http.Cookie{}
|
|
|
|
// Should require 3 parts
|
|
c.Value = ""
|
|
_, err := ValidateCookie(r, c)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid cookie format", err.Error())
|
|
}
|
|
c.Value = "1|2"
|
|
_, err = ValidateCookie(r, c)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid cookie format", err.Error())
|
|
}
|
|
c.Value = "1|2|3|4"
|
|
_, err = ValidateCookie(r, c)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid cookie format", err.Error())
|
|
}
|
|
|
|
// Should catch invalid mac
|
|
c.Value = "MQ==|2|3"
|
|
_, err = ValidateCookie(r, c)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid cookie mac", err.Error())
|
|
}
|
|
|
|
// Should catch expired
|
|
config.Lifetime = time.Second * time.Duration(-1)
|
|
c = MakeCookie(r, "test@test.com")
|
|
_, err = ValidateCookie(r, c)
|
|
if assert.Error(err) {
|
|
assert.Equal("Cookie has expired", err.Error())
|
|
}
|
|
|
|
// Should accept valid cookie
|
|
config.Lifetime = time.Second * time.Duration(10)
|
|
c = MakeCookie(r, "test@test.com")
|
|
email, err := ValidateCookie(r, c)
|
|
assert.Nil(err, "valid request should not return an error")
|
|
assert.Equal("test@test.com", email, "valid request should return user email")
|
|
}
|
|
|
|
func TestAuthValidateEmail(t *testing.T) {
|
|
assert := assert.New(t)
|
|
config, _ = NewConfig([]string{})
|
|
|
|
// Should allow any with no whitelist/domain is specified
|
|
v := ValidateEmail("test@test.com", "default")
|
|
assert.True(v, "should allow any domain if email domain is not defined")
|
|
v = ValidateEmail("one@two.com", "default")
|
|
assert.True(v, "should allow any domain if email domain is not defined")
|
|
|
|
// Should allow matching domain
|
|
config.Domains = []string{"test.com"}
|
|
v = ValidateEmail("one@two.com", "default")
|
|
assert.False(v, "should not allow user from another domain")
|
|
v = ValidateEmail("test@test.com", "default")
|
|
assert.True(v, "should allow user from allowed domain")
|
|
|
|
// Should allow matching whitelisted email address
|
|
config.Domains = []string{}
|
|
config.Whitelist = []string{"test@test.com"}
|
|
v = ValidateEmail("one@two.com", "default")
|
|
assert.False(v, "should not allow user not in whitelist")
|
|
v = ValidateEmail("test@test.com", "default")
|
|
assert.True(v, "should allow user in whitelist")
|
|
|
|
// Should allow only matching email address when
|
|
// MatchWhitelistOrDomain is disabled
|
|
config.Domains = []string{"example.com"}
|
|
config.Whitelist = []string{"test@test.com"}
|
|
config.MatchWhitelistOrDomain = false
|
|
v = ValidateEmail("one@two.com", "default")
|
|
assert.False(v, "should not allow user not in either")
|
|
v = ValidateEmail("test@example.com", "default")
|
|
assert.False(v, "should not allow user from allowed domain")
|
|
v = ValidateEmail("test@test.com", "default")
|
|
assert.True(v, "should allow user in whitelist")
|
|
|
|
// Should allow either matching domain or email address when
|
|
// MatchWhitelistOrDomain is enabled
|
|
config.Domains = []string{"example.com"}
|
|
config.Whitelist = []string{"test@test.com"}
|
|
config.MatchWhitelistOrDomain = true
|
|
v = ValidateEmail("one@two.com", "default")
|
|
assert.False(v, "should not allow user not in either")
|
|
v = ValidateEmail("test@example.com", "default")
|
|
assert.True(v, "should allow user from allowed domain")
|
|
v = ValidateEmail("test@test.com", "default")
|
|
assert.True(v, "should allow user in whitelist")
|
|
|
|
// Rule testing
|
|
|
|
// Should use global whitelist/domain when not specified on rule
|
|
config.Domains = []string{"example.com"}
|
|
config.Whitelist = []string{"test@test.com"}
|
|
config.Rules = map[string]*Rule{"test": NewRule()}
|
|
config.MatchWhitelistOrDomain = true
|
|
v = ValidateEmail("one@two.com", "test")
|
|
assert.False(v, "should not allow user not in either")
|
|
v = ValidateEmail("test@example.com", "test")
|
|
assert.True(v, "should allow user from allowed global domain")
|
|
v = ValidateEmail("test@test.com", "test")
|
|
assert.True(v, "should allow user in global whitelist")
|
|
|
|
// Should allow matching domain in rule
|
|
config.Domains = []string{"testglobal.com"}
|
|
config.Whitelist = []string{}
|
|
rule := NewRule()
|
|
config.Rules = map[string]*Rule{"test": rule}
|
|
rule.Domains = []string{"testrule.com"}
|
|
config.MatchWhitelistOrDomain = false
|
|
v = ValidateEmail("one@two.com", "test")
|
|
assert.False(v, "should not allow user from another domain")
|
|
v = ValidateEmail("one@testglobal.com", "test")
|
|
assert.False(v, "should not allow user from global domain")
|
|
v = ValidateEmail("test@testrule.com", "test")
|
|
assert.True(v, "should allow user from allowed domain")
|
|
|
|
// Should allow matching whitelist in rule
|
|
config.Domains = []string{}
|
|
config.Whitelist = []string{"test@testglobal.com"}
|
|
rule = NewRule()
|
|
config.Rules = map[string]*Rule{"test": rule}
|
|
rule.Whitelist = []string{"test@testrule.com"}
|
|
config.MatchWhitelistOrDomain = false
|
|
v = ValidateEmail("one@two.com", "test")
|
|
assert.False(v, "should not allow user from another domain")
|
|
v = ValidateEmail("test@testglobal.com", "test")
|
|
assert.False(v, "should not allow user from global domain")
|
|
v = ValidateEmail("test@testrule.com", "test")
|
|
assert.True(v, "should allow user from allowed domain")
|
|
|
|
// Should allow only matching email address when
|
|
// MatchWhitelistOrDomain is disabled
|
|
config.Domains = []string{"exampleglobal.com"}
|
|
config.Whitelist = []string{"test@testglobal.com"}
|
|
rule = NewRule()
|
|
config.Rules = map[string]*Rule{"test": rule}
|
|
rule.Domains = []string{"examplerule.com"}
|
|
rule.Whitelist = []string{"test@testrule.com"}
|
|
config.MatchWhitelistOrDomain = false
|
|
v = ValidateEmail("one@two.com", "test")
|
|
assert.False(v, "should not allow user not in either")
|
|
v = ValidateEmail("test@testglobal.com", "test")
|
|
assert.False(v, "should not allow user in global whitelist")
|
|
v = ValidateEmail("test@exampleglobal.com", "test")
|
|
assert.False(v, "should not allow user from global domain")
|
|
v = ValidateEmail("test@examplerule.com", "test")
|
|
assert.False(v, "should not allow user from allowed domain")
|
|
v = ValidateEmail("test@testrule.com", "test")
|
|
assert.True(v, "should allow user in whitelist")
|
|
|
|
// Should allow either matching domain or email address when
|
|
// MatchWhitelistOrDomain is enabled
|
|
config.Domains = []string{"exampleglobal.com"}
|
|
config.Whitelist = []string{"test@testglobal.com"}
|
|
rule = NewRule()
|
|
config.Rules = map[string]*Rule{"test": rule}
|
|
rule.Domains = []string{"examplerule.com"}
|
|
rule.Whitelist = []string{"test@testrule.com"}
|
|
config.MatchWhitelistOrDomain = true
|
|
v = ValidateEmail("one@two.com", "test")
|
|
assert.False(v, "should not allow user not in either")
|
|
v = ValidateEmail("test@testglobal.com", "test")
|
|
assert.False(v, "should not allow user in global whitelist")
|
|
v = ValidateEmail("test@exampleglobal.com", "test")
|
|
assert.False(v, "should not allow user from global domain")
|
|
v = ValidateEmail("test@examplerule.com", "test")
|
|
assert.True(v, "should allow user from allowed domain")
|
|
v = ValidateEmail("test@testrule.com", "test")
|
|
assert.True(v, "should allow user in whitelist")
|
|
}
|
|
|
|
func TestRedirectUri(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
r := httptest.NewRequest("GET", "http://app.example.com/hello", nil)
|
|
r.Header.Add("X-Forwarded-Proto", "http")
|
|
|
|
//
|
|
// No Auth Host
|
|
//
|
|
config, _ = NewConfig([]string{})
|
|
|
|
uri, err := url.Parse(redirectUri(r))
|
|
assert.Nil(err)
|
|
assert.Equal("http", uri.Scheme)
|
|
assert.Equal("app.example.com", uri.Host)
|
|
assert.Equal("/_oauth", uri.Path)
|
|
|
|
//
|
|
// With Auth URL but no matching cookie domain
|
|
// - will not use auth host
|
|
//
|
|
config.AuthHost = "auth.example.com"
|
|
|
|
uri, err = url.Parse(redirectUri(r))
|
|
assert.Nil(err)
|
|
assert.Equal("http", uri.Scheme)
|
|
assert.Equal("app.example.com", uri.Host)
|
|
assert.Equal("/_oauth", uri.Path)
|
|
|
|
//
|
|
// With correct Auth URL + cookie domain
|
|
//
|
|
config.AuthHost = "auth.example.com"
|
|
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
|
|
|
// Check url
|
|
uri, err = url.Parse(redirectUri(r))
|
|
assert.Nil(err)
|
|
assert.Equal("http", uri.Scheme)
|
|
assert.Equal("auth.example.com", uri.Host)
|
|
assert.Equal("/_oauth", uri.Path)
|
|
|
|
//
|
|
// With Auth URL + cookie domain, but from different domain
|
|
// - will not use auth host
|
|
//
|
|
r = httptest.NewRequest("GET", "https://another.com/hello", nil)
|
|
r.Header.Add("X-Forwarded-Proto", "https")
|
|
|
|
config.AuthHost = "auth.example.com"
|
|
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
|
|
|
// Check url
|
|
uri, err = url.Parse(redirectUri(r))
|
|
assert.Nil(err)
|
|
assert.Equal("https", uri.Scheme)
|
|
assert.Equal("another.com", uri.Host)
|
|
assert.Equal("/_oauth", uri.Path)
|
|
}
|
|
|
|
func TestAuthMakeCookie(t *testing.T) {
|
|
assert := assert.New(t)
|
|
config, _ = NewConfig([]string{})
|
|
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
|
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
|
|
|
c := MakeCookie(r, "test@example.com")
|
|
assert.Equal("_forward_auth", c.Name)
|
|
parts := strings.Split(c.Value, "|")
|
|
assert.Len(parts, 3, "cookie should be 3 parts")
|
|
_, err := ValidateCookie(r, c)
|
|
assert.Nil(err, "should generate valid cookie")
|
|
assert.Equal("/", c.Path)
|
|
assert.Equal("app.example.com", c.Domain)
|
|
assert.True(c.Secure)
|
|
|
|
expires := time.Now().Local().Add(config.Lifetime)
|
|
assert.WithinDuration(expires, c.Expires, 10*time.Second)
|
|
|
|
config.CookieName = "testname"
|
|
config.InsecureCookie = true
|
|
c = MakeCookie(r, "test@example.com")
|
|
assert.Equal("testname", c.Name)
|
|
assert.False(c.Secure)
|
|
}
|
|
|
|
func TestAuthMakeCSRFCookie(t *testing.T) {
|
|
assert := assert.New(t)
|
|
config, _ = NewConfig([]string{})
|
|
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
|
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
|
|
|
// No cookie domain or auth url
|
|
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
|
|
assert.Equal("_forward_auth_csrf_123456", c.Name)
|
|
assert.Equal("app.example.com", c.Domain)
|
|
|
|
// With cookie domain but no auth url
|
|
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
|
c = MakeCSRFCookie(r, "12222278901234567890123456789012")
|
|
assert.Equal("_forward_auth_csrf_122222", c.Name)
|
|
assert.Equal("app.example.com", c.Domain)
|
|
|
|
// With cookie domain and auth url
|
|
config.AuthHost = "auth.example.com"
|
|
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
|
c = MakeCSRFCookie(r, "12333378901234567890123456789012")
|
|
assert.Equal("_forward_auth_csrf_123333", c.Name)
|
|
assert.Equal("example.com", c.Domain)
|
|
}
|
|
|
|
func TestAuthClearCSRFCookie(t *testing.T) {
|
|
assert := assert.New(t)
|
|
config, _ = NewConfig([]string{})
|
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
|
|
c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"})
|
|
assert.Equal("someCsrfCookie", c.Name)
|
|
if c.Value != "" {
|
|
t.Error("ClearCSRFCookie should create cookie with empty value")
|
|
}
|
|
}
|
|
|
|
func TestAuthValidateCSRFCookie(t *testing.T) {
|
|
assert := assert.New(t)
|
|
config, _ = NewConfig([]string{})
|
|
c := &http.Cookie{}
|
|
state := ""
|
|
|
|
// Should require 32 char string
|
|
state = ""
|
|
c.Value = ""
|
|
valid, _, _, err := ValidateCSRFCookie(c, state)
|
|
assert.False(valid)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid CSRF cookie value", err.Error())
|
|
}
|
|
c.Value = "123456789012345678901234567890123"
|
|
valid, _, _, err = ValidateCSRFCookie(c, state)
|
|
assert.False(valid)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid CSRF cookie value", err.Error())
|
|
}
|
|
|
|
// Should require provider
|
|
state = "12345678901234567890123456789012:99"
|
|
c.Value = "12345678901234567890123456789012"
|
|
valid, _, _, err = ValidateCSRFCookie(c, state)
|
|
assert.False(valid)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid CSRF state format", err.Error())
|
|
}
|
|
|
|
// Should allow valid state
|
|
state = "12345678901234567890123456789012:p99:url123"
|
|
c.Value = "12345678901234567890123456789012"
|
|
valid, provider, redirect, err := ValidateCSRFCookie(c, state)
|
|
assert.True(valid, "valid request should return valid")
|
|
assert.Nil(err, "valid request should not return an error")
|
|
assert.Equal("p99", provider, "valid request should return correct provider")
|
|
assert.Equal("url123", redirect, "valid request should return correct redirect")
|
|
}
|
|
|
|
func TestValidateState(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
// Should require valid state
|
|
state := "12345678901234567890123456789012:"
|
|
err := ValidateState(state)
|
|
if assert.Error(err) {
|
|
assert.Equal("Invalid CSRF state value", err.Error())
|
|
}
|
|
// Should pass this state
|
|
state = "12345678901234567890123456789012:p99:url123"
|
|
err = ValidateState(state)
|
|
assert.Nil(err, "valid request should not return an error")
|
|
}
|
|
|
|
func TestMakeState(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
r := httptest.NewRequest("GET", "http://example.com/hello", nil)
|
|
r.Header.Add("X-Forwarded-Proto", "http")
|
|
|
|
// Test with google
|
|
p := provider.Google{}
|
|
state := MakeState(r, &p, "nonce")
|
|
assert.Equal("nonce:google:http://example.com/hello", state)
|
|
|
|
// Test with OIDC
|
|
p2 := provider.OIDC{}
|
|
state = MakeState(r, &p2, "nonce")
|
|
assert.Equal("nonce:oidc:http://example.com/hello", state)
|
|
|
|
// Test with Generic OAuth
|
|
p3 := provider.GenericOAuth{}
|
|
state = MakeState(r, &p3, "nonce")
|
|
assert.Equal("nonce:generic-oauth:http://example.com/hello", state)
|
|
}
|
|
|
|
func TestAuthNonce(t *testing.T) {
|
|
assert := assert.New(t)
|
|
err, nonce1 := Nonce()
|
|
assert.Nil(err, "error generating nonce")
|
|
assert.Len(nonce1, 32, "length should be 32 chars")
|
|
|
|
err, nonce2 := Nonce()
|
|
assert.Nil(err, "error generating nonce")
|
|
assert.Len(nonce2, 32, "length should be 32 chars")
|
|
|
|
assert.NotEqual(nonce1, nonce2, "nonce should not be equal")
|
|
}
|
|
|
|
func TestAuthCookieDomainMatch(t *testing.T) {
|
|
assert := assert.New(t)
|
|
cd := NewCookieDomain("example.com")
|
|
|
|
// Exact should match
|
|
assert.True(cd.Match("example.com"), "exact domain should match")
|
|
|
|
// Subdomain should match
|
|
assert.True(cd.Match("test.example.com"), "subdomain should match")
|
|
assert.True(cd.Match("twolevels.test.example.com"), "subdomain should match")
|
|
assert.True(cd.Match("many.many.levels.test.example.com"), "subdomain should match")
|
|
|
|
// Derived domain should not match
|
|
assert.False(cd.Match("testexample.com"), "derived domain should not match")
|
|
|
|
// Other domain should not match
|
|
assert.False(cd.Match("test.com"), "other domain should not match")
|
|
}
|
|
|
|
func TestAuthCookieDomains(t *testing.T) {
|
|
assert := assert.New(t)
|
|
cds := CookieDomains{}
|
|
|
|
err := cds.UnmarshalFlag("one.com,two.org")
|
|
assert.Nil(err)
|
|
expected := CookieDomains{
|
|
CookieDomain{
|
|
Domain: "one.com",
|
|
DomainLen: 7,
|
|
SubDomain: ".one.com",
|
|
SubDomainLen: 8,
|
|
},
|
|
CookieDomain{
|
|
Domain: "two.org",
|
|
DomainLen: 7,
|
|
SubDomain: ".two.org",
|
|
SubDomainLen: 8,
|
|
},
|
|
}
|
|
assert.Equal(expected, cds)
|
|
|
|
marshal, err := cds.MarshalFlag()
|
|
assert.Nil(err)
|
|
assert.Equal("one.com,two.org", marshal)
|
|
}
|