Use new rule config + tidy ups
This commit is contained in:
parent
e057f2d63a
commit
5597b7268b
@ -14,8 +14,8 @@ func main() {
|
|||||||
// Setup logger
|
// Setup logger
|
||||||
log := internal.NewDefaultLogger()
|
log := internal.NewDefaultLogger()
|
||||||
|
|
||||||
// Perform config checks
|
// Perform config validation
|
||||||
config.Checks()
|
config.Validate()
|
||||||
|
|
||||||
// Build server
|
// Build server
|
||||||
server := internal.NewServer()
|
server := internal.NewServer()
|
||||||
@ -24,7 +24,7 @@ func main() {
|
|||||||
http.HandleFunc("/", server.RootHandler)
|
http.HandleFunc("/", server.RootHandler)
|
||||||
|
|
||||||
// Start
|
// Start
|
||||||
log.Debugf("Starting with options: %s", config.Serialise())
|
log.Debugf("Starting with options: %s", config)
|
||||||
log.Info("Listening on :4181")
|
log.Info("Listening on :4181")
|
||||||
log.Info(http.ListenAndServe(":4181", nil))
|
log.Info(http.ListenAndServe(":4181", nil))
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -12,6 +13,8 @@ import (
|
|||||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var config Config
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LogLevel string `long:"log-level" default:"warn" description:"Log level: trace, debug, info, warn, error, fatal, panic"`
|
LogLevel string `long:"log-level" default:"warn" description:"Log level: trace, debug, info, warn, error, fatal, panic"`
|
||||||
LogFormat string `long:"log-format" default:"text" description:"Log format: text, json, pretty"`
|
LogFormat string `long:"log-format" default:"text" description:"Log format: text, json, pretty"`
|
||||||
@ -30,7 +33,7 @@ type Config struct {
|
|||||||
Whitelist CommaSeparatedList `long:"whitelist" description:"Comma separated list of email addresses to allow"`
|
Whitelist CommaSeparatedList `long:"whitelist" description:"Comma separated list of email addresses to allow"`
|
||||||
|
|
||||||
Providers provider.Providers
|
Providers provider.Providers
|
||||||
Rules []Rule `long:"rule"`
|
Rules map[string]*Rule `long:"rule"`
|
||||||
|
|
||||||
Secret []byte
|
Secret []byte
|
||||||
Lifetime time.Duration
|
Lifetime time.Duration
|
||||||
@ -43,20 +46,159 @@ type Config struct {
|
|||||||
usingToml bool
|
usingToml bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type CommaSeparatedList []string
|
// TODO:
|
||||||
|
// - parse ini
|
||||||
|
// - parse env vars
|
||||||
|
// - parse env var file
|
||||||
|
// - support multiple config files
|
||||||
|
// - maintain backwards compat
|
||||||
|
|
||||||
|
func NewGlobalConfig() Config {
|
||||||
|
var err error
|
||||||
|
config, err = NewConfig(os.Args[1:])
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("startup error: %+v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfig(args []string) (Config, error) {
|
||||||
|
c := Config{
|
||||||
|
Rules: map[string]*Rule{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.parseFlags(args)
|
||||||
|
if err != nil {
|
||||||
|
return c, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Struct defaults
|
||||||
|
c.Providers.Google.Build()
|
||||||
|
|
||||||
|
// Transformations
|
||||||
|
c.Path = fmt.Sprintf("/%s", c.Path)
|
||||||
|
c.Secret = []byte(c.SecretString)
|
||||||
|
c.Lifetime = time.Second * time.Duration(c.LifetimeString)
|
||||||
|
|
||||||
|
// TODO: Backwards compatability
|
||||||
|
// "secret" used to be "cookie-secret"
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) parseFlags(args []string) error {
|
||||||
|
parser := flags.NewParser(c, flags.Default)
|
||||||
|
parser.UnknownOptionHandler = c.parseUnknownFlag
|
||||||
|
|
||||||
|
_, err := parser.ParseArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
flagsErr, ok := err.(*flags.Error)
|
||||||
|
if ok && flagsErr.Type == flags.ErrHelp {
|
||||||
|
// Library has just printed cli help
|
||||||
|
os.Exit(0)
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
|
||||||
*c = strings.Split(value, ",")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
|
func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args []string) ([]string, error) {
|
||||||
return strings.Join(*c, ","), nil
|
// Parse rules in the format "rule.<name>.<param>"
|
||||||
|
parts := strings.Split(option, ".")
|
||||||
|
if len(parts) == 3 && parts[0] == "rule" {
|
||||||
|
// Get or create rule
|
||||||
|
rule, ok := c.Rules[parts[1]]
|
||||||
|
if !ok {
|
||||||
|
rule = NewRule()
|
||||||
|
c.Rules[parts[1]] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get value, or pop the next arg
|
||||||
|
val, ok := arg.Value()
|
||||||
|
if !ok {
|
||||||
|
val = args[0]
|
||||||
|
args = args[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check value
|
||||||
|
if len(val) == 0 {
|
||||||
|
return args, errors.New("route param value is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unquote if required
|
||||||
|
if val[0] == '"' {
|
||||||
|
var err error
|
||||||
|
val, err = strconv.Unquote(val)
|
||||||
|
if err != nil {
|
||||||
|
return args, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add param value to rule
|
||||||
|
switch(parts[2]) {
|
||||||
|
case "action":
|
||||||
|
rule.Action = val
|
||||||
|
case "rule":
|
||||||
|
rule.Rule = val
|
||||||
|
case "provider":
|
||||||
|
rule.Provider = val
|
||||||
|
default:
|
||||||
|
return args, fmt.Errorf("inavlid route param: %v", option)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return args, fmt.Errorf("unknown flag: %v", option)
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Validate() {
|
||||||
|
// Check for show stopper errors
|
||||||
|
if len(c.Secret) == 0 {
|
||||||
|
log.Fatal("\"secret\" option must be set.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
|
||||||
|
log.Fatal("google.providers.client-id, google.providers.client-secret must be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rules
|
||||||
|
for _, rule := range c.Rules {
|
||||||
|
rule.Validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Config) String() string {
|
||||||
|
jsonConf, _ := json.Marshal(c)
|
||||||
|
return string(jsonConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
Action string
|
Action string
|
||||||
Rule string
|
Rule string
|
||||||
|
Provider string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRule() *Rule {
|
||||||
|
return &Rule{
|
||||||
|
Action: "auth",
|
||||||
|
Provider: "google", // TODO: Use default provider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rule) Validate() {
|
||||||
|
if r.Action != "auth" && r.Action != "allow" {
|
||||||
|
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Update with more provider support
|
||||||
|
if r.Provider != "google" {
|
||||||
|
log.Fatal("invalid rule provider, must be \"google\"")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Rule) UnmarshalFlag(value string) error {
|
func (r *Rule) UnmarshalFlag(value string) error {
|
||||||
@ -64,11 +206,11 @@ func (r *Rule) UnmarshalFlag(value string) error {
|
|||||||
parts := strings.SplitN(value, ":", 2)
|
parts := strings.SplitN(value, ":", 2)
|
||||||
|
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
return errors.New("Invalid rule format, should be \"action:rule\"")
|
return errors.New("invalid rule format, should be \"action:rule\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
if parts[0] != "auth" && parts[0] != "allow" {
|
if parts[0] != "auth" && parts[0] != "allow" {
|
||||||
return errors.New("Invalid rule action, must be \"auth\" or \"allow\"")
|
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse rule
|
// Parse rule
|
||||||
@ -85,62 +227,13 @@ func (r *Rule) MarshalFlag() (string, error) {
|
|||||||
return fmt.Sprintf("%+v", *r), nil
|
return fmt.Sprintf("%+v", *r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var config Config
|
type CommaSeparatedList []string
|
||||||
|
|
||||||
// TODO:
|
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
||||||
// - parse ini
|
*c = strings.Split(value, ",")
|
||||||
// - parse env vars
|
return nil
|
||||||
// - parse env var file
|
|
||||||
// - support multiple config files
|
|
||||||
// - maintain backwards compat
|
|
||||||
|
|
||||||
func NewGlobalConfig() Config {
|
|
||||||
return NewGlobalConfigWithArgs(os.Args[1:])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGlobalConfigWithArgs(args []string) Config {
|
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
|
||||||
config = Config{}
|
return strings.Join(*c, ","), nil
|
||||||
|
|
||||||
config.parseFlags(args)
|
|
||||||
|
|
||||||
// Struct defaults
|
|
||||||
config.Providers.Google.Build()
|
|
||||||
|
|
||||||
// Transformations
|
|
||||||
config.Path = fmt.Sprintf("/%s", config.Path)
|
|
||||||
config.Secret = []byte(config.SecretString)
|
|
||||||
config.Lifetime = time.Second * time.Duration(config.LifetimeString)
|
|
||||||
|
|
||||||
// TODO: Backwards compatability
|
|
||||||
// "secret" used to be "cookie-secret"
|
|
||||||
|
|
||||||
return config
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) parseFlags(args []string) {
|
|
||||||
if _, err := flags.ParseArgs(c, args); err != nil {
|
|
||||||
flagsErr, ok := err.(*flags.Error)
|
|
||||||
if ok && flagsErr.Type == flags.ErrHelp {
|
|
||||||
os.Exit(0)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("%+v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Config) Checks() {
|
|
||||||
// Check for show stopper errors
|
|
||||||
if len(c.Secret) == 0 {
|
|
||||||
log.Fatal("\"secret\" option must be set.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
|
|
||||||
log.Fatal("google.providers.client-id, google.providers.client-secret must be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Config) Serialise() string {
|
|
||||||
jsonConf, _ := json.Marshal(c)
|
|
||||||
return string(jsonConf)
|
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,6 @@ package tfa
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
// "github.com/jessevdk/go-flags"
|
|
||||||
// "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -13,7 +11,10 @@ import (
|
|||||||
|
|
||||||
func TestConfigDefaults(t *testing.T) {
|
func TestConfigDefaults(t *testing.T) {
|
||||||
// Check defaults
|
// Check defaults
|
||||||
c := NewGlobalConfigWithArgs([]string{})
|
c, err := NewConfig([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
if c.LogLevel != "warn" {
|
if c.LogLevel != "warn" {
|
||||||
t.Error("LogLevel default should be warn, got", c.LogLevel)
|
t.Error("LogLevel default should be warn, got", c.LogLevel)
|
||||||
@ -46,7 +47,7 @@ func TestConfigDefaults(t *testing.T) {
|
|||||||
if len(c.Domains) != 0 {
|
if len(c.Domains) != 0 {
|
||||||
t.Error("Domain default should be empty, got", c.Domains)
|
t.Error("Domain default should be empty, got", c.Domains)
|
||||||
}
|
}
|
||||||
if c.Lifetime != time.Second*time.Duration(43200) {
|
if c.Lifetime != time.Second * time.Duration(43200) {
|
||||||
t.Error("Lifetime default should be 43200, got", c.Lifetime)
|
t.Error("Lifetime default should be 43200, got", c.Lifetime)
|
||||||
}
|
}
|
||||||
if c.Path != "/_oauth" {
|
if c.Path != "/_oauth" {
|
||||||
@ -66,6 +67,70 @@ func TestConfigDefaults(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigParseFlags(t *testing.T) {
|
||||||
|
c, err := NewConfig([]string{
|
||||||
|
"--path=_oauthpath",
|
||||||
|
"--cookie-name", "\"cookiename\"",
|
||||||
|
"--rule.1.action=allow",
|
||||||
|
"--rule.1.rule=PathPrefix(`/one`)",
|
||||||
|
"--rule.two.action=auth",
|
||||||
|
"--rule.two.rule=\"Host(`two.com`) && Path(`/two`)\"",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check normal flags
|
||||||
|
if c.Path != "/_oauthpath" {
|
||||||
|
t.Error("Path default should be /_oauthpath, got", c.Path)
|
||||||
|
}
|
||||||
|
if c.CookieName != "cookiename" {
|
||||||
|
t.Error("CookieName default should be cookiename, got", c.CookieName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rules
|
||||||
|
if len(c.Rules) != 2 {
|
||||||
|
t.Error("Should create 2 rules, got:", len(c.Rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
// First rule
|
||||||
|
if rule, ok := c.Rules["1"]; !ok {
|
||||||
|
t.Error("Could not find rule key '1'")
|
||||||
|
} else {
|
||||||
|
if rule.Action != "allow" {
|
||||||
|
t.Error("First rule action should be allow, got:", rule.Action)
|
||||||
|
}
|
||||||
|
if rule.Rule != "PathPrefix(`/one`)" {
|
||||||
|
t.Error("First rule rule should be PathPrefix(`/one`), got:", rule.Rule)
|
||||||
|
}
|
||||||
|
if rule.Provider != "google" {
|
||||||
|
t.Error("First rule provider should be google, got:", rule.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second rule
|
||||||
|
if rule, ok := c.Rules["two"]; !ok {
|
||||||
|
t.Error("Could not find rule key '1'")
|
||||||
|
} else {
|
||||||
|
if rule.Action != "auth" {
|
||||||
|
t.Error("Second rule action should be auth, got:", rule.Action)
|
||||||
|
}
|
||||||
|
if rule.Rule != "Host(`two.com`) && Path(`/two`)" {
|
||||||
|
t.Error("Second rule rule should be Host(`two.com`) && Path(`/two`), got:", rule.Rule)
|
||||||
|
}
|
||||||
|
if rule.Provider != "google" {
|
||||||
|
t.Error("Second rule provider should be google, got:", rule.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func TestConfigParseUnknownFlags(t *testing.T) {
|
||||||
|
// c := NewConfig([]string{
|
||||||
|
// "--unknown=_oauthpath",
|
||||||
|
// })
|
||||||
|
|
||||||
|
// }
|
||||||
|
|
||||||
// func TestConfigToml(t *testing.T) {
|
// func TestConfigToml(t *testing.T) {
|
||||||
// logrus.SetLevel(logrus.DebugLevel)
|
// logrus.SetLevel(logrus.DebugLevel)
|
||||||
// flag.CommandLine = flag.NewFlagSet("tfa-test", flag.ContinueOnError)
|
// flag.CommandLine = flag.NewFlagSet("tfa-test", flag.ContinueOnError)
|
||||||
|
@ -236,7 +236,6 @@ func TestServerAuthCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServerRoutePathPrefix(t *testing.T) {
|
func TestServerRoutePathPrefix(t *testing.T) {
|
||||||
server := NewServer()
|
|
||||||
config = Config{
|
config = Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
CookieName: "cookie_test",
|
CookieName: "cookie_test",
|
||||||
@ -253,18 +252,25 @@ func TestServerRoutePathPrefix(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Rules: []Rule{
|
Rules: map[string]*Rule{
|
||||||
{
|
"web1": &Rule{
|
||||||
Action: "allow",
|
Action: "allow",
|
||||||
Rule: "PathPrefix(`/api`)",
|
Rule: "PathPrefix(`/api`)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
server := NewServer()
|
||||||
|
|
||||||
|
// Should block any request
|
||||||
|
req := newHttpRequest("/random")
|
||||||
|
res, _ := httpRequest(server, req, nil)
|
||||||
|
if res.StatusCode != 307 {
|
||||||
|
t.Error("Request not matching any rule should require auth, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// Should allow /api request
|
// Should allow /api request
|
||||||
req := newHttpRequest("/api")
|
req = newHttpRequest("/api")
|
||||||
c := MakeCookie(req, "test@example.com")
|
res, _ = httpRequest(server, req, nil)
|
||||||
res, _ := httpRequest(server, req, c)
|
|
||||||
if res.StatusCode != 200 {
|
if res.StatusCode != 200 {
|
||||||
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user