Multiple provider support + OIDC provider
This commit is contained in:
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/containous/traefik/pkg/rules"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@ -27,10 +28,11 @@ func (s *Server) buildRoutes() {
|
||||
|
||||
// Let's build a router
|
||||
for name, rule := range config.Rules {
|
||||
matchRule := rule.formattedRule()
|
||||
if rule.Action == "allow" {
|
||||
s.router.AddRoute(rule.formattedRule(), 1, s.AllowHandler(name))
|
||||
s.router.AddRoute(matchRule, 1, s.AllowHandler(name))
|
||||
} else {
|
||||
s.router.AddRoute(rule.formattedRule(), 1, s.AuthHandler(name))
|
||||
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,7 +43,7 @@ func (s *Server) buildRoutes() {
|
||||
if config.DefaultAction == "allow" {
|
||||
s.router.NewRoute().Handler(s.AllowHandler("default"))
|
||||
} else {
|
||||
s.router.NewRoute().Handler(s.AuthHandler("default"))
|
||||
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +66,9 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Authenticate requests
|
||||
func (s *Server) AuthHandler(rule string) http.HandlerFunc {
|
||||
func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
|
||||
p, _ := config.GetConfiguredProvider(providerName)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Logging setup
|
||||
logger := s.logger(r, rule, "Authenticating request")
|
||||
@ -72,7 +76,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
|
||||
// Get auth cookie
|
||||
c, err := r.Cookie(config.CookieName)
|
||||
if err != nil {
|
||||
s.authRedirect(logger, w, r)
|
||||
s.authRedirect(logger, w, r, p)
|
||||
return
|
||||
}
|
||||
|
||||
@ -81,7 +85,7 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
|
||||
if err != nil {
|
||||
if err.Error() == "Cookie has expired" {
|
||||
logger.Info("Cookie has expired")
|
||||
s.authRedirect(logger, w, r)
|
||||
s.authRedirect(logger, w, r, p)
|
||||
} else {
|
||||
logger.Errorf("Invalid cookie: %v", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
@ -121,18 +125,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Validate state
|
||||
valid, redirect, err := ValidateCSRFCookie(r, c)
|
||||
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
|
||||
if !valid {
|
||||
logger.Warnf("Error validating csrf cookie: %v", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Get provider
|
||||
p, err := config.GetConfiguredProvider(providerName)
|
||||
if err != nil {
|
||||
logger.Warnf("Invalid provider in csrf cookie: %s, %v", providerName, err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, ClearCSRFCookie(r))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := ExchangeCode(r)
|
||||
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
logger.Errorf("Code exchange failed with: %v", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
@ -140,7 +152,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := GetUser(token)
|
||||
user, err := p.GetUser(token)
|
||||
if err != nil {
|
||||
logger.Errorf("Error getting user: %s", err)
|
||||
return
|
||||
@ -157,7 +169,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request, p provider.Provider) {
|
||||
// Error indicates no cookie, generate nonce
|
||||
err, nonce := Nonce()
|
||||
if err != nil {
|
||||
@ -171,7 +183,8 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht
|
||||
logger.Debug("Set CSRF cookie and redirecting to google login")
|
||||
|
||||
// Forward them on
|
||||
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||
loginUrl := p.GetLoginURL(redirectUri(r), MakeState(r, p, nonce))
|
||||
http.Redirect(w, r, loginUrl, http.StatusTemporaryRedirect)
|
||||
|
||||
logger.Debug("Done")
|
||||
return
|
||||
|
Reference in New Issue
Block a user