Refactor progress
- move directory structure - string based rule definition - use traefik rule parsing - drop toml config - new flag library - implement go dep
This commit is contained in:
parent
d51b93d4b0
commit
9abe509f66
@ -1,10 +1,5 @@
|
|||||||
language: go
|
language: go
|
||||||
sudo: false
|
sudo: false
|
||||||
go:
|
go:
|
||||||
- "1.10"
|
- "1.12"
|
||||||
install:
|
script: env GO111MODULE=on go test -v ./...
|
||||||
- go get github.com/BurntSushi/toml
|
|
||||||
- go get github.com/gorilla/mux
|
|
||||||
- go get github.com/namsral/flag
|
|
||||||
- go get github.com/sirupsen/logrus
|
|
||||||
script: go test -v ./...
|
|
||||||
|
2
Makefile
2
Makefile
@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
format:
|
format:
|
||||||
gofmt -w -s *.go provider/*.go
|
gofmt -w -s internal/*.go cmd/*.go
|
||||||
|
|
||||||
.PHONY: format
|
.PHONY: format
|
||||||
|
30
cmd/main.go
Normal file
30
cmd/main.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
internal "github.com/thomseddon/traefik-forward-auth/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Main
|
||||||
|
func main() {
|
||||||
|
// Parse options
|
||||||
|
config := internal.NewGlobalConfig()
|
||||||
|
|
||||||
|
// Setup logger
|
||||||
|
log := internal.NewDefaultLogger()
|
||||||
|
|
||||||
|
// Perform config checks
|
||||||
|
config.Checks()
|
||||||
|
|
||||||
|
// Build server
|
||||||
|
server := internal.NewServer()
|
||||||
|
|
||||||
|
// Attach router to default server
|
||||||
|
http.HandleFunc("/", server.RootHandler)
|
||||||
|
|
||||||
|
// Start
|
||||||
|
log.Debugf("Starting with options: %s", config.Serialise())
|
||||||
|
log.Info("Listening on :4181")
|
||||||
|
log.Info(http.ListenAndServe(":4181", nil))
|
||||||
|
}
|
204
config.go
204
config.go
@ -1,204 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
|
||||||
"github.com/namsral/flag"
|
|
||||||
"github.com/thomseddon/traefik-forward-auth/provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
DefaultAction string
|
|
||||||
Path string
|
|
||||||
Lifetime time.Duration
|
|
||||||
Secret string
|
|
||||||
SecretBytes []byte
|
|
||||||
AuthHost string
|
|
||||||
|
|
||||||
LogLevel *string
|
|
||||||
LogFormat *string
|
|
||||||
TomlConfig *string // temp
|
|
||||||
|
|
||||||
CookieName string
|
|
||||||
CookieDomains []CookieDomain
|
|
||||||
CSRFCookieName string
|
|
||||||
CookieSecure bool
|
|
||||||
|
|
||||||
Domain []string
|
|
||||||
Whitelist []string
|
|
||||||
|
|
||||||
Providers provider.Providers
|
|
||||||
Rules map[string]Rules
|
|
||||||
}
|
|
||||||
|
|
||||||
type Rules struct {
|
|
||||||
Action string
|
|
||||||
Match []Match
|
|
||||||
}
|
|
||||||
|
|
||||||
type Match struct {
|
|
||||||
Host []string
|
|
||||||
PathPrefix []string
|
|
||||||
Header [][]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConfig() *Config {
|
|
||||||
c := &Config{}
|
|
||||||
c.parseFlags()
|
|
||||||
c.applyDefaults()
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Fix
|
|
||||||
// At the moment any flag value will overwrite the toml config
|
|
||||||
// Need to put the flag default values in applyDefaults & empty the flag
|
|
||||||
// defaults so we can check if they're being passed and set accordingly
|
|
||||||
// Ideally we also need to remove the two calls to parseFlags
|
|
||||||
//
|
|
||||||
// We also need to check the default -config flag for toml suffix and
|
|
||||||
// parse that as needed
|
|
||||||
//
|
|
||||||
// Ideally we'd also support multiple config files
|
|
||||||
|
|
||||||
func NewParsedConfig() *Config {
|
|
||||||
c := &Config{}
|
|
||||||
|
|
||||||
// Temp
|
|
||||||
c.parseFlags()
|
|
||||||
|
|
||||||
// Parse toml
|
|
||||||
if *c.TomlConfig != "" {
|
|
||||||
if _, err := toml.DecodeFile(*c.TomlConfig, &c); err != nil {
|
|
||||||
panic(err)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.applyDefaults()
|
|
||||||
|
|
||||||
// Conversions
|
|
||||||
c.SecretBytes = []byte(c.Secret)
|
|
||||||
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) Checks() {
|
|
||||||
// Check for show stopper errors
|
|
||||||
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" || len(c.Secret) == 0 {
|
|
||||||
log.Fatal("client-id, client-secret and secret must all be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) applyDefaults() {
|
|
||||||
// Providers
|
|
||||||
// Google
|
|
||||||
if c.Providers.Google.Scope == "" {
|
|
||||||
c.Providers.Google.Scope = "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email"
|
|
||||||
}
|
|
||||||
if c.Providers.Google.LoginURL == nil {
|
|
||||||
c.Providers.Google.LoginURL = &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "accounts.google.com",
|
|
||||||
Path: "/o/oauth2/auth",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Providers.Google.TokenURL == nil {
|
|
||||||
c.Providers.Google.TokenURL = &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "www.googleapis.com",
|
|
||||||
Path: "/oauth2/v3/token",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.Providers.Google.UserURL == nil {
|
|
||||||
c.Providers.Google.UserURL = &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "www.googleapis.com",
|
|
||||||
Path: "/oauth2/v2/userinfo",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) parseFlags() {
|
|
||||||
c.LogLevel = flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
|
|
||||||
c.LogFormat = flag.String("log-format", "text", "Log format: text, json, pretty")
|
|
||||||
c.TomlConfig = flag.String("toml-config", "", "TEMP")
|
|
||||||
|
|
||||||
// Legacy?
|
|
||||||
path := flag.String("url-path", "_oauth", "Callback URL")
|
|
||||||
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
|
||||||
secret := flag.String("secret", "", "*Secret used for signing (required)")
|
|
||||||
authHost := flag.String("auth-host", "", "Central auth login")
|
|
||||||
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
|
||||||
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
|
||||||
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
|
||||||
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
|
||||||
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
|
||||||
cookieSecret := flag.String("cookie-secret", "", "Deprecated")
|
|
||||||
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
|
||||||
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
|
||||||
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
|
||||||
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
// Add to config
|
|
||||||
c.Path = fmt.Sprintf("/%s", *path)
|
|
||||||
c.Lifetime = time.Second * time.Duration(*lifetime)
|
|
||||||
c.AuthHost = *authHost
|
|
||||||
c.Providers.Google.ClientId = *clientId
|
|
||||||
c.Providers.Google.ClientSecret = *clientSecret
|
|
||||||
c.Providers.Google.Prompt = *prompt
|
|
||||||
c.CookieName = *cookieName
|
|
||||||
c.CSRFCookieName = *cSRFCookieName
|
|
||||||
c.CookieSecure = *cookieSecure
|
|
||||||
|
|
||||||
// Backwards compatibility
|
|
||||||
if *secret == "" && *cookieSecret != "" {
|
|
||||||
*secret = *cookieSecret
|
|
||||||
}
|
|
||||||
c.Secret = *secret
|
|
||||||
|
|
||||||
// Parse lists
|
|
||||||
if *cookieDomainList != "" {
|
|
||||||
for _, d := range strings.Split(*cookieDomainList, ",") {
|
|
||||||
cookieDomain := NewCookieDomain(d)
|
|
||||||
c.CookieDomains = append(c.CookieDomains, *cookieDomain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if *domainList != "" {
|
|
||||||
c.Domain = strings.Split(*domainList, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
if *emailWhitelist != "" {
|
|
||||||
c.Whitelist = strings.Split(*emailWhitelist, ",")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Temp
|
|
||||||
func (c Config) Walk() {
|
|
||||||
for name, rule := range c.Rules {
|
|
||||||
fmt.Printf("Rule: %s\n", name)
|
|
||||||
for _, match := range rule.Match {
|
|
||||||
if len(match.Host) > 0 {
|
|
||||||
for _, val := range match.Host {
|
|
||||||
fmt.Printf(" - Host: %s\n", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(match.PathPrefix) > 0 {
|
|
||||||
for _, val := range match.PathPrefix {
|
|
||||||
fmt.Printf(" - PathPrefix: %s\n", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(match.Header) > 0 {
|
|
||||||
for _, val := range match.Header {
|
|
||||||
fmt.Printf(" - Header: %s: %s\n", val[0], val[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,13 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// import (
|
|
||||||
// "testing"
|
|
||||||
// )
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tests
|
|
||||||
*/
|
|
||||||
|
|
||||||
// func TestMain(t *testing.T) {
|
|
||||||
|
|
||||||
// }
|
|
31
go.mod
Normal file
31
go.mod
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
module github.com/thomseddon/traefik-forward-auth
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/VividCortex/gohistogram v1.0.0 // indirect
|
||||||
|
github.com/cenkalti/backoff v2.1.1+incompatible // indirect
|
||||||
|
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // indirect
|
||||||
|
github.com/containous/flaeg v1.4.1 // indirect
|
||||||
|
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect
|
||||||
|
github.com/containous/traefik v2.0.0-alpha2+incompatible
|
||||||
|
github.com/go-acme/lego v2.4.0+incompatible // indirect
|
||||||
|
github.com/go-kit/kit v0.8.0 // indirect
|
||||||
|
github.com/gorilla/context v1.1.1 // indirect
|
||||||
|
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff // indirect
|
||||||
|
github.com/jessevdk/go-flags v1.4.0
|
||||||
|
github.com/jonboulle/clockwork v0.1.0 // indirect
|
||||||
|
github.com/kr/pretty v0.1.0 // indirect
|
||||||
|
github.com/miekg/dns v1.1.8 // indirect
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||||
|
github.com/pkg/errors v0.8.1 // indirect
|
||||||
|
github.com/ryanuber/go-glob v1.0.0 // indirect
|
||||||
|
github.com/sirupsen/logrus v1.4.1
|
||||||
|
github.com/stretchr/testify v1.3.0 // indirect
|
||||||
|
github.com/vulcand/predicate v1.1.0 // indirect
|
||||||
|
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a // indirect
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 // indirect
|
||||||
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||||
|
gopkg.in/square/go-jose.v2 v2.3.1 // indirect
|
||||||
|
)
|
70
go.sum
Normal file
70
go.sum
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE=
|
||||||
|
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
|
||||||
|
github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY=
|
||||||
|
github.com/cenkalti/backoff v2.1.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
|
||||||
|
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd h1:0n+lFLh5zU0l6KSk3KpnDwfbPGAR44aRLgTbCnhRBHU=
|
||||||
|
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd/go.mod h1:BbQgeDS5i0tNvypwEoF1oNjOJw8knRAE1DnVvjDstcQ=
|
||||||
|
github.com/containous/flaeg v1.4.1 h1:VTouP7EF2JeowNvknpP3fJAJLUDsQ1lDHq/QQTQc1xc=
|
||||||
|
github.com/containous/flaeg v1.4.1/go.mod h1:wgw6PDtRURXHKFFV6HOqQxWhUc3k3Hmq22jw+n2qDro=
|
||||||
|
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 h1:1srn9voikJGofblBhWy3WuZWqo14Ou7NaswNG/I2yWc=
|
||||||
|
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg=
|
||||||
|
github.com/containous/traefik v2.0.0-alpha2+incompatible h1:5RS6mUAOPQCy1jAmcmxLj2nChIcs3fKuxZxH9AF6ih8=
|
||||||
|
github.com/containous/traefik v2.0.0-alpha2+incompatible/go.mod h1:epDRqge3JzKOhlSWzOpNYEEKXmM6yfN5tPzDGKk3ljo=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI=
|
||||||
|
github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
|
||||||
|
github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0=
|
||||||
|
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||||
|
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||||
|
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||||
|
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff h1:xL/fJdlTJL6R/6Qk2tPu3EP1NsXgap9hXLvxKH0Ytko=
|
||||||
|
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff/go.mod h1:RvdOUHE4SHqR3oXlFFKnGzms8a5dugHygGw1bqDstYI=
|
||||||
|
github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA=
|
||||||
|
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||||
|
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
|
||||||
|
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
||||||
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
|
||||||
|
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||||
|
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||||
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
|
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||||
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
|
github.com/miekg/dns v1.1.8 h1:1QYRAKU3lN5cRfLCkPU08hwvLJFhvjP6MqNMmQz6ZVI=
|
||||||
|
github.com/miekg/dns v1.1.8/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||||
|
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||||
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
|
||||||
|
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
|
||||||
|
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
|
||||||
|
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/vulcand/predicate v1.1.0 h1:Gq/uWopa4rx/tnZu2opOSBqHK63Yqlou/SzrbwdJiNg=
|
||||||
|
github.com/vulcand/predicate v1.1.0/go.mod h1:mlccC5IRBoc2cIFmCB8ZM62I3VDb6p2GXESMHa3CnZg=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g=
|
||||||
|
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
||||||
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
|
||||||
|
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts=
|
||||||
|
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
|
||||||
|
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
@ -1,41 +1,24 @@
|
|||||||
package main
|
package tfa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
// "encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
// "net/url"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/thomseddon/traefik-forward-auth/provider"
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ForwardAuthContext int
|
|
||||||
|
|
||||||
const (
|
|
||||||
Nonce ForwardAuthContext = iota
|
|
||||||
Request
|
|
||||||
)
|
|
||||||
|
|
||||||
// Forward Auth
|
|
||||||
type ForwardAuth struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewForwardAuth() *ForwardAuth {
|
|
||||||
return &ForwardAuth{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request Validation
|
// Request Validation
|
||||||
|
|
||||||
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
||||||
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
func ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||||
parts := strings.Split(c.Value, "|")
|
parts := strings.Split(c.Value, "|")
|
||||||
|
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
@ -47,7 +30,7 @@ func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, str
|
|||||||
return false, "", errors.New("Unable to decode cookie mac")
|
return false, "", errors.New("Unable to decode cookie mac")
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
|
expectedSignature := cookieSignature(r, parts[2], parts[1])
|
||||||
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", errors.New("Unable to generate mac")
|
return false, "", errors.New("Unable to generate mac")
|
||||||
@ -73,7 +56,7 @@ func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate email
|
// Validate email
|
||||||
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
func ValidateEmail(email string) bool {
|
||||||
found := false
|
found := false
|
||||||
if len(config.Whitelist) > 0 {
|
if len(config.Whitelist) > 0 {
|
||||||
for _, whitelist := range config.Whitelist {
|
for _, whitelist := range config.Whitelist {
|
||||||
@ -81,12 +64,12 @@ func (f *ForwardAuth) ValidateEmail(email string) bool {
|
|||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if len(config.Domain) > 0 {
|
} else if len(config.Domains) > 0 {
|
||||||
parts := strings.Split(email, "@")
|
parts := strings.Split(email, "@")
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for _, domain := range config.Domain {
|
for _, domain := range config.Domains {
|
||||||
if domain == parts[1] {
|
if domain == parts[1] {
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
@ -101,25 +84,25 @@ func (f *ForwardAuth) ValidateEmail(email string) bool {
|
|||||||
// OAuth Methods
|
// OAuth Methods
|
||||||
|
|
||||||
// Get login url
|
// Get login url
|
||||||
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
func GetLoginURL(r *http.Request, nonce string) string {
|
||||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
state := fmt.Sprintf("%s:%s", nonce, returnUrl(r))
|
||||||
|
|
||||||
// TODO: Support multiple providers
|
// TODO: Support multiple providers
|
||||||
return config.Providers.Google.GetLoginURL(f.redirectUri(r), state)
|
return config.Providers.Google.GetLoginURL(redirectUri(r), state)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
|
|
||||||
func (f *ForwardAuth) ExchangeCode(r *http.Request) (string, error) {
|
func ExchangeCode(r *http.Request) (string, error) {
|
||||||
code := r.URL.Query().Get("code")
|
code := r.URL.Query().Get("code")
|
||||||
|
|
||||||
// TODO: Support multiple providers
|
// TODO: Support multiple providers
|
||||||
return config.Providers.Google.ExchangeCode(f.redirectUri(r), code)
|
return config.Providers.Google.ExchangeCode(redirectUri(r), code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user with token
|
// Get user with token
|
||||||
|
|
||||||
func (f *ForwardAuth) GetUser(token string) (provider.User, error) {
|
func GetUser(token string) (provider.User, error) {
|
||||||
// TODO: Support multiple providers
|
// TODO: Support multiple providers
|
||||||
return config.Providers.Google.GetUser(token)
|
return config.Providers.Google.GetUser(token)
|
||||||
}
|
}
|
||||||
@ -127,7 +110,7 @@ func (f *ForwardAuth) GetUser(token string) (provider.User, error) {
|
|||||||
// Utility methods
|
// Utility methods
|
||||||
|
|
||||||
// Get the redirect base
|
// Get the redirect base
|
||||||
func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
func redirectBase(r *http.Request) string {
|
||||||
proto := r.Header.Get("X-Forwarded-Proto")
|
proto := r.Header.Get("X-Forwarded-Proto")
|
||||||
host := r.Header.Get("X-Forwarded-Host")
|
host := r.Header.Get("X-Forwarded-Host")
|
||||||
|
|
||||||
@ -135,33 +118,33 @@ func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// // Return url
|
// // Return url
|
||||||
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
func returnUrl(r *http.Request) string {
|
||||||
path := r.Header.Get("X-Forwarded-Uri")
|
path := r.Header.Get("X-Forwarded-Uri")
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
|
return fmt.Sprintf("%s%s", redirectBase(r), path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get oauth redirect uri
|
// Get oauth redirect uri
|
||||||
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
func redirectUri(r *http.Request) string {
|
||||||
if use, _ := f.useAuthDomain(r); use {
|
if use, _ := useAuthDomain(r); use {
|
||||||
proto := r.Header.Get("X-Forwarded-Proto")
|
proto := r.Header.Get("X-Forwarded-Proto")
|
||||||
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
|
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s", f.redirectBase(r), config.Path)
|
return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should we use auth host + what it is
|
// Should we use auth host + what it is
|
||||||
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
func useAuthDomain(r *http.Request) (bool, string) {
|
||||||
if config.AuthHost == "" {
|
if config.AuthHost == "" {
|
||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Does the request match a given cookie domain?
|
// Does the request match a given cookie domain?
|
||||||
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||||
|
|
||||||
// Do any of the auth hosts match a cookie domain?
|
// Do any of the auth hosts match a cookie domain?
|
||||||
authMatch, authHost := f.matchCookieDomains(config.AuthHost)
|
authMatch, authHost := matchCookieDomains(config.AuthHost)
|
||||||
|
|
||||||
// We need both to match the same domain
|
// We need both to match the same domain
|
||||||
return reqMatch && authMatch && reqHost == authHost, reqHost
|
return reqMatch && authMatch && reqHost == authHost, reqHost
|
||||||
@ -170,50 +153,50 @@ func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
|||||||
// Cookie methods
|
// Cookie methods
|
||||||
|
|
||||||
// Create an auth cookie
|
// Create an auth cookie
|
||||||
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
func MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||||
expires := f.cookieExpiry()
|
expires := cookieExpiry()
|
||||||
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
||||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: config.CookieName,
|
Name: config.CookieName,
|
||||||
Value: value,
|
Value: value,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.cookieDomain(r),
|
Domain: cookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: config.CookieSecure,
|
Secure: !config.CookieInsecure,
|
||||||
Expires: expires,
|
Expires: expires,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a CSRF cookie (used during login only)
|
// Make a CSRF cookie (used during login only)
|
||||||
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: config.CSRFCookieName,
|
Name: config.CSRFCookieName,
|
||||||
Value: nonce,
|
Value: nonce,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.csrfCookieDomain(r),
|
Domain: csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: config.CookieSecure,
|
Secure: !config.CookieInsecure,
|
||||||
Expires: f.cookieExpiry(),
|
Expires: cookieExpiry(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a cookie to clear csrf cookie
|
// Create a cookie to clear csrf cookie
|
||||||
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: config.CSRFCookieName,
|
Name: config.CSRFCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.csrfCookieDomain(r),
|
Domain: csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: config.CookieSecure,
|
Secure: !config.CookieInsecure,
|
||||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Validate the csrf cookie against state
|
// Validate the csrf cookie against state
|
||||||
func (f *ForwardAuth) ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
if len(c.Value) != 32 {
|
if len(c.Value) != 32 {
|
||||||
@ -233,7 +216,7 @@ func (f *ForwardAuth) ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool,
|
|||||||
return true, state[33:], nil
|
return true, state[33:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ForwardAuth) Nonce() (error, string) {
|
func Nonce() (error, string) {
|
||||||
// Make nonce
|
// Make nonce
|
||||||
nonce := make([]byte, 16)
|
nonce := make([]byte, 16)
|
||||||
_, err := rand.Read(nonce)
|
_, err := rand.Read(nonce)
|
||||||
@ -245,18 +228,18 @@ func (f *ForwardAuth) Nonce() (error, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cookie domain
|
// Cookie domain
|
||||||
func (f *ForwardAuth) cookieDomain(r *http.Request) string {
|
func cookieDomain(r *http.Request) string {
|
||||||
host := r.Header.Get("X-Forwarded-Host")
|
host := r.Header.Get("X-Forwarded-Host")
|
||||||
|
|
||||||
// Check if any of the given cookie domains matches
|
// Check if any of the given cookie domains matches
|
||||||
_, domain := f.matchCookieDomains(host)
|
_, domain := matchCookieDomains(host)
|
||||||
return domain
|
return domain
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookie domain
|
// Cookie domain
|
||||||
func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string {
|
func csrfCookieDomain(r *http.Request) string {
|
||||||
var host string
|
var host string
|
||||||
if use, domain := f.useAuthDomain(r); use {
|
if use, domain := useAuthDomain(r); use {
|
||||||
host = domain
|
host = domain
|
||||||
} else {
|
} else {
|
||||||
host = r.Header.Get("X-Forwarded-Host")
|
host = r.Header.Get("X-Forwarded-Host")
|
||||||
@ -268,7 +251,7 @@ func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return matching cookie domain if exists
|
// Return matching cookie domain if exists
|
||||||
func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
func matchCookieDomains(domain string) (bool, string) {
|
||||||
// Remove port
|
// Remove port
|
||||||
p := strings.Split(domain, ":")
|
p := strings.Split(domain, ":")
|
||||||
|
|
||||||
@ -282,16 +265,16 @@ func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create cookie hmac
|
// Create cookie hmac
|
||||||
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
func cookieSignature(r *http.Request, email, expires string) string {
|
||||||
hash := hmac.New(sha256.New, config.SecretBytes)
|
hash := hmac.New(sha256.New, config.Secret)
|
||||||
hash.Write([]byte(f.cookieDomain(r)))
|
hash.Write([]byte(cookieDomain(r)))
|
||||||
hash.Write([]byte(email))
|
hash.Write([]byte(email))
|
||||||
hash.Write([]byte(expires))
|
hash.Write([]byte(expires))
|
||||||
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get cookie expirary
|
// Get cookie expirary
|
||||||
func (f *ForwardAuth) cookieExpiry() time.Time {
|
func cookieExpiry() time.Time {
|
||||||
return time.Now().Local().Add(config.Lifetime)
|
return time.Now().Local().Add(config.Lifetime)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,10 +282,10 @@ func (f *ForwardAuth) cookieExpiry() time.Time {
|
|||||||
|
|
||||||
// Cookie Domain
|
// Cookie Domain
|
||||||
type CookieDomain struct {
|
type CookieDomain struct {
|
||||||
Domain string
|
Domain string `description:"TEST1"`
|
||||||
DomainLen int
|
DomainLen int `description:"TEST2"`
|
||||||
SubDomain string
|
SubDomain string `description:"TEST3"`
|
||||||
SubDomainLen int
|
SubDomainLen int `description:"TEST4"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCookieDomain(domain string) *CookieDomain {
|
func NewCookieDomain(domain string) *CookieDomain {
|
||||||
@ -327,3 +310,20 @@ func (c *CookieDomain) Match(host string) bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CookieDomains []CookieDomain
|
||||||
|
|
||||||
|
func (c *CookieDomains) UnmarshalFlag(value string) error {
|
||||||
|
// TODO: test
|
||||||
|
if len(value) > 0 {
|
||||||
|
for _, d := range strings.Split(value, ",") {
|
||||||
|
cookieDomain := NewCookieDomain(d)
|
||||||
|
*c = append(*c, *cookieDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CookieDomains) MarshalFlag() (string, error) {
|
||||||
|
return fmt.Sprintf("%+v", *c), nil
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package main
|
package tfa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -8,7 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/thomseddon/traefik-forward-auth/provider"
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -16,7 +16,7 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
fw = &ForwardAuth{}
|
// fw = &ForwardAuth{}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -24,46 +24,46 @@ func init() {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
func TestValidateCookie(t *testing.T) {
|
func TestValidateCookie(t *testing.T) {
|
||||||
config = &Config{}
|
config = Config{}
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
// Should require 3 parts
|
// Should require 3 parts
|
||||||
c.Value = ""
|
c.Value = ""
|
||||||
valid, _, err := fw.ValidateCookie(r, c)
|
valid, _, err := ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie format" {
|
if valid || err.Error() != "Invalid cookie format" {
|
||||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||||
}
|
}
|
||||||
c.Value = "1|2"
|
c.Value = "1|2"
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie format" {
|
if valid || err.Error() != "Invalid cookie format" {
|
||||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||||
}
|
}
|
||||||
c.Value = "1|2|3|4"
|
c.Value = "1|2|3|4"
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie format" {
|
if valid || err.Error() != "Invalid cookie format" {
|
||||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should catch invalid mac
|
// Should catch invalid mac
|
||||||
c.Value = "MQ==|2|3"
|
c.Value = "MQ==|2|3"
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie mac" {
|
if valid || err.Error() != "Invalid cookie mac" {
|
||||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should catch expired
|
// Should catch expired
|
||||||
config.Lifetime = time.Second * time.Duration(-1)
|
config.Lifetime = time.Second * time.Duration(-1)
|
||||||
c = fw.MakeCookie(r, "test@test.com")
|
c = MakeCookie(r, "test@test.com")
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Cookie has expired" {
|
if valid || err.Error() != "Cookie has expired" {
|
||||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should accept valid cookie
|
// Should accept valid cookie
|
||||||
config.Lifetime = time.Second * time.Duration(10)
|
config.Lifetime = time.Second * time.Duration(10)
|
||||||
c = fw.MakeCookie(r, "test@test.com")
|
c = MakeCookie(r, "test@test.com")
|
||||||
valid, email, err := fw.ValidateCookie(r, c)
|
valid, email, err := ValidateCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("Valid request should return as valid")
|
t.Error("Valid request should return as valid")
|
||||||
}
|
}
|
||||||
@ -76,36 +76,36 @@ func TestValidateCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateEmail(t *testing.T) {
|
func TestValidateEmail(t *testing.T) {
|
||||||
config = &Config{}
|
config = Config{}
|
||||||
|
|
||||||
// Should allow any
|
// Should allow any
|
||||||
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
if !ValidateEmail("test@test.com") || !ValidateEmail("one@two.com") {
|
||||||
t.Error("Should allow any domain if email domain is not defined")
|
t.Error("Should allow any domain if email domain is not defined")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should block non matching domain
|
// Should block non matching domain
|
||||||
config.Domain = []string{"test.com"}
|
config.Domains = []string{"test.com"}
|
||||||
if fw.ValidateEmail("one@two.com") {
|
if ValidateEmail("one@two.com") {
|
||||||
t.Error("Should not allow user from another domain")
|
t.Error("Should not allow user from another domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow matching domain
|
// Should allow matching domain
|
||||||
config.Domain = []string{"test.com"}
|
config.Domains = []string{"test.com"}
|
||||||
if !fw.ValidateEmail("test@test.com") {
|
if !ValidateEmail("test@test.com") {
|
||||||
t.Error("Should allow user from allowed domain")
|
t.Error("Should allow user from allowed domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should block non whitelisted email address
|
// Should block non whitelisted email address
|
||||||
config.Domain = []string{}
|
config.Domains = []string{}
|
||||||
config.Whitelist = []string{"test@test.com"}
|
config.Whitelist = []string{"test@test.com"}
|
||||||
if fw.ValidateEmail("one@two.com") {
|
if ValidateEmail("one@two.com") {
|
||||||
t.Error("Should not allow user not in whitelist.")
|
t.Error("Should not allow user not in whitelist.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow matching whitelisted email address
|
// Should allow matching whitelisted email address
|
||||||
config.Domain = []string{}
|
config.Domains = []string{}
|
||||||
config.Whitelist = []string{"test@test.com"}
|
config.Whitelist = []string{"test@test.com"}
|
||||||
if !fw.ValidateEmail("test@test.com") {
|
if !ValidateEmail("test@test.com") {
|
||||||
t.Error("Should allow user in whitelist.")
|
t.Error("Should allow user in whitelist.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,7 +116,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||||
|
|
||||||
config = &Config{
|
config = Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
Providers: provider.Providers{
|
Providers: provider.Providers{
|
||||||
Google: provider.Google{
|
Google: provider.Google{
|
||||||
@ -133,7 +133,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
|
uri, err := url.Parse(GetLoginURL(r, "nonce"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error parsing login url:", err)
|
t.Error("Error parsing login url:", err)
|
||||||
}
|
}
|
||||||
@ -165,7 +165,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// With Auth URL but no matching cookie domain
|
// With Auth URL but no matching cookie domain
|
||||||
// - will not use auth host
|
// - will not use auth host
|
||||||
//
|
//
|
||||||
config = &Config{
|
config = Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
AuthHost: "auth.example.com",
|
AuthHost: "auth.example.com",
|
||||||
Providers: provider.Providers{
|
Providers: provider.Providers{
|
||||||
@ -184,7 +184,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error parsing login url:", err)
|
t.Error("Error parsing login url:", err)
|
||||||
}
|
}
|
||||||
@ -217,7 +217,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// With correct Auth URL + cookie domain
|
// With correct Auth URL + cookie domain
|
||||||
//
|
//
|
||||||
cookieDomain := NewCookieDomain("example.com")
|
cookieDomain := NewCookieDomain("example.com")
|
||||||
config = &Config{
|
config = Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
AuthHost: "auth.example.com",
|
AuthHost: "auth.example.com",
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
@ -237,7 +237,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error parsing login url:", err)
|
t.Error("Error parsing login url:", err)
|
||||||
}
|
}
|
||||||
@ -277,7 +277,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error parsing login url:", err)
|
t.Error("Error parsing login url:", err)
|
||||||
}
|
}
|
||||||
@ -321,49 +321,49 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
func TestMakeCSRFCookie(t *testing.T) {
|
func TestMakeCSRFCookie(t *testing.T) {
|
||||||
config = &Config{}
|
config = Config{}
|
||||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||||
|
|
||||||
// No cookie domain or auth url
|
// No cookie domain or auth url
|
||||||
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
if c.Domain != "app.example.com" {
|
if c.Domain != "app.example.com" {
|
||||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// With cookie domain but no auth url
|
// With cookie domain but no auth url
|
||||||
cookieDomain := NewCookieDomain("example.com")
|
cookieDomain := NewCookieDomain("example.com")
|
||||||
config = &Config{
|
config = Config{
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
}
|
}
|
||||||
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
if c.Domain != "app.example.com" {
|
if c.Domain != "app.example.com" {
|
||||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// With cookie domain and auth url
|
// With cookie domain and auth url
|
||||||
config = &Config{
|
config = Config{
|
||||||
AuthHost: "auth.example.com",
|
AuthHost: "auth.example.com",
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
}
|
}
|
||||||
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
if c.Domain != "example.com" {
|
if c.Domain != "example.com" {
|
||||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClearCSRFCookie(t *testing.T) {
|
func TestClearCSRFCookie(t *testing.T) {
|
||||||
config = &Config{}
|
config = Config{}
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
|
||||||
c := fw.ClearCSRFCookie(r)
|
c := ClearCSRFCookie(r)
|
||||||
if c.Value != "" {
|
if c.Value != "" {
|
||||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateCSRFCookie(t *testing.T) {
|
func TestValidateCSRFCookie(t *testing.T) {
|
||||||
config = &Config{}
|
config = Config{}
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
newCsrfRequest := func(state string) *http.Request {
|
newCsrfRequest := func(state string) *http.Request {
|
||||||
@ -375,12 +375,12 @@ func TestValidateCSRFCookie(t *testing.T) {
|
|||||||
// Should require 32 char string
|
// Should require 32 char string
|
||||||
r := newCsrfRequest("")
|
r := newCsrfRequest("")
|
||||||
c.Value = ""
|
c.Value = ""
|
||||||
valid, _, err := fw.ValidateCSRFCookie(r, c)
|
valid, _, err := ValidateCSRFCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||||
}
|
}
|
||||||
c.Value = "123456789012345678901234567890123"
|
c.Value = "123456789012345678901234567890123"
|
||||||
valid, _, err = fw.ValidateCSRFCookie(r, c)
|
valid, _, err = ValidateCSRFCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||||
}
|
}
|
||||||
@ -388,7 +388,7 @@ func TestValidateCSRFCookie(t *testing.T) {
|
|||||||
// Should require valid state
|
// Should require valid state
|
||||||
r = newCsrfRequest("12345678901234567890123456789012:")
|
r = newCsrfRequest("12345678901234567890123456789012:")
|
||||||
c.Value = "12345678901234567890123456789012"
|
c.Value = "12345678901234567890123456789012"
|
||||||
valid, _, err = fw.ValidateCSRFCookie(r, c)
|
valid, _, err = ValidateCSRFCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid CSRF state value" {
|
if valid || err.Error() != "Invalid CSRF state value" {
|
||||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||||
}
|
}
|
||||||
@ -396,7 +396,7 @@ func TestValidateCSRFCookie(t *testing.T) {
|
|||||||
// Should allow valid state
|
// Should allow valid state
|
||||||
r = newCsrfRequest("12345678901234567890123456789012:99")
|
r = newCsrfRequest("12345678901234567890123456789012:99")
|
||||||
c.Value = "12345678901234567890123456789012"
|
c.Value = "12345678901234567890123456789012"
|
||||||
valid, state, err := fw.ValidateCSRFCookie(r, c)
|
valid, state, err := ValidateCSRFCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("Valid request should return as valid")
|
t.Error("Valid request should return as valid")
|
||||||
}
|
}
|
||||||
@ -409,12 +409,12 @@ func TestValidateCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNonce(t *testing.T) {
|
func TestNonce(t *testing.T) {
|
||||||
err, nonce1 := fw.Nonce()
|
err, nonce1 := Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error generation nonce:", err)
|
t.Error("Error generation nonce:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err, nonce2 := fw.Nonce()
|
err, nonce2 := Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error generation nonce:", err)
|
t.Error("Error generation nonce:", err)
|
||||||
}
|
}
|
146
internal/config.go
Normal file
146
internal/config.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
package tfa
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jessevdk/go-flags"
|
||||||
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
LogLevel string `long:"log-level" default:"warn" description:"Log level: trace, debug, info, warn, error, fatal, panic"`
|
||||||
|
LogFormat string `long:"log-format" default:"text" description:"Log format: text, json, pretty"`
|
||||||
|
|
||||||
|
AuthHost string `long:"auth-host" description:"Host for central auth login"`
|
||||||
|
ConfigFile string `long:"config-file" description:"Config File"`
|
||||||
|
CookieDomains CookieDomains `long:"cookie-domains" description:"Comma separated list of cookie domains"`
|
||||||
|
CookieInsecure bool `long:"cookie-insecure" description:"Use secure cookies"`
|
||||||
|
CookieName string `long:"cookie-name" default:"_forward_auth" description:"Cookie Name"`
|
||||||
|
CSRFCookieName string `long:"csrf-cookie-name" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
||||||
|
DefaultAction string `long:"default-action" default:"allow" description:"Default Action"`
|
||||||
|
Domains CommaSeparatedList `long:"domains" description:"Comma separated list of email domains to allow"`
|
||||||
|
LifetimeString int `long:"lifetime" default:"43200" description:"Lifetime in seconds"`
|
||||||
|
Path string `long:"path" default:"_oauth" description:"Callback URL Path"`
|
||||||
|
SecretString string `long:"secret" description:"*Secret used for signing (required)"`
|
||||||
|
Whitelist CommaSeparatedList `long:"whitelist" description:"Comma separated list of email addresses to allow"`
|
||||||
|
|
||||||
|
Providers provider.Providers
|
||||||
|
Rules []Rule `long:"rule"`
|
||||||
|
|
||||||
|
Secret []byte
|
||||||
|
Lifetime time.Duration
|
||||||
|
|
||||||
|
Prompt string `long:"prompt" description:"DEPRECATED - Use providers.google.prompt"`
|
||||||
|
// TODO: Need to mimick the default behaviour of bool flags
|
||||||
|
CookieSecure string `long:"cookie-secure" default:"true" description:"DEPRECATED - Use \"cookie-insecure\""`
|
||||||
|
|
||||||
|
flags []string
|
||||||
|
usingToml bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type CommaSeparatedList []string
|
||||||
|
|
||||||
|
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
||||||
|
*c = strings.Split(value, ",")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
|
||||||
|
return strings.Join(*c, ","), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Rule struct {
|
||||||
|
Action string
|
||||||
|
Rule string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rule) UnmarshalFlag(value string) error {
|
||||||
|
// Format is "action:rule"
|
||||||
|
parts := strings.SplitN(value, ":", 2)
|
||||||
|
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return errors.New("Invalid rule format, should be \"action:rule\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
if parts[0] != "auth" && parts[0] != "allow" {
|
||||||
|
return errors.New("Invalid rule action, must be \"auth\" or \"allow\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse rule
|
||||||
|
*r = Rule{
|
||||||
|
Action: parts[0],
|
||||||
|
Rule: parts[1],
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rule) MarshalFlag() (string, error) {
|
||||||
|
// TODO: format correctly
|
||||||
|
return fmt.Sprintf("%+v", *r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var config Config
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// - parse ini
|
||||||
|
// - parse env vars
|
||||||
|
// - parse env var file
|
||||||
|
// - support multiple config files
|
||||||
|
// - maintain backwards compat
|
||||||
|
|
||||||
|
func NewGlobalConfig() Config {
|
||||||
|
return NewGlobalConfigWithArgs(os.Args[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGlobalConfigWithArgs(args []string) Config {
|
||||||
|
config = Config{}
|
||||||
|
|
||||||
|
config.parseFlags(args)
|
||||||
|
|
||||||
|
// Struct defaults
|
||||||
|
config.Providers.Google.Build()
|
||||||
|
|
||||||
|
// Transformations
|
||||||
|
config.Path = fmt.Sprintf("/%s", config.Path)
|
||||||
|
config.Secret = []byte(config.SecretString)
|
||||||
|
config.Lifetime = time.Second * time.Duration(config.LifetimeString)
|
||||||
|
|
||||||
|
// TODO: Backwards compatability
|
||||||
|
// "secret" used to be "cookie-secret"
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) parseFlags(args []string) {
|
||||||
|
if _, err := flags.ParseArgs(c, args); err != nil {
|
||||||
|
flagsErr, ok := err.(*flags.Error)
|
||||||
|
if ok && flagsErr.Type == flags.ErrHelp {
|
||||||
|
os.Exit(0)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%+v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Checks() {
|
||||||
|
// Check for show stopper errors
|
||||||
|
if len(c.Secret) == 0 {
|
||||||
|
log.Fatal("\"secret\" option must be set.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
|
||||||
|
log.Fatal("google.providers.client-id, google.providers.client-secret must be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Config) Serialise() string {
|
||||||
|
jsonConf, _ := json.Marshal(c)
|
||||||
|
return string(jsonConf)
|
||||||
|
}
|
81
internal/config_test.go
Normal file
81
internal/config_test.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package tfa
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
// "github.com/jessevdk/go-flags"
|
||||||
|
// "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
func TestConfigDefaults(t *testing.T) {
|
||||||
|
// Check defaults
|
||||||
|
c := NewGlobalConfigWithArgs([]string{})
|
||||||
|
|
||||||
|
if c.LogLevel != "warn" {
|
||||||
|
t.Error("LogLevel default should be warn, got", c.LogLevel)
|
||||||
|
}
|
||||||
|
if c.LogFormat != "text" {
|
||||||
|
t.Error("LogFormat default should be text, got", c.LogFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.AuthHost != "" {
|
||||||
|
t.Error("AuthHost default should be empty, got", c.AuthHost)
|
||||||
|
}
|
||||||
|
if c.ConfigFile != "" {
|
||||||
|
t.Error("ConfigFile default should be empty, got", c.ConfigFile)
|
||||||
|
}
|
||||||
|
if len(c.CookieDomains) != 0 {
|
||||||
|
t.Error("CookieDomains default should be empty, got", c.CookieDomains)
|
||||||
|
}
|
||||||
|
if c.CookieInsecure != false {
|
||||||
|
t.Error("CookieInsecure default should be false, got", c.CookieInsecure)
|
||||||
|
}
|
||||||
|
if c.CookieName != "_forward_auth" {
|
||||||
|
t.Error("CookieName default should be _forward_auth, got", c.CookieName)
|
||||||
|
}
|
||||||
|
if c.CSRFCookieName != "_forward_auth_csrf" {
|
||||||
|
t.Error("CSRFCookieName default should be _forward_auth_csrf, got", c.CSRFCookieName)
|
||||||
|
}
|
||||||
|
if c.DefaultAction != "allow" {
|
||||||
|
t.Error("DefaultAction default should be allow, got", c.DefaultAction)
|
||||||
|
}
|
||||||
|
if len(c.Domains) != 0 {
|
||||||
|
t.Error("Domain default should be empty, got", c.Domains)
|
||||||
|
}
|
||||||
|
if c.Lifetime != time.Second*time.Duration(43200) {
|
||||||
|
t.Error("Lifetime default should be 43200, got", c.Lifetime)
|
||||||
|
}
|
||||||
|
if c.Path != "/_oauth" {
|
||||||
|
t.Error("Path default should be /_oauth, got", c.Path)
|
||||||
|
}
|
||||||
|
if len(c.Whitelist) != 0 {
|
||||||
|
t.Error("Whitelist default should be empty, got", c.Whitelist)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Providers.Google.Prompt != "" {
|
||||||
|
t.Error("Providers.Google.Prompt default should be empty, got", c.Providers.Google.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated options
|
||||||
|
if c.CookieSecure != "true" {
|
||||||
|
t.Error("CookieSecure default should be true, got", c.CookieSecure)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func TestConfigToml(t *testing.T) {
|
||||||
|
// logrus.SetLevel(logrus.DebugLevel)
|
||||||
|
// flag.CommandLine = flag.NewFlagSet("tfa-test", flag.ContinueOnError)
|
||||||
|
|
||||||
|
// flags := []string{
|
||||||
|
// "-config=../test/config.toml",
|
||||||
|
// }
|
||||||
|
// c := NewDefaultConfigWithFlags(flags)
|
||||||
|
|
||||||
|
// if c == nil {
|
||||||
|
// t.Error(c)
|
||||||
|
// }
|
||||||
|
// }
|
@ -1,4 +1,4 @@
|
|||||||
package main
|
package tfa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@ -6,13 +6,15 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewLogger() logrus.FieldLogger {
|
var log logrus.FieldLogger
|
||||||
|
|
||||||
|
func NewDefaultLogger() logrus.FieldLogger {
|
||||||
// Setup logger
|
// Setup logger
|
||||||
log := logrus.StandardLogger()
|
log = logrus.StandardLogger()
|
||||||
logrus.SetOutput(os.Stdout)
|
logrus.SetOutput(os.Stdout)
|
||||||
|
|
||||||
// Set logger format
|
// Set logger format
|
||||||
switch *config.LogFormat {
|
switch config.LogFormat {
|
||||||
case "pretty":
|
case "pretty":
|
||||||
break
|
break
|
||||||
case "json":
|
case "json":
|
||||||
@ -26,7 +28,7 @@ func NewLogger() logrus.FieldLogger {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set logger level
|
// Set logger level
|
||||||
switch *config.LogLevel {
|
switch config.LogLevel {
|
||||||
case "trace":
|
case "trace":
|
||||||
logrus.SetLevel(logrus.TraceLevel)
|
logrus.SetLevel(logrus.TraceLevel)
|
||||||
case "debug":
|
case "debug":
|
@ -8,16 +8,34 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Google struct {
|
type Google struct {
|
||||||
ClientId string
|
ClientId string `long:"providers.google.client-id" description:"Client ID"`
|
||||||
ClientSecret string `json:"-"`
|
ClientSecret string `long:"providers.google.client-secret" description:"Client Secret" json:"-"`
|
||||||
Scope string
|
Scope string
|
||||||
Prompt string
|
Prompt string `long:"providers.google.prompt" description:"Space separated list of OpenID prompt options"`
|
||||||
|
|
||||||
LoginURL *url.URL
|
LoginURL *url.URL
|
||||||
TokenURL *url.URL
|
TokenURL *url.URL
|
||||||
UserURL *url.URL
|
UserURL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *Google) Build() {
|
||||||
|
g.LoginURL = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "accounts.google.com",
|
||||||
|
Path: "/o/oauth2/auth",
|
||||||
|
}
|
||||||
|
g.TokenURL = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "www.googleapis.com",
|
||||||
|
Path: "/oauth2/v3/token",
|
||||||
|
}
|
||||||
|
g.UserURL = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "www.googleapis.com",
|
||||||
|
Path: "/oauth2/v2/userinfo",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (g *Google) GetLoginURL(redirectUri, state string) string {
|
func (g *Google) GetLoginURL(redirectUri, state string) string {
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Set("client_id", g.ClientId)
|
q.Set("client_id", g.ClientId)
|
@ -1,7 +1,7 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
type Providers struct {
|
type Providers struct {
|
||||||
Google Google
|
Google Google `group:"Google Provider"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
@ -1,16 +1,15 @@
|
|||||||
package main
|
package tfa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/containous/traefik/pkg/rules"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
mux *mux.Router
|
router *rules.Router
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
@ -20,21 +19,26 @@ func NewServer() *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) buildRoutes() {
|
func (s *Server) buildRoutes() {
|
||||||
s.mux = mux.NewRouter()
|
var err error
|
||||||
|
s.router, err = rules.NewRouter()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Let's build a server
|
// Let's build a router
|
||||||
for _, rules := range config.Rules {
|
for _, rule := range config.Rules {
|
||||||
// fmt.Printf("Rule: %s\n", name)
|
if rule.Action == "allow" {
|
||||||
for _, match := range rules.Match {
|
s.router.AddRoute(rule.Rule, 1, s.AllowHandler())
|
||||||
s.attachHandler(&match, rules.Action)
|
} else {
|
||||||
|
s.router.AddRoute(rule.Rule, 1, s.AuthHandler())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add callback handler
|
// Add callback handler
|
||||||
s.mux.Handle(config.Path, s.AuthCallbackHandler())
|
s.router.Handle(config.Path, s.AuthCallbackHandler())
|
||||||
|
|
||||||
// Add a default handler
|
// Add a default handler
|
||||||
s.mux.NewRoute().Handler(s.AuthHandler())
|
s.router.NewRoute().Handler(s.AuthHandler())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -42,7 +46,7 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||||
|
|
||||||
// Pass to mux
|
// Pass to mux
|
||||||
s.mux.ServeHTTP(w, r)
|
s.router.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler that allows requests
|
// Handler that allows requests
|
||||||
@ -63,7 +67,7 @@ func (s *Server) AuthHandler() http.HandlerFunc {
|
|||||||
c, err := r.Cookie(config.CookieName)
|
c, err := r.Cookie(config.CookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Error indicates no cookie, generate nonce
|
// Error indicates no cookie, generate nonce
|
||||||
err, nonce := fw.Nonce()
|
err, nonce := Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error generating nonce, %v", err)
|
logger.Errorf("Error generating nonce, %v", err)
|
||||||
http.Error(w, "Service unavailable", 503)
|
http.Error(w, "Service unavailable", 503)
|
||||||
@ -71,17 +75,17 @@ func (s *Server) AuthHandler() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set the CSRF cookie
|
// Set the CSRF cookie
|
||||||
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
http.SetCookie(w, MakeCSRFCookie(r, nonce))
|
||||||
logger.Debug("Set CSRF cookie and redirecting to google login")
|
logger.Debug("Set CSRF cookie and redirecting to google login")
|
||||||
|
|
||||||
// Forward them on
|
// Forward them on
|
||||||
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate cookie
|
// Validate cookie
|
||||||
valid, email, err := fw.ValidateCookie(r, c)
|
valid, email, err := ValidateCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
logger.Errorf("Invalid cookie: %v", err)
|
logger.Errorf("Invalid cookie: %v", err)
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
@ -89,7 +93,7 @@ func (s *Server) AuthHandler() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate user
|
// Validate user
|
||||||
valid = fw.ValidateEmail(email)
|
valid = ValidateEmail(email)
|
||||||
if !valid {
|
if !valid {
|
||||||
logger.WithFields(logrus.Fields{
|
logger.WithFields(logrus.Fields{
|
||||||
"email": email,
|
"email": email,
|
||||||
@ -120,7 +124,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate state
|
// Validate state
|
||||||
valid, redirect, err := fw.ValidateCSRFCookie(r, c)
|
valid, redirect, err := ValidateCSRFCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
logger.Warnf("Error validating csrf cookie: %v", err)
|
logger.Warnf("Error validating csrf cookie: %v", err)
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
@ -128,10 +132,10 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clear CSRF cookie
|
// Clear CSRF cookie
|
||||||
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
http.SetCookie(w, ClearCSRFCookie(r))
|
||||||
|
|
||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
token, err := fw.ExchangeCode(r)
|
token, err := ExchangeCode(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Code exchange failed with: %v", err)
|
logger.Errorf("Code exchange failed with: %v", err)
|
||||||
http.Error(w, "Service unavailable", 503)
|
http.Error(w, "Service unavailable", 503)
|
||||||
@ -139,14 +143,14 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get user
|
// Get user
|
||||||
user, err := fw.GetUser(token)
|
user, err := GetUser(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error getting user: %s", err)
|
logger.Errorf("Error getting user: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate cookie
|
// Generate cookie
|
||||||
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
http.SetCookie(w, MakeCookie(r, user.Email))
|
||||||
logger.WithFields(logrus.Fields{
|
logger.WithFields(logrus.Fields{
|
||||||
"user": user.Email,
|
"user": user.Email,
|
||||||
}).Infof("Generated auth cookie")
|
}).Infof("Generated auth cookie")
|
||||||
@ -156,35 +160,6 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build a handler for a given matcher
|
|
||||||
func (s *Server) attachHandler(m *Match, action string) {
|
|
||||||
// Build a new route matcher
|
|
||||||
route := s.mux.NewRoute()
|
|
||||||
|
|
||||||
for _, host := range m.Host {
|
|
||||||
route.Host(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, pathPrefix := range m.PathPrefix {
|
|
||||||
route.PathPrefix(pathPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, header := range m.Header {
|
|
||||||
if len(header) != 2 {
|
|
||||||
panic("todo")
|
|
||||||
}
|
|
||||||
|
|
||||||
route.Headers(header[0], header[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add handler to new route
|
|
||||||
if action == "allow" {
|
|
||||||
route.Handler(s.AllowHandler())
|
|
||||||
} else {
|
|
||||||
route.Handler(s.AuthHandler())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) logger(r *http.Request, msg string) *logrus.Entry {
|
func (s *Server) logger(r *http.Request, msg string) *logrus.Entry {
|
||||||
// Create logger
|
// Create logger
|
||||||
logger := log.WithFields(logrus.Fields{
|
logger := log.WithFields(logrus.Fields{
|
@ -1,4 +1,4 @@
|
|||||||
package main
|
package tfa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -11,7 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/thomseddon/traefik-forward-auth/provider"
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -19,12 +19,8 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
fw = &ForwardAuth{}
|
config.LogLevel = "panic"
|
||||||
config = NewConfig()
|
log = NewDefaultLogger()
|
||||||
|
|
||||||
logLevel := "panic"
|
|
||||||
config.LogLevel = &logLevel
|
|
||||||
log = NewLogger()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -102,7 +98,7 @@ func qsDiff(one, two url.Values) {
|
|||||||
func TestServerHandler(t *testing.T) {
|
func TestServerHandler(t *testing.T) {
|
||||||
server := NewServer()
|
server := NewServer()
|
||||||
|
|
||||||
config = &Config{
|
config = Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
CookieName: "cookie_test",
|
CookieName: "cookie_test",
|
||||||
Lifetime: time.Second * time.Duration(10),
|
Lifetime: time.Second * time.Duration(10),
|
||||||
@ -134,7 +130,7 @@ func TestServerHandler(t *testing.T) {
|
|||||||
// Should catch invalid cookie
|
// Should catch invalid cookie
|
||||||
req = newHttpRequest("/foo")
|
req = newHttpRequest("/foo")
|
||||||
|
|
||||||
c := fw.MakeCookie(req, "test@example.com")
|
c := MakeCookie(req, "test@example.com")
|
||||||
parts := strings.Split(c.Value, "|")
|
parts := strings.Split(c.Value, "|")
|
||||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||||
|
|
||||||
@ -146,8 +142,8 @@ func TestServerHandler(t *testing.T) {
|
|||||||
// Should validate email
|
// Should validate email
|
||||||
req = newHttpRequest("/foo")
|
req = newHttpRequest("/foo")
|
||||||
|
|
||||||
c = fw.MakeCookie(req, "test@example.com")
|
c = MakeCookie(req, "test@example.com")
|
||||||
config.Domain = []string{"test.com"}
|
config.Domains = []string{"test.com"}
|
||||||
|
|
||||||
res, _ = httpRequest(server, req, c)
|
res, _ = httpRequest(server, req, c)
|
||||||
if res.StatusCode != 401 {
|
if res.StatusCode != 401 {
|
||||||
@ -157,8 +153,8 @@ func TestServerHandler(t *testing.T) {
|
|||||||
// Should allow valid request email
|
// Should allow valid request email
|
||||||
req = newHttpRequest("/foo")
|
req = newHttpRequest("/foo")
|
||||||
|
|
||||||
c = fw.MakeCookie(req, "test@example.com")
|
c = MakeCookie(req, "test@example.com")
|
||||||
config.Domain = []string{}
|
config.Domains = []string{}
|
||||||
|
|
||||||
res, _ = httpRequest(server, req, c)
|
res, _ = httpRequest(server, req, c)
|
||||||
if res.StatusCode != 200 {
|
if res.StatusCode != 200 {
|
||||||
@ -176,7 +172,7 @@ func TestServerHandler(t *testing.T) {
|
|||||||
|
|
||||||
func TestServerAuthCallback(t *testing.T) {
|
func TestServerAuthCallback(t *testing.T) {
|
||||||
server := NewServer()
|
server := NewServer()
|
||||||
config = &Config{
|
config = Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
CookieName: "cookie_test",
|
CookieName: "cookie_test",
|
||||||
Lifetime: time.Second * time.Duration(10),
|
Lifetime: time.Second * time.Duration(10),
|
||||||
@ -218,7 +214,7 @@ func TestServerAuthCallback(t *testing.T) {
|
|||||||
|
|
||||||
// Should catch invalid csrf cookie
|
// Should catch invalid csrf cookie
|
||||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
||||||
res, _ = httpRequest(server, req, c)
|
res, _ = httpRequest(server, req, c)
|
||||||
if res.StatusCode != 401 {
|
if res.StatusCode != 401 {
|
||||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||||
@ -226,7 +222,7 @@ func TestServerAuthCallback(t *testing.T) {
|
|||||||
|
|
||||||
// Should redirect valid request
|
// Should redirect valid request
|
||||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||||
res, _ = httpRequest(server, req, c)
|
res, _ = httpRequest(server, req, c)
|
||||||
if res.StatusCode != 307 {
|
if res.StatusCode != 307 {
|
||||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||||
@ -237,9 +233,9 @@ func TestServerAuthCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerMatcherPathPrefix(t *testing.T) {
|
func TestServerRoutePathPrefix(t *testing.T) {
|
||||||
server := NewServer()
|
server := NewServer()
|
||||||
config = &Config{
|
config = Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
CookieName: "cookie_test",
|
CookieName: "cookie_test",
|
||||||
Lifetime: time.Second * time.Duration(10),
|
Lifetime: time.Second * time.Duration(10),
|
||||||
@ -255,21 +251,17 @@ func TestServerMatcherPathPrefix(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Rules: map[string]Rules{
|
Rules: []Rule{
|
||||||
"rule1": {
|
{
|
||||||
Action: "allow",
|
Action: "allow",
|
||||||
Match: []Match{
|
Rule: "PathPrefix(`/api`)",
|
||||||
{
|
|
||||||
PathPrefix: []string{"/api"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow /api request
|
// Should allow /api request
|
||||||
req := newHttpRequest("/api")
|
req := newHttpRequest("/api")
|
||||||
c := fw.MakeCookie(req, "test@example.com")
|
c := MakeCookie(req, "test@example.com")
|
||||||
res, _ := httpRequest(server, req, c)
|
res, _ := httpRequest(server, req, c)
|
||||||
if res.StatusCode != 200 {
|
if res.StatusCode != 200 {
|
||||||
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
40
main.go
40
main.go
@ -1,40 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Vars
|
|
||||||
var fw *ForwardAuth
|
|
||||||
var log logrus.FieldLogger
|
|
||||||
var config *Config
|
|
||||||
|
|
||||||
// Main
|
|
||||||
func main() {
|
|
||||||
// Parse config
|
|
||||||
config = NewParsedConfig()
|
|
||||||
|
|
||||||
// Setup logger
|
|
||||||
log = NewLogger()
|
|
||||||
|
|
||||||
// Perform config checks
|
|
||||||
config.Checks()
|
|
||||||
|
|
||||||
// Build forward auth handler
|
|
||||||
fw = NewForwardAuth()
|
|
||||||
|
|
||||||
// Build server
|
|
||||||
server := NewServer()
|
|
||||||
|
|
||||||
// Attach router to default server
|
|
||||||
http.HandleFunc("/", server.RootHandler)
|
|
||||||
|
|
||||||
// Start
|
|
||||||
jsonConf, _ := json.Marshal(config)
|
|
||||||
log.Debugf("Starting with options: %s", string(jsonConf))
|
|
||||||
log.Info("Listening on :4181")
|
|
||||||
log.Info(http.ListenAndServe(":4181", nil))
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user