Improve internal function docs
This commit is contained in:
parent
f7a94e7db9
commit
7381450015
@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
// Request Validation
|
// Request Validation
|
||||||
|
|
||||||
|
// ValidateCookie verifies that a cookie matches the expected format of:
|
||||||
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
||||||
func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
|
func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
|
||||||
parts := strings.Split(c.Value, "|")
|
parts := strings.Split(c.Value, "|")
|
||||||
@ -55,7 +56,7 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
|
|||||||
return parts[2], nil
|
return parts[2], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate email
|
// ValidateEmail verifies that an email is permitted by the current config
|
||||||
func ValidateEmail(email string) bool {
|
func ValidateEmail(email string) bool {
|
||||||
found := false
|
found := false
|
||||||
if len(config.Whitelist) > 0 {
|
if len(config.Whitelist) > 0 {
|
||||||
@ -126,7 +127,7 @@ func useAuthDomain(r *http.Request) (bool, string) {
|
|||||||
|
|
||||||
// Cookie methods
|
// Cookie methods
|
||||||
|
|
||||||
// Create an auth cookie
|
// MakeCookie creates an auth cookie
|
||||||
func MakeCookie(r *http.Request, email string) *http.Cookie {
|
func MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||||
expires := cookieExpiry()
|
expires := cookieExpiry()
|
||||||
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
||||||
@ -143,7 +144,7 @@ func MakeCookie(r *http.Request, email string) *http.Cookie {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a CSRF cookie (used during login only)
|
// MakeCSRFCookie makes a csrf cookie (used during login only)
|
||||||
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: config.CSRFCookieName,
|
Name: config.CSRFCookieName,
|
||||||
@ -156,7 +157,7 @@ func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a cookie to clear csrf cookie
|
// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
|
||||||
func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: config.CSRFCookieName,
|
Name: config.CSRFCookieName,
|
||||||
@ -169,7 +170,7 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the csrf cookie against state
|
// ValidateCSRFCookie validates the csrf cookie against state
|
||||||
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
|
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
@ -197,12 +198,13 @@ func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider s
|
|||||||
return true, params[:split], params[split+1:], nil
|
return true, params[:split], params[split+1:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MakeState generates a state value
|
||||||
func MakeState(r *http.Request, p provider.Provider, nonce string) string {
|
func MakeState(r *http.Request, p provider.Provider, nonce string) string {
|
||||||
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
|
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Nonce generates a random nonce
|
||||||
func Nonce() (error, string) {
|
func Nonce() (error, string) {
|
||||||
// Make nonce
|
|
||||||
nonce := make([]byte, 16)
|
nonce := make([]byte, 16)
|
||||||
_, err := rand.Read(nonce)
|
_, err := rand.Read(nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -263,9 +265,7 @@ func cookieExpiry() time.Time {
|
|||||||
return time.Now().Local().Add(config.Lifetime)
|
return time.Now().Local().Add(config.Lifetime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookie Domain
|
// CookieDomain holds cookie domain info
|
||||||
|
|
||||||
// Cookie Domain
|
|
||||||
type CookieDomain struct {
|
type CookieDomain struct {
|
||||||
Domain string
|
Domain string
|
||||||
DomainLen int
|
DomainLen int
|
||||||
@ -273,6 +273,7 @@ type CookieDomain struct {
|
|||||||
SubDomainLen int
|
SubDomainLen int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewCookieDomain creates a new CookieDomain from the given domain string
|
||||||
func NewCookieDomain(domain string) *CookieDomain {
|
func NewCookieDomain(domain string) *CookieDomain {
|
||||||
return &CookieDomain{
|
return &CookieDomain{
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
@ -282,6 +283,7 @@ func NewCookieDomain(domain string) *CookieDomain {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Match checks if the given host matches this CookieDomain
|
||||||
func (c *CookieDomain) Match(host string) bool {
|
func (c *CookieDomain) Match(host string) bool {
|
||||||
// Exact domain match?
|
// Exact domain match?
|
||||||
if host == c.Domain {
|
if host == c.Domain {
|
||||||
@ -296,19 +298,22 @@ func (c *CookieDomain) Match(host string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalFlag converts a string to a CookieDomain
|
||||||
func (c *CookieDomain) UnmarshalFlag(value string) error {
|
func (c *CookieDomain) UnmarshalFlag(value string) error {
|
||||||
*c = *NewCookieDomain(value)
|
*c = *NewCookieDomain(value)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalFlag converts a CookieDomain to a string
|
||||||
func (c *CookieDomain) MarshalFlag() (string, error) {
|
func (c *CookieDomain) MarshalFlag() (string, error) {
|
||||||
return c.Domain, nil
|
return c.Domain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Legacy support for comma separated list of cookie domains
|
// CookieDomains provides legacy sypport for comma separated list of cookie domains
|
||||||
|
|
||||||
type CookieDomains []CookieDomain
|
type CookieDomains []CookieDomain
|
||||||
|
|
||||||
|
// UnmarshalFlag converts a comma separated list of cookie domains to an array
|
||||||
|
// of CookieDomains
|
||||||
func (c *CookieDomains) UnmarshalFlag(value string) error {
|
func (c *CookieDomains) UnmarshalFlag(value string) error {
|
||||||
if len(value) > 0 {
|
if len(value) > 0 {
|
||||||
for _, d := range strings.Split(value, ",") {
|
for _, d := range strings.Split(value, ",") {
|
||||||
@ -319,6 +324,7 @@ func (c *CookieDomains) UnmarshalFlag(value string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalFlag converts an array of CookieDomain to a comma seperated list
|
||||||
func (c *CookieDomains) MarshalFlag() (string, error) {
|
func (c *CookieDomains) MarshalFlag() (string, error) {
|
||||||
var domains []string
|
var domains []string
|
||||||
for _, d := range *c {
|
for _, d := range *c {
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
var config *Config
|
var config *Config
|
||||||
|
|
||||||
|
// Config holds the runtime application config
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
|
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
|
||||||
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
|
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
|
||||||
@ -53,6 +54,7 @@ type Config struct {
|
|||||||
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
|
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewGlobalConfig creates a new global config, parsed from command arguments
|
||||||
func NewGlobalConfig() *Config {
|
func NewGlobalConfig() *Config {
|
||||||
var err error
|
var err error
|
||||||
config, err = NewConfig(os.Args[1:])
|
config, err = NewConfig(os.Args[1:])
|
||||||
@ -66,6 +68,7 @@ func NewGlobalConfig() *Config {
|
|||||||
|
|
||||||
// TODO: move config parsing into new func "NewParsedConfig"
|
// TODO: move config parsing into new func "NewParsedConfig"
|
||||||
|
|
||||||
|
// NewConfig parses and validates provided configuration into a config object
|
||||||
func NewConfig(args []string) (*Config, error) {
|
func NewConfig(args []string) (*Config, error) {
|
||||||
c := &Config{
|
c := &Config{
|
||||||
Rules: map[string]*Rule{},
|
Rules: map[string]*Rule{},
|
||||||
@ -236,6 +239,7 @@ func convertLegacyToIni(name string) (io.Reader, error) {
|
|||||||
return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil
|
return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate validates a config object
|
||||||
func (c *Config) Validate() {
|
func (c *Config) Validate() {
|
||||||
// Check for show stopper errors
|
// Check for show stopper errors
|
||||||
if len(c.Secret) == 0 {
|
if len(c.Secret) == 0 {
|
||||||
@ -317,12 +321,14 @@ func (c *Config) setupProvider(name string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rule holds defined rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
Action string
|
Action string
|
||||||
Rule string
|
Rule string
|
||||||
Provider string
|
Provider string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewRule creates a new rule object
|
||||||
func NewRule() *Rule {
|
func NewRule() *Rule {
|
||||||
return &Rule{
|
return &Rule{
|
||||||
Action: "auth",
|
Action: "auth",
|
||||||
@ -335,6 +341,7 @@ func (r *Rule) formattedRule() string {
|
|||||||
return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(")
|
return strings.ReplaceAll(r.Rule, "Host(", "HostRegexp(")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate validates a rule
|
||||||
func (r *Rule) Validate(c *Config) error {
|
func (r *Rule) Validate(c *Config) error {
|
||||||
if r.Action != "auth" && r.Action != "allow" {
|
if r.Action != "auth" && r.Action != "allow" {
|
||||||
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
|
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
|
||||||
@ -345,13 +352,16 @@ func (r *Rule) Validate(c *Config) error {
|
|||||||
|
|
||||||
// Legacy support for comma separated lists
|
// Legacy support for comma separated lists
|
||||||
|
|
||||||
|
// CommaSeparatedList provides legacy support for config values provided as csv
|
||||||
type CommaSeparatedList []string
|
type CommaSeparatedList []string
|
||||||
|
|
||||||
|
// UnmarshalFlag converts a comma separated list to an array
|
||||||
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
||||||
*c = append(*c, strings.Split(value, ",")...)
|
*c = append(*c, strings.Split(value, ",")...)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalFlag converts an array back to a comma separated list
|
||||||
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
|
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
|
||||||
return strings.Join(*c, ","), nil
|
return strings.Join(*c, ","), nil
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
var log *logrus.Logger
|
var log *logrus.Logger
|
||||||
|
|
||||||
|
// NewDefaultLogger creates a new logger based on the current configuration
|
||||||
func NewDefaultLogger() *logrus.Logger {
|
func NewDefaultLogger() *logrus.Logger {
|
||||||
// Setup logger
|
// Setup logger
|
||||||
log = logrus.StandardLogger()
|
log = logrus.StandardLogger()
|
||||||
|
@ -9,10 +9,12 @@ import (
|
|||||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Server contains router and handler methods
|
||||||
type Server struct {
|
type Server struct {
|
||||||
router *rules.Router
|
router *rules.Router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new server object and builds router
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
s := &Server{}
|
s := &Server{}
|
||||||
s.buildRoutes()
|
s.buildRoutes()
|
||||||
@ -47,6 +49,8 @@ func (s *Server) buildRoutes() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RootHandler Overwrites the request method, host and URL with those from the
|
||||||
|
// forwarded request so it's correctly routed by mux
|
||||||
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// Modify request
|
// Modify request
|
||||||
r.Method = r.Header.Get("X-Forwarded-Method")
|
r.Method = r.Header.Get("X-Forwarded-Method")
|
||||||
@ -57,7 +61,7 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.router.ServeHTTP(w, r)
|
s.router.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler that allows requests
|
// AllowHandler Allows requests
|
||||||
func (s *Server) AllowHandler(rule string) http.HandlerFunc {
|
func (s *Server) AllowHandler(rule string) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
s.logger(r, rule, "Allowing request")
|
s.logger(r, rule, "Allowing request")
|
||||||
@ -65,7 +69,7 @@ func (s *Server) AllowHandler(rule string) http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate requests
|
// AuthHandler Authenticates requests
|
||||||
func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
|
func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
|
||||||
p, _ := config.GetConfiguredProvider(providerName)
|
p, _ := config.GetConfiguredProvider(providerName)
|
||||||
|
|
||||||
@ -110,7 +114,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle auth callback
|
// AuthCallbackHandler Handles auth callback request
|
||||||
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Logging setup
|
// Logging setup
|
||||||
|
Loading…
x
Reference in New Issue
Block a user