use gofmt to update styling

This commit is contained in:
Thom Seddon 2019-01-22 10:50:55 +00:00
parent 6ccd1c6dfc
commit afd8878188
4 changed files with 908 additions and 915 deletions

View File

@ -1,393 +1,390 @@
package main package main
import ( import (
"fmt" "crypto/hmac"
"time" "crypto/rand"
"errors" "crypto/sha256"
"strings" "encoding/base64"
"strconv" "encoding/json"
"net/url" "errors"
"net/http" "fmt"
"crypto/hmac" "net/http"
"crypto/rand" "net/url"
"crypto/sha256" "strconv"
"encoding/json" "strings"
"encoding/base64" "time"
) )
// Forward Auth // Forward Auth
type ForwardAuth struct { type ForwardAuth struct {
Path string Path string
Lifetime time.Duration Lifetime time.Duration
Secret []byte Secret []byte
ClientId string ClientId string
ClientSecret string ClientSecret string
Scope string Scope string
LoginURL *url.URL LoginURL *url.URL
TokenURL *url.URL TokenURL *url.URL
UserURL *url.URL UserURL *url.URL
AuthHost string AuthHost string
CookieName string CookieName string
CookieDomains []CookieDomain CookieDomains []CookieDomain
CSRFCookieName string CSRFCookieName string
CookieSecure bool CookieSecure bool
Domain []string Domain []string
Whitelist []string Whitelist []string
Prompt string Prompt string
} }
// Request Validation // Request Validation
// Cookie = hash(secret, cookie domain, email, expires)|expires|email // Cookie = hash(secret, cookie domain, email, expires)|expires|email
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) { func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
parts := strings.Split(c.Value, "|") parts := strings.Split(c.Value, "|")
if len(parts) != 3 { if len(parts) != 3 {
return false, "", errors.New("Invalid cookie format") return false, "", errors.New("Invalid cookie format")
} }
mac, err := base64.URLEncoding.DecodeString(parts[0]) mac, err := base64.URLEncoding.DecodeString(parts[0])
if err != nil { if err != nil {
return false, "", errors.New("Unable to decode cookie mac") return false, "", errors.New("Unable to decode cookie mac")
} }
expectedSignature := f.cookieSignature(r, parts[2], parts[1]) expectedSignature := f.cookieSignature(r, parts[2], parts[1])
expected, err := base64.URLEncoding.DecodeString(expectedSignature) expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil { if err != nil {
return false, "", errors.New("Unable to generate mac") return false, "", errors.New("Unable to generate mac")
} }
// Valid token? // Valid token?
if !hmac.Equal(mac, expected) { if !hmac.Equal(mac, expected) {
return false, "", errors.New("Invalid cookie mac") return false, "", errors.New("Invalid cookie mac")
} }
expires, err := strconv.ParseInt(parts[1], 10, 64) expires, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil { if err != nil {
return false, "", errors.New("Unable to parse cookie expiry") return false, "", errors.New("Unable to parse cookie expiry")
} }
// Has it expired? // Has it expired?
if time.Unix(expires, 0).Before(time.Now()) { if time.Unix(expires, 0).Before(time.Now()) {
return false, "", errors.New("Cookie has expired") return false, "", errors.New("Cookie has expired")
} }
// Looks valid // Looks valid
return true, parts[2], nil return true, parts[2], nil
} }
// Validate email // Validate email
func (f *ForwardAuth) ValidateEmail(email string) bool { func (f *ForwardAuth) ValidateEmail(email string) bool {
found := false found := false
if len(f.Whitelist) > 0 { if len(f.Whitelist) > 0 {
for _, whitelist := range f.Whitelist { for _, whitelist := range f.Whitelist {
if email == whitelist { if email == whitelist {
found = true found = true
} }
} }
} else if len(f.Domain) > 0 { } else if len(f.Domain) > 0 {
parts := strings.Split(email, "@") parts := strings.Split(email, "@")
if len(parts) < 2 { if len(parts) < 2 {
return false return false
} }
for _, domain := range f.Domain { for _, domain := range f.Domain {
if domain == parts[1] { if domain == parts[1] {
found = true found = true
} }
} }
} else { } else {
return true return true
} }
return found return found
} }
// OAuth Methods // OAuth Methods
// Get login url // Get login url
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string { func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r)) state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
q := url.Values{} q := url.Values{}
q.Set("client_id", fw.ClientId) q.Set("client_id", fw.ClientId)
q.Set("response_type", "code") q.Set("response_type", "code")
q.Set("scope", fw.Scope) q.Set("scope", fw.Scope)
if fw.Prompt != "" { if fw.Prompt != "" {
q.Set("prompt", fw.Prompt) q.Set("prompt", fw.Prompt)
} }
q.Set("redirect_uri", f.redirectUri(r)) q.Set("redirect_uri", f.redirectUri(r))
q.Set("state", state) q.Set("state", state)
var u url.URL var u url.URL
u = *fw.LoginURL u = *fw.LoginURL
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
return u.String() return u.String()
} }
// Exchange code for token // Exchange code for token
type Token struct { type Token struct {
Token string `json:"access_token"` Token string `json:"access_token"`
} }
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) { func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
form := url.Values{} form := url.Values{}
form.Set("client_id", fw.ClientId) form.Set("client_id", fw.ClientId)
form.Set("client_secret", fw.ClientSecret) form.Set("client_secret", fw.ClientSecret)
form.Set("grant_type", "authorization_code") form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", f.redirectUri(r)) form.Set("redirect_uri", f.redirectUri(r))
form.Set("code", code) form.Set("code", code)
res, err := http.PostForm(fw.TokenURL.String(), form)
if err != nil {
return "", err
}
res, err := http.PostForm(fw.TokenURL.String(), form) var token Token
if err != nil { defer res.Body.Close()
return "", err err = json.NewDecoder(res.Body).Decode(&token)
}
var token Token return token.Token, err
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
} }
// Get user with token // Get user with token
type User struct { type User struct {
Id string `json:"id"` Id string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Verified bool `json:"verified_email"` Verified bool `json:"verified_email"`
Hd string `json:"hd"` Hd string `json:"hd"`
} }
func (f *ForwardAuth) GetUser(token string) (User, error) { func (f *ForwardAuth) GetUser(token string) (User, error) {
var user User var user User
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("GET", fw.UserURL.String(), nil) req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
if err != nil { if err != nil {
return user, err return user, err
} }
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return user, err return user, err
} }
defer res.Body.Close() defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user) err = json.NewDecoder(res.Body).Decode(&user)
return user, err return user, err
} }
// Utility methods // Utility methods
// Get the redirect base // Get the redirect base
func (f *ForwardAuth) redirectBase(r *http.Request) string { func (f *ForwardAuth) redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto") proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host") host := r.Header.Get("X-Forwarded-Host")
return fmt.Sprintf("%s://%s", proto, host) return fmt.Sprintf("%s://%s", proto, host)
} }
// Return url // Return url
func (f *ForwardAuth) returnUrl(r *http.Request) string { func (f *ForwardAuth) returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri") path := r.Header.Get("X-Forwarded-Uri")
return fmt.Sprintf("%s%s", f.redirectBase(r), path) return fmt.Sprintf("%s%s", f.redirectBase(r), path)
} }
// Get oauth redirect uri // Get oauth redirect uri
func (f *ForwardAuth) redirectUri(r *http.Request) string { func (f *ForwardAuth) redirectUri(r *http.Request) string {
if use, _ := f.useAuthDomain(r); use { if use, _ := f.useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto") proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path) return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
} }
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path) return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
} }
// Should we use auth host + what it is // Should we use auth host + what it is
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) { func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
if f.AuthHost == "" { if f.AuthHost == "" {
return false, "" return false, ""
} }
// Does the request match a given cookie domain? // Does the request match a given cookie domain?
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host")) reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
// Do any of the auth hosts match a cookie domain? // Do any of the auth hosts match a cookie domain?
authMatch, authHost := f.matchCookieDomains(f.AuthHost) authMatch, authHost := f.matchCookieDomains(f.AuthHost)
// We need both to match the same domain // We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost return reqMatch && authMatch && reqHost == authHost, reqHost
} }
// Cookie methods // Cookie methods
// Create an auth cookie // Create an auth cookie
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie { func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
expires := f.cookieExpiry() expires := f.cookieExpiry()
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix())) mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email) value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
return &http.Cookie{ return &http.Cookie{
Name: f.CookieName, Name: f.CookieName,
Value: value, Value: value,
Path: "/", Path: "/",
Domain: f.cookieDomain(r), Domain: f.cookieDomain(r),
HttpOnly: true, HttpOnly: true,
Secure: f.CookieSecure, Secure: f.CookieSecure,
Expires: expires, Expires: expires,
} }
} }
// Make a CSRF cookie (used during login only) // Make a CSRF cookie (used during login only)
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: f.CSRFCookieName, Name: f.CSRFCookieName,
Value: nonce, Value: nonce,
Path: "/", Path: "/",
Domain: f.csrfCookieDomain(r), Domain: f.csrfCookieDomain(r),
HttpOnly: true, HttpOnly: true,
Secure: f.CookieSecure, Secure: f.CookieSecure,
Expires: f.cookieExpiry(), Expires: f.cookieExpiry(),
} }
} }
// Create a cookie to clear csrf cookie // Create a cookie to clear csrf cookie
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie { func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: f.CSRFCookieName, Name: f.CSRFCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: f.csrfCookieDomain(r), Domain: f.csrfCookieDomain(r),
HttpOnly: true, HttpOnly: true,
Secure: f.CookieSecure, Secure: f.CookieSecure,
Expires: time.Now().Local().Add(time.Hour * -1), Expires: time.Now().Local().Add(time.Hour * -1),
} }
} }
// Validate the csrf cookie against state // Validate the csrf cookie against state
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) { func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
if len(c.Value) != 32 { if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value") return false, "", errors.New("Invalid CSRF cookie value")
} }
if len(state) < 34 { if len(state) < 34 {
return false, "", errors.New("Invalid CSRF state value") return false, "", errors.New("Invalid CSRF state value")
} }
// Check nonce match // Check nonce match
if c.Value != state[:32] { if c.Value != state[:32] {
return false, "", errors.New("CSRF cookie does not match state") return false, "", errors.New("CSRF cookie does not match state")
} }
// Valid, return redirect // Valid, return redirect
return true, state[33:], nil return true, state[33:], nil
} }
func (f *ForwardAuth) Nonce() (error, string) { func (f *ForwardAuth) Nonce() (error, string) {
// Make nonce // Make nonce
nonce := make([]byte, 16) nonce := make([]byte, 16)
_, err := rand.Read(nonce) _, err := rand.Read(nonce)
if err != nil { if err != nil {
return err, "" return err, ""
} }
return nil, fmt.Sprintf("%x", nonce) return nil, fmt.Sprintf("%x", nonce)
} }
// Cookie domain // Cookie domain
func (f *ForwardAuth) cookieDomain(r *http.Request) string { func (f *ForwardAuth) cookieDomain(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host") host := r.Header.Get("X-Forwarded-Host")
// Check if any of the given cookie domains matches // Check if any of the given cookie domains matches
_, domain := f.matchCookieDomains(host) _, domain := f.matchCookieDomains(host)
return domain return domain
} }
// Cookie domain // Cookie domain
func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string { func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string {
var host string var host string
if use, domain := f.useAuthDomain(r); use { if use, domain := f.useAuthDomain(r); use {
host = domain host = domain
} else { } else {
host = r.Header.Get("X-Forwarded-Host") host = r.Header.Get("X-Forwarded-Host")
} }
// Remove port // Remove port
p := strings.Split(host, ":") p := strings.Split(host, ":")
return p[0] return p[0]
} }
// Return matching cookie domain if exists // Return matching cookie domain if exists
func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) { func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
// Remove port // Remove port
p := strings.Split(domain, ":") p := strings.Split(domain, ":")
for _, d := range f.CookieDomains { for _, d := range f.CookieDomains {
if d.Match(p[0]) { if d.Match(p[0]) {
return true, d.Domain return true, d.Domain
} }
} }
return false, p[0] return false, p[0]
} }
// Create cookie hmac // Create cookie hmac
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string { func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, f.Secret) hash := hmac.New(sha256.New, f.Secret)
hash.Write([]byte(f.cookieDomain(r))) hash.Write([]byte(f.cookieDomain(r)))
hash.Write([]byte(email)) hash.Write([]byte(email))
hash.Write([]byte(expires)) hash.Write([]byte(expires))
return base64.URLEncoding.EncodeToString(hash.Sum(nil)) return base64.URLEncoding.EncodeToString(hash.Sum(nil))
} }
// Get cookie expirary // Get cookie expirary
func (f *ForwardAuth) cookieExpiry() time.Time { func (f *ForwardAuth) cookieExpiry() time.Time {
return time.Now().Local().Add(f.Lifetime) return time.Now().Local().Add(f.Lifetime)
} }
// Cookie Domain // Cookie Domain
// Cookie Domain // Cookie Domain
type CookieDomain struct { type CookieDomain struct {
Domain string Domain string
DomainLen int DomainLen int
SubDomain string SubDomain string
SubDomainLen int SubDomainLen int
} }
func NewCookieDomain(domain string) *CookieDomain { func NewCookieDomain(domain string) *CookieDomain {
return &CookieDomain{ return &CookieDomain{
Domain: domain, Domain: domain,
DomainLen: len(domain), DomainLen: len(domain),
SubDomain: fmt.Sprintf(".%s", domain), SubDomain: fmt.Sprintf(".%s", domain),
SubDomainLen: len(domain) + 1, SubDomainLen: len(domain) + 1,
} }
} }
func (c *CookieDomain) Match(host string) bool { func (c *CookieDomain) Match(host string) bool {
// Exact domain match? // Exact domain match?
if host == c.Domain { if host == c.Domain {
return true return true
} }
// Subdomain match? // Subdomain match?
if len(host) >= c.SubDomainLen && host[len(host) - c.SubDomainLen:] == c.SubDomain { if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
return true return true
} }
return false return false
} }

