add outline tests + add validation errors + make more FA methods private
This commit is contained in:
parent
55e8b3064c
commit
18865e956c
@ -4,6 +4,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"errors"
|
||||
"strings"
|
||||
"strconv"
|
||||
"net/url"
|
||||
@ -43,41 +44,41 @@ type ForwardAuth struct {
|
||||
// Request Validation
|
||||
|
||||
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
||||
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string) {
|
||||
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||
parts := strings.Split(c.Value, "|")
|
||||
|
||||
if len(parts) != 3 {
|
||||
return false, ""
|
||||
return false, "", errors.New("Invalid cookie format")
|
||||
}
|
||||
|
||||
mac, err := base64.URLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return false, ""
|
||||
return false, "", errors.New("Unable to decode cookie mac")
|
||||
}
|
||||
|
||||
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
|
||||
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
return false, "", errors.New("Unable to generate mac")
|
||||
}
|
||||
|
||||
// Valid token?
|
||||
if !hmac.Equal(mac, expected) {
|
||||
return false, ""
|
||||
return false, "", errors.New("Invalid cookie mac")
|
||||
}
|
||||
|
||||
expires, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
return false, "", errors.New("Unable to parse cookie expiry")
|
||||
}
|
||||
|
||||
// Has it expired?
|
||||
if time.Unix(expires, 0).Before(time.Now()) {
|
||||
return false, ""
|
||||
return false, "", errors.New("Cookie has expired")
|
||||
}
|
||||
|
||||
// Looks valid
|
||||
return true, parts[2]
|
||||
return true, parts[2], nil
|
||||
}
|
||||
|
||||
// Validate email
|
||||
@ -106,14 +107,14 @@ func (f *ForwardAuth) ValidateEmail(email string) bool {
|
||||
|
||||
// Get login url
|
||||
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
||||
state := fmt.Sprintf("%s:%s", nonce, f.ReturnUrl(r))
|
||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", fw.ClientId)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", fw.Scope)
|
||||
// q.Set("approval_prompt", fw.ClientId)
|
||||
q.Set("redirect_uri", f.RedirectUri(r))
|
||||
q.Set("redirect_uri", f.redirectUri(r))
|
||||
q.Set("state", state)
|
||||
|
||||
var u url.URL
|
||||
@ -129,12 +130,12 @@ type Token struct {
|
||||
Token string `json:"access_token"`
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) ExchangeCode(r *http.Request, code, redirect string) (string, error) {
|
||||
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", fw.ClientId)
|
||||
form.Set("client_secret", fw.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", f.RedirectUri(r))
|
||||
form.Set("redirect_uri", f.redirectUri(r))
|
||||
form.Set("code", code)
|
||||
|
||||
|
||||
@ -183,7 +184,7 @@ func (f *ForwardAuth) GetUser(token string) (User, error) {
|
||||
// Utility methods
|
||||
|
||||
// Get the redirect base
|
||||
func (f *ForwardAuth) RedirectBase(r *http.Request) string {
|
||||
func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
@ -197,7 +198,7 @@ func (f *ForwardAuth) RedirectBase(r *http.Request) string {
|
||||
}
|
||||
|
||||
// Return url
|
||||
func (f *ForwardAuth) ReturnUrl(r *http.Request) string {
|
||||
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
|
||||
// Testing
|
||||
@ -205,12 +206,12 @@ func (f *ForwardAuth) ReturnUrl(r *http.Request) string {
|
||||
path = r.URL.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", f.RedirectBase(r), path)
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
|
||||
}
|
||||
|
||||
// Get oauth redirect uri
|
||||
func (f *ForwardAuth) RedirectUri(r *http.Request) string {
|
||||
return fmt.Sprintf("%s%s", f.RedirectBase(r), f.Path)
|
||||
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
||||
}
|
||||
|
||||
// Cookie methods
|
||||
@ -259,22 +260,22 @@ func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
}
|
||||
|
||||
// Validate the csrf cookie against state
|
||||
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string) {
|
||||
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
|
||||
if len(c.Value) != 32 {
|
||||
return false, ""
|
||||
return false, "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
|
||||
if len(state) < 34 {
|
||||
return false, ""
|
||||
return false, "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, ""
|
||||
return false, "", errors.New("CSRF cookie does not match state")
|
||||
}
|
||||
|
||||
// Valid, return redirect
|
||||
return true, state[33:]
|
||||
return true, state[33:], nil
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) Nonce() (error, string) {
|
||||
|
242
forwardauth_test.go
Normal file
242
forwardauth_test.go
Normal file
@ -0,0 +1,242 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
"testing"
|
||||
"net/url"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func TestValidateCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
c := &http.Cookie{}
|
||||
|
||||
// Should require 3 parts
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2|3|4"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch invalid mac
|
||||
c.Value = "MQ==|2|3"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie mac" {
|
||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch expired
|
||||
fw.Lifetime = time.Second * time.Duration(-1)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Cookie has expired" {
|
||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||
}
|
||||
|
||||
// Should accept valid cookie
|
||||
fw.Lifetime = time.Second * time.Duration(10)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, email, err := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if email != "test@test.com" {
|
||||
t.Error("Valid request should return user email")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEmail(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
|
||||
// Should allow any
|
||||
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should allow any domain if email domain is not defined")
|
||||
}
|
||||
|
||||
// Should block non matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user from another domain")
|
||||
}
|
||||
|
||||
// Should allow matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if !fw.ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user from allowed domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLoginURL(t *testing.T) {
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string, expected:")
|
||||
t.Error(expectedQs)
|
||||
t.Error("Got:")
|
||||
t.Error(qs)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO
|
||||
// func TestExchangeCode(t *testing.T) {
|
||||
// }
|
||||
|
||||
// TODO
|
||||
// func TestGetUser(t *testing.T) {
|
||||
// }
|
||||
|
||||
// TODO? Tested in TestValidateCookie
|
||||
// func TestMakeCookie(t *testing.T) {
|
||||
// }
|
||||
|
||||
// func TestMakeCSRFCookie(t *testing.T) {
|
||||
// t.Log("TODO")
|
||||
// }
|
||||
|
||||
func TestClearCSRFCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
c := fw.ClearCSRFCookie(r)
|
||||
if c.Value != "" {
|
||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCSRFCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
c := &http.Cookie{}
|
||||
|
||||
// Should require 32 char string
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
|
||||
// Should require valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
|
||||
if valid || err.Error() != "Invalid CSRF state value" {
|
||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if state != "99" {
|
||||
t.Error("Valid request should return correct state, got:", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonce(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
|
||||
err, nonce1 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
err, nonce2 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
if len(nonce1) != 32 || len(nonce2) != 32 {
|
||||
t.Error("Nonce should be 32 chars")
|
||||
}
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("Nonce should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieDomainMatch(t *testing.T) {
|
||||
cd := NewCookieDomain("example.com")
|
||||
|
||||
// Exact should match
|
||||
if !cd.Match("example.com") {
|
||||
t.Error("Exact domain should match")
|
||||
}
|
||||
|
||||
// Subdomain should match
|
||||
if !cd.Match("test.example.com") {
|
||||
t.Error("Subdomain should match")
|
||||
}
|
||||
|
||||
// Derived domain should not match
|
||||
if cd.Match("testexample.com") {
|
||||
t.Error("Derived domain should not match")
|
||||
}
|
||||
|
||||
// Other domain should not match
|
||||
if cd.Match("test.com") {
|
||||
t.Error("Other domain should not match")
|
||||
}
|
||||
}
|
6
main.go
6
main.go
@ -58,7 +58,7 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Validate cookie
|
||||
valid, email := fw.ValidateCookie(r, c)
|
||||
valid, email, _ := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
@ -88,7 +88,7 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
||||
|
||||
// Validate state
|
||||
state := qs.Get("state")
|
||||
valid, redirect := fw.ValidateCSRFCookie(csrfCookie, state)
|
||||
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
||||
if !valid && false {
|
||||
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
@ -99,7 +99,7 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
||||
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := fw.ExchangeCode(r, qs.Get("code"), fw.RedirectUri(r))
|
||||
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
||||
if err != nil {
|
||||
log.Debugf("Code exchange failed with: %s\n", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
|
Loading…
x
Reference in New Issue
Block a user