Multiple provider support + OIDC provider
This commit is contained in:
parent
5dfd4f2878
commit
c9289d6fc1
2
Makefile
2
Makefile
@ -1,5 +1,5 @@
|
||||
|
||||
format:
|
||||
gofmt -w -s internal/*.go cmd/*.go
|
||||
gofmt -w -s internal/*.go internal/provider/*.go cmd/*.go
|
||||
|
||||
.PHONY: format
|
||||
|
5
go.mod
5
go.mod
@ -9,6 +9,7 @@ require (
|
||||
github.com/containous/flaeg v1.4.1 // indirect
|
||||
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect
|
||||
github.com/containous/traefik v2.0.0-alpha2+incompatible
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible
|
||||
github.com/go-acme/lego v2.5.0+incompatible // indirect
|
||||
github.com/go-kit/kit v0.8.0 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
@ -21,6 +22,7 @@ require (
|
||||
github.com/miekg/dns v1.1.8 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pkg/errors v0.8.1 // indirect
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/ryanuber/go-glob v1.0.0 // indirect
|
||||
github.com/sirupsen/logrus v1.4.1
|
||||
github.com/stretchr/objx v0.2.0 // indirect
|
||||
@ -29,8 +31,9 @@ require (
|
||||
github.com/vulcand/predicate v1.1.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd // indirect
|
||||
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.3.1
|
||||
)
|
||||
|
13
go.sum
13
go.sum
@ -1,3 +1,4 @@
|
||||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE=
|
||||
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
|
||||
github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY=
|
||||
@ -10,14 +11,18 @@ github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 h1:1srn9voikJGofblB
|
||||
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg=
|
||||
github.com/containous/traefik v2.0.0-alpha2+incompatible h1:5RS6mUAOPQCy1jAmcmxLj2nChIcs3fKuxZxH9AF6ih8=
|
||||
github.com/containous/traefik v2.0.0-alpha2+incompatible/go.mod h1:epDRqge3JzKOhlSWzOpNYEEKXmM6yfN5tPzDGKk3ljo=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI=
|
||||
github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
|
||||
github.com/go-acme/lego v2.5.0+incompatible h1:5fNN9yRQfv8ymH3DSsxla+4aYeQt2IgfZqHKVnK8f0s=
|
||||
github.com/go-acme/lego v2.5.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
|
||||
github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff h1:xL/fJdlTJL6R/6Qk2tPu3EP1NsXgap9hXLvxKH0Ytko=
|
||||
@ -45,6 +50,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
|
||||
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
|
||||
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
|
||||
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
|
||||
@ -68,12 +75,17 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcN
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd h1:sMHc2rZHuzQmrbVoSpt9HgerkXPyIeCSO6k0zUMGfFk=
|
||||
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 h1:HdqqaWmYAUI7/dmByKKEw+yxDksGSo+9GjkUc9Zp34E=
|
||||
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@ -84,6 +96,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
|
||||
|
@ -81,32 +81,6 @@ func ValidateEmail(email string) bool {
|
||||
return found
|
||||
}
|
||||
|
||||
// OAuth Methods
|
||||
|
||||
// Get login url
|
||||
func GetLoginURL(r *http.Request, nonce string) string {
|
||||
state := fmt.Sprintf("%s:%s", nonce, returnUrl(r))
|
||||
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.GetLoginURL(redirectUri(r), state)
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
|
||||
func ExchangeCode(r *http.Request) (string, error) {
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.ExchangeCode(redirectUri(r), code)
|
||||
}
|
||||
|
||||
// Get user with token
|
||||
|
||||
func GetUser(token string) (provider.User, error) {
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.GetUser(token)
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
// Get the redirect base
|
||||
@ -117,7 +91,7 @@ func redirectBase(r *http.Request) string {
|
||||
return fmt.Sprintf("%s://%s", proto, host)
|
||||
}
|
||||
|
||||
// // Return url
|
||||
// Return url
|
||||
func returnUrl(r *http.Request) string {
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
|
||||
@ -196,24 +170,35 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
}
|
||||
|
||||
// Validate the csrf cookie against state
|
||||
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", errors.New("Invalid CSRF cookie value")
|
||||
return false, "", "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
|
||||
if len(state) < 34 {
|
||||
return false, "", errors.New("Invalid CSRF state value")
|
||||
return false, "", "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, "", errors.New("CSRF cookie does not match state")
|
||||
return false, "", "", errors.New("CSRF cookie does not match state")
|
||||
}
|
||||
|
||||
// Valid, return redirect
|
||||
return true, state[33:], nil
|
||||
// Extract provider
|
||||
params := state[33:]
|
||||
split := strings.Index(params, ":")
|
||||
if split == -1 {
|
||||
return false, "", "", errors.New("Invalid CSRF state format")
|
||||
}
|
||||
|
||||
// Valid, return provider and redirect
|
||||
return true, params[:split], params[split+1:], nil
|
||||
}
|
||||
|
||||
func MakeState(r *http.Request, p provider.Provider, nonce string) string {
|
||||
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
|
||||
}
|
||||
|
||||
func Nonce() (error, string) {
|
||||
@ -282,10 +267,10 @@ func cookieExpiry() time.Time {
|
||||
|
||||
// Cookie Domain
|
||||
type CookieDomain struct {
|
||||
Domain string `description:"TEST1"`
|
||||
DomainLen int `description:"TEST2"`
|
||||
SubDomain string `description:"TEST3"`
|
||||
SubDomainLen int `description:"TEST4"`
|
||||
Domain string
|
||||
DomainLen int
|
||||
SubDomain string
|
||||
SubDomainLen int
|
||||
}
|
||||
|
||||
func NewCookieDomain(domain string) *CookieDomain {
|
||||
|
@ -95,139 +95,70 @@ func TestAuthValidateEmail(t *testing.T) {
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
}
|
||||
|
||||
// TODO: Split google tests out
|
||||
func TestAuthGetLoginURL(t *testing.T) {
|
||||
func TestRedirectUri(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
google := provider.Google{
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
|
||||
config, _ = NewConfig([]string{})
|
||||
config.Providers.Google = google
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(GetLoginURL(r, "nonce"))
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
//
|
||||
// No Auth Host
|
||||
//
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
uri, err := url.Parse(redirectUri(r))
|
||||
assert.Nil(err)
|
||||
assert.Equal("http", uri.Scheme)
|
||||
assert.Equal("app.example.com", uri.Host)
|
||||
assert.Equal("/_oauth", uri.Path)
|
||||
|
||||
//
|
||||
// With Auth URL but no matching cookie domain
|
||||
// - will not use auth host
|
||||
//
|
||||
config, _ = NewConfig([]string{})
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.Providers.Google = google
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
uri, err = url.Parse(redirectUri(r))
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
assert.Equal("http", uri.Scheme)
|
||||
assert.Equal("app.example.com", uri.Host)
|
||||
assert.Equal("/_oauth", uri.Path)
|
||||
|
||||
//
|
||||
// With correct Auth URL + cookie domain
|
||||
//
|
||||
config, _ = NewConfig([]string{})
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
config.Providers.Google = google
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
uri, err = url.Parse(redirectUri(r))
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://auth.example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
assert.Equal("http", uri.Scheme)
|
||||
assert.Equal("auth.example.com", uri.Host)
|
||||
assert.Equal("/_oauth", uri.Path)
|
||||
|
||||
//
|
||||
// With Auth URL + cookie domain, but from different domain
|
||||
// - will not use auth host
|
||||
//
|
||||
r, _ = http.NewRequest("GET", "http://another.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Proto", "https")
|
||||
r.Header.Add("X-Forwarded-Host", "another.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
uri, err = url.Parse(redirectUri(r))
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("test.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://another.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://another.com/hello"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
assert.Equal("another.com", uri.Host)
|
||||
assert.Equal("/_oauth", uri.Path)
|
||||
}
|
||||
|
||||
// TODO
|
||||
// func TestAuthExchangeCode(t *testing.T) {
|
||||
// }
|
||||
|
||||
// TODO
|
||||
// func TestAuthGetUser(t *testing.T) {
|
||||
// }
|
||||
|
||||
func TestAuthMakeCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
@ -265,14 +196,14 @@ func TestAuthMakeCSRFCookie(t *testing.T) {
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain but no auth url
|
||||
config = Config{
|
||||
config = &Config{
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain and auth url
|
||||
config = Config{
|
||||
config = &Config{
|
||||
AuthHost: "auth.example.com",
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
@ -304,13 +235,13 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||
// Should require 32 char string
|
||||
r := newCsrfRequest("")
|
||||
c.Value = ""
|
||||
valid, _, err := ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err := ValidateCSRFCookie(r, c)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, err = ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
@ -319,19 +250,48 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||
// Should require valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, err = ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state value", err.Error())
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
// Should require provider
|
||||
r = newCsrfRequest("12345678901234567890123456789012:99")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, state, err := ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state format", err.Error())
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
|
||||
assert.True(valid, "valid request should return valid")
|
||||
assert.Nil(err, "valid request should not return an error")
|
||||
assert.Equal("99", state, "valid request should return correct state")
|
||||
assert.Equal("p99", provider, "valid request should return correct provider")
|
||||
assert.Equal("url123", redirect, "valid request should return correct redirect")
|
||||
}
|
||||
|
||||
func TestMakeState(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
// Test with google
|
||||
p := provider.Google{}
|
||||
state := MakeState(r, &p, "nonce")
|
||||
assert.Equal("nonce:google:http://example.com/hello", state)
|
||||
|
||||
// Test with OIDC
|
||||
p2 := provider.OIDC{}
|
||||
state = MakeState(r, &p2, "nonce")
|
||||
assert.Equal("nonce:oidc:http://example.com/hello", state)
|
||||
}
|
||||
|
||||
func TestAuthNonce(t *testing.T) {
|
||||
@ -356,6 +316,8 @@ func TestAuthCookieDomainMatch(t *testing.T) {
|
||||
|
||||
// Subdomain should match
|
||||
assert.True(cd.Match("test.example.com"), "subdomain should match")
|
||||
assert.True(cd.Match("twolevels.test.example.com"), "subdomain should match")
|
||||
assert.True(cd.Match("many.many.levels.test.example.com"), "subdomain should match")
|
||||
|
||||
// Derived domain should not match
|
||||
assert.False(cd.Match("testexample.com"), "derived domain should not match")
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@ -18,24 +17,25 @@ import (
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
)
|
||||
|
||||
var config Config
|
||||
var config *Config
|
||||
|
||||
type Config struct {
|
||||
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
|
||||
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
|
||||
|
||||
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"`
|
||||
Config func(s string) error `long:"config" env:"CONFIG" description:"Path to config file" json:"-"`
|
||||
CookieDomains []CookieDomain `long:"cookie-domain" env:"COOKIE_DOMAIN" description:"Domain to set auth cookie on, can be set multiple times"`
|
||||
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
|
||||
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
|
||||
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
||||
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"`
|
||||
Domains CommaSeparatedList `long:"domain" env:"DOMAIN" description:"Only allow given email domains, can be set multiple times"`
|
||||
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
|
||||
Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"`
|
||||
SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"`
|
||||
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Only allow given email addresses, can be set multiple times"`
|
||||
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"`
|
||||
Config func(s string) error `long:"config" env:"CONFIG" description:"Path to config file" json:"-"`
|
||||
CookieDomains []CookieDomain `long:"cookie-domain" env:"COOKIE_DOMAIN" description:"Domain to set auth cookie on, can be set multiple times"`
|
||||
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
|
||||
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
|
||||
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
||||
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"`
|
||||
DefaultProvider string `long:"default-provider" env:"DEFAULT_PROVIDER" default:"google" choice:"google" choice:"oidc" description:"Default provider"`
|
||||
Domains CommaSeparatedList `long:"domain" env:"DOMAIN" description:"Only allow given email domains, can be set multiple times"`
|
||||
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
|
||||
Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"`
|
||||
SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"`
|
||||
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Only allow given email addresses, can be set multiple times"`
|
||||
|
||||
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
|
||||
Rules map[string]*Rule `long:"rule.<name>.<param>" description:"Rule definitions, param can be: \"action\" or \"rule\""`
|
||||
@ -53,7 +53,7 @@ type Config struct {
|
||||
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
|
||||
}
|
||||
|
||||
func NewGlobalConfig() Config {
|
||||
func NewGlobalConfig() *Config {
|
||||
var err error
|
||||
config, err = NewConfig(os.Args[1:])
|
||||
if err != nil {
|
||||
@ -64,29 +64,11 @@ func NewGlobalConfig() Config {
|
||||
return config
|
||||
}
|
||||
|
||||
func NewConfig(args []string) (Config, error) {
|
||||
c := Config{
|
||||
// TODO: move config parsing into new func "NewParsedConfig"
|
||||
|
||||
func NewConfig(args []string) (*Config, error) {
|
||||
c := &Config{
|
||||
Rules: map[string]*Rule{},
|
||||
Providers: provider.Providers{
|
||||
Google: provider.Google{
|
||||
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
},
|
||||
TokenURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
},
|
||||
UserURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := c.parseFlags(args)
|
||||
@ -97,13 +79,23 @@ func NewConfig(args []string) (Config, error) {
|
||||
// TODO: as log flags have now been parsed maybe we should return here so
|
||||
// any further errors can be logged via logrus instead of printed?
|
||||
|
||||
// TODO: Rename "Validate" method to "Setup" and move all below logic
|
||||
|
||||
// Setup
|
||||
// Set default provider on any rules where it's not specified
|
||||
for _, rule := range c.Rules {
|
||||
if rule.Provider == "" {
|
||||
rule.Provider = c.DefaultProvider
|
||||
}
|
||||
}
|
||||
|
||||
// Backwards compatability
|
||||
if c.CookieSecretLegacy != "" && c.SecretString == "" {
|
||||
fmt.Println("cookie-secret config option is deprecated, please use secret")
|
||||
c.SecretString = c.CookieSecretLegacy
|
||||
}
|
||||
if c.ClientIdLegacy != "" {
|
||||
c.Providers.Google.ClientId = c.ClientIdLegacy
|
||||
c.Providers.Google.ClientID = c.ClientIdLegacy
|
||||
}
|
||||
if c.ClientSecretLegacy != "" {
|
||||
c.Providers.Google.ClientSecret = c.ClientSecretLegacy
|
||||
@ -247,16 +239,21 @@ func convertLegacyToIni(name string) (io.Reader, error) {
|
||||
func (c *Config) Validate() {
|
||||
// Check for show stopper errors
|
||||
if len(c.Secret) == 0 {
|
||||
log.Fatal("\"secret\" option must be set.")
|
||||
log.Fatal("\"secret\" option must be set")
|
||||
}
|
||||
|
||||
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
|
||||
log.Fatal("providers.google.client-id, providers.google.client-secret must be set")
|
||||
// Setup default provider
|
||||
err := c.setupProvider(c.DefaultProvider)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Check rules
|
||||
// Check rules (validates the rule and the rule provider)
|
||||
for _, rule := range c.Rules {
|
||||
rule.Validate()
|
||||
err = rule.Validate(c)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -265,6 +262,61 @@ func (c Config) String() string {
|
||||
return string(jsonConf)
|
||||
}
|
||||
|
||||
// GetProvider returns the provider of the given name
|
||||
func (c *Config) GetProvider(name string) (provider.Provider, error) {
|
||||
switch name {
|
||||
case "google":
|
||||
return &c.Providers.Google, nil
|
||||
case "oidc":
|
||||
return &c.Providers.OIDC, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown provider: %s", name)
|
||||
}
|
||||
|
||||
// GetConfiguredProvider returns the provider of the given name, if it has been
|
||||
// configured. Returns an error if the provider is unknown, or hasn't been configured
|
||||
func (c *Config) GetConfiguredProvider(name string) (provider.Provider, error) {
|
||||
// Check the provider has been configured
|
||||
if !c.providerConfigured(name) {
|
||||
return nil, fmt.Errorf("Unconfigured provider: %s", name)
|
||||
}
|
||||
|
||||
return c.GetProvider(name)
|
||||
}
|
||||
|
||||
func (c *Config) providerConfigured(name string) bool {
|
||||
// Check default provider
|
||||
if name == c.DefaultProvider {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check rule providers
|
||||
for _, rule := range c.Rules {
|
||||
if name == rule.Provider {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) setupProvider(name string) error {
|
||||
// Check provider exists
|
||||
p, err := c.GetProvider(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Setup
|
||||
err = p.Setup()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
Action string
|
||||
Rule string
|
||||
@ -273,8 +325,7 @@ type Rule struct {
|
||||
|
||||
func NewRule() *Rule {
|
||||
return &Rule{
|
||||
Action: "auth",
|
||||
Provider: "google", // TODO: Use default provider
|
||||
Action: "auth",
|
||||
}
|
||||
}
|
||||
|
||||
@ -284,15 +335,12 @@ func (r *Rule) formattedRule() string {
|
||||
return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(")
|
||||
}
|
||||
|
||||
func (r *Rule) Validate() {
|
||||
func (r *Rule) Validate(c *Config) error {
|
||||
if r.Action != "auth" && r.Action != "allow" {
|
||||
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"")
|
||||
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
|
||||
}
|
||||
|
||||
// TODO: Update with more provider support
|
||||
if r.Provider != "google" {
|
||||
log.Fatal("invalid rule provider, must be \"google\"")
|
||||
}
|
||||
return c.setupProvider(r.Provider)
|
||||
}
|
||||
|
||||
// Legacy support for comma separated lists
|
||||
|
@ -1,11 +1,13 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
// "fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -28,34 +30,11 @@ func TestConfigDefaults(t *testing.T) {
|
||||
assert.Equal("_forward_auth", c.CookieName)
|
||||
assert.Equal("_forward_auth_csrf", c.CSRFCookieName)
|
||||
assert.Equal("auth", c.DefaultAction)
|
||||
assert.Equal("google", c.DefaultProvider)
|
||||
assert.Len(c.Domains, 0)
|
||||
assert.Equal(time.Second*time.Duration(43200), c.Lifetime)
|
||||
assert.Equal("/_oauth", c.Path)
|
||||
assert.Len(c.Whitelist, 0)
|
||||
|
||||
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", c.Providers.Google.Scope)
|
||||
assert.Equal("", c.Providers.Google.Prompt)
|
||||
|
||||
loginURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
}
|
||||
assert.Equal(loginURL, c.Providers.Google.LoginURL)
|
||||
|
||||
tokenURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
}
|
||||
assert.Equal(tokenURL, c.Providers.Google.TokenURL)
|
||||
|
||||
userURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
}
|
||||
assert.Equal(userURL, c.Providers.Google.UserURL)
|
||||
}
|
||||
|
||||
func TestConfigParseArgs(t *testing.T) {
|
||||
@ -63,6 +42,7 @@ func TestConfigParseArgs(t *testing.T) {
|
||||
c, err := NewConfig([]string{
|
||||
"--cookie-name=cookiename",
|
||||
"--csrf-cookie-name", "\"csrfcookiename\"",
|
||||
"--default-provider", "\"oidc\"",
|
||||
"--rule.1.action=allow",
|
||||
"--rule.1.rule=PathPrefix(`/one`)",
|
||||
"--rule.two.action=auth",
|
||||
@ -73,18 +53,19 @@ func TestConfigParseArgs(t *testing.T) {
|
||||
// Check normal flags
|
||||
assert.Equal("cookiename", c.CookieName)
|
||||
assert.Equal("csrfcookiename", c.CSRFCookieName)
|
||||
assert.Equal("oidc", c.DefaultProvider)
|
||||
|
||||
// Check rules
|
||||
assert.Equal(map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
Rule: "PathPrefix(`/one`)",
|
||||
Provider: "google",
|
||||
Provider: "oidc",
|
||||
},
|
||||
"two": {
|
||||
Action: "auth",
|
||||
Rule: "Host(`two.com`) && Path(`/two`)",
|
||||
Provider: "google",
|
||||
Provider: "oidc",
|
||||
},
|
||||
}, c.Rules)
|
||||
}
|
||||
@ -157,7 +138,7 @@ func TestConfigFlagBackwardsCompatability(t *testing.T) {
|
||||
|
||||
// Google provider params used to be top level
|
||||
assert.Equal("clientid", c.ClientIdLegacy)
|
||||
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id")
|
||||
assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
|
||||
assert.Equal("verysecret", c.ClientSecretLegacy)
|
||||
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
|
||||
assert.Equal("prompt", c.PromptLegacy)
|
||||
@ -220,7 +201,7 @@ func TestConfigParseEnvironment(t *testing.T) {
|
||||
assert.Nil(err)
|
||||
|
||||
assert.Equal("env_cookie_name", c.CookieName, "variable should be read from environment")
|
||||
assert.Equal("env_client_id", c.Providers.Google.ClientId, "namespace variable should be read from environment")
|
||||
assert.Equal("env_client_id", c.Providers.Google.ClientID, "namespace variable should be read from environment")
|
||||
|
||||
os.Unsetenv("COOKIE_NAME")
|
||||
os.Unsetenv("PROVIDERS_GOOGLE_CLIENT_ID")
|
||||
@ -265,7 +246,7 @@ func TestConfigParseEnvironmentBackwardsCompatability(t *testing.T) {
|
||||
|
||||
// Google provider params used to be top level
|
||||
assert.Equal("clientid", c.ClientIdLegacy)
|
||||
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id")
|
||||
assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
|
||||
assert.Equal("verysecret", c.ClientSecretLegacy)
|
||||
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
|
||||
assert.Equal("prompt", c.PromptLegacy)
|
||||
@ -305,6 +286,92 @@ func TestConfigTransformation(t *testing.T) {
|
||||
assert.Equal(time.Second*time.Duration(200), c.Lifetime, "lifetime should be read and converted to duration")
|
||||
}
|
||||
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Install new logger + hook
|
||||
var hook *test.Hook
|
||||
log, hook = test.NewNullLogger()
|
||||
log.ExitFunc = func(code int) {}
|
||||
|
||||
// Validate defualt config + rule error
|
||||
c, _ := NewConfig([]string{
|
||||
"--rule.1.action=bad",
|
||||
})
|
||||
c.Validate()
|
||||
|
||||
logs := hook.AllEntries()
|
||||
assert.Len(logs, 3)
|
||||
|
||||
// Should have fatal error requiring secret
|
||||
assert.Equal("\"secret\" option must be set", logs[0].Message)
|
||||
assert.Equal(logrus.FatalLevel, logs[0].Level)
|
||||
|
||||
// Should also have default provider (google) error
|
||||
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", logs[1].Message)
|
||||
assert.Equal(logrus.FatalLevel, logs[1].Level)
|
||||
|
||||
// Should validate rule
|
||||
assert.Equal("invalid rule action, must be \"auth\" or \"allow\"", logs[2].Message)
|
||||
assert.Equal(logrus.FatalLevel, logs[2].Level)
|
||||
|
||||
hook.Reset()
|
||||
|
||||
// Validate with invalid providers
|
||||
c, _ = NewConfig([]string{
|
||||
"--secret=veryverysecret",
|
||||
"--providers.google.client-id=id",
|
||||
"--providers.google.client-secret=secret",
|
||||
"--rule.1.action=auth",
|
||||
"--rule.1.provider=bad2",
|
||||
})
|
||||
c.Validate()
|
||||
|
||||
logs = hook.AllEntries()
|
||||
assert.Len(logs, 1)
|
||||
|
||||
// Should have error for rule provider
|
||||
assert.Equal("Unknown provider: bad2", logs[0].Message)
|
||||
assert.Equal(logrus.FatalLevel, logs[0].Level)
|
||||
}
|
||||
|
||||
func TestConfigGetProvider(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c, _ := NewConfig([]string{})
|
||||
|
||||
// Should be able to get "google" provider
|
||||
p, err := c.GetProvider("google")
|
||||
assert.Nil(err)
|
||||
assert.Equal(&c.Providers.Google, p)
|
||||
|
||||
// Should be able to get "oidc" provider
|
||||
p, err = c.GetProvider("oidc")
|
||||
assert.Nil(err)
|
||||
assert.Equal(&c.Providers.OIDC, p)
|
||||
|
||||
// Should catch unknown provider
|
||||
p, err = c.GetProvider("bad")
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Unknown provider: bad", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigGetConfiguredProvider(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
c, _ := NewConfig([]string{})
|
||||
|
||||
// Should be able to get "google" default provider
|
||||
p, err := c.GetConfiguredProvider("google")
|
||||
assert.Nil(err)
|
||||
assert.Equal(&c.Providers.Google, p)
|
||||
|
||||
// Should fail to get valid "oidc" provider as it's not configured
|
||||
p, err = c.GetConfiguredProvider("oidc")
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Unconfigured provider: oidc", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCommaSeparatedList(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
list := CommaSeparatedList{}
|
||||
|
@ -6,9 +6,9 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var log logrus.FieldLogger
|
||||
var log *logrus.Logger
|
||||
|
||||
func NewDefaultLogger() logrus.FieldLogger {
|
||||
func NewDefaultLogger() *logrus.Logger {
|
||||
// Setup logger
|
||||
log = logrus.StandardLogger()
|
||||
logrus.SetOutput(os.Stdout)
|
||||
|
@ -2,13 +2,15 @@ package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Google provider
|
||||
type Google struct {
|
||||
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
|
||||
Scope string
|
||||
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
|
||||
@ -18,15 +20,48 @@ type Google struct {
|
||||
UserURL *url.URL
|
||||
}
|
||||
|
||||
func (g *Google) GetLoginURL(redirectUri, state string) string {
|
||||
// Name returns the name of the provider
|
||||
func (g *Google) Name() string {
|
||||
return "google"
|
||||
}
|
||||
|
||||
// Setup performs validation and setup
|
||||
func (g *Google) Setup() error {
|
||||
if g.ClientID == "" || g.ClientSecret == "" {
|
||||
return errors.New("providers.google.client-id, providers.google.client-secret must be set")
|
||||
}
|
||||
|
||||
// Set static values
|
||||
g.Scope = "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email"
|
||||
g.LoginURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
}
|
||||
g.TokenURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
}
|
||||
g.UserURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginURL provides the login url for the given redirect uri and state
|
||||
func (g *Google) GetLoginURL(redirectURI, state string) string {
|
||||
q := url.Values{}
|
||||
q.Set("client_id", g.ClientId)
|
||||
q.Set("client_id", g.ClientID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", g.Scope)
|
||||
if g.Prompt != "" {
|
||||
q.Set("prompt", g.Prompt)
|
||||
}
|
||||
q.Set("redirect_uri", redirectUri)
|
||||
q.Set("redirect_uri", redirectURI)
|
||||
q.Set("state", state)
|
||||
|
||||
var u url.URL
|
||||
@ -36,12 +71,13 @@ func (g *Google) GetLoginURL(redirectUri, state string) string {
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
|
||||
// ExchangeCode exchanges the given redirect uri and code for a token
|
||||
func (g *Google) ExchangeCode(redirectURI, code string) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", g.ClientId)
|
||||
form.Set("client_id", g.ClientID)
|
||||
form.Set("client_secret", g.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", redirectUri)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("code", code)
|
||||
|
||||
res, err := http.PostForm(g.TokenURL.String(), form)
|
||||
@ -49,13 +85,14 @@ func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var token Token
|
||||
var token token
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
|
||||
return token.Token, err
|
||||
}
|
||||
|
||||
// GetUser uses the given token and returns a complete provider.User object
|
||||
func (g *Google) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
|
151
internal/provider/google_test.go
Normal file
151
internal/provider/google_test.go
Normal file
@ -0,0 +1,151 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Tests
|
||||
|
||||
func TestGoogleName(t *testing.T) {
|
||||
p := Google{}
|
||||
assert.Equal(t, "google", p.Name())
|
||||
}
|
||||
|
||||
func TestGoogleSetup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := Google{}
|
||||
|
||||
// Check validation
|
||||
err := p.Setup()
|
||||
if assert.Error(err) {
|
||||
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", err.Error())
|
||||
}
|
||||
|
||||
// Check setup
|
||||
p = Google{
|
||||
ClientID: "id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
err = p.Setup()
|
||||
assert.Nil(err)
|
||||
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", p.Scope)
|
||||
assert.Equal("", p.Prompt)
|
||||
|
||||
assert.Equal(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
}, p.LoginURL)
|
||||
|
||||
assert.Equal(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
}, p.TokenURL)
|
||||
|
||||
assert.Equal(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
}, p.UserURL)
|
||||
}
|
||||
|
||||
func TestGoogleGetLoginURL(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := Google{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "google.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("google.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"state"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
}
|
||||
|
||||
func TestGoogleExchangeCode(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Setup server
|
||||
expected := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"client_secret": []string{"sectest"},
|
||||
"code": []string{"code"},
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
}
|
||||
server, serverURL := NewOAuthServer(t, map[string]string{
|
||||
"token": expected.Encode(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Setup provider
|
||||
p := Google{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
TokenURL: &url.URL{
|
||||
Scheme: serverURL.Scheme,
|
||||
Host: serverURL.Host,
|
||||
Path: "/token",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
|
||||
assert.Nil(err)
|
||||
assert.Equal("123456789", token)
|
||||
}
|
||||
|
||||
func TestGoogleGetUser(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Setup server
|
||||
server, serverURL := NewOAuthServer(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Setup provider
|
||||
p := Google{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
UserURL: &url.URL{
|
||||
Scheme: serverURL.Scheme,
|
||||
Host: serverURL.Host,
|
||||
Path: "/userinfo",
|
||||
},
|
||||
}
|
||||
|
||||
user, err := p.GetUser("123456789")
|
||||
assert.Nil(err)
|
||||
|
||||
assert.Equal("1", user.ID)
|
||||
assert.Equal("example@example.com", user.Email)
|
||||
assert.True(user.Verified)
|
||||
assert.Equal("example.com", user.Hd)
|
||||
}
|
108
internal/provider/oidc.go
Normal file
108
internal/provider/oidc.go
Normal file
@ -0,0 +1,108 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OIDC provider
|
||||
type OIDC struct {
|
||||
OAuthProvider
|
||||
|
||||
IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"`
|
||||
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
|
||||
|
||||
provider *oidc.Provider
|
||||
verifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
// Name returns the name of the provider
|
||||
func (o *OIDC) Name() string {
|
||||
return "oidc"
|
||||
}
|
||||
|
||||
// Setup performs validation and setup
|
||||
func (o *OIDC) Setup() error {
|
||||
// Check parms
|
||||
if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" {
|
||||
return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set")
|
||||
}
|
||||
|
||||
var err error
|
||||
o.ctx = context.Background()
|
||||
|
||||
// Try to initiate provider
|
||||
o.provider, err = oidc.NewProvider(o.ctx, o.IssuerURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create oauth2 config
|
||||
o.Config = &oauth2.Config{
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: o.ClientSecret,
|
||||
Endpoint: o.provider.Endpoint(),
|
||||
|
||||
// "openid" is a required scope for OpenID Connect flows.
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
// Create OIDC verifier
|
||||
o.verifier = o.provider.Verifier(&oidc.Config{
|
||||
ClientID: o.ClientID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginURL provides the login url for the given redirect uri and state
|
||||
func (o *OIDC) GetLoginURL(redirectURI, state string) string {
|
||||
return o.OAuthGetLoginURL(redirectURI, state)
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges the given redirect uri and code for a token
|
||||
func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
|
||||
token, err := o.OAuthExchangeCode(redirectURI, code)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Extract ID token
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing id_token")
|
||||
}
|
||||
|
||||
return rawIDToken, nil
|
||||
}
|
||||
|
||||
// GetUser uses the given token and returns a complete provider.User object
|
||||
func (o *OIDC) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
// Parse & Verify ID Token
|
||||
idToken, err := o.verifier.Verify(o.ctx, token)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
// Extract custom claims
|
||||
var claims struct {
|
||||
ID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"email_verified"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
user.ID = claims.ID
|
||||
user.Email = claims.Email
|
||||
user.Verified = claims.Verified
|
||||
|
||||
return user, nil
|
||||
}
|
252
internal/provider/oidc_test.go
Normal file
252
internal/provider/oidc_test.go
Normal file
@ -0,0 +1,252 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// Tests
|
||||
|
||||
func TestOIDCName(t *testing.T) {
|
||||
p := OIDC{}
|
||||
assert.Equal(t, "oidc", p.Name())
|
||||
}
|
||||
|
||||
func TestOIDCSetup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := OIDC{}
|
||||
|
||||
err := p.Setup()
|
||||
if assert.Error(err) {
|
||||
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCGetLoginURL(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, serverURL, _ := setupOIDCTest(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
|
||||
assert.Nil(err)
|
||||
assert.Equal(serverURL.Scheme, uri.Scheme)
|
||||
assert.Equal(serverURL.Host, uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"openid profile email"},
|
||||
"state": []string{"state"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
|
||||
// Calling the method should not modify the underlying config
|
||||
assert.Equal("", provider.Config.RedirectURL)
|
||||
}
|
||||
|
||||
func TestOIDCExchangeCode(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
|
||||
"token": {
|
||||
"code": "code",
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": "http://example.com/_oauth",
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
|
||||
assert.Nil(err)
|
||||
assert.Equal("id_123456789", token)
|
||||
}
|
||||
|
||||
func TestOIDCGetUser(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, serverURL, key := setupOIDCTest(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Generate JWT
|
||||
token := key.sign(t, []byte(`{
|
||||
"iss": "`+serverURL.String()+`",
|
||||
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
|
||||
"aud": "idtest",
|
||||
"sub": "1",
|
||||
"email": "example@example.com",
|
||||
"email_verified": true
|
||||
}`))
|
||||
|
||||
// Get user
|
||||
user, err := provider.GetUser(token)
|
||||
assert.Nil(err)
|
||||
assert.Equal("1", user.ID)
|
||||
assert.Equal("example@example.com", user.Email)
|
||||
assert.True(user.Verified)
|
||||
}
|
||||
|
||||
// Utils
|
||||
|
||||
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
|
||||
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
|
||||
// Generate key
|
||||
key, err := newRSAKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := make(map[string]string)
|
||||
if bodyValues != nil {
|
||||
// URL encode bodyValues into body
|
||||
for method, values := range bodyValues {
|
||||
q := url.Values{}
|
||||
for k, v := range values {
|
||||
q.Set(k, v)
|
||||
}
|
||||
body[method] = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// Set up oidc server
|
||||
server, serverURL := NewOIDCServer(t, key, body)
|
||||
|
||||
// Setup provider
|
||||
p := OIDC{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
IssuerURL: serverURL.String(),
|
||||
}
|
||||
|
||||
// Initialise config/verifier
|
||||
err = p.Setup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &p, server, serverURL, key
|
||||
}
|
||||
|
||||
// OIDCServer is used in the OIDC Tests to mock an OIDC server
|
||||
type OIDCServer struct {
|
||||
t *testing.T
|
||||
url *url.URL
|
||||
body map[string]string // method -> body
|
||||
key *rsaKey
|
||||
}
|
||||
|
||||
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
|
||||
handler := &OIDCServer{t: t, key: key, body: body}
|
||||
server := httptest.NewServer(handler)
|
||||
handler.url, _ = url.Parse(server.URL)
|
||||
return server, handler.url
|
||||
}
|
||||
|
||||
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := ioutil.ReadAll(r.Body)
|
||||
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
// Open id config
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{
|
||||
"issuer":"`+s.url.String()+`",
|
||||
"authorization_endpoint":"`+s.url.String()+`/auth",
|
||||
"token_endpoint":"`+s.url.String()+`/token",
|
||||
"jwks_uri":"`+s.url.String()+`/jwks"
|
||||
}`)
|
||||
} else if r.URL.Path == "/token" {
|
||||
// Token request
|
||||
// Check body
|
||||
if b, ok := s.body["token"]; ok {
|
||||
if b != string(body) {
|
||||
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{
|
||||
"access_token":"123456789",
|
||||
"id_token":"id_123456789"
|
||||
}`)
|
||||
} else if r.URL.Path == "/jwks" {
|
||||
// Key request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
|
||||
} else {
|
||||
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
// rsaKey is used in the OIDCServer tests to sign and verify requests
|
||||
type rsaKey struct {
|
||||
key *rsa.PrivateKey
|
||||
alg jose.SignatureAlgorithm
|
||||
jwkPub *jose.JSONWebKey
|
||||
jwkPriv *jose.JSONWebKey
|
||||
}
|
||||
|
||||
func newRSAKey() (*rsaKey, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1028)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rsaKey{
|
||||
key: key,
|
||||
alg: jose.RS256,
|
||||
jwkPub: &jose.JSONWebKey{
|
||||
Key: key.Public(),
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
jwkPriv: &jose.JSONWebKey{
|
||||
Key: key,
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *rsaKey) publicJWK(t *testing.T) string {
|
||||
b, err := k.jwkPub.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// sign creates a JWS using the private key from the provided payload.
|
||||
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: k.alg,
|
||||
Key: k.key,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jws, err := signer.Sign(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := jws.CompactSerialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
@ -1,16 +1,61 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
// "net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Providers contains all the implemented providers
|
||||
type Providers struct {
|
||||
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
|
||||
OIDC OIDC `group:"OIDC Provider" namespace:"oidc" env-namespace:"OIDC"`
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
// Provider is used to authenticate users
|
||||
type Provider interface {
|
||||
Name() string
|
||||
GetLoginURL(redirectURI, state string) string
|
||||
ExchangeCode(redirectURI, code string) (string, error)
|
||||
GetUser(token string) (User, error)
|
||||
Setup() error
|
||||
}
|
||||
|
||||
type token struct {
|
||||
Token string `json:"access_token"`
|
||||
}
|
||||
|
||||
// User is the authenticated user
|
||||
type User struct {
|
||||
Id string `json:"id"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified_email"`
|
||||
Hd string `json:"hd"`
|
||||
}
|
||||
|
||||
// OAuthProvider is a provider using the oauth2 library
|
||||
type OAuthProvider struct {
|
||||
Config *oauth2.Config
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// ConfigCopy returns a copy of the oauth2 config with the given redirectURI
|
||||
// which ensures the underlying config is not modified
|
||||
func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config {
|
||||
config := *p.Config
|
||||
config.RedirectURL = redirectURI
|
||||
return config
|
||||
}
|
||||
|
||||
// OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2
|
||||
func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string) string {
|
||||
config := p.ConfigCopy(redirectURI)
|
||||
return config.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// OAuthExchangeCode provides a base "ExchangeCode" for proiders using OAauth2
|
||||
func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string) (*oauth2.Token, error) {
|
||||
config := p.ConfigCopy(redirectURI)
|
||||
return config.Exchange(p.ctx, code)
|
||||
}
|
||||
|
48
internal/provider/providers_test.go
Normal file
48
internal/provider/providers_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Utilities
|
||||
|
||||
type OAuthServer struct {
|
||||
t *testing.T
|
||||
url *url.URL
|
||||
body map[string]string // method -> body
|
||||
}
|
||||
|
||||
func NewOAuthServer(t *testing.T, body map[string]string) (*httptest.Server, *url.URL) {
|
||||
handler := &OAuthServer{t: t, body: body}
|
||||
server := httptest.NewServer(handler)
|
||||
handler.url, _ = url.Parse(server.URL)
|
||||
return server, handler.url
|
||||
}
|
||||
|
||||
func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := ioutil.ReadAll(r.Body)
|
||||
// fmt.Println("Got request:", r.URL, r.Method, string(body))
|
||||
|
||||
if r.Method == "POST" && r.URL.Path == "/token" {
|
||||
if s.body["token"] != string(body) {
|
||||
s.t.Fatal("Unexpected request body, expected", s.body["token"], "got", string(body))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, `{"access_token":"123456789"}`)
|
||||
} else if r.Method == "GET" && r.URL.Path == "/userinfo" {
|
||||
fmt.Fprint(w, `{
|
||||
"id":"1",
|
||||
"email":"example@example.com",
|
||||
"verified_email":true,
|
||||
"hd":"example.com"
|
||||
}`)
|
||||
} else {
|
||||
s.t.Fatal("Unrecognised request: ", r.Method, r.URL, string(body))
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/containous/traefik/pkg/rules"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@ -27,10 +28,11 @@ func (s *Server) buildRoutes() {
|
||||
|
||||
// Let's build a router
|
||||
for name, rule := range config.Rules {
|
||||
matchRule := rule.formattedRule()
|
||||
if rule.Action == "allow" {
|
||||
s.router.AddRoute(rule.formattedRule(), 1, s.AllowHandler(name))
|
||||
s.router.AddRoute(matchRule, 1, s.AllowHandler(name))
|
||||
} else {
|
||||
s.router.AddRoute(rule.formattedRule(), 1, s.AuthHandler(name))
|
||||
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,7 +43,7 @@ func (s *Server) buildRoutes() {
|
||||
if config.DefaultAction == "allow" {
|
||||
s.router.NewRoute().Handler(s.AllowHandler("default"))
|
||||
} else {
|
||||
s.router.NewRoute().Handler(s.AuthHandler("default"))
|
||||
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +66,9 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Authenticate requests
|
||||
func (s *Server) AuthHandler(rule string) http.HandlerFunc {
|
||||
func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
|
||||
p, _ := config.GetConfiguredProvider(providerName)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Logging setup
|
||||
logger := s.logger(r, rule, "Authenticating request")
|
||||
@ -72,7 +76,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
|
||||
// Get auth cookie
|
||||
c, err := r.Cookie(config.CookieName)
|
||||
if err != nil {
|
||||
s.authRedirect(logger, w, r)
|
||||
s.authRedirect(logger, w, r, p)
|
||||
return
|
||||
}
|
||||
|
||||
@ -81,7 +85,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
|
||||
if err != nil {
|
||||
if err.Error() == "Cookie has expired" {
|
||||
logger.Info("Cookie has expired")
|
||||
s.authRedirect(logger, w, r)
|
||||
s.authRedirect(logger, w, r, p)
|
||||
} else {
|
||||
logger.Errorf("Invalid cookie: %v", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
@ -121,18 +125,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Validate state
|
||||
valid, redirect, err := ValidateCSRFCookie(r, c)
|
||||
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
|
||||
if !valid {
|
||||
logger.Warnf("Error validating csrf cookie: %v", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Get provider
|
||||
p, err := config.GetConfiguredProvider(providerName)
|
||||
if err != nil {
|
||||
logger.Warnf("Invalid provider in csrf cookie: %s, %v", providerName, err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, ClearCSRFCookie(r))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := ExchangeCode(r)
|
||||
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
logger.Errorf("Code exchange failed with: %v", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
@ -140,7 +152,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := GetUser(token)
|
||||
user, err := p.GetUser(token)
|
||||
if err != nil {
|
||||
logger.Errorf("Error getting user: %s", err)
|
||||
return
|
||||
@ -157,7 +169,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request, p provider.Provider) {
|
||||
// Error indicates no cookie, generate nonce
|
||||
err, nonce := Nonce()
|
||||
if err != nil {
|
||||
@ -171,7 +183,8 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht
|
||||
logger.Debug("Set CSRF cookie and redirecting to google login")
|
||||
|
||||
// Forward them on
|
||||
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||
loginUrl := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce))
|
||||
http.Redirect(w, r, loginUrl, http.StatusTemporaryRedirect)
|
||||
|
||||
logger.Debug("Done")
|
||||
return
|
||||
|
@ -11,15 +11,16 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// TODO:
|
||||
|
||||
/**
|
||||
* Setup
|
||||
*/
|
||||
|
||||
func init() {
|
||||
config = newDefaultConfig()
|
||||
config.LogLevel = "panic"
|
||||
log = NewDefaultLogger()
|
||||
}
|
||||
@ -30,7 +31,7 @@ func init() {
|
||||
|
||||
func TestServerAuthHandlerInvalid(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Should redirect vanilla request to login url
|
||||
req := newDefaultHttpRequest("/foo")
|
||||
@ -42,10 +43,20 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
|
||||
assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
|
||||
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
|
||||
|
||||
// Check state string
|
||||
qs := fwd.Query()
|
||||
state, exists := qs["state"]
|
||||
require.True(t, exists)
|
||||
require.Len(t, state, 1)
|
||||
parts := strings.SplitN(state[0], ":", 3)
|
||||
require.Len(t, parts, 3)
|
||||
assert.Equal("google", parts[1])
|
||||
assert.Equal("http://example.com/foo", parts[2])
|
||||
|
||||
// Should catch invalid cookie
|
||||
req = newDefaultHttpRequest("/foo")
|
||||
c := MakeCookie(req, "test@example.com")
|
||||
parts := strings.Split(c.Value, "|")
|
||||
parts = strings.Split(c.Value, "|")
|
||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||
|
||||
res, _ = doHttpRequest(req, c)
|
||||
@ -62,7 +73,7 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
|
||||
|
||||
func TestServerAuthHandlerExpired(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Lifetime = time.Second * time.Duration(-1)
|
||||
config.Domains = []string{"test.com"}
|
||||
|
||||
@ -90,7 +101,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
|
||||
|
||||
func TestServerAuthHandlerValid(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Should allow valid request email
|
||||
req := newDefaultHttpRequest("/foo")
|
||||
@ -108,7 +119,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
|
||||
|
||||
func TestServerAuthCallback(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Setup token server
|
||||
tokenServerHandler := &TokenServerHandler{}
|
||||
@ -136,7 +147,7 @@ func TestServerAuthCallback(t *testing.T) {
|
||||
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
|
||||
|
||||
// Should redirect valid request
|
||||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
|
||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
||||
@ -149,7 +160,7 @@ func TestServerAuthCallback(t *testing.T) {
|
||||
|
||||
func TestServerDefaultAction(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
|
||||
req := newDefaultHttpRequest("/random")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
@ -161,9 +172,36 @@ func TestServerDefaultAction(t *testing.T) {
|
||||
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
|
||||
}
|
||||
|
||||
func TestServerDefaultProvider(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Should use "google" as default provider when not specified
|
||||
req := newDefaultHttpRequest("/random")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
fwd, _ := res.Location()
|
||||
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
|
||||
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
|
||||
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
|
||||
|
||||
// Should use alternative default provider when set
|
||||
config.DefaultProvider = "oidc"
|
||||
config.Providers.OIDC.OAuthProvider.Config = &oauth2.Config{
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oidc.com/oidcauth",
|
||||
},
|
||||
}
|
||||
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
fwd, _ = res.Location()
|
||||
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to oidc")
|
||||
assert.Equal("oidc.com", fwd.Host, "request with expired cookie should be redirected to oidc")
|
||||
assert.Equal("/oidcauth", fwd.Path, "request with expired cookie should be redirected to oidc")
|
||||
}
|
||||
|
||||
func TestServerRouteHeaders(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -196,7 +234,7 @@ func TestServerRouteHeaders(t *testing.T) {
|
||||
|
||||
func TestServerRouteHost(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -226,7 +264,7 @@ func TestServerRouteHost(t *testing.T) {
|
||||
|
||||
func TestServerRouteMethod(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -247,7 +285,7 @@ func TestServerRouteMethod(t *testing.T) {
|
||||
|
||||
func TestServerRoutePath(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -281,7 +319,7 @@ func TestServerRoutePath(t *testing.T) {
|
||||
|
||||
func TestServerRouteQuery(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
config = newDefaultConfig()
|
||||
config.Rules = map[string]*Rule{
|
||||
"1": {
|
||||
Action: "allow",
|
||||
@ -346,6 +384,18 @@ func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
return res, string(body)
|
||||
}
|
||||
|
||||
func newDefaultConfig() *Config {
|
||||
config, _ = NewConfig([]string{
|
||||
"--providers.google.client-id=id",
|
||||
"--providers.google.client-secret=secret",
|
||||
})
|
||||
|
||||
// Setup the google providers without running all the config validation
|
||||
config.Providers.Google.Setup()
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func newDefaultHttpRequest(uri string) *http.Request {
|
||||
return newHttpRequest("", "http://example.com/", uri)
|
||||
}
|
||||
@ -354,25 +404,8 @@ func newHttpRequest(method, dest, uri string) *http.Request {
|
||||
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
|
||||
p, _ := url.Parse(dest)
|
||||
r.Header.Add("X-Forwarded-Method", method)
|
||||
r.Header.Add("X-Forwarded-Proto", p.Scheme)
|
||||
r.Header.Add("X-Forwarded-Host", p.Host)
|
||||
r.Header.Add("X-Forwarded-Uri", uri)
|
||||
return r
|
||||
}
|
||||
|
||||
func qsDiff(t *testing.T, one, two url.Values) []string {
|
||||
errs := make([]string, 0)
|
||||
for k := range one {
|
||||
if two.Get(k) == "" {
|
||||
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
|
||||
}
|
||||
if one.Get(k) != two.Get(k) {
|
||||
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
|
||||
}
|
||||
}
|
||||
for k := range two {
|
||||
if one.Get(k) == "" {
|
||||
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user