361 lines
8.7 KiB
Go
Raw Normal View History

package tfa
2018-06-26 12:28:47 +01:00
import (
2019-01-22 10:50:55 +00:00
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
2018-06-26 12:28:47 +01:00
"github.com/thomseddon/traefik-forward-auth/internal/provider"
2019-01-30 16:52:47 +00:00
)
2018-06-26 12:28:47 +01:00
// Request Validation
2020-05-11 14:42:33 +01:00
// ValidateCookie verifies that a cookie matches the expected format of:
2018-06-26 12:28:47 +01:00
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
2019-01-22 10:50:55 +00:00
parts := strings.Split(c.Value, "|")
if len(parts) != 3 {
return "", errors.New("Invalid cookie format")
2019-01-22 10:50:55 +00:00
}
mac, err := base64.URLEncoding.DecodeString(parts[0])
if err != nil {
return "", errors.New("Unable to decode cookie mac")
2019-01-22 10:50:55 +00:00
}
expectedSignature := cookieSignature(r, parts[2], parts[1])
2019-01-22 10:50:55 +00:00
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil {
return "", errors.New("Unable to generate mac")
2019-01-22 10:50:55 +00:00
}
// Valid token?
if !hmac.Equal(mac, expected) {
return "", errors.New("Invalid cookie mac")
2019-01-22 10:50:55 +00:00
}
expires, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return "", errors.New("Unable to parse cookie expiry")
2019-01-22 10:50:55 +00:00
}
// Has it expired?
if time.Unix(expires, 0).Before(time.Now()) {
return "", errors.New("Cookie has expired")
2019-01-22 10:50:55 +00:00
}
// Looks valid
return parts[2], nil
2018-06-26 12:28:47 +01:00
}
// ValidateEmail checks if the given email address matches either a whitelisted
// email address, as defined by the "whitelist" config parameter. Or is part of
// a permitted domain, as defined by the "domains" config parameter
func ValidateEmail(email string) bool {
// Do we have any validation to perform?
if len(config.Whitelist) == 0 && len(config.Domains) == 0 {
return true
}
// Email whitelist validation
2019-01-30 16:52:47 +00:00
if len(config.Whitelist) > 0 {
for _, whitelist := range config.Whitelist {
2019-01-22 10:50:55 +00:00
if email == whitelist {
return true
2019-01-22 10:50:55 +00:00
}
}
// If we're not matching *either*, stop here
if !config.MatchWhitelistOrDomain {
return false
}
}
// Domain validation
if len(config.Domains) > 0 {
2019-01-22 10:50:55 +00:00
parts := strings.Split(email, "@")
if len(parts) < 2 {
return false
}
for _, domain := range config.Domains {
2019-01-22 10:50:55 +00:00
if domain == parts[1] {
return true
2019-01-22 10:50:55 +00:00
}
}
}
return false
2018-06-26 12:28:47 +01:00
}
// Utility methods
// Get the redirect base
func redirectBase(r *http.Request) string {
2019-01-22 10:50:55 +00:00
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")
2018-06-26 12:28:47 +01:00
2019-01-22 10:50:55 +00:00
return fmt.Sprintf("%s://%s", proto, host)
2018-06-26 12:28:47 +01:00
}
// Return url
func returnUrl(r *http.Request) string {
2019-01-22 10:50:55 +00:00
path := r.Header.Get("X-Forwarded-Uri")
2018-06-26 12:28:47 +01:00
return fmt.Sprintf("%s%s", redirectBase(r), path)
2018-06-26 12:28:47 +01:00
}
// Get oauth redirect uri
func redirectUri(r *http.Request) string {
if use, _ := useAuthDomain(r); use {
2019-01-22 10:50:55 +00:00
proto := r.Header.Get("X-Forwarded-Proto")
2019-01-30 16:52:47 +00:00
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
2019-01-22 10:50:55 +00:00
}
return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
2018-06-26 12:28:47 +01:00
}
// Should we use auth host + what it is
func useAuthDomain(r *http.Request) (bool, string) {
2019-01-30 16:52:47 +00:00
if config.AuthHost == "" {
2019-01-22 10:50:55 +00:00
return false, ""
}
2019-01-22 10:50:55 +00:00
// Does the request match a given cookie domain?
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
2019-01-22 10:50:55 +00:00
// Do any of the auth hosts match a cookie domain?
authMatch, authHost := matchCookieDomains(config.AuthHost)
2019-01-22 10:50:55 +00:00
// We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost
}
2018-06-26 12:28:47 +01:00
// Cookie methods
2020-05-11 14:42:33 +01:00
// MakeCookie creates an auth cookie
func MakeCookie(r *http.Request, email string) *http.Cookie {
expires := cookieExpiry()
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
2019-01-22 10:50:55 +00:00
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
return &http.Cookie{
2019-01-30 16:52:47 +00:00
Name: config.CookieName,
2019-01-22 10:50:55 +00:00
Value: value,
Path: "/",
Domain: cookieDomain(r),
2019-01-22 10:50:55 +00:00
HttpOnly: true,
Secure: !config.InsecureCookie,
2019-01-22 10:50:55 +00:00
Expires: expires,
}
2018-06-26 12:28:47 +01:00
}
// ClearCookie clears the auth cookie
func ClearCookie(r *http.Request) *http.Cookie {
return &http.Cookie{
Name: config.CookieName,
Value: "",
Path: "/",
Domain: cookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: time.Now().Local().Add(time.Hour * -1),
}
}
2020-05-11 14:42:33 +01:00
// MakeCSRFCookie makes a csrf cookie (used during login only)
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
2019-01-22 10:50:55 +00:00
return &http.Cookie{
2019-01-30 16:52:47 +00:00
Name: config.CSRFCookieName,
2019-01-22 10:50:55 +00:00
Value: nonce,
Path: "/",
Domain: csrfCookieDomain(r),
2019-01-22 10:50:55 +00:00
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: cookieExpiry(),
2019-01-22 10:50:55 +00:00
}
2018-06-26 12:28:47 +01:00
}
2020-05-11 14:42:33 +01:00
// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie {
2019-01-22 10:50:55 +00:00
return &http.Cookie{
2019-01-30 16:52:47 +00:00
Name: config.CSRFCookieName,
2019-01-22 10:50:55 +00:00
Value: "",
Path: "/",
Domain: csrfCookieDomain(r),
2019-01-22 10:50:55 +00:00
HttpOnly: true,
Secure: !config.InsecureCookie,
2019-01-22 10:50:55 +00:00
Expires: time.Now().Local().Add(time.Hour * -1),
}
2018-06-26 12:28:47 +01:00
}
2020-05-11 14:42:33 +01:00
// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
2019-01-30 16:52:47 +00:00
state := r.URL.Query().Get("state")
2019-01-22 10:50:55 +00:00
if len(c.Value) != 32 {
return false, "", "", errors.New("Invalid CSRF cookie value")
2019-01-22 10:50:55 +00:00
}
2018-06-26 12:28:47 +01:00
2019-01-22 10:50:55 +00:00
if len(state) < 34 {
return false, "", "", errors.New("Invalid CSRF state value")
2019-01-22 10:50:55 +00:00
}
2018-06-26 12:28:47 +01:00
2019-01-22 10:50:55 +00:00
// Check nonce match
if c.Value != state[:32] {
return false, "", "", errors.New("CSRF cookie does not match state")
2019-01-22 10:50:55 +00:00
}
2018-06-26 12:28:47 +01:00
// Extract provider
params := state[33:]
split := strings.Index(params, ":")
if split == -1 {
return false, "", "", errors.New("Invalid CSRF state format")
}
// Valid, return provider and redirect
return true, params[:split], params[split+1:], nil
}
2020-05-11 14:42:33 +01:00
// MakeState generates a state value
func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
2018-06-26 12:28:47 +01:00
}
2020-05-11 14:42:33 +01:00
// Nonce generates a random nonce
func Nonce() (error, string) {
2019-01-22 10:50:55 +00:00
nonce := make([]byte, 16)
_, err := rand.Read(nonce)
if err != nil {
return err, ""
}
return nil, fmt.Sprintf("%x", nonce)
2018-06-26 12:28:47 +01:00
}
// Cookie domain
func cookieDomain(r *http.Request) string {
2019-01-22 10:50:55 +00:00
host := r.Header.Get("X-Forwarded-Host")
2018-06-26 12:28:47 +01:00
2019-01-22 10:50:55 +00:00
// Check if any of the given cookie domains matches
_, domain := matchCookieDomains(host)
2019-01-22 10:50:55 +00:00
return domain
}
// Cookie domain
func csrfCookieDomain(r *http.Request) string {
2019-01-22 10:50:55 +00:00
var host string
if use, domain := useAuthDomain(r); use {
2019-01-22 10:50:55 +00:00
host = domain
} else {
host = r.Header.Get("X-Forwarded-Host")
}
// Remove port
p := strings.Split(host, ":")
return p[0]
}
2018-06-26 12:28:47 +01:00
// Return matching cookie domain if exists
func matchCookieDomains(domain string) (bool, string) {
2019-01-22 10:50:55 +00:00
// Remove port
p := strings.Split(domain, ":")
2019-01-30 16:52:47 +00:00
for _, d := range config.CookieDomains {
2019-01-22 10:50:55 +00:00
if d.Match(p[0]) {
return true, d.Domain
}
}
2018-06-26 12:28:47 +01:00
2019-01-22 10:50:55 +00:00
return false, p[0]
2018-06-26 12:28:47 +01:00
}
// Create cookie hmac
func cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, config.Secret)
hash.Write([]byte(cookieDomain(r)))
2019-01-22 10:50:55 +00:00
hash.Write([]byte(email))
hash.Write([]byte(expires))
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
2018-06-26 12:28:47 +01:00
}
2019-07-08 17:21:08 +01:00
// Get cookie expiry
func cookieExpiry() time.Time {
2019-01-30 16:52:47 +00:00
return time.Now().Local().Add(config.Lifetime)
2018-06-26 12:28:47 +01:00
}
2020-05-11 14:42:33 +01:00
// CookieDomain holds cookie domain info
2018-06-26 12:28:47 +01:00
type CookieDomain struct {
Domain string
DomainLen int
SubDomain string
SubDomainLen int
2018-06-26 12:28:47 +01:00
}
2020-05-11 14:42:33 +01:00
// NewCookieDomain creates a new CookieDomain from the given domain string
2018-06-26 12:28:47 +01:00
func NewCookieDomain(domain string) *CookieDomain {
2019-01-22 10:50:55 +00:00
return &CookieDomain{
Domain: domain,
DomainLen: len(domain),
SubDomain: fmt.Sprintf(".%s", domain),
SubDomainLen: len(domain) + 1,
}
2018-06-26 12:28:47 +01:00
}
2020-05-11 14:42:33 +01:00
// Match checks if the given host matches this CookieDomain
2018-06-26 12:28:47 +01:00
func (c *CookieDomain) Match(host string) bool {
2019-01-22 10:50:55 +00:00
// Exact domain match?
if host == c.Domain {
return true
}
2018-06-26 12:28:47 +01:00
2019-01-22 10:50:55 +00:00
// Subdomain match?
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
return true
}
2018-06-26 12:28:47 +01:00
2019-01-22 10:50:55 +00:00
return false
2018-06-26 12:28:47 +01:00
}
2020-05-11 14:42:33 +01:00
// UnmarshalFlag converts a string to a CookieDomain
func (c *CookieDomain) UnmarshalFlag(value string) error {
*c = *NewCookieDomain(value)
return nil
}
2020-05-11 14:42:33 +01:00
// MarshalFlag converts a CookieDomain to a string
func (c *CookieDomain) MarshalFlag() (string, error) {
return c.Domain, nil
}
2020-05-11 14:42:33 +01:00
// CookieDomains provides legacy sypport for comma separated list of cookie domains
type CookieDomains []CookieDomain
2020-05-11 14:42:33 +01:00
// UnmarshalFlag converts a comma separated list of cookie domains to an array
// of CookieDomains
func (c *CookieDomains) UnmarshalFlag(value string) error {
if len(value) > 0 {
for _, d := range strings.Split(value, ",") {
cookieDomain := NewCookieDomain(d)
*c = append(*c, *cookieDomain)
}
}
return nil
}
2020-05-11 14:42:33 +01:00
// MarshalFlag converts an array of CookieDomain to a comma seperated list
func (c *CookieDomains) MarshalFlag() (string, error) {
var domains []string
for _, d := range *c {
domains = append(domains, d.Domain)
}
return strings.Join(domains, ","), nil
}