Redirect to login on cookie expiry + simplify ValidateCookie function

Possible fix for #31
This commit is contained in:
Thom Seddon 2019-06-13 15:13:52 +01:00
parent 3e92400202
commit 3e6ccc8f45
5 changed files with 86 additions and 48 deletions

View File

@ -18,41 +18,41 @@ import (
// Request Validation // Request Validation
// Cookie = hash(secret, cookie domain, email, expires)|expires|email // Cookie = hash(secret, cookie domain, email, expires)|expires|email
func ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) { func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
parts := strings.Split(c.Value, "|") parts := strings.Split(c.Value, "|")
if len(parts) != 3 { if len(parts) != 3 {
return false, "", errors.New("Invalid cookie format") return "", errors.New("Invalid cookie format")
} }
mac, err := base64.URLEncoding.DecodeString(parts[0]) mac, err := base64.URLEncoding.DecodeString(parts[0])
if err != nil { if err != nil {
return false, "", errors.New("Unable to decode cookie mac") return "", errors.New("Unable to decode cookie mac")
} }
expectedSignature := cookieSignature(r, parts[2], parts[1]) expectedSignature := cookieSignature(r, parts[2], parts[1])
expected, err := base64.URLEncoding.DecodeString(expectedSignature) expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil { if err != nil {
return false, "", errors.New("Unable to generate mac") return "", errors.New("Unable to generate mac")
} }
// Valid token? // Valid token?
if !hmac.Equal(mac, expected) { if !hmac.Equal(mac, expected) {
return false, "", errors.New("Invalid cookie mac") return "", errors.New("Invalid cookie mac")
} }
expires, err := strconv.ParseInt(parts[1], 10, 64) expires, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil { if err != nil {
return false, "", errors.New("Unable to parse cookie expiry") return "", errors.New("Unable to parse cookie expiry")
} }
// Has it expired? // Has it expired?
if time.Unix(expires, 0).Before(time.Now()) { if time.Unix(expires, 0).Before(time.Now()) {
return false, "", errors.New("Cookie has expired") return "", errors.New("Cookie has expired")
} }
// Looks valid // Looks valid
return true, parts[2], nil return parts[2], nil
} }
// Validate email // Validate email

View File

