package tfa import ( "bytes" "encoding/json" "errors" "fmt" "io" "io/ioutil" "os" "regexp" "strconv" "strings" "time" "github.com/thomseddon/go-flags" "github.com/thomseddon/traefik-forward-auth/internal/provider" ) var config *Config // Config holds the runtime application config type Config struct { LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"` LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"` AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Single host to use when returning from 3rd party auth"` Config func(s string) error `long:"config" env:"CONFIG" description:"Path to config file" json:"-"` CookieDomains []CookieDomain `long:"cookie-domain" env:"COOKIE_DOMAIN" env-delim:"," description:"Domain to set auth cookie on, can be set multiple times"` InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"` CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"` CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"` DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"` DefaultProvider string `long:"default-provider" env:"DEFAULT_PROVIDER" default:"google" choice:"google" choice:"oidc" choice:"generic-oauth" description:"Default provider"` Domains CommaSeparatedList `long:"domain" env:"DOMAIN" env-delim:"," description:"Only allow given email domains, can be set multiple times"` LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"` LogoutRedirect string `long:"logout-redirect" env:"LOGOUT_REDIRECT" description:"URL to redirect to following logout"` MatchWhitelistOrDomain bool `long:"match-whitelist-or-domain" env:"MATCH_WHITELIST_OR_DOMAIN" description:"Allow users that match *either* whitelist or domain (enabled by default in v3)"` Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"` SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"` Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" env-delim:"," description:"Only allow given email addresses, can be set multiple times"` Port int `long:"port" env:"PORT" default:"4181" description:"Port to listen on"` Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"` Rules map[string]*Rule `long:"rule.." description:"Rule definitions, param can be: \"action\", \"rule\" or \"provider\""` // Filled during transformations Secret []byte `json:"-"` Lifetime time.Duration // Authorization RequiredRole string `long:"required-role" env:"REQUIRED_ROLE" description:"Required role to verify authorization"` // Legacy CookieDomainsLegacy CookieDomains `long:"cookie-domains" env:"COOKIE_DOMAINS" description:"DEPRECATED - Use \"cookie-domain\""` CookieSecretLegacy string `long:"cookie-secret" env:"COOKIE_SECRET" description:"DEPRECATED - Use \"secret\"" json:"-"` CookieSecureLegacy string `long:"cookie-secure" env:"COOKIE_SECURE" description:"DEPRECATED - Use \"insecure-cookie\""` ClientIdLegacy string `long:"client-id" env:"CLIENT_ID" description:"DEPRECATED - Use \"providers.google.client-id\""` ClientSecretLegacy string `long:"client-secret" env:"CLIENT_SECRET" description:"DEPRECATED - Use \"providers.google.client-id\"" json:"-"` PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""` } // NewGlobalConfig creates a new global config, parsed from command arguments func NewGlobalConfig() *Config { var err error config, err = NewConfig(os.Args[1:]) if err != nil { fmt.Printf("%+v\n", err) os.Exit(1) } return config } // TODO: move config parsing into new func "NewParsedConfig" // NewConfig parses and validates provided configuration into a config object func NewConfig(args []string) (*Config, error) { c := &Config{ Rules: map[string]*Rule{}, } err := c.parseFlags(args) if err != nil { return c, err } // TODO: as log flags have now been parsed maybe we should return here so // any further errors can be logged via logrus instead of printed? // TODO: Rename "Validate" method to "Setup" and move all below logic // Setup // Set default provider on any rules where it's not specified for _, rule := range c.Rules { if rule.Provider == "" { rule.Provider = c.DefaultProvider } } // Backwards compatability if c.CookieSecretLegacy != "" && c.SecretString == "" { fmt.Println("cookie-secret config option is deprecated, please use secret") c.SecretString = c.CookieSecretLegacy } if c.ClientIdLegacy != "" { c.Providers.Google.ClientID = c.ClientIdLegacy } if c.ClientSecretLegacy != "" { c.Providers.Google.ClientSecret = c.ClientSecretLegacy } if c.PromptLegacy != "" { fmt.Println("prompt config option is deprecated, please use providers.google.prompt") c.Providers.Google.Prompt = c.PromptLegacy } if c.CookieSecureLegacy != "" { fmt.Println("cookie-secure config option is deprecated, please use insecure-cookie") secure, err := strconv.ParseBool(c.CookieSecureLegacy) if err != nil { return c, err } c.InsecureCookie = !secure } if len(c.CookieDomainsLegacy) > 0 { fmt.Println("cookie-domains config option is deprecated, please use cookie-domain") c.CookieDomains = append(c.CookieDomains, c.CookieDomainsLegacy...) } // Transformations if len(c.Path) > 0 && c.Path[0] != '/' { c.Path = "/" + c.Path } c.Secret = []byte(c.SecretString) c.Lifetime = time.Second * time.Duration(c.LifetimeString) return c, nil } func (c *Config) parseFlags(args []string) error { p := flags.NewParser(c, flags.Default|flags.IniUnknownOptionHandler) p.UnknownOptionHandler = c.parseUnknownFlag i := flags.NewIniParser(p) c.Config = func(s string) error { // Try parsing at as an ini err := i.ParseFile(s) // If it fails with a syntax error, try converting legacy to ini if err != nil && strings.Contains(err.Error(), "malformed key=value") { converted, convertErr := convertLegacyToIni(s) if convertErr != nil { // If conversion fails, return the original error return err } fmt.Println("config format deprecated, please use ini format") return i.Parse(converted) } return err } _, err := p.ParseArgs(args) if err != nil { return handleFlagError(err) } return nil } func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args []string) ([]string, error) { // Parse rules in the format "rule.." parts := strings.Split(option, ".") if len(parts) == 3 && parts[0] == "rule" { // Ensure there is a name name := parts[1] if len(name) == 0 { return args, errors.New("route name is required") } // Get value, or pop the next arg val, ok := arg.Value() if !ok && len(args) > 1 { val = args[0] args = args[1:] } // Check value if len(val) == 0 { return args, errors.New("route param value is required") } // Unquote if required if val[0] == '"' { var err error val, err = strconv.Unquote(val) if err != nil { return args, err } } // Get or create rule rule, ok := c.Rules[name] if !ok { rule = NewRule() c.Rules[name] = rule } // Add param value to rule switch parts[2] { case "action": rule.Action = val case "rule": rule.Rule = val case "provider": rule.Provider = val case "whitelist": list := CommaSeparatedList{} list.UnmarshalFlag(val) rule.Whitelist = list case "domains": list := CommaSeparatedList{} list.UnmarshalFlag(val) rule.Domains = list default: return args, fmt.Errorf("invalid route param: %v", option) } } else { return args, fmt.Errorf("unknown flag: %v", option) } return args, nil } func handleFlagError(err error) error { flagsErr, ok := err.(*flags.Error) if ok && flagsErr.Type == flags.ErrHelp { // Library has just printed cli help os.Exit(0) } return err } var legacyFileFormat = regexp.MustCompile(`(?m)^([a-z-]+) (.*)$`) func convertLegacyToIni(name string) (io.Reader, error) { b, err := ioutil.ReadFile(name) if err != nil { return nil, err } return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil } // Validate validates a config object func (c *Config) Validate() { // Check for show stopper errors if len(c.Secret) == 0 { log.Fatal("\"secret\" option must be set") } // Setup default provider err := c.setupProvider(c.DefaultProvider) if err != nil { log.Fatal(err) } // Check rules (validates the rule and the rule provider) for _, rule := range c.Rules { err = rule.Validate(c) if err != nil { log.Fatal(err) } } } func (c Config) String() string { jsonConf, _ := json.Marshal(c) return string(jsonConf) } // GetProvider returns the provider of the given name func (c *Config) GetProvider(name string) (provider.Provider, error) { switch name { case "google": return &c.Providers.Google, nil case "oidc": return &c.Providers.OIDC, nil case "generic-oauth": return &c.Providers.GenericOAuth, nil } return nil, fmt.Errorf("Unknown provider: %s", name) } // GetConfiguredProvider returns the provider of the given name, if it has been // configured. Returns an error if the provider is unknown, or hasn't been configured func (c *Config) GetConfiguredProvider(name string) (provider.Provider, error) { // Check the provider has been configured if !c.providerConfigured(name) { return nil, fmt.Errorf("Unconfigured provider: %s", name) } return c.GetProvider(name) } func (c *Config) providerConfigured(name string) bool { // Check default provider if name == c.DefaultProvider { return true } // Check rule providers for _, rule := range c.Rules { if name == rule.Provider { return true } } return false } func (c *Config) setupProvider(name string) error { // Check provider exists p, err := c.GetProvider(name) if err != nil { return err } // Setup err = p.Setup(log) if err != nil { return err } return nil } // Rule holds defined rules type Rule struct { Action string Rule string Provider string Whitelist CommaSeparatedList Domains CommaSeparatedList } // NewRule creates a new rule object func NewRule() *Rule { return &Rule{ Action: "auth", } } func (r *Rule) formattedRule() string { // Traefik implements their own "Host" matcher and then offers "HostRegexp" // to invoke the mux "Host" matcher. This ensures the mux version is used return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(") } // Validate validates a rule func (r *Rule) Validate(c *Config) error { if r.Action != "auth" && r.Action != "allow" { return errors.New("invalid rule action, must be \"auth\" or \"allow\"") } return c.setupProvider(r.Provider) } // Legacy support for comma separated lists // CommaSeparatedList provides legacy support for config values provided as csv type CommaSeparatedList []string // UnmarshalFlag converts a comma separated list to an array func (c *CommaSeparatedList) UnmarshalFlag(value string) error { *c = append(*c, strings.Split(value, ",")...) return nil } // MarshalFlag converts an array back to a comma separated list func (c *CommaSeparatedList) MarshalFlag() (string, error) { return strings.Join(*c, ","), nil }