diff --git a/internal/config.go b/internal/config.go index 7692500..123271f 100644 --- a/internal/config.go +++ b/internal/config.go @@ -276,6 +276,12 @@ func NewRule() *Rule { } } +func (r *Rule) formattedRule() string { + // Traefik implements their own "Host" matcher and then offers "HostRegexp" + // to invoke the mux "Host" matcher. This ensures the mux version is used + return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(") +} + func (r *Rule) Validate() { if r.Action != "auth" && r.Action != "allow" { log.Fatal("invalid rule action, must be \"auth\" or \"allow\"") diff --git a/internal/server.go b/internal/server.go index 81edd78..1de9480 100644 --- a/internal/server.go +++ b/internal/server.go @@ -28,9 +28,9 @@ func (s *Server) buildRoutes() { // Let's build a router for name, rule := range config.Rules { if rule.Action == "allow" { - s.router.AddRoute(rule.Rule, 1, s.AllowHandler(name)) + s.router.AddRoute(rule.formattedRule(), 1, s.AllowHandler(name)) } else { - s.router.AddRoute(rule.Rule, 1, s.AuthHandler(name)) + s.router.AddRoute(rule.formattedRule(), 1, s.AuthHandler(name)) } } @@ -47,6 +47,8 @@ func (s *Server) buildRoutes() { func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { // Modify request + r.Method = r.Header.Get("X-Forwarded-Method") + r.Host = r.Header.Get("X-Forwarded-Host") r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri")) // Pass to mux diff --git a/internal/server_test.go b/internal/server_test.go index b2d3bb3..816cb4c 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -32,7 +32,7 @@ func TestServerAuthHandler(t *testing.T) { config, _ = NewConfig([]string{}) // Should redirect vanilla request to login url - req := newHttpRequest("/foo") + req := newDefaultHttpRequest("/foo") res, _ := doHttpRequest(req, nil) assert.Equal(307, res.StatusCode, "vanilla request should be redirected") @@ -42,7 +42,7 @@ func TestServerAuthHandler(t *testing.T) { assert.Equal("/o/oauth2/auth", fwd.Path, "vanilla request should be redirected to google") // Should catch invalid cookie - req = newHttpRequest("/foo") + req = newDefaultHttpRequest("/foo") c := MakeCookie(req, "test@example.com") parts := strings.Split(c.Value, "|") c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2]) @@ -51,7 +51,7 @@ func TestServerAuthHandler(t *testing.T) { assert.Equal(401, res.StatusCode, "invalid cookie should not be authorised") // Should validate email - req = newHttpRequest("/foo") + req = newDefaultHttpRequest("/foo") c = MakeCookie(req, "test@example.com") config.Domains = []string{"test.com"} @@ -59,7 +59,7 @@ func TestServerAuthHandler(t *testing.T) { assert.Equal(401, res.StatusCode, "invalid email should not be authorised") // Should allow valid request email - req = newHttpRequest("/foo") + req = newDefaultHttpRequest("/foo") c = MakeCookie(req, "test@example.com") config.Domains = []string{} @@ -91,18 +91,18 @@ func TestServerAuthCallback(t *testing.T) { config.Providers.Google.UserURL = userUrl // Should pass auth response request to callback - req := newHttpRequest("/_oauth") + req := newDefaultHttpRequest("/_oauth") res, _ := doHttpRequest(req, nil) assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised") // Should catch invalid csrf cookie - req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") + req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") c := MakeCSRFCookie(req, "nononononononononononononononono") res, _ = doHttpRequest(req, c) assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised") // Should redirect valid request - req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") + req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") c = MakeCSRFCookie(req, "12345678901234567890123456789012") res, _ = doHttpRequest(req, c) assert.Equal(307, res.StatusCode, "valid auth callback should be allowed") @@ -117,33 +117,151 @@ func TestServerDefaultAction(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) - req := newHttpRequest("/random") + req := newDefaultHttpRequest("/random") res, _ := doHttpRequest(req, nil) assert.Equal(307, res.StatusCode, "request should require auth with auth default handler") config.DefaultAction = "allow" - req = newHttpRequest("/random") + req = newDefaultHttpRequest("/random") res, _ = doHttpRequest(req, nil) assert.Equal(200, res.StatusCode, "request should be allowed with default handler") } -func TestServerRoutePathPrefix(t *testing.T) { +func TestServerRouteHeaders(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) config.Rules = map[string]*Rule{ - "web1": { + "1": { Action: "allow", - Rule: "PathPrefix(`/api`)", + Rule: "Headers(`X-Test`, `test123`)", + }, + "2": { + Action: "allow", + Rule: "HeadersRegexp(`X-Test`, `test(456|789)`)", }, } // Should block any request - req := newHttpRequest("/random") + req := newDefaultHttpRequest("/random") + req.Header.Add("X-Random", "hello") + res, _ := doHttpRequest(req, nil) + assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") + + // Should allow matching + req = newDefaultHttpRequest("/api") + req.Header.Add("X-Test", "test123") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") + + // Should allow matching + req = newDefaultHttpRequest("/api") + req.Header.Add("X-Test", "test789") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") +} + +func TestServerRouteHost(t *testing.T) { + assert := assert.New(t) + config, _ = NewConfig([]string{}) + config.Rules = map[string]*Rule{ + "1": { + Action: "allow", + Rule: "Host(`api.example.com`)", + }, + "2": { + Action: "allow", + Rule: "HostRegexp(`sub{num:[0-9]}.example.com`)", + }, + } + + // Should block any request + req := newHttpRequest("GET", "https://example.com/", "/") + res, _ := doHttpRequest(req, nil) + assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") + + // Should allow matching request + req = newHttpRequest("GET", "https://api.example.com/", "/") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") + + // Should allow matching request + req = newHttpRequest("GET", "https://sub8.example.com/", "/") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") +} + +func TestServerRouteMethod(t *testing.T) { + assert := assert.New(t) + config, _ = NewConfig([]string{}) + config.Rules = map[string]*Rule{ + "1": { + Action: "allow", + Rule: "Method(`PUT`)", + }, + } + + // Should block any request + req := newHttpRequest("GET", "https://example.com/", "/") + res, _ := doHttpRequest(req, nil) + assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") + + // Should allow matching request + req = newHttpRequest("PUT", "https://example.com/", "/") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") +} + +func TestServerRoutePath(t *testing.T) { + assert := assert.New(t) + config, _ = NewConfig([]string{}) + config.Rules = map[string]*Rule{ + "1": { + Action: "allow", + Rule: "Path(`/api`)", + }, + "2": { + Action: "allow", + Rule: "PathPrefix(`/private`)", + }, + } + + // Should block any request + req := newDefaultHttpRequest("/random") res, _ := doHttpRequest(req, nil) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") // Should allow /api request - req = newHttpRequest("/api") + req = newDefaultHttpRequest("/api") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") + + // Should allow /private request + req = newDefaultHttpRequest("/private") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") + + req = newDefaultHttpRequest("/private/path") + res, _ = doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") +} + +func TestServerRouteQuery(t *testing.T) { + assert := assert.New(t) + config, _ = NewConfig([]string{}) + config.Rules = map[string]*Rule{ + "1": { + Action: "allow", + Rule: "Query(`q=test123`)", + }, + } + + // Should block any request + req := newHttpRequest("GET", "https://example.com/", "/?q=no") + res, _ := doHttpRequest(req, nil) + assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") + + // Should allow matching request + req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123") res, _ = doHttpRequest(req, nil) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } @@ -194,8 +312,15 @@ func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) { return res, string(body) } -func newHttpRequest(uri string) *http.Request { - r := httptest.NewRequest("", "http://example.com/", nil) +func newDefaultHttpRequest(uri string) *http.Request { + return newHttpRequest("", "http://example.com/", uri) +} + +func newHttpRequest(method, dest, uri string) *http.Request { + r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil) + p, _ := url.Parse(dest) + r.Header.Add("X-Forwarded-Method", method) + r.Header.Add("X-Forwarded-Host", p.Host) r.Header.Add("X-Forwarded-Uri", uri) return r }