View File

@ -1,284 +1,282 @@
package main package main
import ( import (
// "fmt" // "fmt"
"time" "net/http"
"reflect" "net/url"
"testing" "reflect"
"net/url" "testing"
"net/http" "time"
) )
func TestValidateCookie(t *testing.T) { func TestValidateCookie(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
r, _ := http.NewRequest("GET", "http://example.com", nil) r, _ := http.NewRequest("GET", "http://example.com", nil)
c := &http.Cookie{} c := &http.Cookie{}
// Should require 3 parts // Should require 3 parts
c.Value = "" c.Value = ""
valid, _, err := fw.ValidateCookie(r, c) valid, _, err := fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" { if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err) t.Error("Should get \"Invalid cookie format\", got:", err)
} }
c.Value = "1|2" c.Value = "1|2"
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" { if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err) t.Error("Should get \"Invalid cookie format\", got:", err)
} }
c.Value = "1|2|3|4" c.Value = "1|2|3|4"
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" { if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err) t.Error("Should get \"Invalid cookie format\", got:", err)
} }
// Should catch invalid mac // Should catch invalid mac
c.Value = "MQ==|2|3" c.Value = "MQ==|2|3"
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie mac" { if valid || err.Error() != "Invalid cookie mac" {
t.Error("Should get \"Invalid cookie mac\", got:", err) t.Error("Should get \"Invalid cookie mac\", got:", err)
} }
// Should catch expired // Should catch expired
fw.Lifetime = time.Second * time.Duration(-1) fw.Lifetime = time.Second * time.Duration(-1)
c = fw.MakeCookie(r, "test@test.com") c = fw.MakeCookie(r, "test@test.com")
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Cookie has expired" { if valid || err.Error() != "Cookie has expired" {
t.Error("Should get \"Cookie has expired\", got:", err) t.Error("Should get \"Cookie has expired\", got:", err)
} }
// Should accept valid cookie // Should accept valid cookie
fw.Lifetime = time.Second * time.Duration(10) fw.Lifetime = time.Second * time.Duration(10)
c = fw.MakeCookie(r, "test@test.com") c = fw.MakeCookie(r, "test@test.com")
valid, email, err := fw.ValidateCookie(r, c) valid, email, err := fw.ValidateCookie(r, c)
if !valid { if !valid {
t.Error("Valid request should return as valid") t.Error("Valid request should return as valid")
} }
if err != nil { if err != nil {
t.Error("Valid request should not return error, got:", err) t.Error("Valid request should not return error, got:", err)
} }
if email != "test@test.com" { if email != "test@test.com" {
t.Error("Valid request should return user email") t.Error("Valid request should return user email")
} }
} }
func TestValidateEmail(t *testing.T) { func TestValidateEmail(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
// Should allow any // Should allow any
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") { if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
t.Error("Should allow any domain if email domain is not defined") t.Error("Should allow any domain if email domain is not defined")
} }
// Should block non matching domain // Should block non matching domain
fw.Domain = []string{"test.com"} fw.Domain = []string{"test.com"}
if fw.ValidateEmail("one@two.com") { if fw.ValidateEmail("one@two.com") {
t.Error("Should not allow user from another domain") t.Error("Should not allow user from another domain")
} }
// Should allow matching domain // Should allow matching domain
fw.Domain = []string{"test.com"} fw.Domain = []string{"test.com"}
if !fw.ValidateEmail("test@test.com") { if !fw.ValidateEmail("test@test.com") {
t.Error("Should allow user from allowed domain") t.Error("Should allow user from allowed domain")
} }
// Should block non whitelisted email address // Should block non whitelisted email address
fw.Domain = []string{} fw.Domain = []string{}
fw.Whitelist = []string{"test@test.com"} fw.Whitelist = []string{"test@test.com"}
if fw.ValidateEmail("one@two.com") { if fw.ValidateEmail("one@two.com") {
t.Error("Should not allow user not in whitelist.") t.Error("Should not allow user not in whitelist.")
} }
// Should allow matching whitelisted email address // Should allow matching whitelisted email address
fw.Domain = []string{} fw.Domain = []string{}
fw.Whitelist = []string{"test@test.com"} fw.Whitelist = []string{"test@test.com"}
if !fw.ValidateEmail("test@test.com") { if !fw.ValidateEmail("test@test.com") {
t.Error("Should allow user in whitelist.") t.Error("Should allow user in whitelist.")
} }
} }
func TestGetLoginURL(t *testing.T) { func TestGetLoginURL(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.com", nil) r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http") r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com") r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello") r.Header.Add("X-Forwarded-Uri", "/hello")
fw = &ForwardAuth{ fw = &ForwardAuth{
Path: "/_oauth", Path: "/_oauth",
ClientId: "idtest", ClientId: "idtest",
ClientSecret: "sectest", ClientSecret: "sectest",
Scope: "scopetest", Scope: "scopetest",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "https", Scheme: "https",
Host: "test.com", Host: "test.com",
Path: "/auth", Path: "/auth",
}, },
} }
// Check url // Check url
uri, err := url.Parse(fw.GetLoginURL(r, "nonce")) uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
if err != nil { if err != nil {
t.Error("Error parsing login url:", err) t.Error("Error parsing login url:", err)
} }
if uri.Scheme != "https" { if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme) t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
} }
if uri.Host != "test.com" { if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host) t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
} }
if uri.Path != "/auth" { if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path) t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
} }
// Check query string // Check query string
qs := uri.Query() qs := uri.Query()
expectedQs := url.Values{ expectedQs := url.Values{
"client_id": []string{"idtest"}, "client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"}, "redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"scopetest"}, "scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"}, "state": []string{"nonce:http://example.com/hello"},
} }
if !reflect.DeepEqual(qs, expectedQs) { if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:") t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs) qsDiff(expectedQs, qs)
} }
//
// With Auth URL but no matching cookie domain
// - will not use auth host
//
fw = &ForwardAuth{
Path: "/_oauth",
AuthHost: "auth.example.com",
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
Prompt: "consent select_account",
}
// // Check url
// With Auth URL but no matching cookie domain uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
// - will not use auth host if err != nil {
// t.Error("Error parsing login url:", err)
fw = &ForwardAuth{ }
Path: "/_oauth", if uri.Scheme != "https" {
AuthHost: "auth.example.com", t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
ClientId: "idtest", }
ClientSecret: "sectest", if uri.Host != "test.com" {
Scope: "scopetest", t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
LoginURL: &url.URL{ }
Scheme: "https", if uri.Path != "/auth" {
Host: "test.com", t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
Path: "/auth", }
},
Prompt: "consent select_account",
}
// Check url // Check query string
uri, err = url.Parse(fw.GetLoginURL(r, "nonce")) qs = uri.Query()
if err != nil { expectedQs = url.Values{
t.Error("Error parsing login url:", err) "client_id": []string{"idtest"},
} "redirect_uri": []string{"http://example.com/_oauth"},
if uri.Scheme != "https" { "response_type": []string{"code"},
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme) "scope": []string{"scopetest"},
} "prompt": []string{"consent select_account"},
if uri.Host != "test.com" { "state": []string{"nonce:http://example.com/hello"},
t.Error("Expected login Host to be \"test.com\", got:", uri.Host) }
} if !reflect.DeepEqual(qs, expectedQs) {
if uri.Path != "/auth" { t.Error("Incorrect login query string:")
t.Error("Expected login Path to be \"/auth\", got:", uri.Path) qsDiff(expectedQs, qs)
} }
// Check query string //
qs = uri.Query() // With correct Auth URL + cookie domain
expectedQs = url.Values{ //
"client_id": []string{"idtest"}, cookieDomain := NewCookieDomain("example.com")
"redirect_uri": []string{"http://example.com/_oauth"}, fw = &ForwardAuth{
"response_type": []string{"code"}, Path: "/_oauth",
"scope": []string{"scopetest"}, AuthHost: "auth.example.com",
"prompt": []string{"consent select_account"}, ClientId: "idtest",
"state": []string{"nonce:http://example.com/hello"}, ClientSecret: "sectest",
} Scope: "scopetest",
if !reflect.DeepEqual(qs, expectedQs) { LoginURL: &url.URL{
t.Error("Incorrect login query string:") Scheme: "https",
qsDiff(expectedQs, qs) Host: "test.com",
} Path: "/auth",
},
CookieDomains: []CookieDomain{*cookieDomain},
}
// // Check url
// With correct Auth URL + cookie domain uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
// if err != nil {
cookieDomain := NewCookieDomain("example.com") t.Error("Error parsing login url:", err)
fw = &ForwardAuth{ }
Path: "/_oauth", if uri.Scheme != "https" {
AuthHost: "auth.example.com", t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
ClientId: "idtest", }
ClientSecret: "sectest", if uri.Host != "test.com" {
Scope: "scopetest", t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
LoginURL: &url.URL{ }
Scheme: "https", if uri.Path != "/auth" {
Host: "test.com", t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
Path: "/auth", }
},
CookieDomains: []CookieDomain{*cookieDomain},
}
// Check url // Check query string
uri, err = url.Parse(fw.GetLoginURL(r, "nonce")) qs = uri.Query()
if err != nil { expectedQs = url.Values{
t.Error("Error parsing login url:", err) "client_id": []string{"idtest"},
} "redirect_uri": []string{"http://auth.example.com/_oauth"},
if uri.Scheme != "https" { "response_type": []string{"code"},
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme) "scope": []string{"scopetest"},
} "state": []string{"nonce:http://example.com/hello"},
if uri.Host != "test.com" { }
t.Error("Expected login Host to be \"test.com\", got:", uri.Host) qsDiff(expectedQs, qs)
} if !reflect.DeepEqual(qs, expectedQs) {
if uri.Path != "/auth" { t.Error("Incorrect login query string:")
t.Error("Expected login Path to be \"/auth\", got:", uri.Path) qsDiff(expectedQs, qs)
} }
// Check query string //
qs = uri.Query() // With Auth URL + cookie domain, but from different domain
expectedQs = url.Values{ // - will not use auth host
"client_id": []string{"idtest"}, //
"redirect_uri": []string{"http://auth.example.com/_oauth"}, r, _ = http.NewRequest("GET", "http://another.com", nil)
"response_type": []string{"code"}, r.Header.Add("X-Forwarded-Proto", "http")
"scope": []string{"scopetest"}, r.Header.Add("X-Forwarded-Host", "another.com")
"state": []string{"nonce:http://example.com/hello"}, r.Header.Add("X-Forwarded-Uri", "/hello")
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
// // Check url
// With Auth URL + cookie domain, but from different domain uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
// - will not use auth host if err != nil {
// t.Error("Error parsing login url:", err)
r, _ = http.NewRequest("GET", "http://another.com", nil) }
r.Header.Add("X-Forwarded-Proto", "http") if uri.Scheme != "https" {
r.Header.Add("X-Forwarded-Host", "another.com") t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
r.Header.Add("X-Forwarded-Uri", "/hello") }
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check url // Check query string
uri, err = url.Parse(fw.GetLoginURL(r, "nonce")) qs = uri.Query()
if err != nil { expectedQs = url.Values{
t.Error("Error parsing login url:", err) "client_id": []string{"idtest"},
} "redirect_uri": []string{"http://another.com/_oauth"},
if uri.Scheme != "https" { "response_type": []string{"code"},
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme) "scope": []string{"scopetest"},
} "state": []string{"nonce:http://another.com/hello"},
if uri.Host != "test.com" { }
t.Error("Expected login Host to be \"test.com\", got:", uri.Host) qsDiff(expectedQs, qs)
} if !reflect.DeepEqual(qs, expectedQs) {
if uri.Path != "/auth" { t.Error("Incorrect login query string:")
t.Error("Expected login Path to be \"/auth\", got:", uri.Path) qsDiff(expectedQs, qs)
} }
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://another.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://another.com/hello"},
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
} }
// TODO // TODO
@ -294,123 +292,123 @@ func TestGetLoginURL(t *testing.T) {
// } // }
func TestMakeCSRFCookie(t *testing.T) { func TestMakeCSRFCookie(t *testing.T) {
r, _ := http.NewRequest("GET", "http://app.example.com", nil) r, _ := http.NewRequest("GET", "http://app.example.com", nil)
r.Header.Add("X-Forwarded-Host", "app.example.com") r.Header.Add("X-Forwarded-Host", "app.example.com")
// No cookie domain or auth url // No cookie domain or auth url
fw = &ForwardAuth{} fw = &ForwardAuth{}
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012") c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" { if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain) t.Error("Cookie Domain should match request domain, got:", c.Domain)
} }
// With cookie domain but no auth url // With cookie domain but no auth url
cookieDomain := NewCookieDomain("example.com") cookieDomain := NewCookieDomain("example.com")
fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain},} fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain}}
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012") c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" { if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain) t.Error("Cookie Domain should match request domain, got:", c.Domain)
} }
// With cookie domain and auth url // With cookie domain and auth url
fw = &ForwardAuth{ fw = &ForwardAuth{
AuthHost: "auth.example.com", AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*cookieDomain}, CookieDomains: []CookieDomain{*cookieDomain},
} }
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012") c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "example.com" { if c.Domain != "example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain) t.Error("Cookie Domain should match request domain, got:", c.Domain)
} }
} }
func TestClearCSRFCookie(t *testing.T) { func TestClearCSRFCookie(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
r, _ := http.NewRequest("GET", "http://example.com", nil) r, _ := http.NewRequest("GET", "http://example.com", nil)
c := fw.ClearCSRFCookie(r) c := fw.ClearCSRFCookie(r)
if c.Value != "" { if c.Value != "" {
t.Error("ClearCSRFCookie should create cookie with empty value") t.Error("ClearCSRFCookie should create cookie with empty value")
} }
} }
func TestValidateCSRFCookie(t *testing.T) { func TestValidateCSRFCookie(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
c := &http.Cookie{} c := &http.Cookie{}
// Should require 32 char string // Should require 32 char string
c.Value = "" c.Value = ""
valid, _, err := fw.ValidateCSRFCookie(c, "") valid, _, err := fw.ValidateCSRFCookie(c, "")
if valid || err.Error() != "Invalid CSRF cookie value" { if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err) t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
} }
c.Value = "123456789012345678901234567890123" c.Value = "123456789012345678901234567890123"
valid, _, err = fw.ValidateCSRFCookie(c, "") valid, _, err = fw.ValidateCSRFCookie(c, "")
if valid || err.Error() != "Invalid CSRF cookie value" { if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err) t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
} }
// Should require valid state // Should require valid state
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:") valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
if valid || err.Error() != "Invalid CSRF state value" { if valid || err.Error() != "Invalid CSRF state value" {
t.Error("Should get \"Invalid CSRF state value\", got:", err) t.Error("Should get \"Invalid CSRF state value\", got:", err)
} }
// Should allow valid state // Should allow valid state
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99") valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
if !valid { if !valid {
t.Error("Valid request should return as valid") t.Error("Valid request should return as valid")
} }
if err != nil { if err != nil {
t.Error("Valid request should not return error, got:", err) t.Error("Valid request should not return error, got:", err)
} }
if state != "99" { if state != "99" {
t.Error("Valid request should return correct state, got:", state) t.Error("Valid request should return correct state, got:", state)
} }
} }
func TestNonce(t *testing.T) { func TestNonce(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
err, nonce1 := fw.Nonce() err, nonce1 := fw.Nonce()
if err != nil { if err != nil {
t.Error("Error generation nonce:", err) t.Error("Error generation nonce:", err)
} }
err, nonce2 := fw.Nonce() err, nonce2 := fw.Nonce()
if err != nil { if err != nil {
t.Error("Error generation nonce:", err) t.Error("Error generation nonce:", err)
} }
if len(nonce1) != 32 || len(nonce2) != 32 { if len(nonce1) != 32 || len(nonce2) != 32 {
t.Error("Nonce should be 32 chars") t.Error("Nonce should be 32 chars")
} }
if nonce1 == nonce2 { if nonce1 == nonce2 {
t.Error("Nonce should not be equal") t.Error("Nonce should not be equal")
} }
} }
func TestCookieDomainMatch(t *testing.T) { func TestCookieDomainMatch(t *testing.T) {
cd := NewCookieDomain("example.com") cd := NewCookieDomain("example.com")
// Exact should match // Exact should match
if !cd.Match("example.com") { if !cd.Match("example.com") {
t.Error("Exact domain should match") t.Error("Exact domain should match")
} }
// Subdomain should match // Subdomain should match
if !cd.Match("test.example.com") { if !cd.Match("test.example.com") {
t.Error("Subdomain should match") t.Error("Subdomain should match")
} }
// Derived domain should not match // Derived domain should not match
if cd.Match("testexample.com") { if cd.Match("testexample.com") {
t.Error("Derived domain should not match") t.Error("Derived domain should not match")
} }
// Other domain should not match // Other domain should not match
if cd.Match("test.com") { if cd.Match("test.com") {
t.Error("Other domain should not match") t.Error("Other domain should not match")
} }
} }

