Refactor logging

Fixes #18
This commit is contained in:
Thom Seddon 2019-01-22 12:40:14 +00:00
parent 1a3a099ac1
commit b3b31e2193
5 changed files with 103 additions and 33 deletions

5
Makefile Normal file
View File

@ -0,0 +1,5 @@
format:
gofmt -w -s *.go
.PHONY: format

View File

@ -36,6 +36,8 @@ The following configuration is supported:
|-lifetime|int|Session length in seconds (default 43200)| |-lifetime|int|Session length in seconds (default 43200)|
|-url-path|string|Callback URL (default "_oauth")| |-url-path|string|Callback URL (default "_oauth")|
|-prompt|string|Space separated list of [OpenID prompt options](https://developers.google.com/identity/protocols/OpenIDConnect#prompt)| |-prompt|string|Space separated list of [OpenID prompt options](https://developers.google.com/identity/protocols/OpenIDConnect#prompt)|
|-log-level|string|Log level: trace, debug, info, warn, error, fatal, panic (default "warn")|
|-log-format|string|Log format: text, json, pretty (default "text")|
Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`) Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`)

View File

@ -22,7 +22,7 @@ type ForwardAuth struct {
Secret []byte Secret []byte
ClientId string ClientId string
ClientSecret string ClientSecret string `json:"-"`
Scope string Scope string
LoginURL *url.URL LoginURL *url.URL

48
log.go Normal file
View File

@ -0,0 +1,48 @@
package main
import (
"os"
"github.com/sirupsen/logrus"
)
func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
// Setup logger
log := logrus.StandardLogger()
logrus.SetOutput(os.Stdout)
// Set logger format
switch logFormat {
case "pretty":
break
case "json":
logrus.SetFormatter(&logrus.JSONFormatter{})
// "text" is the default
default:
logrus.SetFormatter(&logrus.TextFormatter{
DisableColors: true,
FullTimestamp: true,
})
}
// Set logger level
switch logLevel {
case "trace":
logrus.SetLevel(logrus.TraceLevel)
case "debug":
logrus.SetLevel(logrus.DebugLevel)
case "info":
logrus.SetLevel(logrus.InfoLevel)
case "error":
logrus.SetLevel(logrus.ErrorLevel)
case "fatal":
logrus.SetLevel(logrus.FatalLevel)
case "panic":
logrus.SetLevel(logrus.PanicLevel)
// warn is the default
default:
logrus.SetLevel(logrus.WarnLevel)
}
return log
}

