diff --git a/internal/auth.go b/internal/auth.go index 1962e3c..b9b1907 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -17,6 +17,7 @@ import ( // Request Validation +// ValidateCookie verifies that a cookie matches the expected format of: // Cookie = hash(secret, cookie domain, email, expires)|expires|email func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { parts := strings.Split(c.Value, "|") @@ -55,7 +56,7 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { return parts[2], nil } -// Validate email +// ValidateEmail verifies that an email is permitted by the current config func ValidateEmail(email string) bool { found := false if len(config.Whitelist) > 0 { @@ -126,7 +127,7 @@ func useAuthDomain(r *http.Request) (bool, string) { // Cookie methods -// Create an auth cookie +// MakeCookie creates an auth cookie func MakeCookie(r *http.Request, email string) *http.Cookie { expires := cookieExpiry() mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix())) @@ -143,7 +144,7 @@ func MakeCookie(r *http.Request, email string) *http.Cookie { } } -// Make a CSRF cookie (used during login only) +// MakeCSRFCookie makes a csrf cookie (used during login only) func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { return &http.Cookie{ Name: config.CSRFCookieName, @@ -156,7 +157,7 @@ func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { } } -// Create a cookie to clear csrf cookie +// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie func ClearCSRFCookie(r *http.Request) *http.Cookie { return &http.Cookie{ Name: config.CSRFCookieName, @@ -169,7 +170,7 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie { } } -// Validate the csrf cookie against state +// ValidateCSRFCookie validates the csrf cookie against state func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) { state := r.URL.Query().Get("state") @@ -197,12 +198,13 @@ func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider s return true, params[:split], params[split+1:], nil } +// MakeState generates a state value func MakeState(r *http.Request, p provider.Provider, nonce string) string { return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r)) } +// Nonce generates a random nonce func Nonce() (error, string) { - // Make nonce nonce := make([]byte, 16) _, err := rand.Read(nonce) if err != nil { @@ -263,9 +265,7 @@ func cookieExpiry() time.Time { return time.Now().Local().Add(config.Lifetime) } -// Cookie Domain - -// Cookie Domain +// CookieDomain holds cookie domain info type CookieDomain struct { Domain string DomainLen int @@ -273,6 +273,7 @@ type CookieDomain struct { SubDomainLen int } +// NewCookieDomain creates a new CookieDomain from the given domain string func NewCookieDomain(domain string) *CookieDomain { return &CookieDomain{ Domain: domain, @@ -282,6 +283,7 @@ func NewCookieDomain(domain string) *CookieDomain { } } +// Match checks if the given host matches this CookieDomain func (c *CookieDomain) Match(host string) bool { // Exact domain match? if host == c.Domain { @@ -296,19 +298,22 @@ func (c *CookieDomain) Match(host string) bool { return false } +// UnmarshalFlag converts a string to a CookieDomain func (c *CookieDomain) UnmarshalFlag(value string) error { *c = *NewCookieDomain(value) return nil } +// MarshalFlag converts a CookieDomain to a string func (c *CookieDomain) MarshalFlag() (string, error) { return c.Domain, nil } -// Legacy support for comma separated list of cookie domains - +// CookieDomains provides legacy sypport for comma separated list of cookie domains type CookieDomains []CookieDomain +// UnmarshalFlag converts a comma separated list of cookie domains to an array +// of CookieDomains func (c *CookieDomains) UnmarshalFlag(value string) error { if len(value) > 0 { for _, d := range strings.Split(value, ",") { @@ -319,6 +324,7 @@ func (c *CookieDomains) UnmarshalFlag(value string) error { return nil } +// MarshalFlag converts an array of CookieDomain to a comma seperated list func (c *CookieDomains) MarshalFlag() (string, error) { var domains []string for _, d := range *c { diff --git a/internal/config.go b/internal/config.go index c95ad8b..17850cd 100644 --- a/internal/config.go +++ b/internal/config.go @@ -19,6 +19,7 @@ import ( var config *Config +// Config holds the runtime application config type Config struct { LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"` LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"` @@ -53,6 +54,7 @@ type Config struct { PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""` } +// NewGlobalConfig creates a new global config, parsed from command arguments func NewGlobalConfig() *Config { var err error config, err = NewConfig(os.Args[1:]) @@ -66,6 +68,7 @@ func NewGlobalConfig() *Config { // TODO: move config parsing into new func "NewParsedConfig" +// NewConfig parses and validates provided configuration into a config object func NewConfig(args []string) (*Config, error) { c := &Config{ Rules: map[string]*Rule{}, @@ -236,6 +239,7 @@ func convertLegacyToIni(name string) (io.Reader, error) { return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil } +// Validate validates a config object func (c *Config) Validate() { // Check for show stopper errors if len(c.Secret) == 0 { @@ -317,12 +321,14 @@ func (c *Config) setupProvider(name string) error { return nil } +// Rule holds defined rules type Rule struct { Action string Rule string Provider string } +// NewRule creates a new rule object func NewRule() *Rule { return &Rule{ Action: "auth", @@ -335,6 +341,7 @@ func (r *Rule) formattedRule() string { return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(") } +// Validate validates a rule func (r *Rule) Validate(c *Config) error { if r.Action != "auth" && r.Action != "allow" { return errors.New("invalid rule action, must be \"auth\" or \"allow\"") @@ -345,13 +352,16 @@ func (r *Rule) Validate(c *Config) error { // Legacy support for comma separated lists +// CommaSeparatedList provides legacy support for config values provided as csv type CommaSeparatedList []string +// UnmarshalFlag converts a comma separated list to an array func (c *CommaSeparatedList) UnmarshalFlag(value string) error { *c = append(*c, strings.Split(value, ",")...) return nil } +// MarshalFlag converts an array back to a comma separated list func (c *CommaSeparatedList) MarshalFlag() (string, error) { return strings.Join(*c, ","), nil } diff --git a/internal/log.go b/internal/log.go index 10a10d4..7afb336 100644 --- a/internal/log.go +++ b/internal/log.go @@ -8,6 +8,7 @@ import ( var log *logrus.Logger +// NewDefaultLogger creates a new logger based on the current configuration func NewDefaultLogger() *logrus.Logger { // Setup logger log = logrus.StandardLogger() diff --git a/internal/server.go b/internal/server.go index 9027a47..be3363f 100644 --- a/internal/server.go +++ b/internal/server.go @@ -9,10 +9,12 @@ import ( "github.com/thomseddon/traefik-forward-auth/internal/provider" ) +// Server contains router and handler methods type Server struct { router *rules.Router } +// NewServer creates a new server object and builds router func NewServer() *Server { s := &Server{} s.buildRoutes() @@ -47,6 +49,8 @@ func (s *Server) buildRoutes() { } } +// RootHandler Overwrites the request method, host and URL with those from the +// forwarded request so it's correctly routed by mux func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { // Modify request r.Method = r.Header.Get("X-Forwarded-Method") @@ -57,7 +61,7 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { s.router.ServeHTTP(w, r) } -// Handler that allows requests +// AllowHandler Allows requests func (s *Server) AllowHandler(rule string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { s.logger(r, rule, "Allowing request") @@ -65,7 +69,7 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc { } } -// Authenticate requests +// AuthHandler Authenticates requests func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { p, _ := config.GetConfiguredProvider(providerName) @@ -110,7 +114,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { } } -// Handle auth callback +// AuthCallbackHandler Handles auth callback request func (s *Server) AuthCallbackHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Logging setup