diff --git a/internal/provider/generic_oauth.go b/internal/provider/generic_oauth.go index 2b10a7c..23a4e13 100644 --- a/internal/provider/generic_oauth.go +++ b/internal/provider/generic_oauth.go @@ -69,12 +69,13 @@ func (o *GenericOAuth) ExchangeCode(redirectURI, code string) (string, error) { } // GetUser uses the given token and returns a complete provider.User object -func (o *GenericOAuth) GetUser(token string) (User, error) { +func (o *GenericOAuth) GetUser(token string) (User, Roles, error) { var user User + var roles Roles req, err := http.NewRequest("GET", o.UserURL, nil) if err != nil { - return user, err + return user, roles, err } if o.TokenStyle == "header" { @@ -88,11 +89,11 @@ func (o *GenericOAuth) GetUser(token string) (User, error) { client := &http.Client{} res, err := client.Do(req) if err != nil { - return user, err + return user, roles, err } defer res.Body.Close() err = json.NewDecoder(res.Body).Decode(&user) - return user, err + return user, roles, err } diff --git a/internal/provider/google.go b/internal/provider/google.go index f2ee6a7..94fe840 100644 --- a/internal/provider/google.go +++ b/internal/provider/google.go @@ -95,23 +95,24 @@ func (g *Google) ExchangeCode(redirectURI, code string) (string, error) { } // GetUser uses the given token and returns a complete provider.User object -func (g *Google) GetUser(token string) (User, error) { +func (g *Google) GetUser(token string) (User, Roles, error) { var user User + var roles Roles client := &http.Client{} req, err := http.NewRequest("GET", g.UserURL.String(), nil) if err != nil { - return user, err + return user, roles, err } req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) res, err := client.Do(req) if err != nil { - return user, err + return user, roles, err } defer res.Body.Close() err = json.NewDecoder(res.Body).Decode(&user) - return user, err + return user, roles, err } diff --git a/internal/provider/oidc.go b/internal/provider/oidc.go index 880b7ed..14031a1 100644 --- a/internal/provider/oidc.go +++ b/internal/provider/oidc.go @@ -3,6 +3,7 @@ package provider import ( "context" "errors" + "fmt" "github.com/coreos/go-oidc" "golang.org/x/oauth2" @@ -88,14 +89,14 @@ func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) { } // GetUser uses the given token and returns a complete provider.User object -func (o *OIDC) GetUser(token string) (User, error) { +func (o *OIDC) GetUser(token string) (User, Roles, error) { var user User var roles Roles // Parse & Verify ID Token idToken, err := o.verifier.Verify(o.ctx, token) if err != nil { - return user, err + return user, roles, err } @@ -103,16 +104,18 @@ func (o *OIDC) GetUser(token string) (User, error) { // Extract custom claims if err := idToken.Claims(&user); err != nil { - return user, err + return user, roles, err } o.log.WithField("user", user).Debug("getUser") if err := idToken.Claims(&roles); err != nil { - return user, err + return user, roles, err } o.log.WithField("roles", roles).Debug("getUser") + for i, r := range roles.Roles { + o.log.Debug(fmt.Sprintf("%d, %s", i, r)) + } - return user, errors.New("access denied") -// return user, nil + return user, roles, nil } diff --git a/internal/provider/providers.go b/internal/provider/providers.go index 558e776..d7be57f 100644 --- a/internal/provider/providers.go +++ b/internal/provider/providers.go @@ -20,7 +20,7 @@ type Provider interface { Name() string GetLoginURL(redirectURI, state string) string ExchangeCode(redirectURI, code string) (string, error) - GetUser(token string) (User, error) + GetUser(token string) (User, Roles, error) Setup(*logrus.Logger) error } diff --git a/internal/server.go b/internal/server.go index 2e20df5..89aed78 100644 --- a/internal/server.go +++ b/internal/server.go @@ -178,12 +178,23 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } // Get user - user, err := p.GetUser(token) + user, roles, err := p.GetUser(token) if err != nil { logger.WithField("error", err).Error("Error getting user") http.Error(w, "Service unavailable", 503) return } + found := false + for _, r := range roles.Roles { + if r == "whoami_admin" { + found = true + } + } + if ! found { + logger.Debug("required role not found, deny access") + http.Error(w, "Forbidden", 403) + return + } // Generate cookie http.SetCookie(w, MakeCookie(r, user.Email)) @@ -191,6 +202,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { "provider": providerName, "redirect": redirect, "user": user.Email, + "roles": roles.Roles, }).Info("Successfully generated auth cookie, redirecting user.") // Redirect