Add more v2 tests + fixes + improve legacy config parsing

This commit is contained in:
Thom Seddon
2019-04-18 15:07:39 +01:00
parent 5597b7268b
commit 6968f6181b
13 changed files with 538 additions and 407 deletions

View File

@ -1,10 +1,14 @@
package tfa
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"regexp"
"strconv"
"strings"
"time"
@ -16,48 +20,41 @@ import (
var config Config
type Config struct {
LogLevel string `long:"log-level" default:"warn" description:"Log level: trace, debug, info, warn, error, fatal, panic"`
LogFormat string `long:"log-format" default:"text" description:"Log format: text, json, pretty"`
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
AuthHost string `long:"auth-host" description:"Host for central auth login"`
ConfigFile string `long:"config-file" description:"Config File"`
CookieDomains CookieDomains `long:"cookie-domains" description:"Comma separated list of cookie domains"`
CookieInsecure bool `long:"cookie-insecure" description:"Use secure cookies"`
CookieName string `long:"cookie-name" default:"_forward_auth" description:"Cookie Name"`
CSRFCookieName string `long:"csrf-cookie-name" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
DefaultAction string `long:"default-action" default:"allow" description:"Default Action"`
Domains CommaSeparatedList `long:"domains" description:"Comma separated list of email domains to allow"`
LifetimeString int `long:"lifetime" default:"43200" description:"Lifetime in seconds"`
Path string `long:"path" default:"_oauth" description:"Callback URL Path"`
SecretString string `long:"secret" description:"*Secret used for signing (required)"`
Whitelist CommaSeparatedList `long:"whitelist" description:"Comma separated list of email addresses to allow"`
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Host for central auth login"`
Config func(s string) error `long:"config" env:"CONFIG" description:"Config file"`
CookieDomains CookieDomains `long:"cookie-domains" env:"COOKIE_DOMAINS" description:"Comma separated list of cookie domains"`
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default Action"`
Domains CommaSeparatedList `long:"domains" env:"DOMAINS" description:"Comma separated list of email domains to allow"`
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
Path string `long:"url-path" env:"URL_PATH" default:"_oauth" description:"Callback URL Path"`
SecretString string `long:"secret" env:"SECRET" description:"*Secret used for signing (required)"`
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Comma separated list of email addresses to allow"`
Providers provider.Providers
Rules map[string]*Rule `long:"rule"`
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
Rules map[string]*Rule `long:"rules.<name>.<param>" description:"Rule definitions, see docs, param can be: \"action\", \"rule\""`
// Filled during transformations
Secret []byte
Lifetime time.Duration
Prompt string `long:"prompt" description:"DEPRECATED - Use providers.google.prompt"`
// TODO: Need to mimick the default behaviour of bool flags
CookieSecure string `long:"cookie-secure" default:"true" description:"DEPRECATED - Use \"cookie-insecure\""`
flags []string
usingToml bool
// Legacy
ClientIdLegacy string `long:"client-id" env:"CLIENT_ID" group:"DEPs" description:"DEPRECATED - Use \"providers.google.client-id\""`
ClientSecretLegacy string `long:"client-secret" env:"CLIENT_SECRET" description:"DEPRECATED - Use \"providers.google.client-id\""`
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
CookieSecureLegacy string `long:"cookie-secure" env:"COOKIE_SECURE" namespace:"DERPS" description:"DEPRECATED - Use \"insecure-cookie\""`
}
// TODO:
// - parse ini
// - parse env vars
// - parse env var file
// - support multiple config files
// - maintain backwards compat
func NewGlobalConfig() Config {
var err error
config, err = NewConfig(os.Args[1:])
if err != nil {
fmt.Printf("startup error: %+v", err)
fmt.Printf("%+v\n", err)
os.Exit(1)
}
@ -74,7 +71,28 @@ func NewConfig(args []string) (Config, error) {
return c, err
}
// Struct defaults
// TODO: as log flags have now been parsed maybe we should return here so
// any further errors can be logged via logrus instead of printed?
// Backwards compatability
if c.ClientIdLegacy != "" {
c.Providers.Google.ClientId = c.ClientIdLegacy
}
if c.ClientSecretLegacy != "" {
c.Providers.Google.ClientSecret = c.ClientSecretLegacy
}
if c.PromptLegacy != "" {
c.Providers.Google.Prompt = c.PromptLegacy
}
if c.CookieSecureLegacy != "" {
secure, err := strconv.ParseBool(c.CookieSecureLegacy)
if err != nil {
return c, err
}
c.InsecureCookie = !secure
}
// Provider defaults
c.Providers.Google.Build()
// Transformations
@ -82,25 +100,35 @@ func NewConfig(args []string) (Config, error) {
c.Secret = []byte(c.SecretString)
c.Lifetime = time.Second * time.Duration(c.LifetimeString)
// TODO: Backwards compatability
// "secret" used to be "cookie-secret"
return c, nil
}
func (c *Config) parseFlags(args []string) error {
parser := flags.NewParser(c, flags.Default)
parser.UnknownOptionHandler = c.parseUnknownFlag
p := flags.NewParser(c, flags.Default)
p.UnknownOptionHandler = c.parseUnknownFlag
_, err := parser.ParseArgs(args)
if err != nil {
flagsErr, ok := err.(*flags.Error)
if ok && flagsErr.Type == flags.ErrHelp {
// Library has just printed cli help
os.Exit(0)
} else {
return err
i := flags.NewIniParser(p)
c.Config = func(s string) error {
// Try parsing at as an ini
err := i.ParseFile(s)
// If it fails with a syntax error, try converting legacy to ini
if err != nil && strings.Contains(err.Error(), "malformed key=value") {
converted, convertErr := convertLegacyToIni(s)
if convertErr != nil {
// If conversion fails, return the original error
return err
}
return i.Parse(converted)
}
return err
}
_, err := p.ParseArgs(args)
if err != nil {
return handlFlagError(err)
}
return nil
@ -139,7 +167,7 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [
}
// Add param value to rule
switch(parts[2]) {
switch parts[2] {
case "action":
rule.Action = val
case "rule":
@ -156,6 +184,27 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [
return args, nil
}
func handlFlagError(err error) error {
flagsErr, ok := err.(*flags.Error)
if ok && flagsErr.Type == flags.ErrHelp {
// Library has just printed cli help
os.Exit(0)
}
return err
}
var legacyFileFormat = regexp.MustCompile(`^([a-z-]+) ([\w\W]+)$`)
func convertLegacyToIni(name string) (io.Reader, error) {
b, err := ioutil.ReadFile(name)
if err != nil {
return nil, err
}
return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil
}
func (c *Config) Validate() {
// Check for show stopper errors
if len(c.Secret) == 0 {
@ -185,7 +234,7 @@ type Rule struct {
func NewRule() *Rule {
return &Rule{
Action: "auth",
Action: "auth",
Provider: "google", // TODO: Use default provider
}
}
@ -201,32 +250,6 @@ func (r *Rule) Validate() {
}
}
func (r *Rule) UnmarshalFlag(value string) error {
// Format is "action:rule"
parts := strings.SplitN(value, ":", 2)
if len(parts) != 2 {
return errors.New("invalid rule format, should be \"action:rule\"")
}
if parts[0] != "auth" && parts[0] != "allow" {
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
}
// Parse rule
*r = Rule{
Action: parts[0],
Rule: parts[1],
}
return nil
}
func (r *Rule) MarshalFlag() (string, error) {
// TODO: format correctly
return fmt.Sprintf("%+v", *r), nil
}
type CommaSeparatedList []string
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {