Begin refactor + selective auth

This commit is contained in:
Thom Seddon
2019-01-30 16:52:47 +00:00
parent 5c800a0170
commit ae95e8b2e5
12 changed files with 943 additions and 577 deletions

View File

@ -5,41 +5,31 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
// "encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
// "net/url"
"strconv"
"strings"
"time"
"github.com/thomseddon/traefik-forward-auth/provider"
)
type ForwardAuthContext int
const (
Nonce ForwardAuthContext = iota
Request
)
// Forward Auth
type ForwardAuth struct {
Path string
Lifetime time.Duration
Secret []byte
}
ClientId string
ClientSecret string `json:"-"`
Scope string
LoginURL *url.URL
TokenURL *url.URL
UserURL *url.URL
AuthHost string
CookieName string
CookieDomains []CookieDomain
CSRFCookieName string
CookieSecure bool
Domain []string
Whitelist []string
Prompt string
func NewForwardAuth() *ForwardAuth {
return &ForwardAuth{}
}
// Request Validation
@ -85,18 +75,18 @@ func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, str
// Validate email
func (f *ForwardAuth) ValidateEmail(email string) bool {
found := false
if len(f.Whitelist) > 0 {
for _, whitelist := range f.Whitelist {
if len(config.Whitelist) > 0 {
for _, whitelist := range config.Whitelist {
if email == whitelist {
found = true
}
}
} else if len(f.Domain) > 0 {
} else if len(config.Domain) > 0 {
parts := strings.Split(email, "@")
if len(parts) < 2 {
return false
}
for _, domain := range f.Domain {
for _, domain := range config.Domain {
if domain == parts[1] {
found = true
}
@ -114,77 +104,24 @@ func (f *ForwardAuth) ValidateEmail(email string) bool {
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
q := url.Values{}
q.Set("client_id", fw.ClientId)
q.Set("response_type", "code")
q.Set("scope", fw.Scope)
if fw.Prompt != "" {
q.Set("prompt", fw.Prompt)
}
q.Set("redirect_uri", f.redirectUri(r))
q.Set("state", state)
var u url.URL
u = *fw.LoginURL
u.RawQuery = q.Encode()
return u.String()
// TODO: Support multiple providers
return config.Providers.Google.GetLoginURL(f.redirectUri(r), state)
}
// Exchange code for token
type Token struct {
Token string `json:"access_token"`
}
func (f *ForwardAuth) ExchangeCode(r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
form := url.Values{}
form.Set("client_id", fw.ClientId)
form.Set("client_secret", fw.ClientSecret)
form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", f.redirectUri(r))
form.Set("code", code)
res, err := http.PostForm(fw.TokenURL.String(), form)
if err != nil {
return "", err
}
var token Token
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
// TODO: Support multiple providers
return config.Providers.Google.ExchangeCode(f.redirectUri(r), code)
}
// Get user with token
type User struct {
Id string `json:"id"`
Email string `json:"email"`
Verified bool `json:"verified_email"`
Hd string `json:"hd"`
}
func (f *ForwardAuth) GetUser(token string) (User, error) {
var user User
client := &http.Client{}
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
if err != nil {
return user, err
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user)
return user, err
func (f *ForwardAuth) GetUser(token string) (provider.User, error) {
// TODO: Support multiple providers
return config.Providers.Google.GetUser(token)
}
// Utility methods
@ -197,7 +134,7 @@ func (f *ForwardAuth) redirectBase(r *http.Request) string {
return fmt.Sprintf("%s://%s", proto, host)
}
// Return url
// // Return url
func (f *ForwardAuth) returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri")
@ -208,15 +145,15 @@ func (f *ForwardAuth) returnUrl(r *http.Request) string {
func (f *ForwardAuth) redirectUri(r *http.Request) string {
if use, _ := f.useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
}
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
return fmt.Sprintf("%s%s", f.redirectBase(r), config.Path)
}
// Should we use auth host + what it is
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
if f.AuthHost == "" {
if config.AuthHost == "" {
return false, ""
}
@ -224,7 +161,7 @@ func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
// Do any of the auth hosts match a cookie domain?
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
authMatch, authHost := f.matchCookieDomains(config.AuthHost)
// We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost
@ -239,12 +176,12 @@ func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
return &http.Cookie{
Name: f.CookieName,
Name: config.CookieName,
Value: value,
Path: "/",
Domain: f.cookieDomain(r),
HttpOnly: true,
Secure: f.CookieSecure,
Secure: config.CookieSecure,
Expires: expires,
}
}
@ -252,12 +189,12 @@ func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
// Make a CSRF cookie (used during login only)
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: f.CSRFCookieName,
Name: config.CSRFCookieName,
Value: nonce,
Path: "/",
Domain: f.csrfCookieDomain(r),
HttpOnly: true,
Secure: f.CookieSecure,
Secure: config.CookieSecure,
Expires: f.cookieExpiry(),
}
}
@ -265,18 +202,20 @@ func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie
// Create a cookie to clear csrf cookie
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{
Name: f.CSRFCookieName,
Name: config.CSRFCookieName,
Value: "",
Path: "/",
Domain: f.csrfCookieDomain(r),
HttpOnly: true,
Secure: f.CookieSecure,
Secure: config.CookieSecure,
Expires: time.Now().Local().Add(time.Hour * -1),
}
}
// Validate the csrf cookie against state
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
// // Validate the csrf cookie against state
func (f *ForwardAuth) ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
state := r.URL.Query().Get("state")
if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value")
}
@ -333,7 +272,7 @@ func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
// Remove port
p := strings.Split(domain, ":")
for _, d := range f.CookieDomains {
for _, d := range config.CookieDomains {
if d.Match(p[0]) {
return true, d.Domain
}
@ -344,7 +283,7 @@ func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
// Create cookie hmac
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, f.Secret)
hash := hmac.New(sha256.New, config.SecretBytes)
hash.Write([]byte(f.cookieDomain(r)))
hash.Write([]byte(email))
hash.Write([]byte(expires))
@ -353,7 +292,7 @@ func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) st
// Get cookie expirary
func (f *ForwardAuth) cookieExpiry() time.Time {
return time.Now().Local().Add(f.Lifetime)
return time.Now().Local().Add(config.Lifetime)
}
// Cookie Domain