
The previous behaviour would ignore domains if the whitelist parameter was provided, however if both parameters are provided then matching either is more likely the intent.
370 lines
11 KiB
Go
370 lines
11 KiB
Go
package tfa
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/thomseddon/go-flags"
|
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
|
)
|
|
|
|
var config *Config
|
|
|
|
// Config holds the runtime application config
|
|
type Config struct {
|
|
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
|
|
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
|
|
|
|
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"`
|
|
Config func(s string) error `long:"config" env:"CONFIG" description:"Path to config file" json:"-"`
|
|
CookieDomains []CookieDomain `long:"cookie-domain" env:"COOKIE_DOMAIN" env-delim:"," description:"Domain to set auth cookie on, can be set multiple times"`
|
|
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
|
|
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
|
|
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
|
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"`
|
|
DefaultProvider string `long:"default-provider" env:"DEFAULT_PROVIDER" default:"google" choice:"google" choice:"oidc" description:"Default provider"`
|
|
Domains CommaSeparatedList `long:"domain" env:"DOMAIN" env-delim:"," description:"Only allow given email domains, can be set multiple times"`
|
|
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
|
|
LogoutRedirect string `long:"logout-redirect" env:"LOGOUT_REDIRECT" description:"URL to redirect to following logout"`
|
|
MatchWhitelistOrDomain bool `long:"match-whitelist-or-domain" env:"MATCH_WHITELIST_OR_DOMAIN" description:"Allow users that match *either* whitelist or domain (enabled by default in v3)"`
|
|
Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"`
|
|
SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"`
|
|
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" env-delim:"," description:"Only allow given email addresses, can be set multiple times"`
|
|
|
|
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
|
|
Rules map[string]*Rule `long:"rule.<name>.<param>" description:"Rule definitions, param can be: \"action\", \"rule\" or \"provider\""`
|
|
|
|
// Filled during transformations
|
|
Secret []byte `json:"-"`
|
|
Lifetime time.Duration
|
|
|
|
// Legacy
|
|
CookieDomainsLegacy CookieDomains `long:"cookie-domains" env:"COOKIE_DOMAINS" description:"DEPRECATED - Use \"cookie-domain\""`
|
|
CookieSecretLegacy string `long:"cookie-secret" env:"COOKIE_SECRET" description:"DEPRECATED - Use \"secret\"" json:"-"`
|
|
CookieSecureLegacy string `long:"cookie-secure" env:"COOKIE_SECURE" description:"DEPRECATED - Use \"insecure-cookie\""`
|
|
ClientIdLegacy string `long:"client-id" env:"CLIENT_ID" description:"DEPRECATED - Use \"providers.google.client-id\""`
|
|
ClientSecretLegacy string `long:"client-secret" env:"CLIENT_SECRET" description:"DEPRECATED - Use \"providers.google.client-id\"" json:"-"`
|
|
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
|
|
}
|
|
|
|
// NewGlobalConfig creates a new global config, parsed from command arguments
|
|
func NewGlobalConfig() *Config {
|
|
var err error
|
|
config, err = NewConfig(os.Args[1:])
|
|
if err != nil {
|
|
fmt.Printf("%+v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
return config
|
|
}
|
|
|
|
// TODO: move config parsing into new func "NewParsedConfig"
|
|
|
|
// NewConfig parses and validates provided configuration into a config object
|
|
func NewConfig(args []string) (*Config, error) {
|
|
c := &Config{
|
|
Rules: map[string]*Rule{},
|
|
}
|
|
|
|
err := c.parseFlags(args)
|
|
if err != nil {
|
|
return c, err
|
|
}
|
|
|
|
// TODO: as log flags have now been parsed maybe we should return here so
|
|
// any further errors can be logged via logrus instead of printed?
|
|
|
|
// TODO: Rename "Validate" method to "Setup" and move all below logic
|
|
|
|
// Setup
|
|
// Set default provider on any rules where it's not specified
|
|
for _, rule := range c.Rules {
|
|
if rule.Provider == "" {
|
|
rule.Provider = c.DefaultProvider
|
|
}
|
|
}
|
|
|
|
// Backwards compatability
|
|
if c.CookieSecretLegacy != "" && c.SecretString == "" {
|
|
fmt.Println("cookie-secret config option is deprecated, please use secret")
|
|
c.SecretString = c.CookieSecretLegacy
|
|
}
|
|
if c.ClientIdLegacy != "" {
|
|
c.Providers.Google.ClientID = c.ClientIdLegacy
|
|
}
|
|
if c.ClientSecretLegacy != "" {
|
|
c.Providers.Google.ClientSecret = c.ClientSecretLegacy
|
|
}
|
|
if c.PromptLegacy != "" {
|
|
fmt.Println("prompt config option is deprecated, please use providers.google.prompt")
|
|
c.Providers.Google.Prompt = c.PromptLegacy
|
|
}
|
|
if c.CookieSecureLegacy != "" {
|
|
fmt.Println("cookie-secure config option is deprecated, please use insecure-cookie")
|
|
secure, err := strconv.ParseBool(c.CookieSecureLegacy)
|
|
if err != nil {
|
|
return c, err
|
|
}
|
|
c.InsecureCookie = !secure
|
|
}
|
|
if len(c.CookieDomainsLegacy) > 0 {
|
|
fmt.Println("cookie-domains config option is deprecated, please use cookie-domain")
|
|
c.CookieDomains = append(c.CookieDomains, c.CookieDomainsLegacy...)
|
|
}
|
|
|
|
// Transformations
|
|
if len(c.Path) > 0 && c.Path[0] != '/' {
|
|
c.Path = "/" + c.Path
|
|
}
|
|
c.Secret = []byte(c.SecretString)
|
|
c.Lifetime = time.Second * time.Duration(c.LifetimeString)
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (c *Config) parseFlags(args []string) error {
|
|
p := flags.NewParser(c, flags.Default|flags.IniUnknownOptionHandler)
|
|
p.UnknownOptionHandler = c.parseUnknownFlag
|
|
|
|
i := flags.NewIniParser(p)
|
|
c.Config = func(s string) error {
|
|
// Try parsing at as an ini
|
|
err := i.ParseFile(s)
|
|
|
|
// If it fails with a syntax error, try converting legacy to ini
|
|
if err != nil && strings.Contains(err.Error(), "malformed key=value") {
|
|
converted, convertErr := convertLegacyToIni(s)
|
|
if convertErr != nil {
|
|
// If conversion fails, return the original error
|
|
return err
|
|
}
|
|
|
|
fmt.Println("config format deprecated, please use ini format")
|
|
return i.Parse(converted)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
_, err := p.ParseArgs(args)
|
|
if err != nil {
|
|
return handleFlagError(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args []string) ([]string, error) {
|
|
// Parse rules in the format "rule.<name>.<param>"
|
|
parts := strings.Split(option, ".")
|
|
if len(parts) == 3 && parts[0] == "rule" {
|
|
// Ensure there is a name
|
|
name := parts[1]
|
|
if len(name) == 0 {
|
|
return args, errors.New("route name is required")
|
|
}
|
|
|
|
// Get value, or pop the next arg
|
|
val, ok := arg.Value()
|
|
if !ok && len(args) > 1 {
|
|
val = args[0]
|
|
args = args[1:]
|
|
}
|
|
|
|
// Check value
|
|
if len(val) == 0 {
|
|
return args, errors.New("route param value is required")
|
|
}
|
|
|
|
// Unquote if required
|
|
if val[0] == '"' {
|
|
var err error
|
|
val, err = strconv.Unquote(val)
|
|
if err != nil {
|
|
return args, err
|
|
}
|
|
}
|
|
|
|
// Get or create rule
|
|
rule, ok := c.Rules[name]
|
|
if !ok {
|
|
rule = NewRule()
|
|
c.Rules[name] = rule
|
|
}
|
|
|
|
// Add param value to rule
|
|
switch parts[2] {
|
|
case "action":
|
|
rule.Action = val
|
|
case "rule":
|
|
rule.Rule = val
|
|
case "provider":
|
|
rule.Provider = val
|
|
default:
|
|
return args, fmt.Errorf("invalid route param: %v", option)
|
|
}
|
|
} else {
|
|
return args, fmt.Errorf("unknown flag: %v", option)
|
|
}
|
|
|
|
return args, nil
|
|
}
|
|
|
|
func handleFlagError(err error) error {
|
|
flagsErr, ok := err.(*flags.Error)
|
|
if ok && flagsErr.Type == flags.ErrHelp {
|
|
// Library has just printed cli help
|
|
os.Exit(0)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
var legacyFileFormat = regexp.MustCompile(`(?m)^([a-z-]+) (.*)$`)
|
|
|
|
func convertLegacyToIni(name string) (io.Reader, error) {
|
|
b, err := ioutil.ReadFile(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil
|
|
}
|
|
|
|
// Validate validates a config object
|
|
func (c *Config) Validate() {
|
|
// Check for show stopper errors
|
|
if len(c.Secret) == 0 {
|
|
log.Fatal("\"secret\" option must be set")
|
|
}
|
|
|
|
// Setup default provider
|
|
err := c.setupProvider(c.DefaultProvider)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Check rules (validates the rule and the rule provider)
|
|
for _, rule := range c.Rules {
|
|
err = rule.Validate(c)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c Config) String() string {
|
|
jsonConf, _ := json.Marshal(c)
|
|
return string(jsonConf)
|
|
}
|
|
|
|
// GetProvider returns the provider of the given name
|
|
func (c *Config) GetProvider(name string) (provider.Provider, error) {
|
|
switch name {
|
|
case "google":
|
|
return &c.Providers.Google, nil
|
|
case "oidc":
|
|
return &c.Providers.OIDC, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("Unknown provider: %s", name)
|
|
}
|
|
|
|
// GetConfiguredProvider returns the provider of the given name, if it has been
|
|
// configured. Returns an error if the provider is unknown, or hasn't been configured
|
|
func (c *Config) GetConfiguredProvider(name string) (provider.Provider, error) {
|
|
// Check the provider has been configured
|
|
if !c.providerConfigured(name) {
|
|
return nil, fmt.Errorf("Unconfigured provider: %s", name)
|
|
}
|
|
|
|
return c.GetProvider(name)
|
|
}
|
|
|
|
func (c *Config) providerConfigured(name string) bool {
|
|
// Check default provider
|
|
if name == c.DefaultProvider {
|
|
return true
|
|
}
|
|
|
|
// Check rule providers
|
|
for _, rule := range c.Rules {
|
|
if name == rule.Provider {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (c *Config) setupProvider(name string) error {
|
|
// Check provider exists
|
|
p, err := c.GetProvider(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Setup
|
|
err = p.Setup()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Rule holds defined rules
|
|
type Rule struct {
|
|
Action string
|
|
Rule string
|
|
Provider string
|
|
}
|
|
|
|
// NewRule creates a new rule object
|
|
func NewRule() *Rule {
|
|
return &Rule{
|
|
Action: "auth",
|
|
}
|
|
}
|
|
|
|
func (r *Rule) formattedRule() string {
|
|
// Traefik implements their own "Host" matcher and then offers "HostRegexp"
|
|
// to invoke the mux "Host" matcher. This ensures the mux version is used
|
|
return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(")
|
|
}
|
|
|
|
// Validate validates a rule
|
|
func (r *Rule) Validate(c *Config) error {
|
|
if r.Action != "auth" && r.Action != "allow" {
|
|
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
|
|
}
|
|
|
|
return c.setupProvider(r.Provider)
|
|
}
|
|
|
|
// Legacy support for comma separated lists
|
|
|
|
// CommaSeparatedList provides legacy support for config values provided as csv
|
|
type CommaSeparatedList []string
|
|
|
|
// UnmarshalFlag converts a comma separated list to an array
|
|
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
|
*c = append(*c, strings.Split(value, ",")...)
|
|
return nil
|
|
}
|
|
|
|
// MarshalFlag converts an array back to a comma separated list
|
|
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
|
|
return strings.Join(*c, ","), nil
|
|
}
|