79
main.go
View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -8,26 +9,37 @@ import (
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/op/go-logging" "github.com/sirupsen/logrus"
) )
// Vars // Vars
var fw *ForwardAuth var fw *ForwardAuth
var log = logging.MustGetLogger("traefik-forward-auth") var log logrus.FieldLogger
// Primary handler // Primary handler
func handler(w http.ResponseWriter, r *http.Request) { func handler(w http.ResponseWriter, r *http.Request) {
logger := log
if logrus.GetLevel() >= logrus.DebugLevel {
logger = log.WithFields(logrus.Fields{
"RemoteAddr": r.RemoteAddr,
"X-Forwarded-Uri": r.Header.Get("X-Forwarded-Uri"),
})
}
logger.Debugf("Handling request")
// Parse uri // Parse uri
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri")) uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
if err != nil { if err != nil {
log.Error("Error parsing url") logger.Errorf("Error parsing X-Forwarded-Uri, %v", err)
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
// Handle callback // Handle callback
if uri.Path == fw.Path { if uri.Path == fw.Path {
handleCallback(w, r, uri.Query()) logger.Debugf("Passing request to auth callback")
handleCallback(w, r, uri.Query(), logger)
return return
} }
@ -37,14 +49,14 @@ func handler(w http.ResponseWriter, r *http.Request) {
// Error indicates no cookie, generate nonce // Error indicates no cookie, generate nonce
err, nonce := fw.Nonce() err, nonce := fw.Nonce()
if err != nil { if err != nil {
log.Error("Error generating nonce") logger.Errorf("Error generating nonce, %v", err)
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
// Set the CSRF cookie // Set the CSRF cookie
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce)) http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
log.Debug("Set CSRF cookie and redirecting to google login") logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on // Forward them on
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect) http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
@ -55,7 +67,7 @@ func handler(w http.ResponseWriter, r *http.Request) {
// Validate cookie // Validate cookie
valid, email, err := fw.ValidateCookie(r, c) valid, email, err := fw.ValidateCookie(r, c)
if !valid { if !valid {
log.Debugf("Invalid cookie: %s", err) logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
@ -63,22 +75,26 @@ func handler(w http.ResponseWriter, r *http.Request) {
// Validate user // Validate user
valid = fw.ValidateEmail(email) valid = fw.ValidateEmail(email)
if !valid { if !valid {
log.Debugf("Invalid email: %s", email) logger.WithFields(logrus.Fields{
"email": email,
}).Errorf("Invalid email")
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Valid request // Valid request
logger.Debugf("Allowing valid request ")
w.Header().Set("X-Forwarded-User", email) w.Header().Set("X-Forwarded-User", email)
w.WriteHeader(200) w.WriteHeader(200)
} }
// Authenticate user after they have come back from google // Authenticate user after they have come back from google
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) { func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values,
logger logrus.FieldLogger) {
// Check for CSRF cookie // Check for CSRF cookie
csrfCookie, err := r.Cookie(fw.CSRFCookieName) csrfCookie, err := r.Cookie(fw.CSRFCookieName)
if err != nil { if err != nil {
log.Debug("Missing csrf cookie") logger.Warn("Missing csrf cookie")
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
@ -87,7 +103,10 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
state := qs.Get("state") state := qs.Get("state")
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state) valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
if !valid { if !valid {
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state) logger.WithFields(logrus.Fields{
"csrf": csrfCookie.Value,
"state": state,
}).Warnf("CSRF cookie does not match state")
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
@ -98,7 +117,7 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
// Exchange code for token // Exchange code for token
token, err := fw.ExchangeCode(r, qs.Get("code")) token, err := fw.ExchangeCode(r, qs.Get("code"))
if err != nil { if err != nil {
log.Debugf("Code exchange failed with: %s\n", err) logger.Errorf("Code exchange failed with: %v", err)
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
@ -106,13 +125,15 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
// Get user // Get user
user, err := fw.GetUser(token) user, err := fw.GetUser(token)
if err != nil { if err != nil {
log.Debugf("Error getting user: %s\n", err) logger.Errorf("Error getting user: %s", err)
return return
} }
// Generate cookie // Generate cookie
http.SetCookie(w, fw.MakeCookie(r, user.Email)) http.SetCookie(w, fw.MakeCookie(r, user.Email))
log.Debugf("Generated auth cookie for %s\n", user.Email) logger.WithFields(logrus.Fields{
"user": user.Email,
}).Infof("Generated auth cookie")
// Redirect // Redirect
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
@ -136,30 +157,22 @@ func main() {
domainList := flag.String("domain", "", "Comma separated list of email domains to allow") domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow") emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options") prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
logLevel := flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
logFormat := flag.String("log-format", "text", "Log format: text, json, pretty")
flag.Parse() flag.Parse()
// Setup logger
log = CreateLogger(*logLevel, *logFormat)
// Backwards compatability // Backwards compatability
if *secret == "" && *cookieSecret != "" { if *secret == "" && *cookieSecret != "" {
*secret = *cookieSecret *secret = *cookieSecret
} }
// Check for show stopper errors // Check for show stopper errors
stop := false if *clientId == "" || *clientSecret == "" || *secret == "" {
if *clientId == "" { log.Fatal("client-id, client-secret and secret must all be set")
stop = true
log.Critical("client-id must be set")
}
if *clientSecret == "" {
stop = true
log.Critical("client-secret must be set")
}
if *secret == "" {
stop = true
log.Critical("secret must be set")
}
if stop {
return
} }
// Parse lists // Parse lists
@ -220,7 +233,9 @@ func main() {
// Attach handler // Attach handler
http.HandleFunc("/", handler) http.HandleFunc("/", handler)
log.Debugf("Starting with options: %#v", fw) // Start
log.Notice("Listening on :4181") jsonConf, _ := json.Marshal(fw)
log.Notice(http.ListenAndServe(":4181", nil)) log.Debugf("Starting with options: %s", string(jsonConf))
log.Info("Listening on :4181")
log.Info(http.ListenAndServe(":4181", nil))
} }