Multiple provider support + OIDC provider

This commit is contained in:
Thom Seddon 2019-09-18 17:55:52 +01:00
parent 5dfd4f2878
commit 5a9c6adedf
16 changed files with 1043 additions and 278 deletions

View File

@ -1,5 +1,5 @@
format:
gofmt -w -s internal/*.go cmd/*.go
gofmt -w -s internal/*.go internal/provider/*.go cmd/*.go
.PHONY: format

5
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/containous/flaeg v1.4.1 // indirect
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect
github.com/containous/traefik v2.0.0-alpha2+incompatible
github.com/coreos/go-oidc v2.1.0+incompatible
github.com/go-acme/lego v2.5.0+incompatible // indirect
github.com/go-kit/kit v0.8.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
@ -21,6 +22,7 @@ require (
github.com/miekg/dns v1.1.8 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pkg/errors v0.8.1 // indirect
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/sirupsen/logrus v1.4.1
github.com/stretchr/objx v0.2.0 // indirect
@ -29,8 +31,9 @@ require (
github.com/vulcand/predicate v1.1.0 // indirect
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd // indirect
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 // indirect
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/square/go-jose.v2 v2.3.1 // indirect
gopkg.in/square/go-jose.v2 v2.3.1
)

13
go.sum
View File

@ -1,3 +1,4 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE=
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY=
@ -10,14 +11,18 @@ github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 h1:1srn9voikJGofblB
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg=
github.com/containous/traefik v2.0.0-alpha2+incompatible h1:5RS6mUAOPQCy1jAmcmxLj2nChIcs3fKuxZxH9AF6ih8=
github.com/containous/traefik v2.0.0-alpha2+incompatible/go.mod h1:epDRqge3JzKOhlSWzOpNYEEKXmM6yfN5tPzDGKk3ljo=
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI=
github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
github.com/go-acme/lego v2.5.0+incompatible h1:5fNN9yRQfv8ymH3DSsxla+4aYeQt2IgfZqHKVnK8f0s=
github.com/go-acme/lego v2.5.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff h1:xL/fJdlTJL6R/6Qk2tPu3EP1NsXgap9hXLvxKH0Ytko=
@ -45,6 +50,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
@ -68,12 +75,17 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcN
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd h1:sMHc2rZHuzQmrbVoSpt9HgerkXPyIeCSO6k0zUMGfFk=
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 h1:HdqqaWmYAUI7/dmByKKEw+yxDksGSo+9GjkUc9Zp34E=
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -84,6 +96,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=

View File

@ -81,32 +81,6 @@ func ValidateEmail(email string) bool {
return found
}
// OAuth Methods
// Get login url
func GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, returnUrl(r))
// TODO: Support multiple providers
return config.Providers.Google.GetLoginURL(redirectUri(r), state)
}
// Exchange code for token
func ExchangeCode(r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
// TODO: Support multiple providers
return config.Providers.Google.ExchangeCode(redirectUri(r), code)
}
// Get user with token
func GetUser(token string) (provider.User, error) {
// TODO: Support multiple providers
return config.Providers.Google.GetUser(token)
}
// Utility methods
// Get the redirect base
@ -117,7 +91,7 @@ func redirectBase(r *http.Request) string {
return fmt.Sprintf("%s://%s", proto, host)
}
// // Return url
// Return url
func returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri")
@ -196,24 +170,35 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
}
// Validate the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
state := r.URL.Query().Get("state")
if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value")
return false, "", "", errors.New("Invalid CSRF cookie value")
}
if len(state) < 34 {
return false, "", errors.New("Invalid CSRF state value")
return false, "", "", errors.New("Invalid CSRF state value")
}
// Check nonce match
if c.Value != state[:32] {
return false, "", errors.New("CSRF cookie does not match state")
return false, "", "", errors.New("CSRF cookie does not match state")
}
// Valid, return redirect
return true, state[33:], nil
// Extract provider
params := state[33:]
split := strings.Index(params, ":")
if split == -1 {
return false, "", "", errors.New("Invalid CSRF state format")
}
// Valid, return provider and redirect
return true, params[:split], params[split+1:], nil
}
func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
}
func Nonce() (error, string) {
@ -282,10 +267,10 @@ func cookieExpiry() time.Time {
// Cookie Domain
type CookieDomain struct {
Domain string `description:"TEST1"`
DomainLen int `description:"TEST2"`
SubDomain string `description:"TEST3"`
SubDomainLen int `description:"TEST4"`
Domain string
DomainLen int
SubDomain string
SubDomainLen int
}
func NewCookieDomain(domain string) *CookieDomain {

View File

@ -95,139 +95,70 @@ func TestAuthValidateEmail(t *testing.T) {
assert.True(v, "should allow user in whitelist")
}
// TODO: Split google tests out
func TestAuthGetLoginURL(t *testing.T) {
func TestRedirectUri(t *testing.T) {
assert := assert.New(t)
google := provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
}
config, _ = NewConfig([]string{})
config.Providers.Google = google
r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Host", "app.example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
// Check url
uri, err := url.Parse(GetLoginURL(r, "nonce"))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("test.com", uri.Host)
assert.Equal("/auth", uri.Path)
//
// No Auth Host
//
config, _ = NewConfig([]string{})
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"nonce:http://example.com/hello"},
}
assert.Equal(expectedQs, qs)
uri, err := url.Parse(redirectUri(r))
assert.Nil(err)
assert.Equal("http", uri.Scheme)
assert.Equal("app.example.com", uri.Host)
assert.Equal("/_oauth", uri.Path)
//
// With Auth URL but no matching cookie domain
// - will not use auth host
//
config, _ = NewConfig([]string{})
config.AuthHost = "auth.example.com"
config.Providers.Google = google
// Check url
uri, err = url.Parse(GetLoginURL(r, "nonce"))
uri, err = url.Parse(redirectUri(r))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("test.com", uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"nonce:http://example.com/hello"},
}
assert.Equal(expectedQs, qs)
assert.Equal("http", uri.Scheme)
assert.Equal("app.example.com", uri.Host)
assert.Equal("/_oauth", uri.Path)
//
// With correct Auth URL + cookie domain
//
config, _ = NewConfig([]string{})
config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
config.Providers.Google = google
// Check url
uri, err = url.Parse(GetLoginURL(r, "nonce"))
uri, err = url.Parse(redirectUri(r))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("test.com", uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://auth.example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"},
"prompt": []string{"consent select_account"},
}
assert.Equal(expectedQs, qs)
assert.Equal("http", uri.Scheme)
assert.Equal("auth.example.com", uri.Host)
assert.Equal("/_oauth", uri.Path)
//
// With Auth URL + cookie domain, but from different domain
// - will not use auth host
//
r, _ = http.NewRequest("GET", "http://another.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Proto", "https")
r.Header.Add("X-Forwarded-Host", "another.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
// Check url
uri, err = url.Parse(GetLoginURL(r, "nonce"))
uri, err = url.Parse(redirectUri(r))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("test.com", uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://another.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://another.com/hello"},
"prompt": []string{"consent select_account"},
}
assert.Equal(expectedQs, qs)
assert.Equal("another.com", uri.Host)
assert.Equal("/_oauth", uri.Path)
}
// TODO
// func TestAuthExchangeCode(t *testing.T) {
// }
// TODO
// func TestAuthGetUser(t *testing.T) {
// }
func TestAuthMakeCookie(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
@ -265,14 +196,14 @@ func TestAuthMakeCSRFCookie(t *testing.T) {
assert.Equal("app.example.com", c.Domain)
// With cookie domain but no auth url
config = Config{
config = &Config{
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
assert.Equal("app.example.com", c.Domain)
// With cookie domain and auth url
config = Config{
config = &Config{
AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
@ -304,13 +235,13 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
// Should require 32 char string
r := newCsrfRequest("")
c.Value = ""
valid, _, err := ValidateCSRFCookie(r, c)
valid, _, _, err := ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
}
c.Value = "123456789012345678901234567890123"
valid, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error())
@ -319,19 +250,48 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
// Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012"
valid, _, err = ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}
// Should allow valid state
// Should require provider
r = newCsrfRequest("12345678901234567890123456789012:99")
c.Value = "12345678901234567890123456789012"
valid, state, err := ValidateCSRFCookie(r, c)
valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state format", err.Error())
}
// Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
c.Value = "12345678901234567890123456789012"
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
assert.True(valid, "valid request should return valid")
assert.Nil(err, "valid request should not return an error")
assert.Equal("99", state, "valid request should return correct state")
assert.Equal("p99", provider, "valid request should return correct provider")
assert.Equal("url123", redirect, "valid request should return correct redirect")
}
func TestMakeState(t *testing.T) {
assert := assert.New(t)
r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
// Test with google
p := provider.Google{}
state := MakeState(r, &p, "nonce")
assert.Equal("nonce:google:http://example.com/hello", state)
// Test with OIDC
p2 := provider.OIDC{}
state = MakeState(r, &p2, "nonce")
assert.Equal("nonce:oidc:http://example.com/hello", state)
}
func TestAuthNonce(t *testing.T) {
@ -356,6 +316,8 @@ func TestAuthCookieDomainMatch(t *testing.T) {
// Subdomain should match
assert.True(cd.Match("test.example.com"), "subdomain should match")
assert.True(cd.Match("twolevels.test.example.com"), "subdomain should match")
assert.True(cd.Match("many.many.levels.test.example.com"), "subdomain should match")
// Derived domain should not match
assert.False(cd.Match("testexample.com"), "derived domain should not match")

View File

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"io/ioutil"
"net/url"
"os"
"regexp"
"strconv"
@ -18,24 +17,25 @@ import (
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
var config Config
var config *Config
type Config struct {
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"`
Config func(s string) error `long:"config" env:"CONFIG" description:"Path to config file" json:"-"`
CookieDomains []CookieDomain `long:"cookie-domain" env:"COOKIE_DOMAIN" description:"Domain to set auth cookie on, can be set multiple times"`
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"`
Domains CommaSeparatedList `long:"domain" env:"DOMAIN" description:"Only allow given email domains, can be set multiple times"`
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"`
SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"`
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Only allow given email addresses, can be set multiple times"`
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"`
Config func(s string) error `long:"config" env:"CONFIG" description:"Path to config file" json:"-"`
CookieDomains []CookieDomain `long:"cookie-domain" env:"COOKIE_DOMAIN" description:"Domain to set auth cookie on, can be set multiple times"`
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"`
DefaultProvider string `long:"default-provider" env:"DEFAULT_PROVIDER" default:"google" choice:"google" choice:"oidc" description:"Default provider"`
Domains CommaSeparatedList `long:"domain" env:"DOMAIN" description:"Only allow given email domains, can be set multiple times"`
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"`
SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"`
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Only allow given email addresses, can be set multiple times"`
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
Rules map[string]*Rule `long:"rule.<name>.<param>" description:"Rule definitions, param can be: \"action\" or \"rule\""`
@ -53,7 +53,7 @@ type Config struct {
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
}
func NewGlobalConfig() Config {
func NewGlobalConfig() *Config {
var err error
config, err = NewConfig(os.Args[1:])
if err != nil {
@ -64,29 +64,11 @@ func NewGlobalConfig() Config {
return config
}
func NewConfig(args []string) (Config, error) {
c := Config{
// TODO: move config parsing into new func "NewParsedConfig"
func NewConfig(args []string) (*Config, error) {
c := &Config{
Rules: map[string]*Rule{},
Providers: provider.Providers{
Google: provider.Google{
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
LoginURL: &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
},
TokenURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
},
UserURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
},
},
},
}
err := c.parseFlags(args)
@ -97,13 +79,23 @@ func NewConfig(args []string) (Config, error) {
// TODO: as log flags have now been parsed maybe we should return here so
// any further errors can be logged via logrus instead of printed?
// TODO: Rename "Validate" method to "Setup" and move all below logic
// Setup
// Set default provider on any rules where it's not specified
for _, rule := range c.Rules {
if rule.Provider == "" {
rule.Provider = c.DefaultProvider
}
}
// Backwards compatability
if c.CookieSecretLegacy != "" && c.SecretString == "" {
fmt.Println("cookie-secret config option is deprecated, please use secret")
c.SecretString = c.CookieSecretLegacy
}
if c.ClientIdLegacy != "" {
c.Providers.Google.ClientId = c.ClientIdLegacy
c.Providers.Google.ClientID = c.ClientIdLegacy
}
if c.ClientSecretLegacy != "" {
c.Providers.Google.ClientSecret = c.ClientSecretLegacy
@ -247,16 +239,21 @@ func convertLegacyToIni(name string) (io.Reader, error) {
func (c *Config) Validate() {
// Check for show stopper errors
if len(c.Secret) == 0 {
log.Fatal("\"secret\" option must be set.")
log.Fatal("\"secret\" option must be set")
}
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
log.Fatal("providers.google.client-id, providers.google.client-secret must be set")
// Setup default provider
err := c.setupProvider(c.DefaultProvider)
if err != nil {
log.Fatal(err)
}
// Check rules
// Check rules (validates the rule and the rule provider)
for _, rule := range c.Rules {
rule.Validate()
err = rule.Validate(c)
if err != nil {
log.Fatal(err)
}
}
}
@ -265,6 +262,61 @@ func (c Config) String() string {
return string(jsonConf)
}
// GetProvider returns the provider of the given name
func (c *Config) GetProvider(name string) (provider.Provider, error) {
switch name {
case "google":
return &c.Providers.Google, nil
case "oidc":
return &c.Providers.OIDC, nil
}
return nil, fmt.Errorf("Unknown provider: %s", name)
}
// GetConfiguredProvider returns the provider of the given name, if it has been
// configured. Returns an error if the provider is unknown, or hasn't been configured
func (c *Config) GetConfiguredProvider(name string) (provider.Provider, error) {
// Check the provider has been configured
if !c.providerConfigured(name) {
return nil, fmt.Errorf("Unconfigured provider: %s", name)
}
return c.GetProvider(name)
}
func (c *Config) providerConfigured(name string) bool {
// Check default provider
if name == c.DefaultProvider {
return true
}
// Check rule providers
for _, rule := range c.Rules {
if name == rule.Provider {
return true
}
}
return false
}
func (c *Config) setupProvider(name string) error {
// Check provider exists
p, err := c.GetProvider(name)
if err != nil {
return err
}
// Setup
err = p.Setup()
if err != nil {
return err
}
return nil
}
type Rule struct {
Action string
Rule string
@ -273,8 +325,7 @@ type Rule struct {
func NewRule() *Rule {
return &Rule{
Action: "auth",
Provider: "google", // TODO: Use default provider
Action: "auth",
}
}
@ -284,15 +335,12 @@ func (r *Rule) formattedRule() string {
return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(")
}
func (r *Rule) Validate() {
func (r *Rule) Validate(c *Config) error {
if r.Action != "auth" && r.Action != "allow" {
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"")
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
}
// TODO: Update with more provider support
if r.Provider != "google" {
log.Fatal("invalid rule provider, must be \"google\"")
}
return c.setupProvider(r.Provider)
}
// Legacy support for comma separated lists

View File

@ -1,11 +1,13 @@
package tfa
import (
"net/url"
// "fmt"
"os"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -28,34 +30,11 @@ func TestConfigDefaults(t *testing.T) {
assert.Equal("_forward_auth", c.CookieName)
assert.Equal("_forward_auth_csrf", c.CSRFCookieName)
assert.Equal("auth", c.DefaultAction)
assert.Equal("google", c.DefaultProvider)
assert.Len(c.Domains, 0)
assert.Equal(time.Second*time.Duration(43200), c.Lifetime)
assert.Equal("/_oauth", c.Path)
assert.Len(c.Whitelist, 0)
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", c.Providers.Google.Scope)
assert.Equal("", c.Providers.Google.Prompt)
loginURL := &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
assert.Equal(loginURL, c.Providers.Google.LoginURL)
tokenURL := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
assert.Equal(tokenURL, c.Providers.Google.TokenURL)
userURL := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
assert.Equal(userURL, c.Providers.Google.UserURL)
}
func TestConfigParseArgs(t *testing.T) {
@ -63,6 +42,7 @@ func TestConfigParseArgs(t *testing.T) {
c, err := NewConfig([]string{
"--cookie-name=cookiename",
"--csrf-cookie-name", "\"csrfcookiename\"",
"--default-provider", "\"oidc\"",
"--rule.1.action=allow",
"--rule.1.rule=PathPrefix(`/one`)",
"--rule.two.action=auth",
@ -73,18 +53,19 @@ func TestConfigParseArgs(t *testing.T) {
// Check normal flags
assert.Equal("cookiename", c.CookieName)
assert.Equal("csrfcookiename", c.CSRFCookieName)
assert.Equal("oidc", c.DefaultProvider)
// Check rules
assert.Equal(map[string]*Rule{
"1": {
Action: "allow",
Rule: "PathPrefix(`/one`)",
Provider: "google",
Provider: "oidc",
},
"two": {
Action: "auth",
Rule: "Host(`two.com`) && Path(`/two`)",
Provider: "google",
Provider: "oidc",
},
}, c.Rules)
}
@ -157,7 +138,7 @@ func TestConfigFlagBackwardsCompatability(t *testing.T) {
// Google provider params used to be top level
assert.Equal("clientid", c.ClientIdLegacy)
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id")
assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
assert.Equal("verysecret", c.ClientSecretLegacy)
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
assert.Equal("prompt", c.PromptLegacy)
@ -220,7 +201,7 @@ func TestConfigParseEnvironment(t *testing.T) {
assert.Nil(err)
assert.Equal("env_cookie_name", c.CookieName, "variable should be read from environment")
assert.Equal("env_client_id", c.Providers.Google.ClientId, "namespace variable should be read from environment")
assert.Equal("env_client_id", c.Providers.Google.ClientID, "namespace variable should be read from environment")
os.Unsetenv("COOKIE_NAME")
os.Unsetenv("PROVIDERS_GOOGLE_CLIENT_ID")
@ -265,7 +246,7 @@ func TestConfigParseEnvironmentBackwardsCompatability(t *testing.T) {
// Google provider params used to be top level
assert.Equal("clientid", c.ClientIdLegacy)
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id")
assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
assert.Equal("verysecret", c.ClientSecretLegacy)
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
assert.Equal("prompt", c.PromptLegacy)
@ -305,6 +286,92 @@ func TestConfigTransformation(t *testing.T) {
assert.Equal(time.Second*time.Duration(200), c.Lifetime, "lifetime should be read and converted to duration")
}
func TestConfigValidate(t *testing.T) {
assert := assert.New(t)
// Install new logger + hook
var hook *test.Hook
log, hook = test.NewNullLogger()
log.ExitFunc = func(code int) {}
// Validate defualt config + rule error
c, _ := NewConfig([]string{
"--rule.1.action=bad",
})
c.Validate()
logs := hook.AllEntries()
assert.Len(logs, 3)
// Should have fatal error requiring secret
assert.Equal("\"secret\" option must be set", logs[0].Message)
assert.Equal(logrus.FatalLevel, logs[0].Level)
// Should also have default provider (google) error
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", logs[1].Message)
assert.Equal(logrus.FatalLevel, logs[1].Level)
// Should validate rule
assert.Equal("invalid rule action, must be \"auth\" or \"allow\"", logs[2].Message)
assert.Equal(logrus.FatalLevel, logs[2].Level)
hook.Reset()
// Validate with invalid providers
c, _ = NewConfig([]string{
"--secret=veryverysecret",
"--providers.google.client-id=id",
"--providers.google.client-secret=secret",
"--rule.1.action=auth",
"--rule.1.provider=bad2",
})
c.Validate()
logs = hook.AllEntries()
assert.Len(logs, 1)
// Should have error for rule provider
assert.Equal("Unknown provider: bad2", logs[0].Message)
assert.Equal(logrus.FatalLevel, logs[0].Level)
}
func TestConfigGetProvider(t *testing.T) {
assert := assert.New(t)
c, _ := NewConfig([]string{})
// Should be able to get "google" provider
p, err := c.GetProvider("google")
assert.Nil(err)
assert.Equal(&c.Providers.Google, p)
// Should be able to get "oidc" provider
p, err = c.GetProvider("oidc")
assert.Nil(err)
assert.Equal(&c.Providers.OIDC, p)
// Should catch unknown provider
p, err = c.GetProvider("bad")
if assert.Error(err) {
assert.Equal("Unknown provider: bad", err.Error())
}
}
func TestConfigGetConfiguredProvider(t *testing.T) {
assert := assert.New(t)
c, _ := NewConfig([]string{})
// Should be able to get "google" default provider
p, err := c.GetConfiguredProvider("google")
assert.Nil(err)
assert.Equal(&c.Providers.Google, p)
// Should fail to get valid "oidc" provider as it's not configured
p, err = c.GetConfiguredProvider("oidc")
if assert.Error(err) {
assert.Equal("Unconfigured provider: oidc", err.Error())
}
}
func TestConfigCommaSeparatedList(t *testing.T) {
assert := assert.New(t)
list := CommaSeparatedList{}

View File

@ -6,9 +6,9 @@ import (
"github.com/sirupsen/logrus"
)
var log logrus.FieldLogger
var log *logrus.Logger
func NewDefaultLogger() logrus.FieldLogger {
func NewDefaultLogger() *logrus.Logger {
// Setup logger
log = logrus.StandardLogger()
logrus.SetOutput(os.Stdout)

View File

@ -2,13 +2,15 @@ package provider
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
)
// Google provider
type Google struct {
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
Scope string
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
@ -18,15 +20,48 @@ type Google struct {
UserURL *url.URL
}
func (g *Google) GetLoginURL(redirectUri, state string) string {
// Name returns the name of the provider
func (g *Google) Name() string {
return "google"
}
// Setup performs validation and setup
func (g *Google) Setup() error {
if g.ClientID == "" || g.ClientSecret == "" {
return errors.New("providers.google.client-id, providers.google.client-secret must be set")
}
// Set static values
g.Scope = "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email"
g.LoginURL = &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
g.TokenURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
g.UserURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
return nil
}
// GetLoginURL provides the login url for the given redirect uri and state
func (g *Google) GetLoginURL(redirectURI, state string) string {
q := url.Values{}
q.Set("client_id", g.ClientId)
q.Set("client_id", g.ClientID)
q.Set("response_type", "code")
q.Set("scope", g.Scope)
if g.Prompt != "" {
q.Set("prompt", g.Prompt)
}
q.Set("redirect_uri", redirectUri)
q.Set("redirect_uri", redirectURI)
q.Set("state", state)
var u url.URL
@ -36,12 +71,13 @@ func (g *Google) GetLoginURL(redirectUri, state string) string {
return u.String()
}
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
// ExchangeCode exchanges the given redirect uri and code for a token
func (g *Google) ExchangeCode(redirectURI, code string) (string, error) {
form := url.Values{}
form.Set("client_id", g.ClientId)
form.Set("client_id", g.ClientID)
form.Set("client_secret", g.ClientSecret)
form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", redirectUri)
form.Set("redirect_uri", redirectURI)
form.Set("code", code)
res, err := http.PostForm(g.TokenURL.String(), form)
@ -49,13 +85,14 @@ func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
return "", err
}
var token Token
var token token
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
}
// GetUser uses the given token and returns a complete provider.User object
func (g *Google) GetUser(token string) (User, error) {
var user User

View File

@ -0,0 +1,151 @@
package provider
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
// Tests
func TestGoogleName(t *testing.T) {
p := Google{}
assert.Equal(t, "google", p.Name())
}
func TestGoogleSetup(t *testing.T) {
assert := assert.New(t)
p := Google{}
// Check validation
err := p.Setup()
if assert.Error(err) {
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", err.Error())
}
// Check setup
p = Google{
ClientID: "id",
ClientSecret: "secret",
}
err = p.Setup()
assert.Nil(err)
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", p.Scope)
assert.Equal("", p.Prompt)
assert.Equal(&url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}, p.LoginURL)
assert.Equal(&url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}, p.TokenURL)
assert.Equal(&url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}, p.UserURL)
}
func TestGoogleGetLoginURL(t *testing.T) {
assert := assert.New(t)
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "google.com",
Path: "/auth",
},
}
// Check url
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("google.com", uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"state"},
}
assert.Equal(expectedQs, qs)
}
func TestGoogleExchangeCode(t *testing.T) {
assert := assert.New(t)
// Setup server
expected := url.Values{
"client_id": []string{"idtest"},
"client_secret": []string{"sectest"},
"code": []string{"code"},
"grant_type": []string{"authorization_code"},
"redirect_uri": []string{"http://example.com/_oauth"},
}
server, serverURL := NewOAuthServer(t, map[string]string{
"token": expected.Encode(),
})
defer server.Close()
// Setup provider
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
TokenURL: &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/token",
},
}
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("123456789", token)
}
func TestGoogleGetUser(t *testing.T) {
assert := assert.New(t)
// Setup server
server, serverURL := NewOAuthServer(t, nil)
defer server.Close()
// Setup provider
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
UserURL: &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/userinfo",
},
}
user, err := p.GetUser("123456789")
assert.Nil(err)
assert.Equal("1", user.ID)
assert.Equal("example@example.com", user.Email)
assert.True(user.Verified)
assert.Equal("example.com", user.Hd)
}

108
internal/provider/oidc.go Normal file
View File

@ -0,0 +1,108 @@
package provider
import (
"context"
"errors"
"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
)
// OIDC provider
type OIDC struct {
OAuthProvider
IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"`
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
}
// Name returns the name of the provider
func (o *OIDC) Name() string {
return "oidc"
}
// Setup performs validation and setup
func (o *OIDC) Setup() error {
// Check parms
if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" {
return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set")
}
var err error
o.ctx = context.Background()
// Try to initiate provider
o.provider, err = oidc.NewProvider(o.ctx, o.IssuerURL)
if err != nil {
return err
}
// Create oauth2 config
o.Config = &oauth2.Config{
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
Endpoint: o.provider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
// Create OIDC verifier
o.verifier = o.provider.Verifier(&oidc.Config{
ClientID: o.ClientID,
})
return nil
}
// GetLoginURL provides the login url for the given redirect uri and state
func (o *OIDC) GetLoginURL(redirectURI, state string) string {
return o.OAuthGetLoginURL(redirectURI, state)
}
// ExchangeCode exchanges the given redirect uri and code for a token
func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
token, err := o.OAuthExchangeCode(redirectURI, code)
if err != nil {
return "", err
}
// Extract ID token
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return "", errors.New("Missing id_token")
}
return rawIDToken, nil
}
// GetUser uses the given token and returns a complete provider.User object
func (o *OIDC) GetUser(token string) (User, error) {
var user User
// Parse & Verify ID Token
idToken, err := o.verifier.Verify(o.ctx, token)
if err != nil {
return user, err
}
// Extract custom claims
var claims struct {
ID string `json:"sub"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
return user, err
}
user.ID = claims.ID
user.Email = claims.Email
user.Verified = claims.Verified
return user, nil
}

View File

@ -0,0 +1,252 @@
package provider
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
jose "gopkg.in/square/go-jose.v2"
)
// Tests
func TestOIDCName(t *testing.T) {
p := OIDC{}
assert.Equal(t, "oidc", p.Name())
}
func TestOIDCSetup(t *testing.T) {
assert := assert.New(t)
p := OIDC{}
err := p.Setup()
if assert.Error(err) {
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
}
}
func TestOIDCGetLoginURL(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, _ := setupOIDCTest(t, nil)
defer server.Close()
// Check url
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal(serverURL.Scheme, uri.Scheme)
assert.Equal(serverURL.Host, uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"openid profile email"},
"state": []string{"state"},
}
assert.Equal(expectedQs, qs)
// Calling the method should not modify the underlying config
assert.Equal("", provider.Config.RedirectURL)
}
func TestOIDCExchangeCode(t *testing.T) {
assert := assert.New(t)
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
"token": {
"code": "code",
"grant_type": "authorization_code",
"redirect_uri": "http://example.com/_oauth",
},
})
defer server.Close()
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("id_123456789", token)
}
func TestOIDCGetUser(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, key := setupOIDCTest(t, nil)
defer server.Close()
// Generate JWT
token := key.sign(t, []byte(`{
"iss": "`+serverURL.String()+`",
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
"aud": "idtest",
"sub": "1",
"email": "example@example.com",
"email_verified": true
}`))
// Get user
user, err := provider.GetUser(token)
assert.Nil(err)
assert.Equal("1", user.ID)
assert.Equal("example@example.com", user.Email)
assert.True(user.Verified)
}
// Utils
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
// Generate key
key, err := newRSAKey()
if err != nil {
t.Fatal(err)
}
body := make(map[string]string)
if bodyValues != nil {
// URL encode bodyValues into body
for method, values := range bodyValues {
q := url.Values{}
for k, v := range values {
q.Set(k, v)
}
body[method] = q.Encode()
}
}
// Set up oidc server
server, serverURL := NewOIDCServer(t, key, body)
// Setup provider
p := OIDC{
ClientID: "idtest",
ClientSecret: "sectest",
IssuerURL: serverURL.String(),
}
// Initialise config/verifier
err = p.Setup()
if err != nil {
t.Fatal(err)
}
return &p, server, serverURL, key
}
// OIDCServer is used in the OIDC Tests to mock an OIDC server
type OIDCServer struct {
t *testing.T
url *url.URL
body map[string]string // method -> body
key *rsaKey
}
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
handler := &OIDCServer{t: t, key: key, body: body}
server := httptest.NewServer(handler)
handler.url, _ = url.Parse(server.URL)
return server, handler.url
}
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
if r.URL.Path == "/.well-known/openid-configuration" {
// Open id config
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"issuer":"`+s.url.String()+`",
"authorization_endpoint":"`+s.url.String()+`/auth",
"token_endpoint":"`+s.url.String()+`/token",
"jwks_uri":"`+s.url.String()+`/jwks"
}`)
} else if r.URL.Path == "/token" {
// Token request
// Check body
if b, ok := s.body["token"]; ok {
if b != string(body) {
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
}
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"access_token":"123456789",
"id_token":"id_123456789"
}`)
} else if r.URL.Path == "/jwks" {
// Key request
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
} else {
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
}
}
// rsaKey is used in the OIDCServer tests to sign and verify requests
type rsaKey struct {
key *rsa.PrivateKey
alg jose.SignatureAlgorithm
jwkPub *jose.JSONWebKey
jwkPriv *jose.JSONWebKey
}
func newRSAKey() (*rsaKey, error) {
key, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
return nil, err
}
return &rsaKey{
key: key,
alg: jose.RS256,
jwkPub: &jose.JSONWebKey{
Key: key.Public(),
Algorithm: string(jose.RS256),
},
jwkPriv: &jose.JSONWebKey{
Key: key,
Algorithm: string(jose.RS256),
},
}, nil
}
func (k *rsaKey) publicJWK(t *testing.T) string {
b, err := k.jwkPub.MarshalJSON()
if err != nil {
t.Fatal(err)
}
return string(b)
}
// sign creates a JWS using the private key from the provided payload.
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: k.alg,
Key: k.key,
}, nil)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(payload)
if err != nil {
t.Fatal(err)
}
data, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return data
}

View File

@ -1,16 +1,61 @@
package provider
import (
"context"
// "net/url"
"golang.org/x/oauth2"
)
// Providers contains all the implemented providers
type Providers struct {
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
OIDC OIDC `group:"OIDC Provider" namespace:"oidc" env-namespace:"OIDC"`
}
type Token struct {
// Provider is used to authenticate users
type Provider interface {
Name() string
GetLoginURL(redirectURI, state string) string
ExchangeCode(redirectURI, code string) (string, error)
GetUser(token string) (User, error)
Setup() error
}
type token struct {
Token string `json:"access_token"`
}
// User is the authenticated user
type User struct {
Id string `json:"id"`
ID string `json:"id"`
Email string `json:"email"`
Verified bool `json:"verified_email"`
Hd string `json:"hd"`
}
// OAuthProvider is a provider using the oauth2 library
type OAuthProvider struct {
Config *oauth2.Config
ctx context.Context
}
// ConfigCopy returns a copy of the oauth2 config with the given redirectURI
// which ensures the underlying config is not modified
func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config {
config := *p.Config
config.RedirectURL = redirectURI
return config
}
// OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2
func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string) string {
config := p.ConfigCopy(redirectURI)
return config.AuthCodeURL(state)
}
// OAuthExchangeCode provides a base "ExchangeCode" for proiders using OAauth2
func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string) (*oauth2.Token, error) {
config := p.ConfigCopy(redirectURI)
return config.Exchange(p.ctx, code)
}

View File

@ -0,0 +1,48 @@
package provider
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
// Utilities
type OAuthServer struct {
t *testing.T
url *url.URL
body map[string]string // method -> body
}
func NewOAuthServer(t *testing.T, body map[string]string) (*httptest.Server, *url.URL) {
handler := &OAuthServer{t: t, body: body}
server := httptest.NewServer(handler)
handler.url, _ = url.Parse(server.URL)
return server, handler.url
}
func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
// fmt.Println("Got request:", r.URL, r.Method, string(body))
if r.Method == "POST" && r.URL.Path == "/token" {
if s.body["token"] != string(body) {
s.t.Fatal("Unexpected request body, expected", s.body["token"], "got", string(body))
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"access_token":"123456789"}`)
} else if r.Method == "GET" && r.URL.Path == "/userinfo" {
fmt.Fprint(w, `{
"id":"1",
"email":"example@example.com",
"verified_email":true,
"hd":"example.com"
}`)
} else {
s.t.Fatal("Unrecognised request: ", r.Method, r.URL, string(body))
}
}

View File

@ -6,6 +6,7 @@ import (
"github.com/containous/traefik/pkg/rules"
"github.com/sirupsen/logrus"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
type Server struct {
@ -27,10 +28,11 @@ func (s *Server) buildRoutes() {
// Let's build a router
for name, rule := range config.Rules {
matchRule := rule.formattedRule()
if rule.Action == "allow" {
s.router.AddRoute(rule.formattedRule(), 1, s.AllowHandler(name))
s.router.AddRoute(matchRule, 1, s.AllowHandler(name))
} else {
s.router.AddRoute(rule.formattedRule(), 1, s.AuthHandler(name))
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
}
}
@ -41,7 +43,7 @@ func (s *Server) buildRoutes() {
if config.DefaultAction == "allow" {
s.router.NewRoute().Handler(s.AllowHandler("default"))
} else {
s.router.NewRoute().Handler(s.AuthHandler("default"))
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
}
}
@ -64,7 +66,9 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc {
}
// Authenticate requests
func (s *Server) AuthHandler(rule string) http.HandlerFunc {
func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
p, _ := config.GetConfiguredProvider(providerName)
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, rule, "Authenticating request")
@ -72,7 +76,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
// Get auth cookie
c, err := r.Cookie(config.CookieName)
if err != nil {
s.authRedirect(logger, w, r)
s.authRedirect(logger, w, r, p)
return
}
@ -81,7 +85,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
if err != nil {
if err.Error() == "Cookie has expired" {
logger.Info("Cookie has expired")
s.authRedirect(logger, w, r)
s.authRedirect(logger, w, r, p)
} else {
logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401)
@ -121,18 +125,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
}
// Validate state
valid, redirect, err := ValidateCSRFCookie(r, c)
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
if !valid {
logger.Warnf("Error validating csrf cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Get provider
p, err := config.GetConfiguredProvider(providerName)
if err != nil {
logger.Warnf("Invalid provider in csrf cookie: %s, %v", providerName, err)
http.Error(w, "Not authorized", 401)
return
}
// Clear CSRF cookie
http.SetCookie(w, ClearCSRFCookie(r))
// Exchange code for token
token, err := ExchangeCode(r)
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
if err != nil {
logger.Errorf("Code exchange failed with: %v", err)
http.Error(w, "Service unavailable", 503)
@ -140,7 +152,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
}
// Get user
user, err := GetUser(token)
user, err := p.GetUser(token)
if err != nil {
logger.Errorf("Error getting user: %s", err)
return
@ -157,7 +169,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
}
}
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) {
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request, p provider.Provider) {
// Error indicates no cookie, generate nonce
err, nonce := Nonce()
if err != nil {
@ -171,7 +183,8 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht
logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
loginUrl := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce))
http.Redirect(w, r, loginUrl, http.StatusTemporaryRedirect)
logger.Debug("Done")
return

View File

@ -11,15 +11,16 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
// TODO:
/**
* Setup
*/
func init() {
config = newDefaultConfig()
config.LogLevel = "panic"
log = NewDefaultLogger()
}
@ -30,7 +31,7 @@ func init() {
func TestServerAuthHandlerInvalid(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
// Should redirect vanilla request to login url
req := newDefaultHttpRequest("/foo")
@ -42,10 +43,20 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
// Check state string
qs := fwd.Query()
state, exists := qs["state"]
require.True(t, exists)
require.Len(t, state, 1)
parts := strings.SplitN(state[0], ":", 3)
require.Len(t, parts, 3)
assert.Equal("google", parts[1])
assert.Equal("http://example.com/foo", parts[2])
// Should catch invalid cookie
req = newDefaultHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|")
parts = strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = doHttpRequest(req, c)
@ -62,7 +73,7 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
func TestServerAuthHandlerExpired(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Lifetime = time.Second * time.Duration(-1)
config.Domains = []string{"test.com"}
@ -90,7 +101,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
func TestServerAuthHandlerValid(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
// Should allow valid request email
req := newDefaultHttpRequest("/foo")
@ -108,7 +119,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
func TestServerAuthCallback(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
// Setup token server
tokenServerHandler := &TokenServerHandler{}
@ -136,7 +147,7 @@ func TestServerAuthCallback(t *testing.T) {
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
// Should redirect valid request
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
@ -149,7 +160,7 @@ func TestServerAuthCallback(t *testing.T) {
func TestServerDefaultAction(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
req := newDefaultHttpRequest("/random")
res, _ := doHttpRequest(req, nil)
@ -161,9 +172,36 @@ func TestServerDefaultAction(t *testing.T) {
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
}
func TestServerDefaultProvider(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
// Should use "google" as default provider when not specified
req := newDefaultHttpRequest("/random")
res, _ := doHttpRequest(req, nil)
fwd, _ := res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
// Should use alternative default provider when set
config.DefaultProvider = "oidc"
config.Providers.OIDC.OAuthProvider.Config = &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: "https://oidc.com/oidcauth",
},
}
res, _ = doHttpRequest(req, nil)
fwd, _ = res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to oidc")
assert.Equal("oidc.com", fwd.Host, "request with expired cookie should be redirected to oidc")
assert.Equal("/oidcauth", fwd.Path, "request with expired cookie should be redirected to oidc")
}
func TestServerRouteHeaders(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -196,7 +234,7 @@ func TestServerRouteHeaders(t *testing.T) {
func TestServerRouteHost(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -226,7 +264,7 @@ func TestServerRouteHost(t *testing.T) {
func TestServerRouteMethod(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -247,7 +285,7 @@ func TestServerRouteMethod(t *testing.T) {
func TestServerRoutePath(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -281,7 +319,7 @@ func TestServerRoutePath(t *testing.T) {
func TestServerRouteQuery(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config = newDefaultConfig()
config.Rules = map[string]*Rule{
"1": {
Action: "allow",
@ -346,6 +384,18 @@ func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
return res, string(body)
}
func newDefaultConfig() *Config {
config, _ = NewConfig([]string{
"--providers.google.client-id=id",
"--providers.google.client-secret=secret",
})
// Setup the google providers without running all the config validation
config.Providers.Google.Setup()
return config
}
func newDefaultHttpRequest(uri string) *http.Request {
return newHttpRequest("", "http://example.com/", uri)
}
@ -354,25 +404,8 @@ func newHttpRequest(method, dest, uri string) *http.Request {
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
p, _ := url.Parse(dest)
r.Header.Add("X-Forwarded-Method", method)
r.Header.Add("X-Forwarded-Proto", p.Scheme)
r.Header.Add("X-Forwarded-Host", p.Host)
r.Header.Add("X-Forwarded-Uri", uri)
return r
}
func qsDiff(t *testing.T, one, two url.Values) []string {
errs := make([]string, 0)
for k := range one {
if two.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
}
if one.Get(k) != two.Get(k) {
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
}
}
for k := range two {
if one.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
}
}
return errs
}