Fix host/method rule matching + tests
This commit is contained in:
parent
6f3ac5efe5
commit
d1b12e4ffb
@ -276,6 +276,12 @@ func NewRule() *Rule {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Rule) formattedRule() string {
|
||||||
|
// Traefik implements their own "Host" matcher and then offers "HostRegexp"
|
||||||
|
// to invoke the mux "Host" matcher. This ensures the mux version is used
|
||||||
|
return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(")
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Rule) Validate() {
|
func (r *Rule) Validate() {
|
||||||
if r.Action != "auth" && r.Action != "allow" {
|
if r.Action != "auth" && r.Action != "allow" {
|
||||||
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"")
|
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"")
|
||||||
|
@ -28,9 +28,9 @@ func (s *Server) buildRoutes() {
|
|||||||
// Let's build a router
|
// Let's build a router
|
||||||
for name, rule := range config.Rules {
|
for name, rule := range config.Rules {
|
||||||
if rule.Action == "allow" {
|
if rule.Action == "allow" {
|
||||||
s.router.AddRoute(rule.Rule, 1, s.AllowHandler(name))
|
s.router.AddRoute(rule.formattedRule(), 1, s.AllowHandler(name))
|
||||||
} else {
|
} else {
|
||||||
s.router.AddRoute(rule.Rule, 1, s.AuthHandler(name))
|
s.router.AddRoute(rule.formattedRule(), 1, s.AuthHandler(name))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,6 +47,8 @@ func (s *Server) buildRoutes() {
|
|||||||
|
|
||||||
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// Modify request
|
// Modify request
|
||||||
|
r.Method = r.Header.Get("X-Forwarded-Method")
|
||||||
|
r.Host = r.Header.Get("X-Forwarded-Host")
|
||||||
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||||
|
|
||||||
// Pass to mux
|
// Pass to mux
|
||||||
|
@ -32,7 +32,7 @@ func TestServerAuthHandler(t *testing.T) {
|
|||||||
config, _ = NewConfig([]string{})
|
config, _ = NewConfig([]string{})
|
||||||
|
|
||||||
// Should redirect vanilla request to login url
|
// Should redirect vanilla request to login url
|
||||||
req := newHttpRequest("/foo")
|
req := newDefaultHttpRequest("/foo")
|
||||||
res, _ := doHttpRequest(req, nil)
|
res, _ := doHttpRequest(req, nil)
|
||||||
assert.Equal(307, res.StatusCode, "vanilla request should be redirected")
|
assert.Equal(307, res.StatusCode, "vanilla request should be redirected")
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ func TestServerAuthHandler(t *testing.T) {
|
|||||||
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
|
assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google")
|
||||||
|
|
||||||
// Should catch invalid cookie
|
// Should catch invalid cookie
|
||||||
req = newHttpRequest("/foo")
|
req = newDefaultHttpRequest("/foo")
|
||||||
c := MakeCookie(req, "test@example.com")
|
c := MakeCookie(req, "test@example.com")
|
||||||
parts := strings.Split(c.Value, "|")
|
parts := strings.Split(c.Value, "|")
|
||||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||||
@ -51,7 +51,7 @@ func TestServerAuthHandler(t *testing.T) {
|
|||||||
assert.Equal(401, res.StatusCode, "invalid cookie should not be authorised")
|
assert.Equal(401, res.StatusCode, "invalid cookie should not be authorised")
|
||||||
|
|
||||||
// Should validate email
|
// Should validate email
|
||||||
req = newHttpRequest("/foo")
|
req = newDefaultHttpRequest("/foo")
|
||||||
c = MakeCookie(req, "test@example.com")
|
c = MakeCookie(req, "test@example.com")
|
||||||
config.Domains = []string{"test.com"}
|
config.Domains = []string{"test.com"}
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ func TestServerAuthHandler(t *testing.T) {
|
|||||||
assert.Equal(401, res.StatusCode, "invalid email should not be authorised")
|
assert.Equal(401, res.StatusCode, "invalid email should not be authorised")
|
||||||
|
|
||||||
// Should allow valid request email
|
// Should allow valid request email
|
||||||
req = newHttpRequest("/foo")
|
req = newDefaultHttpRequest("/foo")
|
||||||
c = MakeCookie(req, "test@example.com")
|
c = MakeCookie(req, "test@example.com")
|
||||||
config.Domains = []string{}
|
config.Domains = []string{}
|
||||||
|
|
||||||
@ -91,18 +91,18 @@ func TestServerAuthCallback(t *testing.T) {
|
|||||||
config.Providers.Google.UserURL = userUrl
|
config.Providers.Google.UserURL = userUrl
|
||||||
|
|
||||||
// Should pass auth response request to callback
|
// Should pass auth response request to callback
|
||||||
req := newHttpRequest("/_oauth")
|
req := newDefaultHttpRequest("/_oauth")
|
||||||
res, _ := doHttpRequest(req, nil)
|
res, _ := doHttpRequest(req, nil)
|
||||||
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
|
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
|
||||||
|
|
||||||
// Should catch invalid csrf cookie
|
// Should catch invalid csrf cookie
|
||||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
||||||
res, _ = doHttpRequest(req, c)
|
res, _ = doHttpRequest(req, c)
|
||||||
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
|
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
|
||||||
|
|
||||||
// Should redirect valid request
|
// Should redirect valid request
|
||||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||||
res, _ = doHttpRequest(req, c)
|
res, _ = doHttpRequest(req, c)
|
||||||
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
||||||
@ -117,33 +117,151 @@ func TestServerDefaultAction(t *testing.T) {
|
|||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
config, _ = NewConfig([]string{})
|
config, _ = NewConfig([]string{})
|
||||||
|
|
||||||
req := newHttpRequest("/random")
|
req := newDefaultHttpRequest("/random")
|
||||||
res, _ := doHttpRequest(req, nil)
|
res, _ := doHttpRequest(req, nil)
|
||||||
assert.Equal(307, res.StatusCode, "request should require auth with auth default handler")
|
assert.Equal(307, res.StatusCode, "request should require auth with auth default handler")
|
||||||
|
|
||||||
config.DefaultAction = "allow"
|
config.DefaultAction = "allow"
|
||||||
req = newHttpRequest("/random")
|
req = newDefaultHttpRequest("/random")
|
||||||
res, _ = doHttpRequest(req, nil)
|
res, _ = doHttpRequest(req, nil)
|
||||||
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
|
assert.Equal(200, res.StatusCode, "request should be allowed with default handler")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerRoutePathPrefix(t *testing.T) {
|
func TestServerRouteHeaders(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
config, _ = NewConfig([]string{})
|
config, _ = NewConfig([]string{})
|
||||||
config.Rules = map[string]*Rule{
|
config.Rules = map[string]*Rule{
|
||||||
"web1": {
|
"1": {
|
||||||
Action: "allow",
|
Action: "allow",
|
||||||
Rule: "PathPrefix(`/api`)",
|
Rule: "Headers(`X-Test`, `test123`)",
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "HeadersRegexp(`X-Test`, `test(456|789)`)",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should block any request
|
// Should block any request
|
||||||
req := newHttpRequest("/random")
|
req := newDefaultHttpRequest("/random")
|
||||||
|
req.Header.Add("X-Random", "hello")
|
||||||
|
res, _ := doHttpRequest(req, nil)
|
||||||
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||||
|
|
||||||
|
// Should allow matching
|
||||||
|
req = newDefaultHttpRequest("/api")
|
||||||
|
req.Header.Add("X-Test", "test123")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
|
||||||
|
// Should allow matching
|
||||||
|
req = newDefaultHttpRequest("/api")
|
||||||
|
req.Header.Add("X-Test", "test789")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouteHost(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
config.Rules = map[string]*Rule{
|
||||||
|
"1": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "Host(`api.example.com`)",
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "HostRegexp(`sub{num:[0-9]}.example.com`)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should block any request
|
||||||
|
req := newHttpRequest("GET", "https://example.com/", "/")
|
||||||
|
res, _ := doHttpRequest(req, nil)
|
||||||
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||||
|
|
||||||
|
// Should allow matching request
|
||||||
|
req = newHttpRequest("GET", "https://api.example.com/", "/")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
|
||||||
|
// Should allow matching request
|
||||||
|
req = newHttpRequest("GET", "https://sub8.example.com/", "/")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouteMethod(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
config.Rules = map[string]*Rule{
|
||||||
|
"1": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "Method(`PUT`)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should block any request
|
||||||
|
req := newHttpRequest("GET", "https://example.com/", "/")
|
||||||
|
res, _ := doHttpRequest(req, nil)
|
||||||
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||||
|
|
||||||
|
// Should allow matching request
|
||||||
|
req = newHttpRequest("PUT", "https://example.com/", "/")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRoutePath(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
config.Rules = map[string]*Rule{
|
||||||
|
"1": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "Path(`/api`)",
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "PathPrefix(`/private`)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should block any request
|
||||||
|
req := newDefaultHttpRequest("/random")
|
||||||
res, _ := doHttpRequest(req, nil)
|
res, _ := doHttpRequest(req, nil)
|
||||||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||||
|
|
||||||
// Should allow /api request
|
// Should allow /api request
|
||||||
req = newHttpRequest("/api")
|
req = newDefaultHttpRequest("/api")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
|
||||||
|
// Should allow /private request
|
||||||
|
req = newDefaultHttpRequest("/private")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
|
||||||
|
req = newDefaultHttpRequest("/private/path")
|
||||||
|
res, _ = doHttpRequest(req, nil)
|
||||||
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRouteQuery(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
config.Rules = map[string]*Rule{
|
||||||
|
"1": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "Query(`q=test123`)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should block any request
|
||||||
|
req := newHttpRequest("GET", "https://example.com/", "/?q=no")
|
||||||
|
res, _ := doHttpRequest(req, nil)
|
||||||
|
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||||
|
|
||||||
|
// Should allow matching request
|
||||||
|
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123")
|
||||||
res, _ = doHttpRequest(req, nil)
|
res, _ = doHttpRequest(req, nil)
|
||||||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||||
}
|
}
|
||||||
@ -194,8 +312,15 @@ func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
|||||||
return res, string(body)
|
return res, string(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHttpRequest(uri string) *http.Request {
|
func newDefaultHttpRequest(uri string) *http.Request {
|
||||||
r := httptest.NewRequest("", "http://example.com/", nil)
|
return newHttpRequest("", "http://example.com/", uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHttpRequest(method, dest, uri string) *http.Request {
|
||||||
|
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
|
||||||
|
p, _ := url.Parse(dest)
|
||||||
|
r.Header.Add("X-Forwarded-Method", method)
|
||||||
|
r.Header.Add("X-Forwarded-Host", p.Host)
|
||||||
r.Header.Add("X-Forwarded-Uri", uri)
|
r.Header.Add("X-Forwarded-Uri", uri)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user