Support concurrent CSRF cookies by using a prefix of nonce (#187)
* Support concurrent CSRF cookies by using a prefix of nonce. * Move ValidateState out and make CSRF cookies last 1h * add tests to check csrf cookie nam + minor tweaks Co-authored-by: Michal Witkowski <michal@cerberus>
This commit is contained in:
parent
1743537438
commit
41560feaa7
@ -170,23 +170,31 @@ func ClearCookie(r *http.Request) *http.Cookie {
|
||||
}
|
||||
}
|
||||
|
||||
func buildCSRFCookieName(nonce string) string {
|
||||
return config.CSRFCookieName + "_" + nonce[:6]
|
||||
}
|
||||
|
||||
// MakeCSRFCookie makes a csrf cookie (used during login only)
|
||||
//
|
||||
// Note, CSRF cookies live shorter than auth cookies, a fixed 1h.
|
||||
// That's because some CSRF cookies may belong to auth flows that don't complete
|
||||
// and thus may not get cleared by ClearCookie.
|
||||
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: config.CSRFCookieName,
|
||||
Name: buildCSRFCookieName(nonce),
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
Domain: csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: !config.InsecureCookie,
|
||||
Expires: cookieExpiry(),
|
||||
Expires: time.Now().Local().Add(time.Hour * 1),
|
||||
}
|
||||
}
|
||||
|
||||
// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
|
||||
func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: config.CSRFCookieName,
|
||||
Name: c.Name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: csrfCookieDomain(r),
|
||||
@ -196,18 +204,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateCSRFCookie validates the csrf cookie against state
|
||||
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
|
||||
state := r.URL.Query().Get("state")
|
||||
// FindCSRFCookie extracts the CSRF cookie from the request based on state.
|
||||
func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) {
|
||||
// Check for CSRF cookie
|
||||
return r.Cookie(buildCSRFCookieName(state))
|
||||
}
|
||||
|
||||
// ValidateCSRFCookie validates the csrf cookie against state
|
||||
func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) {
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
|
||||
if len(state) < 34 {
|
||||
return false, "", "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, "", "", errors.New("CSRF cookie does not match state")
|
||||
@ -229,6 +237,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string {
|
||||
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
|
||||
}
|
||||
|
||||
// ValidateState checks whether the state is of right length.
|
||||
func ValidateState(state string) error {
|
||||
if len(state) < 34 {
|
||||
return errors.New("Invalid CSRF state value")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Nonce generates a random nonce
|
||||
func Nonce() (error, string) {
|
||||
nonce := make([]byte, 16)
|
||||
|
@ -1,7 +1,6 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -217,29 +216,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) {
|
||||
|
||||
// No cookie domain or auth url
|
||||
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
assert.Equal("_forward_auth_csrf_123456", c.Name)
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain but no auth url
|
||||
config = &Config{
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
c = MakeCSRFCookie(r, "12222278901234567890123456789012")
|
||||
assert.Equal("_forward_auth_csrf_122222", c.Name)
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain and auth url
|
||||
config = &Config{
|
||||
AuthHost: "auth.example.com",
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
c = MakeCSRFCookie(r, "12333378901234567890123456789012")
|
||||
assert.Equal("_forward_auth_csrf_123333", c.Name)
|
||||
assert.Equal("example.com", c.Domain)
|
||||
}
|
||||
|
||||
func TestAuthClearCSRFCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
c := ClearCSRFCookie(r)
|
||||
c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"})
|
||||
assert.Equal("someCsrfCookie", c.Name)
|
||||
if c.Value != "" {
|
||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||
}
|
||||
@ -249,56 +249,57 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
c := &http.Cookie{}
|
||||
|
||||
newCsrfRequest := func(state string) *http.Request {
|
||||
u := fmt.Sprintf("http://example.com?state=%s", state)
|
||||
r, _ := http.NewRequest("GET", u, nil)
|
||||
return r
|
||||
}
|
||||
state := ""
|
||||
|
||||
// Should require 32 char string
|
||||
r := newCsrfRequest("")
|
||||
state = ""
|
||||
c.Value = ""
|
||||
valid, _, _, err := ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err := ValidateCSRFCookie(c, state)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err = ValidateCSRFCookie(c, state)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
}
|
||||
|
||||
// Should require valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state value", err.Error())
|
||||
}
|
||||
|
||||
// Should require provider
|
||||
r = newCsrfRequest("12345678901234567890123456789012:99")
|
||||
state = "12345678901234567890123456789012:99"
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err = ValidateCSRFCookie(c, state)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state format", err.Error())
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
|
||||
state = "12345678901234567890123456789012:p99:url123"
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
|
||||
valid, provider, redirect, err := ValidateCSRFCookie(c, state)
|
||||
assert.True(valid, "valid request should return valid")
|
||||
assert.Nil(err, "valid request should not return an error")
|
||||
assert.Equal("p99", provider, "valid request should return correct provider")
|
||||
assert.Equal("url123", redirect, "valid request should return correct redirect")
|
||||
}
|
||||
|
||||
func TestValidateState(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Should require valid state
|
||||
state := "12345678901234567890123456789012:"
|
||||
err := ValidateState(state)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state value", err.Error())
|
||||
}
|
||||
// Should pass this state
|
||||
state = "12345678901234567890123456789012:p99:url123"
|
||||
err = ValidateState(state)
|
||||
assert.Nil(err, "valid request should not return an error")
|
||||
}
|
||||
|
||||
func TestMakeState(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
|
@ -121,16 +121,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
// Logging setup
|
||||
logger := s.logger(r, "AuthCallback", "default", "Handling callback")
|
||||
|
||||
// Check state
|
||||
state := r.URL.Query().Get("state")
|
||||
if err := ValidateState(state); err != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
}).Warn("Error validating state")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for CSRF cookie
|
||||
c, err := r.Cookie(config.CSRFCookieName)
|
||||
c, err := FindCSRFCookie(r, state)
|
||||
if err != nil {
|
||||
logger.Info("Missing csrf cookie")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state
|
||||
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
|
||||
// Validate CSRF cookie against state
|
||||
valid, providerName, redirect, err := ValidateCSRFCookie(c, state)
|
||||
if !valid {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
@ -153,7 +163,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, ClearCSRFCookie(r))
|
||||
http.SetCookie(w, ClearCSRFCookie(r, c))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
|
||||
|
@ -98,7 +98,7 @@ func TestServerAuthHandlerExpired(t *testing.T) {
|
||||
// Check for CSRF cookie
|
||||
var cookie *http.Cookie
|
||||
for _, c := range res.Cookies() {
|
||||
if c.Name == config.CSRFCookieName {
|
||||
if strings.HasPrefix(c.Name, config.CSRFCookieName) {
|
||||
cookie = c
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user