363
main.go
View File

@ -1,229 +1,226 @@
package main package main
import ( import (
"fmt" "fmt"
"time" "net/http"
"strings" "net/url"
"net/url" "strings"
"net/http" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/op/go-logging" "github.com/op/go-logging"
) )
// Vars // Vars
var fw *ForwardAuth; var fw *ForwardAuth
var log = logging.MustGetLogger("traefik-forward-auth") var log = logging.MustGetLogger("traefik-forward-auth")
// Primary handler // Primary handler
func handler(w http.ResponseWriter, r *http.Request) { func handler(w http.ResponseWriter, r *http.Request) {
// Parse uri // Parse uri
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri")) uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
if err != nil { if err != nil {
log.Error("Error parsing url") log.Error("Error parsing url")
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
// Handle callback // Handle callback
if uri.Path == fw.Path { if uri.Path == fw.Path {
handleCallback(w, r, uri.Query()) handleCallback(w, r, uri.Query())
return return
} }
// Get auth cookie // Get auth cookie
c, err := r.Cookie(fw.CookieName) c, err := r.Cookie(fw.CookieName)
if err != nil { if err != nil {
// Error indicates no cookie, generate nonce // Error indicates no cookie, generate nonce
err, nonce := fw.Nonce() err, nonce := fw.Nonce()
if err != nil { if err != nil {
log.Error("Error generating nonce") log.Error("Error generating nonce")
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
// Set the CSRF cookie // Set the CSRF cookie
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce)) http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
log.Debug("Set CSRF cookie and redirecting to google login") log.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on // Forward them on
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect) http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
return return
} }
// Validate cookie // Validate cookie
valid, email, err := fw.ValidateCookie(r, c) valid, email, err := fw.ValidateCookie(r, c)
if !valid { if !valid {
log.Debugf("Invalid cookie: %s", err) log.Debugf("Invalid cookie: %s", err)
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Validate user // Validate user
valid = fw.ValidateEmail(email) valid = fw.ValidateEmail(email)
if !valid { if !valid {
log.Debugf("Invalid email: %s", email) log.Debugf("Invalid email: %s", email)
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Valid request // Valid request
w.Header().Set("X-Forwarded-User", email) w.Header().Set("X-Forwarded-User", email)
w.WriteHeader(200) w.WriteHeader(200)
} }
// Authenticate user after they have come back from google // Authenticate user after they have come back from google
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) { func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
// Check for CSRF cookie // Check for CSRF cookie
csrfCookie, err := r.Cookie(fw.CSRFCookieName) csrfCookie, err := r.Cookie(fw.CSRFCookieName)
if err != nil { if err != nil {
log.Debug("Missing csrf cookie") log.Debug("Missing csrf cookie")
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Validate state // Validate state
state := qs.Get("state") state := qs.Get("state")
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state) valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
if !valid { if !valid {
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state) log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state)
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Clear CSRF cookie // Clear CSRF cookie
http.SetCookie(w, fw.ClearCSRFCookie(r)) http.SetCookie(w, fw.ClearCSRFCookie(r))
// Exchange code for token // Exchange code for token
token, err := fw.ExchangeCode(r, qs.Get("code")) token, err := fw.ExchangeCode(r, qs.Get("code"))
if err != nil { if err != nil {
log.Debugf("Code exchange failed with: %s\n", err) log.Debugf("Code exchange failed with: %s\n", err)
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
// Get user // Get user
user, err := fw.GetUser(token) user, err := fw.GetUser(token)
if err != nil { if err != nil {
log.Debugf("Error getting user: %s\n", err) log.Debugf("Error getting user: %s\n", err)
return return
} }
// Generate cookie // Generate cookie
http.SetCookie(w, fw.MakeCookie(r, user.Email)) http.SetCookie(w, fw.MakeCookie(r, user.Email))
log.Debugf("Generated auth cookie for %s\n", user.Email) log.Debugf("Generated auth cookie for %s\n", user.Email)
// Redirect // Redirect
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
} }
// Main // Main
func main() { func main() {
// Parse options // Parse options
flag.String(flag.DefaultConfigFlagname, "", "Path to config file") flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
path := flag.String("url-path", "_oauth", "Callback URL") path := flag.String("url-path", "_oauth", "Callback URL")
lifetime := flag.Int("lifetime", 43200, "Session length in seconds") lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
secret := flag.String("secret", "", "*Secret used for signing (required)") secret := flag.String("secret", "", "*Secret used for signing (required)")
authHost := flag.String("auth-host", "", "Central auth login") authHost := flag.String("auth-host", "", "Central auth login")
clientId := flag.String("client-id", "", "*Google Client ID (required)") clientId := flag.String("client-id", "", "*Google Client ID (required)")
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)") clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name") cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name") cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
cookieSecret := flag.String("cookie-secret", "", "Deprecated") cookieSecret := flag.String("cookie-secret", "", "Deprecated")
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies") cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
domainList := flag.String("domain", "", "Comma separated list of email domains to allow") domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow") emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options") prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
flag.Parse() flag.Parse()
// Backwards compatability // Backwards compatability
if *secret == "" && *cookieSecret != "" { if *secret == "" && *cookieSecret != "" {
*secret = *cookieSecret *secret = *cookieSecret
} }
// Check for show stopper errors // Check for show stopper errors
stop := false stop := false
if *clientId == "" { if *clientId == "" {
stop = true stop = true
log.Critical("client-id must be set") log.Critical("client-id must be set")
} }
if *clientSecret == "" { if *clientSecret == "" {
stop = true stop = true
log.Critical("client-secret must be set") log.Critical("client-secret must be set")
} }
if *secret == "" { if *secret == "" {
stop = true stop = true
log.Critical("secret must be set") log.Critical("secret must be set")
} }
if stop { if stop {
return return
} }
// Parse lists // Parse lists
var cookieDomains []CookieDomain var cookieDomains []CookieDomain
if *cookieDomainList != "" { if *cookieDomainList != "" {
for _, d := range strings.Split(*cookieDomainList, ",") { for _, d := range strings.Split(*cookieDomainList, ",") {
cookieDomain := NewCookieDomain(d) cookieDomain := NewCookieDomain(d)
cookieDomains = append(cookieDomains, *cookieDomain) cookieDomains = append(cookieDomains, *cookieDomain)
} }
} }
var domain []string var domain []string
if *domainList != "" { if *domainList != "" {
domain = strings.Split(*domainList, ",") domain = strings.Split(*domainList, ",")
} }
var whitelist []string var whitelist []string
if *emailWhitelist != "" { if *emailWhitelist != "" {
whitelist = strings.Split(*emailWhitelist, ",") whitelist = strings.Split(*emailWhitelist, ",")
} }
// Setup // Setup
fw = &ForwardAuth{ fw = &ForwardAuth{
Path: fmt.Sprintf("/%s", *path), Path: fmt.Sprintf("/%s", *path),
Lifetime: time.Second * time.Duration(*lifetime), Lifetime: time.Second * time.Duration(*lifetime),
Secret: []byte(*secret), Secret: []byte(*secret),
AuthHost: *authHost, AuthHost: *authHost,
ClientId: *clientId, ClientId: *clientId,
ClientSecret: *clientSecret, ClientSecret: *clientSecret,
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "https", Scheme: "https",
Host: "accounts.google.com", Host: "accounts.google.com",
Path: "/o/oauth2/auth", Path: "/o/oauth2/auth",
}, },
TokenURL: &url.URL{ TokenURL: &url.URL{
Scheme: "https", Scheme: "https",
Host: "www.googleapis.com", Host: "www.googleapis.com",
Path: "/oauth2/v3/token", Path: "/oauth2/v3/token",
}, },
UserURL: &url.URL{ UserURL: &url.URL{
Scheme: "https", Scheme: "https",
Host: "www.googleapis.com", Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo", Path: "/oauth2/v2/userinfo",
}, },
CookieName: *cookieName, CookieName: *cookieName,
CSRFCookieName: *cSRFCookieName, CSRFCookieName: *cSRFCookieName,
CookieDomains: cookieDomains, CookieDomains: cookieDomains,
CookieSecure: *cookieSecure, CookieSecure: *cookieSecure,
Domain: domain, Domain: domain,
Whitelist: whitelist, Whitelist: whitelist,
Prompt: *prompt, Prompt: *prompt,
} }
// Attach handler // Attach handler
http.HandleFunc("/", handler) http.HandleFunc("/", handler)
log.Debugf("Starting with options: %#v", fw) log.Debugf("Starting with options: %#v", fw)
log.Notice("Listening on :4181") log.Notice("Listening on :4181")
log.Notice(http.ListenAndServe(":4181", nil)) log.Notice(http.ListenAndServe(":4181", nil))
} }

View File

@ -1,32 +1,33 @@
package main package main
import ( import (
"fmt" "fmt"
"time" "time"
// "reflect" // "reflect"
"strings" "io/ioutil"
"testing" "net/http"
"net/url" "net/http/httptest"
"net/http" "net/url"
"io/ioutil" "strings"
"net/http/httptest" "testing"
"github.com/op/go-logging" "github.com/op/go-logging"
) )
/** /**
* Utilities * Utilities
*/ */
type TokenServerHandler struct {} type TokenServerHandler struct{}
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"access_token":"123456789"}`) fmt.Fprint(w, `{"access_token":"123456789"}`)
} }
type UserServerHandler struct {} type UserServerHandler struct{}
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{ fmt.Fprint(w, `{
"id":"1", "id":"1",
"email":"example@example.com", "email":"example@example.com",
"verified_email":true, "verified_email":true,
@ -35,51 +36,51 @@ func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func init() { func init() {
// Remove for debugging // Remove for debugging
logging.SetLevel(logging.INFO, "traefik-forward-auth") logging.SetLevel(logging.INFO, "traefik-forward-auth")
} }
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) { func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Set cookies on recorder // Set cookies on recorder
if c != nil { if c != nil {
http.SetCookie(w, c) http.SetCookie(w, c)
} }
// Copy into request // Copy into request
for _, c := range w.HeaderMap["Set-Cookie"] { for _, c := range w.HeaderMap["Set-Cookie"] {
r.Header.Add("Cookie", c) r.Header.Add("Cookie", c)
} }
handler(w, r) handler(w, r)
res := w.Result() res := w.Result()
body, _ := ioutil.ReadAll(res.Body) body, _ := ioutil.ReadAll(res.Body)
return res, string(body) return res, string(body)
} }
func newHttpRequest(uri string) *http.Request { func newHttpRequest(uri string) *http.Request {
r := httptest.NewRequest("", "http://example.com", nil) r := httptest.NewRequest("", "http://example.com", nil)
r.Header.Add("X-Forwarded-Uri", uri) r.Header.Add("X-Forwarded-Uri", uri)
return r return r
} }
func qsDiff(one, two url.Values) { func qsDiff(one, two url.Values) {
for k, _ := range one { for k, _ := range one {
if two.Get(k) == "" { if two.Get(k) == "" {
fmt.Printf("Key missing: %s\n", k) fmt.Printf("Key missing: %s\n", k)
} }
if one.Get(k) != two.Get(k) { if one.Get(k) != two.Get(k) {
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k)) fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
} }
} }
for k, _ := range two { for k, _ := range two {
if one.Get(k) == "" { if one.Get(k) == "" {
fmt.Printf("Extra key: %s\n", k) fmt.Printf("Extra key: %s\n", k)
} }
} }
} }
/** /**
@ -87,126 +88,126 @@ func qsDiff(one, two url.Values) {
*/ */
func TestHandler(t *testing.T) { func TestHandler(t *testing.T) {
fw = &ForwardAuth{ fw = &ForwardAuth{
Path: "_oauth", Path: "_oauth",
ClientId: "idtest", ClientId: "idtest",
ClientSecret: "sectest", ClientSecret: "sectest",
Scope: "scopetest", Scope: "scopetest",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: "test.com", Host: "test.com",
Path: "/auth", Path: "/auth",
}, },
CookieName: "cookie_test", CookieName: "cookie_test",
Lifetime: time.Second * time.Duration(10), Lifetime: time.Second * time.Duration(10),
} }
// Should redirect vanilla request to login url // Should redirect vanilla request to login url
req := newHttpRequest("foo") req := newHttpRequest("foo")
res, _ := httpRequest(req, nil) res, _ := httpRequest(req, nil)
if res.StatusCode != 307 { if res.StatusCode != 307 {
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode) t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
} }
fwd, _ := res.Location() fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" { if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
t.Error("Vanilla request should be redirected to login url, got:", fwd) t.Error("Vanilla request should be redirected to login url, got:", fwd)
} }
// Should catch invalid cookie // Should catch invalid cookie
req = newHttpRequest("foo") req = newHttpRequest("foo")
c := fw.MakeCookie(req, "test@example.com") c := fw.MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|") parts := strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2]) c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode) t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
} }
// Should validate email // Should validate email
req = newHttpRequest("foo") req = newHttpRequest("foo")
c = fw.MakeCookie(req, "test@example.com") c = fw.MakeCookie(req, "test@example.com")
fw.Domain = []string{"test.com"} fw.Domain = []string{"test.com"}
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode) t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
} }
// Should allow valid request email // Should allow valid request email
req = newHttpRequest("foo") req = newHttpRequest("foo")
c = fw.MakeCookie(req, "test@example.com") c = fw.MakeCookie(req, "test@example.com")
fw.Domain = []string{} fw.Domain = []string{}
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 200 { if res.StatusCode != 200 {
t.Error("Valid request should be allowed, got:", res.StatusCode) t.Error("Valid request should be allowed, got:", res.StatusCode)
} }
// Should pass through user // Should pass through user
users := res.Header["X-Forwarded-User"]; users := res.Header["X-Forwarded-User"]
if len(users) != 1 { if len(users) != 1 {
t.Error("Valid request missing X-Forwarded-User header") t.Error("Valid request missing X-Forwarded-User header")
} else if users[0] != "test@example.com" { } else if users[0] != "test@example.com" {
t.Error("X-Forwarded-User should match user, got: ", users) t.Error("X-Forwarded-User should match user, got: ", users)
} }
} }
func TestCallback(t *testing.T) { func TestCallback(t *testing.T) {
fw = &ForwardAuth{ fw = &ForwardAuth{
Path: "_oauth", Path: "_oauth",
ClientId: "idtest", ClientId: "idtest",
ClientSecret: "sectest", ClientSecret: "sectest",
Scope: "scopetest", Scope: "scopetest",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: "test.com", Host: "test.com",
Path: "/auth", Path: "/auth",
}, },
CSRFCookieName: "csrf_test", CSRFCookieName: "csrf_test",
} }
// Setup token server // Setup token server
tokenServerHandler := &TokenServerHandler{} tokenServerHandler := &TokenServerHandler{}
tokenServer := httptest.NewServer(tokenServerHandler) tokenServer := httptest.NewServer(tokenServerHandler)
defer tokenServer.Close() defer tokenServer.Close()
tokenUrl, _ := url.Parse(tokenServer.URL) tokenUrl, _ := url.Parse(tokenServer.URL)
fw.TokenURL = tokenUrl fw.TokenURL = tokenUrl
// Setup user server // Setup user server
userServerHandler := &UserServerHandler{} userServerHandler := &UserServerHandler{}
userServer := httptest.NewServer(userServerHandler) userServer := httptest.NewServer(userServerHandler)
defer userServer.Close() defer userServer.Close()
userUrl, _ := url.Parse(userServer.URL) userUrl, _ := url.Parse(userServer.URL)
fw.UserURL = userUrl fw.UserURL = userUrl
// Should pass auth response request to callback // Should pass auth response request to callback
req := newHttpRequest("_oauth") req := newHttpRequest("_oauth")
res, _ := httpRequest(req, nil) res, _ := httpRequest(req, nil)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode) t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
} }
// Should catch invalid csrf cookie // Should catch invalid csrf cookie
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect") req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
c := fw.MakeCSRFCookie(req, "nononononononononononononononono") c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode) t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
} }
// Should redirect valid request // Should redirect valid request
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect") req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012") c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 307 { if res.StatusCode != 307 {
t.Error("Valid callback should be allowed, got:", res.StatusCode) t.Error("Valid callback should be allowed, got:", res.StatusCode)
} }
fwd, _ := res.Location() fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" { if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
t.Error("Valid request should be redirected to return url, got:", fwd) t.Error("Valid request should be redirected to return url, got:", fwd)
} }
} }