Support concurrent CSRF cookies by using a prefix of nonce (#187)

* Support concurrent CSRF cookies by using a prefix of nonce.
* Move ValidateState out and make CSRF cookies last 1h
* add tests to check csrf cookie nam + minor tweaks

Co-authored-by: Michal Witkowski <michal@cerberus>
This commit is contained in:
Thom Seddon 2020-09-23 14:48:04 +01:00 committed by GitHub
parent 1743537438
commit 41560feaa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 49 deletions

View File

@ -170,23 +170,31 @@ func ClearCookie(r *http.Request) *http.Cookie {
} }
} }
func buildCSRFCookieName(nonce string) string {
return config.CSRFCookieName + "_" + nonce[:6]
}
// MakeCSRFCookie makes a csrf cookie (used during login only) // MakeCSRFCookie makes a csrf cookie (used during login only)
//
// Note, CSRF cookies live shorter than auth cookies, a fixed 1h.
// That's because some CSRF cookies may belong to auth flows that don't complete
// and thus may not get cleared by ClearCookie.
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: config.CSRFCookieName, Name: buildCSRFCookieName(nonce),
Value: nonce, Value: nonce,
Path: "/", Path: "/",
Domain: csrfCookieDomain(r), Domain: csrfCookieDomain(r),
HttpOnly: true, HttpOnly: true,
Secure: !config.InsecureCookie, Secure: !config.InsecureCookie,
Expires: cookieExpiry(), Expires: time.Now().Local().Add(time.Hour * 1),
} }
} }
// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie // ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie { func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: config.CSRFCookieName, Name: c.Name,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: csrfCookieDomain(r), Domain: csrfCookieDomain(r),
@ -196,16 +204,16 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
} }
} }
// ValidateCSRFCookie validates the csrf cookie against state // FindCSRFCookie extracts the CSRF cookie from the request based on state.
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) { func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) {
state := r.URL.Query().Get("state") // Check for CSRF cookie
return r.Cookie(buildCSRFCookieName(state))
if len(c.Value) != 32 {
return false, "", "", errors.New("Invalid CSRF cookie value")
} }
if len(state) < 34 { // ValidateCSRFCookie validates the csrf cookie against state
return false, "", "", errors.New("Invalid CSRF state value") func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) {
if len(c.Value) != 32 {
return false, "", "", errors.New("Invalid CSRF cookie value")
} }
// Check nonce match // Check nonce match
@ -229,6 +237,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r)) return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
} }
// ValidateState checks whether the state is of right length.
func ValidateState(state string) error {
if len(state) < 34 {
return errors.New("Invalid CSRF state value")
}
return nil
}
// Nonce generates a random nonce // Nonce generates a random nonce
func Nonce() (error, string) { func Nonce() (error, string) {
nonce := make([]byte, 16) nonce := make([]byte, 16)

View File

@ -1,7 +1,6 @@
package tfa package tfa
import ( import (
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -217,29 +216,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) {
// No cookie domain or auth url // No cookie domain or auth url
c := MakeCSRFCookie(r, "12345678901234567890123456789012") c := MakeCSRFCookie(r, "12345678901234567890123456789012")
assert.Equal("_forward_auth_csrf_123456", c.Name)
assert.Equal("app.example.com", c.Domain) assert.Equal("app.example.com", c.Domain)
// With cookie domain but no auth url // With cookie domain but no auth url
config = &Config{ config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, c = MakeCSRFCookie(r, "12222278901234567890123456789012")
} assert.Equal("_forward_auth_csrf_122222", c.Name)
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
assert.Equal("app.example.com", c.Domain) assert.Equal("app.example.com", c.Domain)
// With cookie domain and auth url // With cookie domain and auth url
config = &Config{ config.AuthHost = "auth.example.com"
AuthHost: "auth.example.com", config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")}, c = MakeCSRFCookie(r, "12333378901234567890123456789012")
} assert.Equal("_forward_auth_csrf_123333", c.Name)
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
assert.Equal("example.com", c.Domain) assert.Equal("example.com", c.Domain)
} }
func TestAuthClearCSRFCookie(t *testing.T) { func TestAuthClearCSRFCookie(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{}) config, _ = NewConfig([]string{})
r, _ := http.NewRequest("GET", "http://example.com", nil) r, _ := http.NewRequest("GET", "http://example.com", nil)
c := ClearCSRFCookie(r) c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"})
assert.Equal("someCsrfCookie", c.Name)
if c.Value != "" { if c.Value != "" {
t.Error("ClearCSRFCookie should create cookie with empty value") t.Error("ClearCSRFCookie should create cookie with empty value")
} }
@ -249,56 +249,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
config, _ = NewConfig([]string{}) config, _ = NewConfig([]string{})
c := &http.Cookie{} c := &http.Cookie{}
state := ""
newCsrfRequest := func(state string) *http.Request {
u := fmt.Sprintf("http://example.com?state=%s", state)
r, _ := http.NewRequest("GET", u, nil)
return r
}
// Should require 32 char string // Should require 32 char string
r := newCsrfRequest("") state = ""
c.Value = "" c.Value = ""
valid, _, _, err := ValidateCSRFCookie(r, c) valid, _, _, err := ValidateCSRFCookie(c, state)
assert.False(valid) assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error()) assert.Equal("Invalid CSRF cookie value", err.Error())
} }
c.Value = "123456789012345678901234567890123" c.Value = "123456789012345678901234567890123"
valid, _, _, err = ValidateCSRFCookie(r, c) valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid) assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid CSRF cookie value", err.Error()) assert.Equal("Invalid CSRF cookie value", err.Error())
} }
// Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c)
assert.False(valid)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}
// Should require provider // Should require provider
r = newCsrfRequest("12345678901234567890123456789012:99") state = "12345678901234567890123456789012:99"
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, _, _, err = ValidateCSRFCookie(r, c) valid, _, _, err = ValidateCSRFCookie(c, state)
assert.False(valid) assert.False(valid)
if assert.Error(err) { if assert.Error(err) {
assert.Equal("Invalid CSRF state format", err.Error()) assert.Equal("Invalid CSRF state format", err.Error())
} }
// Should allow valid state // Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:p99:url123") state = "12345678901234567890123456789012:p99:url123"
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, provider, redirect, err := ValidateCSRFCookie(r, c) valid, provider, redirect, err := ValidateCSRFCookie(c, state)
assert.True(valid, "valid request should return valid") assert.True(valid, "valid request should return valid")
assert.Nil(err, "valid request should not return an error") assert.Nil(err, "valid request should not return an error")
assert.Equal("p99", provider, "valid request should return correct provider") assert.Equal("p99", provider, "valid request should return correct provider")
assert.Equal("url123", redirect, "valid request should return correct redirect") assert.Equal("url123", redirect, "valid request should return correct redirect")
} }
func TestValidateState(t *testing.T) {
assert := assert.New(t)
// Should require valid state
state := "12345678901234567890123456789012:"
err := ValidateState(state)
if assert.Error(err) {
assert.Equal("Invalid CSRF state value", err.Error())
}
// Should pass this state
state = "12345678901234567890123456789012:p99:url123"
err = ValidateState(state)
assert.Nil(err, "valid request should not return an error")
}
func TestMakeState(t *testing.T) { func TestMakeState(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)

View File

@ -121,16 +121,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
// Logging setup // Logging setup
logger := s.logger(r, "AuthCallback", "default", "Handling callback") logger := s.logger(r, "AuthCallback", "default", "Handling callback")
// Check state
state := r.URL.Query().Get("state")
if err := ValidateState(state); err != nil {
logger.WithFields(logrus.Fields{
"error": err,
}).Warn("Error validating state")
http.Error(w, "Not authorized", 401)
return
}
// Check for CSRF cookie // Check for CSRF cookie
c, err := r.Cookie(config.CSRFCookieName) c, err := FindCSRFCookie(r, state)
if err != nil { if err != nil {
logger.Info("Missing csrf cookie") logger.Info("Missing csrf cookie")
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Validate state // Validate CSRF cookie against state
valid, providerName, redirect, err := ValidateCSRFCookie(r, c) valid, providerName, redirect, err := ValidateCSRFCookie(c, state)
if !valid { if !valid {
logger.WithFields(logrus.Fields{ logger.WithFields(logrus.Fields{
"error": err, "error": err,
@ -153,7 +163,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
} }
// Clear CSRF cookie // Clear CSRF cookie
http.SetCookie(w, ClearCSRFCookie(r)) http.SetCookie(w, ClearCSRFCookie(r, c))
// Exchange code for token // Exchange code for token
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code")) token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))

View File

@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
// Check for CSRF cookie // Check for CSRF cookie
var cookie *http.Cookie var cookie *http.Cookie
for _, c := range res.Cookies() { for _, c := range res.Cookies() {
if c.Name == config.CSRFCookieName { if strings.HasPrefix(c.Name, config.CSRFCookieName) {
cookie = c cookie = c
} }
} }