@ -24,28 +24,24 @@ func TestAuthValidateCookie(t *testing.T) {
// Should require 3 parts // Should require 3 parts
c.Value = "" c.Value = ""
valid, _, err := ValidateCookie(r, c) _, err := ValidateCookie(r, c)
assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid cookie format", err.Error()) assert.Equal("Invalid cookie format", err.Error())
} }
c.Value = "1|2" c.Value = "1|2"
valid, _, err = ValidateCookie(r, c) _, err = ValidateCookie(r, c)
assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid cookie format", err.Error()) assert.Equal("Invalid cookie format", err.Error())
} }
c.Value = "1|2|3|4" c.Value = "1|2|3|4"
valid, _, err = ValidateCookie(r, c) _, err = ValidateCookie(r, c)
assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid cookie format", err.Error()) assert.Equal("Invalid cookie format", err.Error())
} }
// Should catch invalid mac // Should catch invalid mac
c.Value = "MQ==|2|3" c.Value = "MQ==|2|3"
valid, _, err = ValidateCookie(r, c) _, err = ValidateCookie(r, c)
assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid cookie mac", err.Error()) assert.Equal("Invalid cookie mac", err.Error())
} }
@ -53,8 +49,7 @@ func TestAuthValidateCookie(t *testing.T) {
// Should catch expired // Should catch expired
config.Lifetime = time.Second * time.Duration(-1) config.Lifetime = time.Second * time.Duration(-1)
c = MakeCookie(r, "test@test.com") c = MakeCookie(r, "test@test.com")
valid, _, err = ValidateCookie(r, c) _, err = ValidateCookie(r, c)
assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Cookie has expired", err.Error()) assert.Equal("Cookie has expired", err.Error())
} }
@ -62,8 +57,7 @@ func TestAuthValidateCookie(t *testing.T) {
// Should accept valid cookie // Should accept valid cookie
config.Lifetime = time.Second * time.Duration(10) config.Lifetime = time.Second * time.Duration(10)
c = MakeCookie(r, "test@test.com") c = MakeCookie(r, "test@test.com")
valid, email, err := ValidateCookie(r, c) email, err := ValidateCookie(r, c)
assert.True(valid, "valid request should return valid")
assert.Nil(err, "valid request should not return an error") assert.Nil(err, "valid request should not return an error")
assert.Equal("test@test.com", email, "valid request should return user email") assert.Equal("test@test.com", email, "valid request should return user email")
} }
@ -244,8 +238,8 @@ func TestAuthMakeCookie(t *testing.T) {
assert.Equal("_forward_auth", c.Name) assert.Equal("_forward_auth", c.Name)
parts := strings.Split(c.Value, "|") parts := strings.Split(c.Value, "|")
assert.Len(parts, 3, "cookie should be 3 parts") assert.Len(parts, 3, "cookie should be 3 parts")
valid, _, _ := ValidateCookie(r, c) _, err := ValidateCookie(r, c)
assert.True(valid, "should generate valid cookie") assert.Nil(err, "should generate valid cookie")
assert.Equal("/", c.Path) assert.Equal("/", c.Path)
assert.Equal("app.example.com", c.Domain) assert.Equal("app.example.com", c.Domain)
assert.True(c.Secure) assert.True(c.Secure)

View File

@ -237,7 +237,7 @@ func TestConfigParseEnvironmentBackwardsCompatability(t *testing.T) {
"COOKIE_SECURE": "false", "COOKIE_SECURE": "false",
"COOKIE_DOMAINS": "test1.com,example.org", "COOKIE_DOMAINS": "test1.com,example.org",
"COOKIE_DOMAIN": "another1.net", "COOKIE_DOMAIN": "another1.net",
"DOMAIN": "test2.com,example.org", "DOMAIN": "test2.com,example.org",
"WHITELIST": "test3.com,example.org", "WHITELIST": "test3.com,example.org",
} }
for k, v := range vars { for k, v := range vars {

View File

@ -72,35 +72,25 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc {
// Get auth cookie // Get auth cookie
c, err := r.Cookie(config.CookieName) c, err := r.Cookie(config.CookieName)
if err != nil { if err != nil {
// Error indicates no cookie, generate nonce s.authRedirect(logger, w, r)
err, nonce := Nonce()
if err != nil {
logger.Errorf("Error generating nonce, %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Set the CSRF cookie
http.SetCookie(w, MakeCSRFCookie(r, nonce))
logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
logger.Debug("Done")
return return
} }
// Validate cookie // Validate cookie
valid, email, err := ValidateCookie(r, c) email, err := ValidateCookie(r, c)
if !valid { if err != nil {
logger.Errorf("Invalid cookie: %v", err) if err.Error() == "Cookie has expired" {
http.Error(w, "Not authorized", 401) logger.Info("Cookie has expired")
s.authRedirect(logger, w, r)
} else {
logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401)
}
return return
} }
// Validate user // Validate user
valid = ValidateEmail(email) valid := ValidateEmail(email)
if !valid { if !valid {
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"email": email, "email": email,
@ -167,6 +157,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
} }
} }
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) {
// Error indicates no cookie, generate nonce
err, nonce := Nonce()
if err != nil {
logger.Errorf("Error generating nonce, %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Set the CSRF cookie
http.SetCookie(w, MakeCSRFCookie(r, nonce))
logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
logger.Debug("Done")
return
}
func (s *Server) logger(r *http.Request, rule, msg string) *logrus.Entry { func (s *Server) logger(r *http.Request, rule, msg string) *logrus.Entry {
// Create logger // Create logger
logger := log.WithFields(logrus.Fields{ logger := log.WithFields(logrus.Fields{

View File

@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -27,7 +28,7 @@ func init() {
* Tests * Tests
*/ */
func TestServerAuthHandler(t *testing.T) { func TestServerAuthHandlerInvalid(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config, _ = NewConfig([]string{})
@ -57,13 +58,46 @@ func TestServerAuthHandler(t *testing.T) {
res, _ = doHttpRequest(req, c) res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "invalid email should not be authorised") assert.Equal(401, res.StatusCode, "invalid email should not be authorised")
}
func TestServerAuthHandlerExpired(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
config.Lifetime = time.Second * time.Duration(-1)
config.Domains = []string{"test.com"}
// Should redirect expired cookie
req := newDefaultHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
res, _ := doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")
// Check for CSRF cookie
var cookie *http.Cookie
for _, c := range res.Cookies() {
if c.Name == config.CSRFCookieName {
cookie = c
}
}
assert.NotNil(cookie)
// Check redirection location
fwd, _ := res.Location()
assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google")
assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google")
assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google")
}
func TestServerAuthHandlerValid(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})
// Should allow valid request email // Should allow valid request email
req = newDefaultHttpRequest("/foo") req := newDefaultHttpRequest("/foo")
c = MakeCookie(req, "test@example.com") c := MakeCookie(req, "test@example.com")
config.Domains = []string{} config.Domains = []string{}
res, _ = doHttpRequest(req, c) res, _ := doHttpRequest(req, c)
assert.Equal(200, res.StatusCode, "valid request should be allowed") assert.Equal(200, res.StatusCode, "valid request should be allowed")
// Should pass through user // Should pass through user