Multiple provider support + OIDC provider

This commit is contained in:
Thom Seddon 2019-09-18 17:55:52 +01:00
parent 5dfd4f2878
commit c9289d6fc1
16 changed files with 1043 additions and 278 deletions

View File

@ -1,5 +1,5 @@
format: format:
gofmt -w -s internal/*.go cmd/*.go gofmt -w -s internal/*.go internal/provider/*.go cmd/*.go
.PHONY: format .PHONY: format

5
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/containous/flaeg v1.4.1 // indirect github.com/containous/flaeg v1.4.1 // indirect
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect
github.com/containous/traefik v2.0.0-alpha2+incompatible github.com/containous/traefik v2.0.0-alpha2+incompatible
github.com/coreos/go-oidc v2.1.0+incompatible
github.com/go-acme/lego v2.5.0+incompatible // indirect github.com/go-acme/lego v2.5.0+incompatible // indirect
github.com/go-kit/kit v0.8.0 // indirect github.com/go-kit/kit v0.8.0 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
@ -21,6 +22,7 @@ require (
github.com/miekg/dns v1.1.8 // indirect github.com/miekg/dns v1.1.8 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pkg/errors v0.8.1 // indirect github.com/pkg/errors v0.8.1 // indirect
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/sirupsen/logrus v1.4.1 github.com/sirupsen/logrus v1.4.1
github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/objx v0.2.0 // indirect
@ -29,8 +31,9 @@ require (
github.com/vulcand/predicate v1.1.0 // indirect github.com/vulcand/predicate v1.1.0 // indirect
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd // indirect golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd // indirect
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 // indirect golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 // indirect
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/square/go-jose.v2 v2.3.1 // indirect gopkg.in/square/go-jose.v2 v2.3.1
) )

13
go.sum
View File

@ -1,3 +1,4 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE= github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE=
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY= github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY=
@ -10,14 +11,18 @@ github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 h1:1srn9voikJGofblB
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg= github.com/containous/mux v0.0.0-20181024131434-c33f32e26898/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg=
github.com/containous/traefik v2.0.0-alpha2+incompatible h1:5RS6mUAOPQCy1jAmcmxLj2nChIcs3fKuxZxH9AF6ih8= github.com/containous/traefik v2.0.0-alpha2+incompatible h1:5RS6mUAOPQCy1jAmcmxLj2nChIcs3fKuxZxH9AF6ih8=
github.com/containous/traefik v2.0.0-alpha2+incompatible/go.mod h1:epDRqge3JzKOhlSWzOpNYEEKXmM6yfN5tPzDGKk3ljo= github.com/containous/traefik v2.0.0-alpha2+incompatible/go.mod h1:epDRqge3JzKOhlSWzOpNYEEKXmM6yfN5tPzDGKk3ljo=
github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM=
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI= github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI=
github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M= github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
github.com/go-acme/lego v2.5.0+incompatible h1:5fNN9yRQfv8ymH3DSsxla+4aYeQt2IgfZqHKVnK8f0s=
github.com/go-acme/lego v2.5.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M= github.com/go-acme/lego v2.5.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0= github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff h1:xL/fJdlTJL6R/6Qk2tPu3EP1NsXgap9hXLvxKH0Ytko= github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff h1:xL/fJdlTJL6R/6Qk2tPu3EP1NsXgap9hXLvxKH0Ytko=
@ -45,6 +50,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU=
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k= github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
@ -68,12 +75,17 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcN
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd h1:sMHc2rZHuzQmrbVoSpt9HgerkXPyIeCSO6k0zUMGfFk= golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd h1:sMHc2rZHuzQmrbVoSpt9HgerkXPyIeCSO6k0zUMGfFk=
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 h1:HdqqaWmYAUI7/dmByKKEw+yxDksGSo+9GjkUc9Zp34E= golang.org/x/net v0.0.0-20190420063019-afa5a82059c6 h1:HdqqaWmYAUI7/dmByKKEw+yxDksGSo+9GjkUc9Zp34E=
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -84,6 +96,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=

View File

@ -81,32 +81,6 @@ func ValidateEmail(email string) bool {
return found return found
} }
// OAuth Methods
// Get login url
func GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, returnUrl(r))
// TODO: Support multiple providers
return config.Providers.Google.GetLoginURL(redirectUri(r), state)
}
// Exchange code for token
func ExchangeCode(r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
// TODO: Support multiple providers
return config.Providers.Google.ExchangeCode(redirectUri(r), code)
}
// Get user with token
func GetUser(token string) (provider.User, error) {
// TODO: Support multiple providers
return config.Providers.Google.GetUser(token)
}
// Utility methods // Utility methods
// Get the redirect base // Get the redirect base
@ -117,7 +91,7 @@ func redirectBase(r *http.Request) string {
return fmt.Sprintf("%s://%s", proto, host) return fmt.Sprintf("%s://%s", proto, host)
} }
// // Return url // Return url
func returnUrl(r *http.Request) string { func returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri") path := r.Header.Get("X-Forwarded-Uri")
@ -196,24 +170,35 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
} }
// Validate the csrf cookie against state // Validate the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) { func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
state := r.URL.Query().Get("state") state := r.URL.Query().Get("state")
if len(c.Value) != 32 { if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value") return false, "", "", errors.New("Invalid CSRF cookie value")
} }
if len(state) < 34 { if len(state) < 34 {
return false, "", errors.New("Invalid CSRF state value") return false, "", "", errors.New("Invalid CSRF state value")
} }
// Check nonce match // Check nonce match
if c.Value != state[:32] { if c.Value != state[:32] {
return false, "", errors.New("CSRF cookie does not match state") return false, "", "", errors.New("CSRF cookie does not match state")
} }
// Valid, return redirect // Extract provider
return true, state[33:], nil params := state[33:]
split := strings.Index(params, ":")
if split == -1 {
return false, "", "", errors.New("Invalid CSRF state format")
}
// Valid, return provider and redirect
return true, params[:split], params[split+1:], nil
}
func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
} }
func Nonce() (error, string) { func Nonce() (error, string) {
@ -282,10 +267,10 @@ func cookieExpiry() time.Time {
// Cookie Domain // Cookie Domain
type CookieDomain struct { type CookieDomain struct {
Domain string `description:"TEST1"` Domain string
DomainLen int `description:"TEST2"` DomainLen int
SubDomain string `description:"TEST3"` SubDomain string
SubDomainLen int `description:"TEST4"` SubDomainLen int
} }
func NewCookieDomain(domain string) *CookieDomain { func NewCookieDomain(domain string) *CookieDomain {

View File

@ -95,138 +95,69 @@ func TestAuthValidateEmail(t *testing.T) {
assert.True(v, "should allow user in whitelist") assert.True(v, "should allow user in whitelist")
} }
// TODO: Split google tests out func TestRedirectUri(t *testing.T) {
func TestAuthGetLoginURL(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
google := provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
}
config, _ = NewConfig([]string{})
config.Providers.Google = google
r, _ := http.NewRequest("GET", "http://example.com", nil) r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http") r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com") r.Header.Add("X-Forwarded-Host", "app.example.com")
r.Header.Add("X-Forwarded-Uri", "/hello") r.Header.Add("X-Forwarded-Uri", "/hello")
// Check url //
uri, err := url.Parse(GetLoginURL(r, "nonce")) // No Auth Host
assert.Nil(err) //
assert.Equal("https", uri.Scheme) config, _ = NewConfig([]string{})
assert.Equal("test.com", uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string uri, err := url.Parse(redirectUri(r))
qs := uri.Query() assert.Nil(err)
expectedQs := url.Values{ assert.Equal("http", uri.Scheme)
"client_id": []string{"idtest"}, assert.Equal("app.example.com", uri.Host)
"redirect_uri": []string{"http://example.com/_oauth"}, assert.Equal("/_oauth", uri.Path)
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"nonce:http://example.com/hello"},
}
assert.Equal(expectedQs, qs)
// //
// With Auth URL but no matching cookie domain // With Auth URL but no matching cookie domain
// - will not use auth host // - will not use auth host
// //
config, _ = NewConfig([]string{})
config.AuthHost = "auth.example.com" config.AuthHost = "auth.example.com"
config.Providers.Google = google
// Check url uri, err = url.Parse(redirectUri(r))
uri, err = url.Parse(GetLoginURL(r, "nonce"))
assert.Nil(err) assert.Nil(err)
assert.Equal("https", uri.Scheme) assert.Equal("http", uri.Scheme)
assert.Equal("test.com", uri.Host) assert.Equal("app.example.com", uri.Host)
assert.Equal("/auth", uri.Path) assert.Equal("/_oauth", uri.Path)
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"nonce:http://example.com/hello"},
}
assert.Equal(expectedQs, qs)
// //
// With correct Auth URL + cookie domain // With correct Auth URL + cookie domain
// //
config, _ = NewConfig([]string{})
config.AuthHost = "auth.example.com" config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
config.Providers.Google = google
// Check url // Check url
uri, err = url.Parse(GetLoginURL(r, "nonce")) uri, err = url.Parse(redirectUri(r))
assert.Nil(err) assert.Nil(err)
assert.Equal("https", uri.Scheme) assert.Equal("http", uri.Scheme)
assert.Equal("test.com", uri.Host) assert.Equal("auth.example.com", uri.Host)
assert.Equal("/auth", uri.Path) assert.Equal("/_oauth", uri.Path)
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://auth.example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"},
"prompt": []string{"consent select_account"},
}
assert.Equal(expectedQs, qs)
// //
// With Auth URL + cookie domain, but from different domain // With Auth URL + cookie domain, but from different domain
// - will not use auth host // - will not use auth host
// //
r, _ = http.NewRequest("GET", "http://another.com", nil) r, _ = http.NewRequest("GET", "http://another.com", nil)
r.Header.Add("X-Forwarded-Proto", "http") r.Header.Add("X-Forwarded-Proto", "https")
r.Header.Add("X-Forwarded-Host", "another.com") r.Header.Add("X-Forwarded-Host", "another.com")
r.Header.Add("X-Forwarded-Uri", "/hello") r.Header.Add("X-Forwarded-Uri", "/hello")
config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
// Check url // Check url
uri, err = url.Parse(GetLoginURL(r, "nonce")) uri, err = url.Parse(redirectUri(r))
assert.Nil(err) assert.Nil(err)
assert.Equal("https", uri.Scheme) assert.Equal("https", uri.Scheme)
assert.Equal("test.com", uri.Host) assert.Equal("another.com", uri.Host)
assert.Equal("/auth", uri.Path) assert.Equal("/_oauth", uri.Path)
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://another.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://another.com/hello"},
"prompt": []string{"consent select_account"},
} }
assert.Equal(expectedQs, qs)
}
// TODO
// func TestAuthExchangeCode(t *testing.T) {
// }
// TODO
// func TestAuthGetUser(t *testing.T) {
// }
func TestAuthMakeCookie(t *testing.T) { func TestAuthMakeCookie(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
@ -265,14 +196,14 @@ func TestAuthMakeCSRFCookie(t *testing.T) {
assert.Equal("app.example.com", c.Domain) assert.Equal("app.example.com", c.Domain)
// With cookie domain but no auth url // With cookie domain but no auth url
config = Config{ config = &Config{
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
} }
c = MakeCSRFCookie(r, "12345678901234567890123456789012") c = MakeCSRFCookie(r, "12345678901234567890123456789012")
assert.Equal("app.example.com", c.Domain) assert.Equal("app.example.com", c.Domain)
// With cookie domain and auth url // With cookie domain and auth url
config = Config{ config = &Config{
AuthHost: "auth.example.com", AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
} }
@ -304,13 +235,13 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
// Should require 32 char string // Should require 32 char string
r := newCsrfRequest("") r := newCsrfRequest("")
c.Value = "" c.Value = ""
valid, _, err := ValidateCSRFCookie(r, c) valid, _, _, err := ValidateCSRFCookie(r, c)
assert.False(valid) assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error()) assert.Equal("Invalid CSRF cookie value", err.Error())
} }
c.Value = "123456789012345678901234567890123" c.Value = "123456789012345678901234567890123"
valid, _, err = ValidateCSRFCookie(r, c) valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid) assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error()) assert.Equal("Invalid CSRF cookie value", err.Error())
@ -319,19 +250,48 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
// Should require valid state // Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:") r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, _, err = ValidateCSRFCookie(r, c) valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid) assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error()) assert.Equal("Invalid CSRF state value", err.Error())
} }
// Should allow valid state // Should require provider
r = newCsrfRequest("12345678901234567890123456789012:99") r = newCsrfRequest("12345678901234567890123456789012:99")
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, state, err := ValidateCSRFCookie(r, c) valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state format", err.Error())
}
// Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
c.Value = "12345678901234567890123456789012"
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
assert.True(valid, "valid request should return valid") assert.True(valid, "valid request should return valid")
assert.Nil(err, "valid request should not return an error") assert.Nil(err, "valid request should not return an error")
assert.Equal("99", state, "valid request should return correct state") assert.Equal("p99", provider, "valid request should return correct provider")
assert.Equal("url123", redirect, "valid request should return correct redirect")
}
func TestMakeState(t *testing.T) {
assert := assert.New(t)
r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
// Test with google
p := provider.Google{}
state := MakeState(r, &p, "nonce")
assert.Equal("nonce:google:http://example.com/hello", state)
// Test with OIDC
p2 := provider.OIDC{}
state = MakeState(r, &p2, "nonce")
assert.Equal("nonce:oidc:http://example.com/hello", state)
} }
func TestAuthNonce(t *testing.T) { func TestAuthNonce(t *testing.T) {
@ -356,6 +316,8 @@ func TestAuthCookieDomainMatch(t *testing.T) {
// Subdomain should match // Subdomain should match
assert.True(cd.Match("test.example.com"), "subdomain should match") assert.True(cd.Match("test.example.com"), "subdomain should match")
assert.True(cd.Match("twolevels.test.example.com"), "subdomain should match")
assert.True(cd.Match("many.many.levels.test.example.com"), "subdomain should match")
// Derived domain should not match // Derived domain should not match
assert.False(cd.Match("testexample.com"), "derived domain should not match") assert.False(cd.Match("testexample.com"), "derived domain should not match")

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/url"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
@ -18,7 +17,7 @@ import (
"github.com/thomseddon/traefik-forward-auth/internal/provider" "github.com/thomseddon/traefik-forward-auth/internal/provider"
) )
var config Config var config *Config
type Config struct { type Config struct {
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"` LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
@ -31,6 +30,7 @@ type Config struct {
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"` CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"` CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"` DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"`
DefaultProvider string `long:"default-provider" env:"DEFAULT_PROVIDER" default:"google" choice:"google" choice:"oidc" description:"Default provider"`
Domains CommaSeparatedList `long:"domain" env:"DOMAIN" description:"Only allow given email domains, can be set multiple times"` Domains CommaSeparatedList `long:"domain" env:"DOMAIN" description:"Only allow given email domains, can be set multiple times"`
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"` LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"` Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"`
@ -53,7 +53,7 @@ type Config struct {
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""` PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
} }
func NewGlobalConfig() Config { func NewGlobalConfig() *Config {
var err error var err error
config, err = NewConfig(os.Args[1:]) config, err = NewConfig(os.Args[1:])
if err != nil { if err != nil {
@ -64,29 +64,11 @@ func NewGlobalConfig() Config {
return config return config
} }
func NewConfig(args []string) (Config, error) { // TODO: move config parsing into new func "NewParsedConfig"
c := Config{
func NewConfig(args []string) (*Config, error) {
c := &Config{
Rules: map[string]*Rule{}, Rules: map[string]*Rule{},
Providers: provider.Providers{
Google: provider.Google{
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
LoginURL: &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
},
TokenURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
},
UserURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
},
},
},
} }
err := c.parseFlags(args) err := c.parseFlags(args)
@ -97,13 +79,23 @@ func NewConfig(args []string) (Config, error) {
// TODO: as log flags have now been parsed maybe we should return here so // TODO: as log flags have now been parsed maybe we should return here so
// any further errors can be logged via logrus instead of printed? // any further errors can be logged via logrus instead of printed?
// TODO: Rename "Validate" method to "Setup" and move all below logic
// Setup
// Set default provider on any rules where it's not specified
for _, rule := range c.Rules {
if rule.Provider == "" {
rule.Provider = c.DefaultProvider
}
}
// Backwards compatability // Backwards compatability
if c.CookieSecretLegacy != "" && c.SecretString == "" { if c.CookieSecretLegacy != "" && c.SecretString == "" {
fmt.Println("cookie-secret config option is deprecated, please use secret") fmt.Println("cookie-secret config option is deprecated, please use secret")
c.SecretString = c.CookieSecretLegacy c.SecretString = c.CookieSecretLegacy
} }
if c.ClientIdLegacy != "" { if c.ClientIdLegacy != "" {
c.Providers.Google.ClientId = c.ClientIdLegacy c.Providers.Google.ClientID = c.ClientIdLegacy
} }
if c.ClientSecretLegacy != "" { if c.ClientSecretLegacy != "" {
c.Providers.Google.ClientSecret = c.ClientSecretLegacy c.Providers.Google.ClientSecret = c.ClientSecretLegacy
@ -247,16 +239,21 @@ func convertLegacyToIni(name string) (io.Reader, error) {
func (c *Config) Validate() { func (c *Config) Validate() {
// Check for show stopper errors // Check for show stopper errors
if len(c.Secret) == 0 { if len(c.Secret) == 0 {
log.Fatal("\"secret\" option must be set.") log.Fatal("\"secret\" option must be set")
} }
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" { // Setup default provider
log.Fatal("providers.google.client-id, providers.google.client-secret must be set") err := c.setupProvider(c.DefaultProvider)
if err != nil {
log.Fatal(err)
} }
// Check rules // Check rules (validates the rule and the rule provider)
for _, rule := range c.Rules { for _, rule := range c.Rules {
rule.Validate() err = rule.Validate(c)
if err != nil {
log.Fatal(err)
}
} }
} }
@ -265,6 +262,61 @@ func (c Config) String() string {
return string(jsonConf) return string(jsonConf)
} }
// GetProvider returns the provider of the given name
func (c *Config) GetProvider(name string) (provider.Provider, error) {
switch name {
case "google":
return &c.Providers.Google, nil
case "oidc":
return &c.Providers.OIDC, nil
}
return nil, fmt.Errorf("Unknown provider: %s", name)
}
// GetConfiguredProvider returns the provider of the given name, if it has been
// configured. Returns an error if the provider is unknown, or hasn't been configured
func (c *Config) GetConfiguredProvider(name string) (provider.Provider, error) {
// Check the provider has been configured
if !c.providerConfigured(name) {
return nil, fmt.Errorf("Unconfigured provider: %s", name)
}
return c.GetProvider(name)
}
func (c *Config) providerConfigured(name string) bool {
// Check default provider
if name == c.DefaultProvider {
return true
}
// Check rule providers
for _, rule := range c.Rules {
if name == rule.Provider {
return true
}
}
return false
}
func (c *Config) setupProvider(name string) error {
// Check provider exists
p, err := c.GetProvider(name)
if err != nil {
return err
}
// Setup
err = p.Setup()
if err != nil {
return err
}
return nil
}
type Rule struct { type Rule struct {
Action string Action string
Rule string Rule string
@ -274,7 +326,6 @@ type Rule struct {
func NewRule() *Rule { func NewRule() *Rule {
return &Rule{ return &Rule{
Action: "auth", Action: "auth",
Provider: "google", // TODO: Use default provider
} }
} }
@ -284,15 +335,12 @@ func (r *Rule) formattedRule() string {
return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(") return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(")
} }
func (r *Rule) Validate() { func (r *Rule) Validate(c *Config) error {
if r.Action != "auth" && r.Action != "allow" { if r.Action != "auth" && r.Action != "allow" {
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"") return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
} }
// TODO: Update with more provider support return c.setupProvider(r.Provider)
if r.Provider != "google" {
log.Fatal("invalid rule provider, must be \"google\"")
}
} }
// Legacy support for comma separated lists // Legacy support for comma separated lists

View File

@ -1,11 +1,13 @@
package tfa package tfa
import ( import (
"net/url" // "fmt"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -28,34 +30,11 @@ func TestConfigDefaults(t *testing.T) {
assert.Equal("_forward_auth", c.CookieName) assert.Equal("_forward_auth", c.CookieName)
assert.Equal("_forward_auth_csrf", c.CSRFCookieName) assert.Equal("_forward_auth_csrf", c.CSRFCookieName)
assert.Equal("auth", c.DefaultAction) assert.Equal("auth", c.DefaultAction)
assert.Equal("google", c.DefaultProvider)
assert.Len(c.Domains, 0) assert.Len(c.Domains, 0)
assert.Equal(time.Second*time.Duration(43200), c.Lifetime) assert.Equal(time.Second*time.Duration(43200), c.Lifetime)
assert.Equal("/_oauth", c.Path) assert.Equal("/_oauth", c.Path)
assert.Len(c.Whitelist, 0) assert.Len(c.Whitelist, 0)
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", c.Providers.Google.Scope)
assert.Equal("", c.Providers.Google.Prompt)
loginURL := &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
assert.Equal(loginURL, c.Providers.Google.LoginURL)
tokenURL := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
assert.Equal(tokenURL, c.Providers.Google.TokenURL)
userURL := &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
assert.Equal(userURL, c.Providers.Google.UserURL)
} }
func TestConfigParseArgs(t *testing.T) { func TestConfigParseArgs(t *testing.T) {
@ -63,6 +42,7 @@ func TestConfigParseArgs(t *testing.T) {
c, err := NewConfig([]string{ c, err := NewConfig([]string{
"--cookie-name=cookiename", "--cookie-name=cookiename",
"--csrf-cookie-name", "\"csrfcookiename\"", "--csrf-cookie-name", "\"csrfcookiename\"",
"--default-provider", "\"oidc\"",
"--rule.1.action=allow", "--rule.1.action=allow",
"--rule.1.rule=PathPrefix(`/one`)", "--rule.1.rule=PathPrefix(`/one`)",
"--rule.two.action=auth", "--rule.two.action=auth",
@ -73,18 +53,19 @@ func TestConfigParseArgs(t *testing.T) {
// Check normal flags // Check normal flags
assert.Equal("cookiename", c.CookieName) assert.Equal("cookiename", c.CookieName)
assert.Equal("csrfcookiename", c.CSRFCookieName) assert.Equal("csrfcookiename", c.CSRFCookieName)
assert.Equal("oidc", c.DefaultProvider)
// Check rules // Check rules
assert.Equal(map[string]*Rule{ assert.Equal(map[string]*Rule{
"1": { "1": {
Action: "allow", Action: "allow",
Rule: "PathPrefix(`/one`)", Rule: "PathPrefix(`/one`)",
Provider: "google", Provider: "oidc",
}, },
"two": { "two": {
Action: "auth", Action: "auth",
Rule: "Host(`two.com`) && Path(`/two`)", Rule: "Host(`two.com`) && Path(`/two`)",
Provider: "google", Provider: "oidc",
}, },
}, c.Rules) }, c.Rules)
} }
@ -157,7 +138,7 @@ func TestConfigFlagBackwardsCompatability(t *testing.T) {
// Google provider params used to be top level // Google provider params used to be top level
assert.Equal("clientid", c.ClientIdLegacy) assert.Equal("clientid", c.ClientIdLegacy)
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id") assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
assert.Equal("verysecret", c.ClientSecretLegacy) assert.Equal("verysecret", c.ClientSecretLegacy)
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret") assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
assert.Equal("prompt", c.PromptLegacy) assert.Equal("prompt", c.PromptLegacy)
@ -220,7 +201,7 @@ func TestConfigParseEnvironment(t *testing.T) {
assert.Nil(err) assert.Nil(err)
assert.Equal("env_cookie_name", c.CookieName, "variable should be read from environment") assert.Equal("env_cookie_name", c.CookieName, "variable should be read from environment")
assert.Equal("env_client_id", c.Providers.Google.ClientId, "namespace variable should be read from environment") assert.Equal("env_client_id", c.Providers.Google.ClientID, "namespace variable should be read from environment")
os.Unsetenv("COOKIE_NAME") os.Unsetenv("COOKIE_NAME")
os.Unsetenv("PROVIDERS_GOOGLE_CLIENT_ID") os.Unsetenv("PROVIDERS_GOOGLE_CLIENT_ID")
@ -265,7 +246,7 @@ func TestConfigParseEnvironmentBackwardsCompatability(t *testing.T) {
// Google provider params used to be top level // Google provider params used to be top level
assert.Equal("clientid", c.ClientIdLegacy) assert.Equal("clientid", c.ClientIdLegacy)
assert.Equal("clientid", c.Providers.Google.ClientId, "--client-id should set providers.google.client-id") assert.Equal("clientid", c.Providers.Google.ClientID, "--client-id should set providers.google.client-id")
assert.Equal("verysecret", c.ClientSecretLegacy) assert.Equal("verysecret", c.ClientSecretLegacy)
assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret") assert.Equal("verysecret", c.Providers.Google.ClientSecret, "--client-secret should set providers.google.client-secret")
assert.Equal("prompt", c.PromptLegacy) assert.Equal("prompt", c.PromptLegacy)
@ -305,6 +286,92 @@ func TestConfigTransformation(t *testing.T) {
assert.Equal(time.Second*time.Duration(200), c.Lifetime, "lifetime should be read and converted to duration") assert.Equal(time.Second*time.Duration(200), c.Lifetime, "lifetime should be read and converted to duration")
} }
func TestConfigValidate(t *testing.T) {
assert := assert.New(t)
// Install new logger + hook
var hook *test.Hook
log, hook = test.NewNullLogger()
log.ExitFunc = func(code int) {}
// Validate defualt config + rule error
c, _ := NewConfig([]string{
"--rule.1.action=bad",
})
c.Validate()
logs := hook.AllEntries()
assert.Len(logs, 3)
// Should have fatal error requiring secret
assert.Equal("\"secret\" option must be set", logs[0].Message)
assert.Equal(logrus.FatalLevel, logs[0].Level)
// Should also have default provider (google) error
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", logs[1].Message)
assert.Equal(logrus.FatalLevel, logs[1].Level)
// Should validate rule
assert.Equal("invalid rule action, must be \"auth\" or \"allow\"", logs[2].Message)
assert.Equal(logrus.FatalLevel, logs[2].Level)
hook.Reset()
// Validate with invalid providers
c, _ = NewConfig([]string{
"--secret=veryverysecret",
"--providers.google.client-id=id",
"--providers.google.client-secret=secret",
"--rule.1.action=auth",
"--rule.1.provider=bad2",
})
c.Validate()
logs = hook.AllEntries()
assert.Len(logs, 1)
// Should have error for rule provider
assert.Equal("Unknown provider: bad2", logs[0].Message)
assert.Equal(logrus.FatalLevel, logs[0].Level)
}
func TestConfigGetProvider(t *testing.T) {
assert := assert.New(t)
c, _ := NewConfig([]string{})
// Should be able to get "google" provider
p, err := c.GetProvider("google")
assert.Nil(err)
assert.Equal(&c.Providers.Google, p)
// Should be able to get "oidc" provider
p, err = c.GetProvider("oidc")
assert.Nil(err)
assert.Equal(&c.Providers.OIDC, p)
// Should catch unknown provider
p, err = c.GetProvider("bad")
if assert.Error(err) {
assert.Equal("Unknown provider: bad", err.Error())
}
}
func TestConfigGetConfiguredProvider(t *testing.T) {
assert := assert.New(t)
c, _ := NewConfig([]string{})
// Should be able to get "google" default provider
p, err := c.GetConfiguredProvider("google")
assert.Nil(err)
assert.Equal(&c.Providers.Google, p)
// Should fail to get valid "oidc" provider as it's not configured
p, err = c.GetConfiguredProvider("oidc")
if assert.Error(err) {
assert.Equal("Unconfigured provider: oidc", err.Error())
}
}
func TestConfigCommaSeparatedList(t *testing.T) { func TestConfigCommaSeparatedList(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
list := CommaSeparatedList{} list := CommaSeparatedList{}

View File

@ -6,9 +6,9 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
var log logrus.FieldLogger var log *logrus.Logger
func NewDefaultLogger() logrus.FieldLogger { func NewDefaultLogger() *logrus.Logger {
// Setup logger // Setup logger
log = logrus.StandardLogger() log = logrus.StandardLogger()
logrus.SetOutput(os.Stdout) logrus.SetOutput(os.Stdout)

View File

@ -2,13 +2,15 @@ package provider
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
) )
// Google provider
type Google struct { type Google struct {
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"` ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"` ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
Scope string Scope string
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"` Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
@ -18,15 +20,48 @@ type Google struct {
UserURL *url.URL UserURL *url.URL
} }
func (g *Google) GetLoginURL(redirectUri, state string) string { // Name returns the name of the provider
func (g *Google) Name() string {
return "google"
}
// Setup performs validation and setup
func (g *Google) Setup() error {
if g.ClientID == "" || g.ClientSecret == "" {
return errors.New("providers.google.client-id, providers.google.client-secret must be set")
}
// Set static values
g.Scope = "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email"
g.LoginURL = &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
g.TokenURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
g.UserURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
return nil
}
// GetLoginURL provides the login url for the given redirect uri and state
func (g *Google) GetLoginURL(redirectURI, state string) string {
q := url.Values{} q := url.Values{}
q.Set("client_id", g.ClientId) q.Set("client_id", g.ClientID)
q.Set("response_type", "code") q.Set("response_type", "code")
q.Set("scope", g.Scope) q.Set("scope", g.Scope)
if g.Prompt != "" { if g.Prompt != "" {
q.Set("prompt", g.Prompt) q.Set("prompt", g.Prompt)
} }
q.Set("redirect_uri", redirectUri) q.Set("redirect_uri", redirectURI)
q.Set("state", state) q.Set("state", state)
var u url.URL var u url.URL
@ -36,12 +71,13 @@ func (g *Google) GetLoginURL(redirectUri, state string) string {
return u.String() return u.String()
} }
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) { // ExchangeCode exchanges the given redirect uri and code for a token
func (g *Google) ExchangeCode(redirectURI, code string) (string, error) {
form := url.Values{} form := url.Values{}
form.Set("client_id", g.ClientId) form.Set("client_id", g.ClientID)
form.Set("client_secret", g.ClientSecret) form.Set("client_secret", g.ClientSecret)
form.Set("grant_type", "authorization_code") form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", redirectUri) form.Set("redirect_uri", redirectURI)
form.Set("code", code) form.Set("code", code)
res, err := http.PostForm(g.TokenURL.String(), form) res, err := http.PostForm(g.TokenURL.String(), form)
@ -49,13 +85,14 @@ func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
return "", err return "", err
} }
var token Token var token token
defer res.Body.Close() defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token) err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err return token.Token, err
} }
// GetUser uses the given token and returns a complete provider.User object
func (g *Google) GetUser(token string) (User, error) { func (g *Google) GetUser(token string) (User, error) {
var user User var user User

View File

@ -0,0 +1,151 @@
package provider
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
// Tests
func TestGoogleName(t *testing.T) {
p := Google{}
assert.Equal(t, "google", p.Name())
}
func TestGoogleSetup(t *testing.T) {
assert := assert.New(t)
p := Google{}
// Check validation
err := p.Setup()
if assert.Error(err) {
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", err.Error())
}
// Check setup
p = Google{
ClientID: "id",
ClientSecret: "secret",
}
err = p.Setup()
assert.Nil(err)
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", p.Scope)
assert.Equal("", p.Prompt)
assert.Equal(&url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}, p.LoginURL)
assert.Equal(&url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}, p.TokenURL)
assert.Equal(&url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}, p.UserURL)
}
func TestGoogleGetLoginURL(t *testing.T) {
assert := assert.New(t)
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "google.com",
Path: "/auth",
},
}
// Check url
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("google.com", uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"state"},
}
assert.Equal(expectedQs, qs)
}
func TestGoogleExchangeCode(t *testing.T) {
assert := assert.New(t)
// Setup server
expected := url.Values{
"client_id": []string{"idtest"},
"client_secret": []string{"sectest"},
"code": []string{"code"},
"grant_type": []string{"authorization_code"},
"redirect_uri": []string{"http://example.com/_oauth"},
}
server, serverURL := NewOAuthServer(t, map[string]string{
"token": expected.Encode(),
})
defer server.Close()
// Setup provider
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
TokenURL: &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/token",
},
}
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("123456789", token)
}
func TestGoogleGetUser(t *testing.T) {
assert := assert.New(t)
// Setup server
server, serverURL := NewOAuthServer(t, nil)
defer server.Close()
// Setup provider
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
UserURL: &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/userinfo",
},
}
user, err := p.GetUser("123456789")
assert.Nil(err)
assert.Equal("1", user.ID)
assert.Equal("example@example.com", user.Email)
assert.True(user.Verified)
assert.Equal("example.com", user.Hd)
}

108
internal/provider/oidc.go Normal file
View File

@ -0,0 +1,108 @@
package provider
import (
"context"
"errors"
"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
)
// OIDC provider
type OIDC struct {
OAuthProvider
IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"`
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
}
// Name returns the name of the provider
func (o *OIDC) Name() string {
return "oidc"
}
// Setup performs validation and setup
func (o *OIDC) Setup() error {
// Check parms
if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" {
return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set")
}
var err error
o.ctx = context.Background()
// Try to initiate provider
o.provider, err = oidc.NewProvider(o.ctx, o.IssuerURL)
if err != nil {
return err
}
// Create oauth2 config
o.Config = &oauth2.Config{
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
Endpoint: o.provider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
// Create OIDC verifier
o.verifier = o.provider.Verifier(&oidc.Config{
ClientID: o.ClientID,
})
return nil
}
// GetLoginURL provides the login url for the given redirect uri and state
func (o *OIDC) GetLoginURL(redirectURI, state string) string {
return o.OAuthGetLoginURL(redirectURI, state)
}
// ExchangeCode exchanges the given redirect uri and code for a token
func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
token, err := o.OAuthExchangeCode(redirectURI, code)
if err != nil {
return "", err
}
// Extract ID token
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return "", errors.New("Missing id_token")
}
return rawIDToken, nil
}
// GetUser uses the given token and returns a complete provider.User object
func (o *OIDC) GetUser(token string) (User, error) {
var user User
// Parse & Verify ID Token
idToken, err := o.verifier.Verify(o.ctx, token)
if err != nil {
return user, err
}
// Extract custom claims
var claims struct {
ID string `json:"sub"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
return user, err
}
user.ID = claims.ID
user.Email = claims.Email
user.Verified = claims.Verified
return user, nil
}

View File

@ -0,0 +1,252 @@
package provider
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
jose "gopkg.in/square/go-jose.v2"
)
// Tests
func TestOIDCName(t *testing.T) {
p := OIDC{}
assert.Equal(t, "oidc", p.Name())
}
func TestOIDCSetup(t *testing.T) {
assert := assert.New(t)
p := OIDC{}
err := p.Setup()
if assert.Error(err) {
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
}
}
func TestOIDCGetLoginURL(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, _ := setupOIDCTest(t, nil)
defer server.Close()
// Check url
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal(serverURL.Scheme, uri.Scheme)
assert.Equal(serverURL.Host, uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"openid profile email"},
"state": []string{"state"},
}
assert.Equal(expectedQs, qs)
// Calling the method should not modify the underlying config
assert.Equal("", provider.Config.RedirectURL)
}
func TestOIDCExchangeCode(t *testing.T) {
assert := assert.New(t)
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
"token": {
"code": "code",
"grant_type": "authorization_code",
"redirect_uri": "http://example.com/_oauth",
},
})
defer server.Close()
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("id_123456789", token)
}
func TestOIDCGetUser(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, key := setupOIDCTest(t, nil)
defer server.Close()
// Generate JWT
token := key.sign(t, []byte(`{
"iss": "`+serverURL.String()+`",
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
"aud": "idtest",
"sub": "1",
"email": "example@example.com",
"email_verified": true
}`))
// Get user
user, err := provider.GetUser(token)
assert.Nil(err)
assert.Equal("1", user.ID)
assert.Equal("example@example.com", user.Email)
assert.True(user.Verified)
}
// Utils
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
// Generate key
key, err := newRSAKey()
if err != nil {
t.Fatal(err)
}
body := make(map[string]string)
if bodyValues != nil {
// URL encode bodyValues into body
for method, values := range bodyValues {
q := url.Values{}
for k, v := range values {
q.Set(k, v)
}
body[method] = q.Encode()
}
}
// Set up oidc server
server, serverURL := NewOIDCServer(t, key, body)
// Setup provider
p := OIDC{
ClientID: "idtest",
ClientSecret: "sectest",
IssuerURL: serverURL.String(),
}
// Initialise config/verifier
err = p.Setup()
if err != nil {
t.Fatal(err)
}
return &p, server, serverURL, key
}
// OIDCServer is used in the OIDC Tests to mock an OIDC server
type OIDCServer struct {
t *testing.T
url *url.URL
body map[string]string // method -> body
key *rsaKey
}
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
handler := &OIDCServer{t: t, key: key, body: body}
server := httptest.NewServer(handler)
handler.url, _ = url.Parse(server.URL)
return server, handler.url
}
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
if r.URL.Path == "/.well-known/openid-configuration" {
// Open id config
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"issuer":"`+s.url.String()+`",
"authorization_endpoint":"`+s.url.String()+`/auth",
"token_endpoint":"`+s.url.String()+`/token",
"jwks_uri":"`+s.url.String()+`/jwks"
}`)
} else if r.URL.Path == "/token" {
// Token request
// Check body
if b, ok := s.body["token"]; ok {
if b != string(body) {
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
}
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"access_token":"123456789",
"id_token":"id_123456789"
}`)
} else if r.URL.Path == "/jwks" {
// Key request
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
} else {
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
}
}
// rsaKey is used in the OIDCServer tests to sign and verify requests
type rsaKey struct {
key *rsa.PrivateKey
alg jose.SignatureAlgorithm
jwkPub *jose.JSONWebKey
jwkPriv *jose.JSONWebKey
}
func newRSAKey() (*rsaKey, error) {
key, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
return nil, err
}
return &rsaKey{
key: key,
alg: jose.RS256,
jwkPub: &jose.JSONWebKey{
Key: key.Public(),
Algorithm: string(jose.RS256),
},
jwkPriv: &jose.JSONWebKey{
Key: key,
Algorithm: string(jose.RS256),
},
}, nil
}
func (k *rsaKey) publicJWK(t *testing.T) string {
b, err := k.jwkPub.MarshalJSON()
if err != nil {
t.Fatal(err)
}
return string(b)
}
// sign creates a JWS using the private key from the provided payload.
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: k.alg,
Key: k.key,
}, nil)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(payload)
if err != nil {
t.Fatal(err)
}
data, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return data
}

View File

@ -1,16 +1,61 @@
package provider package provider
import (
"context"
// "net/url"
"golang.org/x/oauth2"
)
// Providers contains all the implemented providers
type Providers struct { type Providers struct {
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"` Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
OIDC OIDC `group:"OIDC Provider" namespace:"oidc" env-namespace:"OIDC"`
} }
type Token struct { // Provider is used to authenticate users
type Provider interface {
Name() string
GetLoginURL(redirectURI, state string) string
ExchangeCode(redirectURI, code string) (string, error)
GetUser(token string) (User, error)
Setup() error
}
type token struct {
Token string `json:"access_token"` Token string `json:"access_token"`
} }
// User is the authenticated user
type User struct { type User struct {
Id string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Verified bool `json:"verified_email"` Verified bool `json:"verified_email"`
Hd string `json:"hd"` Hd string `json:"hd"`
} }
// OAuthProvider is a provider using the oauth2 library
type OAuthProvider struct {
Config *oauth2.Config
ctx context.Context
}
// ConfigCopy returns a copy of the oauth2 config with the given redirectURI
// which ensures the underlying config is not modified
func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config {
config := *p.Config
config.RedirectURL = redirectURI
return config
}
// OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2
func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string) string {
config := p.ConfigCopy(redirectURI)
return config.AuthCodeURL(state)
}
// OAuthExchangeCode provides a base "ExchangeCode" for proiders using OAauth2
func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string) (*oauth2.Token, error) {
config := p.ConfigCopy(redirectURI)
return config.Exchange(p.ctx, code)
}

View File

@ -0,0 +1,48 @@
package provider
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
// Utilities
type OAuthServer struct {
t *testing.T
url *url.URL
body map[string]string // method -> body
}
func NewOAuthServer(t *testing.T, body map[string]string) (*httptest.Server, *url.URL) {
handler := &OAuthServer{t: t, body: body}
server := httptest.NewServer(handler)
handler.url, _ = url.Parse(server.URL)
return server, handler.url
}
func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
// fmt.Println("Got request:", r.URL, r.Method, string(body))
if r.Method == "POST" && r.URL.Path == "/token" {
if s.body["token"] != string(body) {
s.t.Fatal("Unexpected request body, expected", s.body["token"], "got", string(body))
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"access_token":"123456789"}`)
} else if r.Method == "GET" && r.URL.Path == "/userinfo" {
fmt.Fprint(w, `{
"id":"1",
"email":"example@example.com",
"verified_email":true,
"hd":"example.com"
}`)
} else {
s.t.Fatal("Unrecognised request: ", r.Method, r.URL, string(body))
}
}

View File

@ -6,6 +6,7 @@ import (
"github.com/containous/traefik/pkg/rules" "github.com/containous/traefik/pkg/rules"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
) )
type Server struct { type Server struct {
@ -27,10 +28,11 @@ func (s *Server) buildRoutes() {
// Let's build a router // Let's build a router
for name, rule := range config.Rules { for name, rule := range config.Rules {
matchRule := rule.formattedRule()
if rule.Action == "allow" { if rule.Action == "allow" {
s.router.AddRoute(rule.formattedRule(), 1, s.AllowHandler(name)) s.router.AddRoute(matchRule, 1, s.AllowHandler(name))
} else { } else {
s.router.AddRoute(rule.formattedRule(), 1, s.AuthHandler(name)) s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
} }
} }
@ -41,7 +43,7 @@ func (s *Server) buildRoutes() {
if config.DefaultAction == "allow" { if config.DefaultAction == "allow" {
s.router.NewRoute().Handler(s.AllowHandler("default")) s.router.NewRoute().Handler(s.AllowHandler("default"))
} else { } else {
s.router.NewRoute().Handler(s.AuthHandler("default")) s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
} }
} }
@ -64,7 +66,9 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc {
} }
// Authenticate requests // Authenticate requests
func (s *Server) AuthHandler(rule string) http.HandlerFunc { func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
p, _ := config.GetConfiguredProvider(providerName)
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// Logging setup // Logging setup
logger := s.logger(r, rule, "Authenticating request") logger := s.logger(r, rule, "Authenticating request")
@ -72,7 +76,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
// Get auth cookie // Get auth cookie
c, err := r.Cookie(config.CookieName) c, err := r.Cookie(config.CookieName)
if err != nil { if err != nil {
s.authRedirect(logger, w, r) s.authRedirect(logger, w, r, p)
return return
} }
@ -81,7 +85,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
if err != nil { if err != nil {
if err.Error() == "Cookie has expired" { if err.Error() == "Cookie has expired" {
logger.Info("Cookie has expired") logger.Info("Cookie has expired")
s.authRedirect(logger, w, r) s.authRedirect(logger, w, r, p)
} else { } else {
logger.Errorf("Invalid cookie: %v", err) logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
@ -121,18 +125,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
} }
// Validate state // Validate state
valid, redirect, err := ValidateCSRFCookie(r, c) valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
if !valid { if !valid {
logger.Warnf("Error validating csrf cookie: %v", err) logger.Warnf("Error validating csrf cookie: %v", err)
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Get provider
p, err := config.GetConfiguredProvider(providerName)
if err != nil {
logger.Warnf("Invalid provider in csrf cookie: %s, %v", providerName, err)
http.Error(w, "Not authorized", 401)
return
}
// Clear CSRF cookie // Clear CSRF cookie
http.SetCookie(w, ClearCSRFCookie(r)) http.SetCookie(w, ClearCSRFCookie(r))
// Exchange code for token // Exchange code for token
token, err := ExchangeCode(r) token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
if err != nil { if err != nil {
logger.Errorf("Code exchange failed with: %v", err) logger.Errorf("Code exchange failed with: %v", err)
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
@ -140,7 +152,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
} }
// Get user // Get user
user, err := GetUser(token) user, err := p.GetUser(token)
if err != nil { if err != nil {
logger.Errorf("Error getting user: %s", err) logger.Errorf("Error getting user: %s", err)
return return
@ -157,7 +169,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
} }
} }
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) { func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request, p provider.Provider) {
// Error indicates no cookie, generate nonce // Error indicates no cookie, generate nonce
err, nonce := Nonce() err, nonce := Nonce()
if err != nil { if err != nil {
@ -171,7 +183,8 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht
logger.Debug("Set CSRF cookie and redirecting to google login") logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on // Forward them on
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect) loginUrl := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce))
http.Redirect(w, r, loginUrl, http.StatusTemporaryRedirect)
logger.Debug("Done") logger.Debug("Done")
return return

View File

@ -11,15 +11,16 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
) )
// TODO:
/** /**
* Setup * Setup
*/ */
func init() { func init() {
config = newDefaultConfig()
config.LogLevel = "panic" config.LogLevel = "panic"
log = NewDefaultLogger() log = NewDefaultLogger()
} }
@ -30,7 +31,7 @@ func init() {
func TestServerAuthHandlerInvalid(t *testing.T) { func TestServerAuthHandlerInvalid(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
// Should redirect vanilla request to login url // Should redirect vanilla request to login url
req := newDefaultHttpRequest("/foo") req := newDefaultHttpRequest("/foo")
@ -42,10 +43,20 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google") assert.Equal("accounts.google.com", fwd.Host, "vanilla request should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google") assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
// Check state string
qs := fwd.Query()
state, exists := qs["state"]
require.True(t, exists)
require.Len(t, state, 1)
parts := strings.SplitN(state[0], ":", 3)
require.Len(t, parts, 3)
assert.Equal("google", parts[1])
assert.Equal("http://example.com/foo", parts[2])
// Should catch invalid cookie // Should catch invalid cookie
req = newDefaultHttpRequest("/foo") req = newDefaultHttpRequest("/foo")
c := MakeCookie(req, "test@example.com") c := MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|") parts = strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2]) c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = doHttpRequest(req, c) res, _ = doHttpRequest(req, c)
@ -62,7 +73,7 @@ func TestServerAuthHandlerInvalid(t *testing.T) {
func TestServerAuthHandlerExpired(t *testing.T) { func TestServerAuthHandlerExpired(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
config.Lifetime = time.Second * time.Duration(-1) config.Lifetime = time.Second * time.Duration(-1)
config.Domains = []string{"test.com"} config.Domains = []string{"test.com"}
@ -90,7 +101,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
func TestServerAuthHandlerValid(t *testing.T) { func TestServerAuthHandlerValid(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
// Should allow valid request email // Should allow valid request email
req := newDefaultHttpRequest("/foo") req := newDefaultHttpRequest("/foo")
@ -108,7 +119,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
func TestServerAuthCallback(t *testing.T) { func TestServerAuthCallback(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
// Setup token server // Setup token server
tokenServerHandler := &TokenServerHandler{} tokenServerHandler := &TokenServerHandler{}
@ -136,7 +147,7 @@ func TestServerAuthCallback(t *testing.T) {
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised") assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
// Should redirect valid request // Should redirect valid request
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012") c = MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = doHttpRequest(req, c) res, _ = doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed") assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
@ -149,7 +160,7 @@ func TestServerAuthCallback(t *testing.T) {
func TestServerDefaultAction(t *testing.T) { func TestServerDefaultAction(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
req := newDefaultHttpRequest("/random") req := newDefaultHttpRequest("/random")
res, _ := doHttpRequest(req, nil) res, _ := doHttpRequest(req, nil)
@ -161,9 +172,36 @@ func TestServerDefaultAction(t *testing.T) {
assert.Equal(200, res.StatusCode, "request should be allowed with default handler") assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
} }
func TestServerDefaultProvider(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
// Should use "google" as default provider when not specified
req := newDefaultHttpRequest("/random")
res, _ := doHttpRequest(req, nil)
fwd, _ := res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
// Should use alternative default provider when set
config.DefaultProvider = "oidc"
config.Providers.OIDC.OAuthProvider.Config = &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: "https://oidc.com/oidcauth",
},
}
res, _ = doHttpRequest(req, nil)
fwd, _ = res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to oidc")
assert.Equal("oidc.com", fwd.Host, "request with expired cookie should be redirected to oidc")
assert.Equal("/oidcauth", fwd.Path, "request with expired cookie should be redirected to oidc")
}
func TestServerRouteHeaders(t *testing.T) { func TestServerRouteHeaders(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
config.Rules = map[string]*Rule{ config.Rules = map[string]*Rule{
"1": { "1": {
Action: "allow", Action: "allow",
@ -196,7 +234,7 @@ func TestServerRouteHeaders(t *testing.T) {
func TestServerRouteHost(t *testing.T) { func TestServerRouteHost(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
config.Rules = map[string]*Rule{ config.Rules = map[string]*Rule{
"1": { "1": {
Action: "allow", Action: "allow",
@ -226,7 +264,7 @@ func TestServerRouteHost(t *testing.T) {
func TestServerRouteMethod(t *testing.T) { func TestServerRouteMethod(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
config.Rules = map[string]*Rule{ config.Rules = map[string]*Rule{
"1": { "1": {
Action: "allow", Action: "allow",
@ -247,7 +285,7 @@ func TestServerRouteMethod(t *testing.T) {
func TestServerRoutePath(t *testing.T) { func TestServerRoutePath(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
config.Rules = map[string]*Rule{ config.Rules = map[string]*Rule{
"1": { "1": {
Action: "allow", Action: "allow",
@ -281,7 +319,7 @@ func TestServerRoutePath(t *testing.T) {
func TestServerRouteQuery(t *testing.T) { func TestServerRouteQuery(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config = newDefaultConfig()
config.Rules = map[string]*Rule{ config.Rules = map[string]*Rule{
"1": { "1": {
Action: "allow", Action: "allow",
@ -346,6 +384,18 @@ func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
return res, string(body) return res, string(body)
} }
func newDefaultConfig() *Config {
config, _ = NewConfig([]string{
"--providers.google.client-id=id",
"--providers.google.client-secret=secret",
})
// Setup the google providers without running all the config validation
config.Providers.Google.Setup()
return config
}
func newDefaultHttpRequest(uri string) *http.Request { func newDefaultHttpRequest(uri string) *http.Request {
return newHttpRequest("", "http://example.com/", uri) return newHttpRequest("", "http://example.com/", uri)
} }
@ -354,25 +404,8 @@ func newHttpRequest(method, dest, uri string) *http.Request {
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil) r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
p, _ := url.Parse(dest) p, _ := url.Parse(dest)
r.Header.Add("X-Forwarded-Method", method) r.Header.Add("X-Forwarded-Method", method)
r.Header.Add("X-Forwarded-Proto", p.Scheme)
r.Header.Add("X-Forwarded-Host", p.Host) r.Header.Add("X-Forwarded-Host", p.Host)
r.Header.Add("X-Forwarded-Uri", uri) r.Header.Add("X-Forwarded-Uri", uri)
return r return r
} }
func qsDiff(t *testing.T, one, two url.Values) []string {
errs := make([]string, 0)
for k := range one {
if two.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
}
if one.Get(k) != two.Get(k) {
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
}
}
for k := range two {
if one.Get(k) == "" {
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
}
}
return errs
}