Begin refactor + selective auth
This commit is contained in:
151
forwardauth.go
151
forwardauth.go
@ -5,41 +5,31 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
// "encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
// "net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/thomseddon/traefik-forward-auth/provider"
|
||||
)
|
||||
|
||||
type ForwardAuthContext int
|
||||
|
||||
const (
|
||||
Nonce ForwardAuthContext = iota
|
||||
Request
|
||||
)
|
||||
|
||||
// Forward Auth
|
||||
type ForwardAuth struct {
|
||||
Path string
|
||||
Lifetime time.Duration
|
||||
Secret []byte
|
||||
}
|
||||
|
||||
ClientId string
|
||||
ClientSecret string `json:"-"`
|
||||
Scope string
|
||||
|
||||
LoginURL *url.URL
|
||||
TokenURL *url.URL
|
||||
UserURL *url.URL
|
||||
|
||||
AuthHost string
|
||||
|
||||
CookieName string
|
||||
CookieDomains []CookieDomain
|
||||
CSRFCookieName string
|
||||
CookieSecure bool
|
||||
|
||||
Domain []string
|
||||
Whitelist []string
|
||||
|
||||
Prompt string
|
||||
func NewForwardAuth() *ForwardAuth {
|
||||
return &ForwardAuth{}
|
||||
}
|
||||
|
||||
// Request Validation
|
||||
@ -85,18 +75,18 @@ func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, str
|
||||
// Validate email
|
||||
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
||||
found := false
|
||||
if len(f.Whitelist) > 0 {
|
||||
for _, whitelist := range f.Whitelist {
|
||||
if len(config.Whitelist) > 0 {
|
||||
for _, whitelist := range config.Whitelist {
|
||||
if email == whitelist {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
} else if len(f.Domain) > 0 {
|
||||
} else if len(config.Domain) > 0 {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
for _, domain := range f.Domain {
|
||||
for _, domain := range config.Domain {
|
||||
if domain == parts[1] {
|
||||
found = true
|
||||
}
|
||||
@ -114,77 +104,24 @@ func (f *ForwardAuth) ValidateEmail(email string) bool {
|
||||
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", fw.ClientId)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", fw.Scope)
|
||||
if fw.Prompt != "" {
|
||||
q.Set("prompt", fw.Prompt)
|
||||
}
|
||||
q.Set("redirect_uri", f.redirectUri(r))
|
||||
q.Set("state", state)
|
||||
|
||||
var u url.URL
|
||||
u = *fw.LoginURL
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.GetLoginURL(f.redirectUri(r), state)
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
|
||||
type Token struct {
|
||||
Token string `json:"access_token"`
|
||||
}
|
||||
func (f *ForwardAuth) ExchangeCode(r *http.Request) (string, error) {
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", fw.ClientId)
|
||||
form.Set("client_secret", fw.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", f.redirectUri(r))
|
||||
form.Set("code", code)
|
||||
|
||||
res, err := http.PostForm(fw.TokenURL.String(), form)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var token Token
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
|
||||
return token.Token, err
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.ExchangeCode(f.redirectUri(r), code)
|
||||
}
|
||||
|
||||
// Get user with token
|
||||
|
||||
type User struct {
|
||||
Id string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified_email"`
|
||||
Hd string `json:"hd"`
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&user)
|
||||
|
||||
return user, err
|
||||
func (f *ForwardAuth) GetUser(token string) (provider.User, error) {
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.GetUser(token)
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
@ -197,7 +134,7 @@ func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
||||
return fmt.Sprintf("%s://%s", proto, host)
|
||||
}
|
||||
|
||||
// Return url
|
||||
// // Return url
|
||||
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
|
||||
@ -208,15 +145,15 @@ func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
||||
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
||||
if use, _ := f.useAuthDomain(r); use {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
|
||||
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), config.Path)
|
||||
}
|
||||
|
||||
// Should we use auth host + what it is
|
||||
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
||||
if f.AuthHost == "" {
|
||||
if config.AuthHost == "" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
@ -224,7 +161,7 @@ func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
||||
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||
|
||||
// Do any of the auth hosts match a cookie domain?
|
||||
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
|
||||
authMatch, authHost := f.matchCookieDomains(config.AuthHost)
|
||||
|
||||
// We need both to match the same domain
|
||||
return reqMatch && authMatch && reqHost == authHost, reqHost
|
||||
@ -239,12 +176,12 @@ func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||
|
||||
return &http.Cookie{
|
||||
Name: f.CookieName,
|
||||
Name: config.CookieName,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: f.cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Secure: config.CookieSecure,
|
||||
Expires: expires,
|
||||
}
|
||||
}
|
||||
@ -252,12 +189,12 @@ func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||
// Make a CSRF cookie (used during login only)
|
||||
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Name: config.CSRFCookieName,
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
Domain: f.csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Secure: config.CookieSecure,
|
||||
Expires: f.cookieExpiry(),
|
||||
}
|
||||
}
|
||||
@ -265,18 +202,20 @@ func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie
|
||||
// Create a cookie to clear csrf cookie
|
||||
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Name: config.CSRFCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: f.csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Secure: config.CookieSecure,
|
||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the csrf cookie against state
|
||||
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
|
||||
// // Validate the csrf cookie against state
|
||||
func (f *ForwardAuth) ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
@ -333,7 +272,7 @@ func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
||||
// Remove port
|
||||
p := strings.Split(domain, ":")
|
||||
|
||||
for _, d := range f.CookieDomains {
|
||||
for _, d := range config.CookieDomains {
|
||||
if d.Match(p[0]) {
|
||||
return true, d.Domain
|
||||
}
|
||||
@ -344,7 +283,7 @@ func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
||||
|
||||
// Create cookie hmac
|
||||
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
||||
hash := hmac.New(sha256.New, f.Secret)
|
||||
hash := hmac.New(sha256.New, config.SecretBytes)
|
||||
hash.Write([]byte(f.cookieDomain(r)))
|
||||
hash.Write([]byte(email))
|
||||
hash.Write([]byte(expires))
|
||||
@ -353,7 +292,7 @@ func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) st
|
||||
|
||||
// Get cookie expirary
|
||||
func (f *ForwardAuth) cookieExpiry() time.Time {
|
||||
return time.Now().Local().Add(f.Lifetime)
|
||||
return time.Now().Local().Add(config.Lifetime)
|
||||
}
|
||||
|
||||
// Cookie Domain
|
||||
|
Reference in New Issue
Block a user