evaluate role in higher layer

This commit is contained in:
Wolfgang Hottgenroth 2023-11-06 22:09:29 +01:00
parent ab2d527dbd
commit f6120640d2
Signed by: wn
GPG Key ID: 836E9E1192A6B132
5 changed files with 33 additions and 16 deletions

View File

@ -69,12 +69,13 @@ func (o *GenericOAuth) ExchangeCode(redirectURI, code string) (string, error) {
} }
// GetUser uses the given token and returns a complete provider.User object // GetUser uses the given token and returns a complete provider.User object
func (o *GenericOAuth) GetUser(token string) (User, error) { func (o *GenericOAuth) GetUser(token string) (User, Roles, error) {
var user User var user User
var roles Roles
req, err := http.NewRequest("GET", o.UserURL, nil) req, err := http.NewRequest("GET", o.UserURL, nil)
if err != nil { if err != nil {
return user, err return user, roles, err
} }
if o.TokenStyle == "header" { if o.TokenStyle == "header" {
@ -88,11 +89,11 @@ func (o *GenericOAuth) GetUser(token string) (User, error) {
client := &http.Client{} client := &http.Client{}
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return user, err return user, roles, err
} }
defer res.Body.Close() defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user) err = json.NewDecoder(res.Body).Decode(&user)
return user, err return user, roles, err
} }

View File

@ -95,23 +95,24 @@ func (g *Google) ExchangeCode(redirectURI, code string) (string, error) {
} }
// GetUser uses the given token and returns a complete provider.User object // GetUser uses the given token and returns a complete provider.User object
func (g *Google) GetUser(token string) (User, error) { func (g *Google) GetUser(token string) (User, Roles, error) {
var user User var user User
var roles Roles
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("GET", g.UserURL.String(), nil) req, err := http.NewRequest("GET", g.UserURL.String(), nil)
if err != nil { if err != nil {
return user, err return user, roles, err
} }
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return user, err return user, roles, err
} }
defer res.Body.Close() defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user) err = json.NewDecoder(res.Body).Decode(&user)
return user, err return user, roles, err
} }

View File

@ -3,6 +3,7 @@ package provider
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -88,14 +89,14 @@ func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
} }
// GetUser uses the given token and returns a complete provider.User object // GetUser uses the given token and returns a complete provider.User object
func (o *OIDC) GetUser(token string) (User, error) { func (o *OIDC) GetUser(token string) (User, Roles, error) {
var user User var user User
var roles Roles var roles Roles
// Parse & Verify ID Token // Parse & Verify ID Token
idToken, err := o.verifier.Verify(o.ctx, token) idToken, err := o.verifier.Verify(o.ctx, token)
if err != nil { if err != nil {
return user, err return user, roles, err
} }
@ -103,16 +104,18 @@ func (o *OIDC) GetUser(token string) (User, error) {
// Extract custom claims // Extract custom claims
if err := idToken.Claims(&user); err != nil { if err := idToken.Claims(&user); err != nil {
return user, err return user, roles, err
} }
o.log.WithField("user", user).Debug("getUser") o.log.WithField("user", user).Debug("getUser")
if err := idToken.Claims(&roles); err != nil { if err := idToken.Claims(&roles); err != nil {
return user, err return user, roles, err
} }
o.log.WithField("roles", roles).Debug("getUser") o.log.WithField("roles", roles).Debug("getUser")
for i, r := range roles.Roles {
o.log.Debug(fmt.Sprintf("%d, %s", i, r))
return user, errors.New("access denied") }
// return user, nil
return user, roles, nil
} }

View File

@ -20,7 +20,7 @@ type Provider interface {
Name() string Name() string
GetLoginURL(redirectURI, state string) string GetLoginURL(redirectURI, state string) string
ExchangeCode(redirectURI, code string) (string, error) ExchangeCode(redirectURI, code string) (string, error)
GetUser(token string) (User, error) GetUser(token string) (User, Roles, error)
Setup(*logrus.Logger) error Setup(*logrus.Logger) error
} }

View File

@ -178,12 +178,23 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
} }
// Get user // Get user
user, err := p.GetUser(token) user, roles, err := p.GetUser(token)
if err != nil { if err != nil {
logger.WithField("error", err).Error("Error getting user") logger.WithField("error", err).Error("Error getting user")
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
found := false
for _, r := range roles.Roles {
if r == "whoami_admin" {
found = true
}
}
if ! found {
logger.Debug("required role not found, deny access")
http.Error(w, "Forbidden", 403)
return
}
// Generate cookie // Generate cookie
http.SetCookie(w, MakeCookie(r, user.Email)) http.SetCookie(w, MakeCookie(r, user.Email))
@ -191,6 +202,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
"provider": providerName, "provider": providerName,
"redirect": redirect, "redirect": redirect,
"user": user.Email, "user": user.Email,
"roles": roles.Roles,
}).Info("Successfully generated auth cookie, redirecting user.") }).Info("Successfully generated auth cookie, redirecting user.")
// Redirect // Redirect