use gofmt to update styling
This commit is contained in:
parent
6ccd1c6dfc
commit
afd8878188
479
forwardauth.go
479
forwardauth.go
@ -1,393 +1,390 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"errors"
|
||||
"strings"
|
||||
"strconv"
|
||||
"net/url"
|
||||
"net/http"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/base64"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Forward Auth
|
||||
type ForwardAuth struct {
|
||||
Path string
|
||||
Lifetime time.Duration
|
||||
Secret []byte
|
||||
Path string
|
||||
Lifetime time.Duration
|
||||
Secret []byte
|
||||
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
Scope string
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
Scope string
|
||||
|
||||
LoginURL *url.URL
|
||||
TokenURL *url.URL
|
||||
UserURL *url.URL
|
||||
LoginURL *url.URL
|
||||
TokenURL *url.URL
|
||||
UserURL *url.URL
|
||||
|
||||
AuthHost string
|
||||
AuthHost string
|
||||
|
||||
CookieName string
|
||||
CookieDomains []CookieDomain
|
||||
CSRFCookieName string
|
||||
CookieSecure bool
|
||||
CookieName string
|
||||
CookieDomains []CookieDomain
|
||||
CSRFCookieName string
|
||||
CookieSecure bool
|
||||
|
||||
Domain []string
|
||||
Whitelist []string
|
||||
Domain []string
|
||||
Whitelist []string
|
||||
|
||||
Prompt string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// Request Validation
|
||||
|
||||
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
||||
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||
parts := strings.Split(c.Value, "|")
|
||||
parts := strings.Split(c.Value, "|")
|
||||
|
||||
if len(parts) != 3 {
|
||||
return false, "", errors.New("Invalid cookie format")
|
||||
}
|
||||
if len(parts) != 3 {
|
||||
return false, "", errors.New("Invalid cookie format")
|
||||
}
|
||||
|
||||
mac, err := base64.URLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to decode cookie mac")
|
||||
}
|
||||
mac, err := base64.URLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to decode cookie mac")
|
||||
}
|
||||
|
||||
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
|
||||
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to generate mac")
|
||||
}
|
||||
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
|
||||
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to generate mac")
|
||||
}
|
||||
|
||||
// Valid token?
|
||||
if !hmac.Equal(mac, expected) {
|
||||
return false, "", errors.New("Invalid cookie mac")
|
||||
}
|
||||
// Valid token?
|
||||
if !hmac.Equal(mac, expected) {
|
||||
return false, "", errors.New("Invalid cookie mac")
|
||||
}
|
||||
|
||||
expires, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to parse cookie expiry")
|
||||
}
|
||||
expires, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to parse cookie expiry")
|
||||
}
|
||||
|
||||
// Has it expired?
|
||||
if time.Unix(expires, 0).Before(time.Now()) {
|
||||
return false, "", errors.New("Cookie has expired")
|
||||
}
|
||||
// Has it expired?
|
||||
if time.Unix(expires, 0).Before(time.Now()) {
|
||||
return false, "", errors.New("Cookie has expired")
|
||||
}
|
||||
|
||||
// Looks valid
|
||||
return true, parts[2], nil
|
||||
// Looks valid
|
||||
return true, parts[2], nil
|
||||
}
|
||||
|
||||
// Validate email
|
||||
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
||||
found := false
|
||||
if len(f.Whitelist) > 0 {
|
||||
for _, whitelist := range f.Whitelist {
|
||||
if email == whitelist {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
} else if len(f.Domain) > 0 {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
for _, domain := range f.Domain {
|
||||
if domain == parts[1] {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
found := false
|
||||
if len(f.Whitelist) > 0 {
|
||||
for _, whitelist := range f.Whitelist {
|
||||
if email == whitelist {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
} else if len(f.Domain) > 0 {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
for _, domain := range f.Domain {
|
||||
if domain == parts[1] {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
|
||||
return found
|
||||
return found
|
||||
}
|
||||
|
||||
|
||||
// OAuth Methods
|
||||
|
||||
// Get login url
|
||||
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", fw.ClientId)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", fw.Scope)
|
||||
if fw.Prompt != "" {
|
||||
q.Set("prompt", fw.Prompt)
|
||||
}
|
||||
q.Set("redirect_uri", f.redirectUri(r))
|
||||
q.Set("state", state)
|
||||
q := url.Values{}
|
||||
q.Set("client_id", fw.ClientId)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", fw.Scope)
|
||||
if fw.Prompt != "" {
|
||||
q.Set("prompt", fw.Prompt)
|
||||
}
|
||||
q.Set("redirect_uri", f.redirectUri(r))
|
||||
q.Set("state", state)
|
||||
|
||||
var u url.URL
|
||||
u = *fw.LoginURL
|
||||
u.RawQuery = q.Encode()
|
||||
var u url.URL
|
||||
u = *fw.LoginURL
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
|
||||
type Token struct {
|
||||
Token string `json:"access_token"`
|
||||
Token string `json:"access_token"`
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", fw.ClientId)
|
||||
form.Set("client_secret", fw.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", f.redirectUri(r))
|
||||
form.Set("code", code)
|
||||
form := url.Values{}
|
||||
form.Set("client_id", fw.ClientId)
|
||||
form.Set("client_secret", fw.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", f.redirectUri(r))
|
||||
form.Set("code", code)
|
||||
|
||||
res, err := http.PostForm(fw.TokenURL.String(), form)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
res, err := http.PostForm(fw.TokenURL.String(), form)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var token Token
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
|
||||
var token Token
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
|
||||
return token.Token, err
|
||||
return token.Token, err
|
||||
}
|
||||
|
||||
// Get user with token
|
||||
|
||||
type User struct {
|
||||
Id string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified_email"`
|
||||
Hd string `json:"hd"`
|
||||
Id string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified_email"`
|
||||
Hd string `json:"hd"`
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
var user User
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&user)
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&user)
|
||||
|
||||
return user, err
|
||||
return user, err
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
// Get the redirect base
|
||||
func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
return fmt.Sprintf("%s://%s", proto, host)
|
||||
return fmt.Sprintf("%s://%s", proto, host)
|
||||
}
|
||||
|
||||
// Return url
|
||||
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
|
||||
}
|
||||
|
||||
// Get oauth redirect uri
|
||||
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
||||
if use, _ := f.useAuthDomain(r); use {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
|
||||
}
|
||||
if use, _ := f.useAuthDomain(r); use {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
||||
}
|
||||
|
||||
// Should we use auth host + what it is
|
||||
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
||||
if f.AuthHost == "" {
|
||||
return false, ""
|
||||
}
|
||||
if f.AuthHost == "" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Does the request match a given cookie domain?
|
||||
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||
// Does the request match a given cookie domain?
|
||||
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||
|
||||
// Do any of the auth hosts match a cookie domain?
|
||||
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
|
||||
// Do any of the auth hosts match a cookie domain?
|
||||
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
|
||||
|
||||
// We need both to match the same domain
|
||||
return reqMatch && authMatch && reqHost == authHost, reqHost
|
||||
// We need both to match the same domain
|
||||
return reqMatch && authMatch && reqHost == authHost, reqHost
|
||||
}
|
||||
|
||||
// Cookie methods
|
||||
|
||||
// Create an auth cookie
|
||||
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||
expires := f.cookieExpiry()
|
||||
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||
expires := f.cookieExpiry()
|
||||
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||
|
||||
return &http.Cookie{
|
||||
Name: f.CookieName,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: f.cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: expires,
|
||||
}
|
||||
return &http.Cookie{
|
||||
Name: f.CookieName,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: f.cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: expires,
|
||||
}
|
||||
}
|
||||
|
||||
// Make a CSRF cookie (used during login only)
|
||||
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
Domain: f.csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: f.cookieExpiry(),
|
||||
}
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
Domain: f.csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: f.cookieExpiry(),
|
||||
}
|
||||
}
|
||||
|
||||
// Create a cookie to clear csrf cookie
|
||||
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: f.csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||
}
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: f.csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the csrf cookie against state
|
||||
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
|
||||
if len(state) < 34 {
|
||||
return false, "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
if len(state) < 34 {
|
||||
return false, "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, "", errors.New("CSRF cookie does not match state")
|
||||
}
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, "", errors.New("CSRF cookie does not match state")
|
||||
}
|
||||
|
||||
// Valid, return redirect
|
||||
return true, state[33:], nil
|
||||
// Valid, return redirect
|
||||
return true, state[33:], nil
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) Nonce() (error, string) {
|
||||
// Make nonce
|
||||
nonce := make([]byte, 16)
|
||||
_, err := rand.Read(nonce)
|
||||
if err != nil {
|
||||
return err, ""
|
||||
}
|
||||
// Make nonce
|
||||
nonce := make([]byte, 16)
|
||||
_, err := rand.Read(nonce)
|
||||
if err != nil {
|
||||
return err, ""
|
||||
}
|
||||
|
||||
return nil, fmt.Sprintf("%x", nonce)
|
||||
return nil, fmt.Sprintf("%x", nonce)
|
||||
}
|
||||
|
||||
// Cookie domain
|
||||
func (f *ForwardAuth) cookieDomain(r *http.Request) string {
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
// Check if any of the given cookie domains matches
|
||||
_, domain := f.matchCookieDomains(host)
|
||||
return domain
|
||||
// Check if any of the given cookie domains matches
|
||||
_, domain := f.matchCookieDomains(host)
|
||||
return domain
|
||||
}
|
||||
|
||||
// Cookie domain
|
||||
func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string {
|
||||
var host string
|
||||
if use, domain := f.useAuthDomain(r); use {
|
||||
host = domain
|
||||
} else {
|
||||
host = r.Header.Get("X-Forwarded-Host")
|
||||
}
|
||||
var host string
|
||||
if use, domain := f.useAuthDomain(r); use {
|
||||
host = domain
|
||||
} else {
|
||||
host = r.Header.Get("X-Forwarded-Host")
|
||||
}
|
||||
|
||||
// Remove port
|
||||
p := strings.Split(host, ":")
|
||||
return p[0]
|
||||
// Remove port
|
||||
p := strings.Split(host, ":")
|
||||
return p[0]
|
||||
}
|
||||
|
||||
// Return matching cookie domain if exists
|
||||
func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
||||
// Remove port
|
||||
p := strings.Split(domain, ":")
|
||||
// Remove port
|
||||
p := strings.Split(domain, ":")
|
||||
|
||||
for _, d := range f.CookieDomains {
|
||||
if d.Match(p[0]) {
|
||||
return true, d.Domain
|
||||
}
|
||||
}
|
||||
for _, d := range f.CookieDomains {
|
||||
if d.Match(p[0]) {
|
||||
return true, d.Domain
|
||||
}
|
||||
}
|
||||
|
||||
return false, p[0]
|
||||
return false, p[0]
|
||||
}
|
||||
|
||||
// Create cookie hmac
|
||||
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
||||
hash := hmac.New(sha256.New, f.Secret)
|
||||
hash.Write([]byte(f.cookieDomain(r)))
|
||||
hash.Write([]byte(email))
|
||||
hash.Write([]byte(expires))
|
||||
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
||||
hash := hmac.New(sha256.New, f.Secret)
|
||||
hash.Write([]byte(f.cookieDomain(r)))
|
||||
hash.Write([]byte(email))
|
||||
hash.Write([]byte(expires))
|
||||
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// Get cookie expirary
|
||||
func (f *ForwardAuth) cookieExpiry() time.Time {
|
||||
return time.Now().Local().Add(f.Lifetime)
|
||||
return time.Now().Local().Add(f.Lifetime)
|
||||
}
|
||||
|
||||
// Cookie Domain
|
||||
|
||||
// Cookie Domain
|
||||
type CookieDomain struct {
|
||||
Domain string
|
||||
DomainLen int
|
||||
SubDomain string
|
||||
SubDomainLen int
|
||||
Domain string
|
||||
DomainLen int
|
||||
SubDomain string
|
||||
SubDomainLen int
|
||||
}
|
||||
|
||||
func NewCookieDomain(domain string) *CookieDomain {
|
||||
return &CookieDomain{
|
||||
Domain: domain,
|
||||
DomainLen: len(domain),
|
||||
SubDomain: fmt.Sprintf(".%s", domain),
|
||||
SubDomainLen: len(domain) + 1,
|
||||
}
|
||||
return &CookieDomain{
|
||||
Domain: domain,
|
||||
DomainLen: len(domain),
|
||||
SubDomain: fmt.Sprintf(".%s", domain),
|
||||
SubDomainLen: len(domain) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CookieDomain) Match(host string) bool {
|
||||
// Exact domain match?
|
||||
if host == c.Domain {
|
||||
return true
|
||||
}
|
||||
// Exact domain match?
|
||||
if host == c.Domain {
|
||||
return true
|
||||
}
|
||||
|
||||
// Subdomain match?
|
||||
if len(host) >= c.SubDomainLen && host[len(host) - c.SubDomainLen:] == c.SubDomain {
|
||||
return true
|
||||
}
|
||||
// Subdomain match?
|
||||
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
return false
|
||||
}
|
||||
|
@ -1,284 +1,282 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
"testing"
|
||||
"net/url"
|
||||
"net/http"
|
||||
// "fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestValidateCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
c := &http.Cookie{}
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
c := &http.Cookie{}
|
||||
|
||||
// Should require 3 parts
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2|3|4"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
// Should require 3 parts
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2|3|4"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch invalid mac
|
||||
c.Value = "MQ==|2|3"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie mac" {
|
||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||
}
|
||||
// Should catch invalid mac
|
||||
c.Value = "MQ==|2|3"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie mac" {
|
||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch expired
|
||||
fw.Lifetime = time.Second * time.Duration(-1)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Cookie has expired" {
|
||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||
}
|
||||
// Should catch expired
|
||||
fw.Lifetime = time.Second * time.Duration(-1)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Cookie has expired" {
|
||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||
}
|
||||
|
||||
// Should accept valid cookie
|
||||
fw.Lifetime = time.Second * time.Duration(10)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, email, err := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if email != "test@test.com" {
|
||||
t.Error("Valid request should return user email")
|
||||
}
|
||||
// Should accept valid cookie
|
||||
fw.Lifetime = time.Second * time.Duration(10)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, email, err := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if email != "test@test.com" {
|
||||
t.Error("Valid request should return user email")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEmail(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
fw = &ForwardAuth{}
|
||||
|
||||
// Should allow any
|
||||
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should allow any domain if email domain is not defined")
|
||||
}
|
||||
// Should allow any
|
||||
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should allow any domain if email domain is not defined")
|
||||
}
|
||||
|
||||
// Should block non matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user from another domain")
|
||||
}
|
||||
// Should block non matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user from another domain")
|
||||
}
|
||||
|
||||
// Should allow matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if !fw.ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user from allowed domain")
|
||||
}
|
||||
// Should allow matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if !fw.ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user from allowed domain")
|
||||
}
|
||||
|
||||
// Should block non whitelisted email address
|
||||
fw.Domain = []string{}
|
||||
fw.Whitelist = []string{"test@test.com"}
|
||||
if fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user not in whitelist.")
|
||||
}
|
||||
// Should block non whitelisted email address
|
||||
fw.Domain = []string{}
|
||||
fw.Whitelist = []string{"test@test.com"}
|
||||
if fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user not in whitelist.")
|
||||
}
|
||||
|
||||
// Should allow matching whitelisted email address
|
||||
fw.Domain = []string{}
|
||||
fw.Whitelist = []string{"test@test.com"}
|
||||
if !fw.ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user in whitelist.")
|
||||
}
|
||||
// Should allow matching whitelisted email address
|
||||
fw.Domain = []string{}
|
||||
fw.Whitelist = []string{"test@test.com"}
|
||||
if !fw.ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user in whitelist.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLoginURL(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
// Check url
|
||||
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
|
||||
//
|
||||
// With Auth URL but no matching cookie domain
|
||||
// - will not use auth host
|
||||
//
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
AuthHost: "auth.example.com",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
Prompt: "consent select_account",
|
||||
}
|
||||
|
||||
//
|
||||
// With Auth URL but no matching cookie domain
|
||||
// - will not use auth host
|
||||
//
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
AuthHost: "auth.example.com",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
Prompt: "consent select_account",
|
||||
}
|
||||
// Check url
|
||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
//
|
||||
// With correct Auth URL + cookie domain
|
||||
//
|
||||
cookieDomain := NewCookieDomain("example.com")
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
AuthHost: "auth.example.com",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CookieDomains: []CookieDomain{*cookieDomain},
|
||||
}
|
||||
|
||||
//
|
||||
// With correct Auth URL + cookie domain
|
||||
//
|
||||
cookieDomain := NewCookieDomain("example.com")
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
AuthHost: "auth.example.com",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CookieDomains: []CookieDomain{*cookieDomain},
|
||||
}
|
||||
// Check url
|
||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://auth.example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
qsDiff(expectedQs, qs)
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://auth.example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
qsDiff(expectedQs, qs)
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
//
|
||||
// With Auth URL + cookie domain, but from different domain
|
||||
// - will not use auth host
|
||||
//
|
||||
r, _ = http.NewRequest("GET", "http://another.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "another.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
//
|
||||
// With Auth URL + cookie domain, but from different domain
|
||||
// - will not use auth host
|
||||
//
|
||||
r, _ = http.NewRequest("GET", "http://another.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "another.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
// Check url
|
||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://another.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://another.com/hello"},
|
||||
}
|
||||
qsDiff(expectedQs, qs)
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://another.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://another.com/hello"},
|
||||
}
|
||||
qsDiff(expectedQs, qs)
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string:")
|
||||
qsDiff(expectedQs, qs)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
@ -294,123 +292,123 @@ func TestGetLoginURL(t *testing.T) {
|
||||
// }
|
||||
|
||||
func TestMakeCSRFCookie(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
|
||||
// No cookie domain or auth url
|
||||
fw = &ForwardAuth{}
|
||||
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
// No cookie domain or auth url
|
||||
fw = &ForwardAuth{}
|
||||
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
|
||||
// With cookie domain but no auth url
|
||||
cookieDomain := NewCookieDomain("example.com")
|
||||
fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain},}
|
||||
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
// With cookie domain but no auth url
|
||||
cookieDomain := NewCookieDomain("example.com")
|
||||
fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain}}
|
||||
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
|
||||
// With cookie domain and auth url
|
||||
fw = &ForwardAuth{
|
||||
AuthHost: "auth.example.com",
|
||||
CookieDomains: []CookieDomain{*cookieDomain},
|
||||
}
|
||||
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
// With cookie domain and auth url
|
||||
fw = &ForwardAuth{
|
||||
AuthHost: "auth.example.com",
|
||||
CookieDomains: []CookieDomain{*cookieDomain},
|
||||
}
|
||||
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCSRFCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
c := fw.ClearCSRFCookie(r)
|
||||
if c.Value != "" {
|
||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||
}
|
||||
c := fw.ClearCSRFCookie(r)
|
||||
if c.Value != "" {
|
||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCSRFCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
c := &http.Cookie{}
|
||||
fw = &ForwardAuth{}
|
||||
c := &http.Cookie{}
|
||||
|
||||
// Should require 32 char string
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
// Should require 32 char string
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
|
||||
// Should require valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
|
||||
if valid || err.Error() != "Invalid CSRF state value" {
|
||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||
}
|
||||
// Should require valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
|
||||
if valid || err.Error() != "Invalid CSRF state value" {
|
||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if state != "99" {
|
||||
t.Error("Valid request should return correct state, got:", state)
|
||||
}
|
||||
// Should allow valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if state != "99" {
|
||||
t.Error("Valid request should return correct state, got:", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonce(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
fw = &ForwardAuth{}
|
||||
|
||||
err, nonce1 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
err, nonce1 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
err, nonce2 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
err, nonce2 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
if len(nonce1) != 32 || len(nonce2) != 32 {
|
||||
t.Error("Nonce should be 32 chars")
|
||||
}
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("Nonce should not be equal")
|
||||
}
|
||||
if len(nonce1) != 32 || len(nonce2) != 32 {
|
||||
t.Error("Nonce should be 32 chars")
|
||||
}
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("Nonce should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieDomainMatch(t *testing.T) {
|
||||
cd := NewCookieDomain("example.com")
|
||||
cd := NewCookieDomain("example.com")
|
||||
|
||||
// Exact should match
|
||||
if !cd.Match("example.com") {
|
||||
t.Error("Exact domain should match")
|
||||
}
|
||||
// Exact should match
|
||||
if !cd.Match("example.com") {
|
||||
t.Error("Exact domain should match")
|
||||
}
|
||||
|
||||
// Subdomain should match
|
||||
if !cd.Match("test.example.com") {
|
||||
t.Error("Subdomain should match")
|
||||
}
|
||||
// Subdomain should match
|
||||
if !cd.Match("test.example.com") {
|
||||
t.Error("Subdomain should match")
|
||||
}
|
||||
|
||||
// Derived domain should not match
|
||||
if cd.Match("testexample.com") {
|
||||
t.Error("Derived domain should not match")
|
||||
}
|
||||
// Derived domain should not match
|
||||
if cd.Match("testexample.com") {
|
||||
t.Error("Derived domain should not match")
|
||||
}
|
||||
|
||||
// Other domain should not match
|
||||
if cd.Match("test.com") {
|
||||
t.Error("Other domain should not match")
|
||||
}
|
||||
// Other domain should not match
|
||||
if cd.Match("test.com") {
|
||||
t.Error("Other domain should not match")
|
||||
}
|
||||
}
|
||||
|
363
main.go
363
main.go
@ -1,229 +1,226 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"strings"
|
||||
"net/url"
|
||||
"net/http"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
"github.com/op/go-logging"
|
||||
"github.com/namsral/flag"
|
||||
"github.com/op/go-logging"
|
||||
)
|
||||
|
||||
// Vars
|
||||
var fw *ForwardAuth;
|
||||
var fw *ForwardAuth
|
||||
var log = logging.MustGetLogger("traefik-forward-auth")
|
||||
|
||||
// Primary handler
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse uri
|
||||
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||
if err != nil {
|
||||
log.Error("Error parsing url")
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
// Parse uri
|
||||
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||
if err != nil {
|
||||
log.Error("Error parsing url")
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle callback
|
||||
if uri.Path == fw.Path {
|
||||
handleCallback(w, r, uri.Query())
|
||||
return
|
||||
}
|
||||
// Handle callback
|
||||
if uri.Path == fw.Path {
|
||||
handleCallback(w, r, uri.Query())
|
||||
return
|
||||
}
|
||||
|
||||
// Get auth cookie
|
||||
c, err := r.Cookie(fw.CookieName)
|
||||
if err != nil {
|
||||
// Error indicates no cookie, generate nonce
|
||||
err, nonce := fw.Nonce()
|
||||
if err != nil {
|
||||
log.Error("Error generating nonce")
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
// Get auth cookie
|
||||
c, err := r.Cookie(fw.CookieName)
|
||||
if err != nil {
|
||||
// Error indicates no cookie, generate nonce
|
||||
err, nonce := fw.Nonce()
|
||||
if err != nil {
|
||||
log.Error("Error generating nonce")
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the CSRF cookie
|
||||
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
||||
log.Debug("Set CSRF cookie and redirecting to google login")
|
||||
// Set the CSRF cookie
|
||||
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
||||
log.Debug("Set CSRF cookie and redirecting to google login")
|
||||
|
||||
// Forward them on
|
||||
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||
// Forward them on
|
||||
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Validate cookie
|
||||
valid, email, err := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
log.Debugf("Invalid cookie: %s", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
// Validate cookie
|
||||
valid, email, err := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
log.Debugf("Invalid cookie: %s", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate user
|
||||
valid = fw.ValidateEmail(email)
|
||||
if !valid {
|
||||
log.Debugf("Invalid email: %s", email)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
// Validate user
|
||||
valid = fw.ValidateEmail(email)
|
||||
if !valid {
|
||||
log.Debugf("Invalid email: %s", email)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Valid request
|
||||
w.Header().Set("X-Forwarded-User", email)
|
||||
w.WriteHeader(200)
|
||||
// Valid request
|
||||
w.Header().Set("X-Forwarded-User", email)
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
|
||||
// Authenticate user after they have come back from google
|
||||
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
||||
// Check for CSRF cookie
|
||||
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
||||
if err != nil {
|
||||
log.Debug("Missing csrf cookie")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
// Check for CSRF cookie
|
||||
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
||||
if err != nil {
|
||||
log.Debug("Missing csrf cookie")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state
|
||||
state := qs.Get("state")
|
||||
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
||||
if !valid {
|
||||
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
// Validate state
|
||||
state := qs.Get("state")
|
||||
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
||||
if !valid {
|
||||
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
||||
if err != nil {
|
||||
log.Debugf("Code exchange failed with: %s\n", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
// Exchange code for token
|
||||
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
||||
if err != nil {
|
||||
log.Debugf("Code exchange failed with: %s\n", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := fw.GetUser(token)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting user: %s\n", err)
|
||||
return
|
||||
}
|
||||
// Get user
|
||||
user, err := fw.GetUser(token)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting user: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate cookie
|
||||
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
||||
log.Debugf("Generated auth cookie for %s\n", user.Email)
|
||||
// Generate cookie
|
||||
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
||||
log.Debugf("Generated auth cookie for %s\n", user.Email)
|
||||
|
||||
// Redirect
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
// Redirect
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
|
||||
// Main
|
||||
func main() {
|
||||
// Parse options
|
||||
flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
|
||||
path := flag.String("url-path", "_oauth", "Callback URL")
|
||||
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
||||
secret := flag.String("secret", "", "*Secret used for signing (required)")
|
||||
authHost := flag.String("auth-host", "", "Central auth login")
|
||||
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
||||
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
||||
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
||||
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
||||
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
||||
cookieSecret := flag.String("cookie-secret", "", "Deprecated")
|
||||
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
||||
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
||||
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
||||
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
||||
// Parse options
|
||||
flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
|
||||
path := flag.String("url-path", "_oauth", "Callback URL")
|
||||
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
||||
secret := flag.String("secret", "", "*Secret used for signing (required)")
|
||||
authHost := flag.String("auth-host", "", "Central auth login")
|
||||
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
||||
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
||||
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
||||
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
||||
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
||||
cookieSecret := flag.String("cookie-secret", "", "Deprecated")
|
||||
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
||||
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
||||
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
||||
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
||||
|
||||
flag.Parse()
|
||||
flag.Parse()
|
||||
|
||||
// Backwards compatability
|
||||
if *secret == "" && *cookieSecret != "" {
|
||||
*secret = *cookieSecret
|
||||
}
|
||||
// Backwards compatability
|
||||
if *secret == "" && *cookieSecret != "" {
|
||||
*secret = *cookieSecret
|
||||
}
|
||||
|
||||
// Check for show stopper errors
|
||||
stop := false
|
||||
if *clientId == "" {
|
||||
stop = true
|
||||
log.Critical("client-id must be set")
|
||||
}
|
||||
if *clientSecret == "" {
|
||||
stop = true
|
||||
log.Critical("client-secret must be set")
|
||||
}
|
||||
if *secret == "" {
|
||||
stop = true
|
||||
log.Critical("secret must be set")
|
||||
}
|
||||
if stop {
|
||||
return
|
||||
}
|
||||
// Check for show stopper errors
|
||||
stop := false
|
||||
if *clientId == "" {
|
||||
stop = true
|
||||
log.Critical("client-id must be set")
|
||||
}
|
||||
if *clientSecret == "" {
|
||||
stop = true
|
||||
log.Critical("client-secret must be set")
|
||||
}
|
||||
if *secret == "" {
|
||||
stop = true
|
||||
log.Critical("secret must be set")
|
||||
}
|
||||
if stop {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse lists
|
||||
var cookieDomains []CookieDomain
|
||||
if *cookieDomainList != "" {
|
||||
for _, d := range strings.Split(*cookieDomainList, ",") {
|
||||
cookieDomain := NewCookieDomain(d)
|
||||
cookieDomains = append(cookieDomains, *cookieDomain)
|
||||
}
|
||||
}
|
||||
// Parse lists
|
||||
var cookieDomains []CookieDomain
|
||||
if *cookieDomainList != "" {
|
||||
for _, d := range strings.Split(*cookieDomainList, ",") {
|
||||
cookieDomain := NewCookieDomain(d)
|
||||
cookieDomains = append(cookieDomains, *cookieDomain)
|
||||
}
|
||||
}
|
||||
|
||||
var domain []string
|
||||
if *domainList != "" {
|
||||
domain = strings.Split(*domainList, ",")
|
||||
}
|
||||
var whitelist []string
|
||||
if *emailWhitelist != "" {
|
||||
whitelist = strings.Split(*emailWhitelist, ",")
|
||||
}
|
||||
var domain []string
|
||||
if *domainList != "" {
|
||||
domain = strings.Split(*domainList, ",")
|
||||
}
|
||||
var whitelist []string
|
||||
if *emailWhitelist != "" {
|
||||
whitelist = strings.Split(*emailWhitelist, ",")
|
||||
}
|
||||
|
||||
// Setup
|
||||
fw = &ForwardAuth{
|
||||
Path: fmt.Sprintf("/%s", *path),
|
||||
Lifetime: time.Second * time.Duration(*lifetime),
|
||||
Secret: []byte(*secret),
|
||||
AuthHost: *authHost,
|
||||
// Setup
|
||||
fw = &ForwardAuth{
|
||||
Path: fmt.Sprintf("/%s", *path),
|
||||
Lifetime: time.Second * time.Duration(*lifetime),
|
||||
Secret: []byte(*secret),
|
||||
AuthHost: *authHost,
|
||||
|
||||
ClientId: *clientId,
|
||||
ClientSecret: *clientSecret,
|
||||
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
},
|
||||
TokenURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
},
|
||||
UserURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
},
|
||||
ClientId: *clientId,
|
||||
ClientSecret: *clientSecret,
|
||||
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
},
|
||||
TokenURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
},
|
||||
UserURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
},
|
||||
|
||||
CookieName: *cookieName,
|
||||
CSRFCookieName: *cSRFCookieName,
|
||||
CookieDomains: cookieDomains,
|
||||
CookieSecure: *cookieSecure,
|
||||
CookieName: *cookieName,
|
||||
CSRFCookieName: *cSRFCookieName,
|
||||
CookieDomains: cookieDomains,
|
||||
CookieSecure: *cookieSecure,
|
||||
|
||||
Domain: domain,
|
||||
Whitelist: whitelist,
|
||||
Domain: domain,
|
||||
Whitelist: whitelist,
|
||||
|
||||
Prompt: *prompt,
|
||||
}
|
||||
Prompt: *prompt,
|
||||
}
|
||||
|
||||
// Attach handler
|
||||
http.HandleFunc("/", handler)
|
||||
// Attach handler
|
||||
http.HandleFunc("/", handler)
|
||||
|
||||
log.Debugf("Starting with options: %#v", fw)
|
||||
log.Notice("Listening on :4181")
|
||||
log.Notice(http.ListenAndServe(":4181", nil))
|
||||
log.Debugf("Starting with options: %#v", fw)
|
||||
log.Notice("Listening on :4181")
|
||||
log.Notice(http.ListenAndServe(":4181", nil))
|
||||
}
|
||||
|
301
main_test.go
301
main_test.go
@ -1,32 +1,33 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
// "reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"net/url"
|
||||
"net/http"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"fmt"
|
||||
"time"
|
||||
// "reflect"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/op/go-logging"
|
||||
"github.com/op/go-logging"
|
||||
)
|
||||
|
||||
/**
|
||||
* Utilities
|
||||
*/
|
||||
|
||||
type TokenServerHandler struct {}
|
||||
type TokenServerHandler struct{}
|
||||
|
||||
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
||||
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
||||
}
|
||||
|
||||
type UserServerHandler struct {}
|
||||
type UserServerHandler struct{}
|
||||
|
||||
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, `{
|
||||
fmt.Fprint(w, `{
|
||||
"id":"1",
|
||||
"email":"example@example.com",
|
||||
"verified_email":true,
|
||||
@ -35,51 +36,51 @@ func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Remove for debugging
|
||||
logging.SetLevel(logging.INFO, "traefik-forward-auth")
|
||||
// Remove for debugging
|
||||
logging.SetLevel(logging.INFO, "traefik-forward-auth")
|
||||
}
|
||||
|
||||
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
w := httptest.NewRecorder()
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Set cookies on recorder
|
||||
if c != nil {
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
// Set cookies on recorder
|
||||
if c != nil {
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
|
||||
// Copy into request
|
||||
for _, c := range w.HeaderMap["Set-Cookie"] {
|
||||
r.Header.Add("Cookie", c)
|
||||
}
|
||||
// Copy into request
|
||||
for _, c := range w.HeaderMap["Set-Cookie"] {
|
||||
r.Header.Add("Cookie", c)
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
handler(w, r)
|
||||
|
||||
res := w.Result()
|
||||
body, _ := ioutil.ReadAll(res.Body)
|
||||
res := w.Result()
|
||||
body, _ := ioutil.ReadAll(res.Body)
|
||||
|
||||
return res, string(body)
|
||||
return res, string(body)
|
||||
}
|
||||
|
||||
func newHttpRequest(uri string) *http.Request {
|
||||
r := httptest.NewRequest("", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Uri", uri)
|
||||
return r
|
||||
r := httptest.NewRequest("", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Uri", uri)
|
||||
return r
|
||||
}
|
||||
|
||||
func qsDiff(one, two url.Values) {
|
||||
for k, _ := range one {
|
||||
if two.Get(k) == "" {
|
||||
fmt.Printf("Key missing: %s\n", k)
|
||||
}
|
||||
if one.Get(k) != two.Get(k) {
|
||||
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
|
||||
}
|
||||
}
|
||||
for k, _ := range two {
|
||||
if one.Get(k) == "" {
|
||||
fmt.Printf("Extra key: %s\n", k)
|
||||
}
|
||||
}
|
||||
for k, _ := range one {
|
||||
if two.Get(k) == "" {
|
||||
fmt.Printf("Key missing: %s\n", k)
|
||||
}
|
||||
if one.Get(k) != two.Get(k) {
|
||||
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
|
||||
}
|
||||
}
|
||||
for k, _ := range two {
|
||||
if one.Get(k) == "" {
|
||||
fmt.Printf("Extra key: %s\n", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -87,126 +88,126 @@ func qsDiff(one, two url.Values) {
|
||||
*/
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
fw = &ForwardAuth{
|
||||
Path: "_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CookieName: "cookie_test",
|
||||
Lifetime: time.Second * time.Duration(10),
|
||||
}
|
||||
fw = &ForwardAuth{
|
||||
Path: "_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CookieName: "cookie_test",
|
||||
Lifetime: time.Second * time.Duration(10),
|
||||
}
|
||||
|
||||
// Should redirect vanilla request to login url
|
||||
req := newHttpRequest("foo")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||
}
|
||||
// Should redirect vanilla request to login url
|
||||
req := newHttpRequest("foo")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||
}
|
||||
|
||||
// Should catch invalid cookie
|
||||
req = newHttpRequest("foo")
|
||||
// Should catch invalid cookie
|
||||
req = newHttpRequest("foo")
|
||||
|
||||
c := fw.MakeCookie(req, "test@example.com")
|
||||
parts := strings.Split(c.Value, "|")
|
||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||
c := fw.MakeCookie(req, "test@example.com")
|
||||
parts := strings.Split(c.Value, "|")
|
||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should validate email
|
||||
req = newHttpRequest("foo")
|
||||
// Should validate email
|
||||
req = newHttpRequest("foo")
|
||||
|
||||
c = fw.MakeCookie(req, "test@example.com")
|
||||
fw.Domain = []string{"test.com"}
|
||||
c = fw.MakeCookie(req, "test@example.com")
|
||||
fw.Domain = []string{"test.com"}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should allow valid request email
|
||||
req = newHttpRequest("foo")
|
||||
// Should allow valid request email
|
||||
req = newHttpRequest("foo")
|
||||
|
||||
c = fw.MakeCookie(req, "test@example.com")
|
||||
fw.Domain = []string{}
|
||||
c = fw.MakeCookie(req, "test@example.com")
|
||||
fw.Domain = []string{}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should pass through user
|
||||
users := res.Header["X-Forwarded-User"];
|
||||
if len(users) != 1 {
|
||||
t.Error("Valid request missing X-Forwarded-User header")
|
||||
} else if users[0] != "test@example.com" {
|
||||
t.Error("X-Forwarded-User should match user, got: ", users)
|
||||
}
|
||||
// Should pass through user
|
||||
users := res.Header["X-Forwarded-User"]
|
||||
if len(users) != 1 {
|
||||
t.Error("Valid request missing X-Forwarded-User header")
|
||||
} else if users[0] != "test@example.com" {
|
||||
t.Error("X-Forwarded-User should match user, got: ", users)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
fw = &ForwardAuth{
|
||||
Path: "_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CSRFCookieName: "csrf_test",
|
||||
}
|
||||
fw = &ForwardAuth{
|
||||
Path: "_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CSRFCookieName: "csrf_test",
|
||||
}
|
||||
|
||||
// Setup token server
|
||||
tokenServerHandler := &TokenServerHandler{}
|
||||
tokenServer := httptest.NewServer(tokenServerHandler)
|
||||
defer tokenServer.Close()
|
||||
tokenUrl, _ := url.Parse(tokenServer.URL)
|
||||
fw.TokenURL = tokenUrl
|
||||
// Setup token server
|
||||
tokenServerHandler := &TokenServerHandler{}
|
||||
tokenServer := httptest.NewServer(tokenServerHandler)
|
||||
defer tokenServer.Close()
|
||||
tokenUrl, _ := url.Parse(tokenServer.URL)
|
||||
fw.TokenURL = tokenUrl
|
||||
|
||||
// Setup user server
|
||||
userServerHandler := &UserServerHandler{}
|
||||
userServer := httptest.NewServer(userServerHandler)
|
||||
defer userServer.Close()
|
||||
userUrl, _ := url.Parse(userServer.URL)
|
||||
fw.UserURL = userUrl
|
||||
// Setup user server
|
||||
userServerHandler := &UserServerHandler{}
|
||||
userServer := httptest.NewServer(userServerHandler)
|
||||
defer userServer.Close()
|
||||
userUrl, _ := url.Parse(userServer.URL)
|
||||
fw.UserURL = userUrl
|
||||
|
||||
// Should pass auth response request to callback
|
||||
req := newHttpRequest("_oauth")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
// Should pass auth response request to callback
|
||||
req := newHttpRequest("_oauth")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should catch invalid csrf cookie
|
||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
// Should catch invalid csrf cookie
|
||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should redirect valid request
|
||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||
}
|
||||
}
|
||||
// Should redirect valid request
|
||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user