parent
1a3a099ac1
commit
b3b31e2193
@ -36,6 +36,8 @@ The following configuration is supported:
|
|||||||
|-lifetime|int|Session length in seconds (default 43200)|
|
|-lifetime|int|Session length in seconds (default 43200)|
|
||||||
|-url-path|string|Callback URL (default "_oauth")|
|
|-url-path|string|Callback URL (default "_oauth")|
|
||||||
|-prompt|string|Space separated list of [OpenID prompt options](https://developers.google.com/identity/protocols/OpenIDConnect#prompt)|
|
|-prompt|string|Space separated list of [OpenID prompt options](https://developers.google.com/identity/protocols/OpenIDConnect#prompt)|
|
||||||
|
|-log-level|string|Log level: trace, debug, info, warn, error, fatal, panic (default "warn")|
|
||||||
|
|-log-format|string|Log format: text, json, pretty (default "text")|
|
||||||
|
|
||||||
Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`)
|
Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ type ForwardAuth struct {
|
|||||||
Secret []byte
|
Secret []byte
|
||||||
|
|
||||||
ClientId string
|
ClientId string
|
||||||
ClientSecret string
|
ClientSecret string `json:"-"`
|
||||||
Scope string
|
Scope string
|
||||||
|
|
||||||
LoginURL *url.URL
|
LoginURL *url.URL
|
||||||
|
48
log.go
Normal file
48
log.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
|
||||||
|
// Setup logger
|
||||||
|
log := logrus.StandardLogger()
|
||||||
|
logrus.SetOutput(os.Stdout)
|
||||||
|
|
||||||
|
// Set logger format
|
||||||
|
switch logFormat {
|
||||||
|
case "pretty":
|
||||||
|
break
|
||||||
|
case "json":
|
||||||
|
logrus.SetFormatter(&logrus.JSONFormatter{})
|
||||||
|
// "text" is the default
|
||||||
|
default:
|
||||||
|
logrus.SetFormatter(&logrus.TextFormatter{
|
||||||
|
DisableColors: true,
|
||||||
|
FullTimestamp: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set logger level
|
||||||
|
switch logLevel {
|
||||||
|
case "trace":
|
||||||
|
logrus.SetLevel(logrus.TraceLevel)
|
||||||
|
case "debug":
|
||||||
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
|
case "info":
|
||||||
|
logrus.SetLevel(logrus.InfoLevel)
|
||||||
|
case "error":
|
||||||
|
logrus.SetLevel(logrus.ErrorLevel)
|
||||||
|
case "fatal":
|
||||||
|
logrus.SetLevel(logrus.FatalLevel)
|
||||||
|
case "panic":
|
||||||
|
logrus.SetLevel(logrus.PanicLevel)
|
||||||
|
// warn is the default
|
||||||
|
default:
|
||||||
|
logrus.SetLevel(logrus.WarnLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return log
|
||||||
|
}
|
79
main.go
79
main.go
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -8,26 +9,37 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/namsral/flag"
|
"github.com/namsral/flag"
|
||||||
"github.com/op/go-logging"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Vars
|
// Vars
|
||||||
var fw *ForwardAuth
|
var fw *ForwardAuth
|
||||||
var log = logging.MustGetLogger("traefik-forward-auth")
|
var log logrus.FieldLogger
|
||||||
|
|
||||||
// Primary handler
|
// Primary handler
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
logger := log
|
||||||
|
if logrus.GetLevel() >= logrus.DebugLevel {
|
||||||
|
logger = log.WithFields(logrus.Fields{
|
||||||
|
"RemoteAddr": r.RemoteAddr,
|
||||||
|
"X-Forwarded-Uri": r.Header.Get("X-Forwarded-Uri"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("Handling request")
|
||||||
|
|
||||||
// Parse uri
|
// Parse uri
|
||||||
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error parsing url")
|
logger.Errorf("Error parsing X-Forwarded-Uri, %v", err)
|
||||||
http.Error(w, "Service unavailable", 503)
|
http.Error(w, "Service unavailable", 503)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle callback
|
// Handle callback
|
||||||
if uri.Path == fw.Path {
|
if uri.Path == fw.Path {
|
||||||
handleCallback(w, r, uri.Query())
|
logger.Debugf("Passing request to auth callback")
|
||||||
|
handleCallback(w, r, uri.Query(), logger)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,14 +49,14 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Error indicates no cookie, generate nonce
|
// Error indicates no cookie, generate nonce
|
||||||
err, nonce := fw.Nonce()
|
err, nonce := fw.Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error generating nonce")
|
logger.Errorf("Error generating nonce, %v", err)
|
||||||
http.Error(w, "Service unavailable", 503)
|
http.Error(w, "Service unavailable", 503)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the CSRF cookie
|
// Set the CSRF cookie
|
||||||
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
||||||
log.Debug("Set CSRF cookie and redirecting to google login")
|
logger.Debug("Set CSRF cookie and redirecting to google login")
|
||||||
|
|
||||||
// Forward them on
|
// Forward them on
|
||||||
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||||
@ -55,7 +67,7 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Validate cookie
|
// Validate cookie
|
||||||
valid, email, err := fw.ValidateCookie(r, c)
|
valid, email, err := fw.ValidateCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
log.Debugf("Invalid cookie: %s", err)
|
logger.Errorf("Invalid cookie: %v", err)
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -63,22 +75,26 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Validate user
|
// Validate user
|
||||||
valid = fw.ValidateEmail(email)
|
valid = fw.ValidateEmail(email)
|
||||||
if !valid {
|
if !valid {
|
||||||
log.Debugf("Invalid email: %s", email)
|
logger.WithFields(logrus.Fields{
|
||||||
|
"email": email,
|
||||||
|
}).Errorf("Invalid email")
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid request
|
// Valid request
|
||||||
|
logger.Debugf("Allowing valid request ")
|
||||||
w.Header().Set("X-Forwarded-User", email)
|
w.Header().Set("X-Forwarded-User", email)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate user after they have come back from google
|
// Authenticate user after they have come back from google
|
||||||
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values,
|
||||||
|
logger logrus.FieldLogger) {
|
||||||
// Check for CSRF cookie
|
// Check for CSRF cookie
|
||||||
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug("Missing csrf cookie")
|
logger.Warn("Missing csrf cookie")
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -87,7 +103,10 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
|||||||
state := qs.Get("state")
|
state := qs.Get("state")
|
||||||
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
||||||
if !valid {
|
if !valid {
|
||||||
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state)
|
logger.WithFields(logrus.Fields{
|
||||||
|
"csrf": csrfCookie.Value,
|
||||||
|
"state": state,
|
||||||
|
}).Warnf("CSRF cookie does not match state")
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -98,7 +117,7 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
|||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Code exchange failed with: %s\n", err)
|
logger.Errorf("Code exchange failed with: %v", err)
|
||||||
http.Error(w, "Service unavailable", 503)
|
http.Error(w, "Service unavailable", 503)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -106,13 +125,15 @@ func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
|||||||
// Get user
|
// Get user
|
||||||
user, err := fw.GetUser(token)
|
user, err := fw.GetUser(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Error getting user: %s\n", err)
|
logger.Errorf("Error getting user: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate cookie
|
// Generate cookie
|
||||||
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
||||||
log.Debugf("Generated auth cookie for %s\n", user.Email)
|
logger.WithFields(logrus.Fields{
|
||||||
|
"user": user.Email,
|
||||||
|
}).Infof("Generated auth cookie")
|
||||||
|
|
||||||
// Redirect
|
// Redirect
|
||||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||||
@ -136,30 +157,22 @@ func main() {
|
|||||||
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
||||||
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
||||||
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
||||||
|
logLevel := flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
|
||||||
|
logFormat := flag.String("log-format", "text", "Log format: text, json, pretty")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
// Setup logger
|
||||||
|
log = CreateLogger(*logLevel, *logFormat)
|
||||||
|
|
||||||
// Backwards compatability
|
// Backwards compatability
|
||||||
if *secret == "" && *cookieSecret != "" {
|
if *secret == "" && *cookieSecret != "" {
|
||||||
*secret = *cookieSecret
|
*secret = *cookieSecret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for show stopper errors
|
// Check for show stopper errors
|
||||||
stop := false
|
if *clientId == "" || *clientSecret == "" || *secret == "" {
|
||||||
if *clientId == "" {
|
log.Fatal("client-id, client-secret and secret must all be set")
|
||||||
stop = true
|
|
||||||
log.Critical("client-id must be set")
|
|
||||||
}
|
|
||||||
if *clientSecret == "" {
|
|
||||||
stop = true
|
|
||||||
log.Critical("client-secret must be set")
|
|
||||||
}
|
|
||||||
if *secret == "" {
|
|
||||||
stop = true
|
|
||||||
log.Critical("secret must be set")
|
|
||||||
}
|
|
||||||
if stop {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse lists
|
// Parse lists
|
||||||
@ -220,7 +233,9 @@ func main() {
|
|||||||
// Attach handler
|
// Attach handler
|
||||||
http.HandleFunc("/", handler)
|
http.HandleFunc("/", handler)
|
||||||
|
|
||||||
log.Debugf("Starting with options: %#v", fw)
|
// Start
|
||||||
log.Notice("Listening on :4181")
|
jsonConf, _ := json.Marshal(fw)
|
||||||
log.Notice(http.ListenAndServe(":4181", nil))
|
log.Debugf("Starting with options: %s", string(jsonConf))
|
||||||
|
log.Info("Listening on :4181")
|
||||||
|
log.Info(http.ListenAndServe(":4181", nil))
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user