Refactor progress

- move directory structure
- string based rule definition
- use traefik rule parsing
- drop toml config
- new flag library
- implement go dep
This commit is contained in:
Thom Seddon
2019-04-12 16:12:13 +01:00
parent d51b93d4b0
commit 9abe509f66
18 changed files with 547 additions and 464 deletions

329
internal/auth.go Normal file
View File

@ -0,0 +1,329 @@
package tfa
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
// Request Validation
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
func ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
parts := strings.Split(c.Value, "|")
if len(parts) != 3 {
return false, "", errors.New("Invalid cookie format")
}
mac, err := base64.URLEncoding.DecodeString(parts[0])
if err != nil {
return false, "", errors.New("Unable to decode cookie mac")
}
expectedSignature := cookieSignature(r, parts[2], parts[1])
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil {
return false, "", errors.New("Unable to generate mac")
}
// Valid token?
if !hmac.Equal(mac, expected) {
return false, "", errors.New("Invalid cookie mac")
}
expires, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return false, "", errors.New("Unable to parse cookie expiry")
}
// Has it expired?
if time.Unix(expires, 0).Before(time.Now()) {
return false, "", errors.New("Cookie has expired")
}
// Looks valid
return true, parts[2], nil
}
// Validate email
func ValidateEmail(email string) bool {
found := false
if len(config.Whitelist) > 0 {
for _, whitelist := range config.Whitelist {
if email == whitelist {
found = true
}
}
} else if len(config.Domains) > 0 {
parts := strings.Split(email, "@")
if len(parts) < 2 {
return false
}
for _, domain := range config.Domains {
if domain == parts[1] {
found = true
}
}
} else {
return true
}
return found
}
// OAuth Methods
// Get login url
func GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, returnUrl(r))
// TODO: Support multiple providers
return config.Providers.Google.GetLoginURL(redirectUri(r), state)
}
// Exchange code for token
func ExchangeCode(r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
// TODO: Support multiple providers
return config.Providers.Google.ExchangeCode(redirectUri(r), code)
}
// Get user with token
func GetUser(token string) (provider.User, error) {
// TODO: Support multiple providers
return config.Providers.Google.GetUser(token)
}
// Utility methods
// Get the redirect base
func redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")
return fmt.Sprintf("%s://%s", proto, host)
}
// // Return url
func returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri")
return fmt.Sprintf("%s%s", redirectBase(r), path)
}
// Get oauth redirect uri
func redirectUri(r *http.Request) string {
if use, _ := useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
}
return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
}
// Should we use auth host + what it is
func useAuthDomain(r *http.Request) (bool, string) {
if config.AuthHost == "" {
return false, ""
}
// Does the request match a given cookie domain?
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
// Do any of the auth hosts match a cookie domain?
authMatch, authHost := matchCookieDomains(config.AuthHost)
// We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost
}
// Cookie methods
// Create an auth cookie
func MakeCookie(r *http.Request, email string) *http.Cookie {
expires := cookieExpiry()
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
return &http.Cookie{
Name: config.CookieName,
Value: value,
Path: "/",
Domain: cookieDomain(r),
HttpOnly: true,
Secure: !config.CookieInsecure,
Expires: expires,
}
}
// Make a CSRF cookie (used during login only)
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Value: nonce,
Path: "/",
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.CookieInsecure,
Expires: cookieExpiry(),
}
}
// Create a cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Value: "",
Path: "/",
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.CookieInsecure,
Expires: time.Now().Local().Add(time.Hour * -1),
}
}
// Validate the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
state := r.URL.Query().Get("state")
if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value")
}
if len(state) < 34 {
return false, "", errors.New("Invalid CSRF state value")
}
// Check nonce match
if c.Value != state[:32] {
return false, "", errors.New("CSRF cookie does not match state")
}
// Valid, return redirect
return true, state[33:], nil
}
func Nonce() (error, string) {
// Make nonce
nonce := make([]byte, 16)
_, err := rand.Read(nonce)
if err != nil {
return err, ""
}
return nil, fmt.Sprintf("%x", nonce)
}
// Cookie domain
func cookieDomain(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")
// Check if any of the given cookie domains matches
_, domain := matchCookieDomains(host)
return domain
}
// Cookie domain
func csrfCookieDomain(r *http.Request) string {
var host string
if use, domain := useAuthDomain(r); use {
host = domain
} else {
host = r.Header.Get("X-Forwarded-Host")
}
// Remove port
p := strings.Split(host, ":")
return p[0]
}
// Return matching cookie domain if exists
func matchCookieDomains(domain string) (bool, string) {
// Remove port
p := strings.Split(domain, ":")
for _, d := range config.CookieDomains {
if d.Match(p[0]) {
return true, d.Domain
}
}
return false, p[0]
}
// Create cookie hmac
func cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, config.Secret)
hash.Write([]byte(cookieDomain(r)))
hash.Write([]byte(email))
hash.Write([]byte(expires))
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
}
// Get cookie expirary
func cookieExpiry() time.Time {
return time.Now().Local().Add(config.Lifetime)
}
// Cookie Domain
// Cookie Domain
type CookieDomain struct {
Domain string `description:"TEST1"`
DomainLen int `description:"TEST2"`
SubDomain string `description:"TEST3"`
SubDomainLen int `description:"TEST4"`
}
func NewCookieDomain(domain string) *CookieDomain {
return &CookieDomain{
Domain: domain,
DomainLen: len(domain),
SubDomain: fmt.Sprintf(".%s", domain),
SubDomainLen: len(domain) + 1,
}
}
func (c *CookieDomain) Match(host string) bool {
// Exact domain match?
if host == c.Domain {
return true
}
// Subdomain match?
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
return true
}
return false
}
type CookieDomains []CookieDomain
func (c *CookieDomains) UnmarshalFlag(value string) error {
// TODO: test
if len(value) > 0 {
for _, d := range strings.Split(value, ",") {
cookieDomain := NewCookieDomain(d)
*c = append(*c, *cookieDomain)
}
}
return nil
}
func (c *CookieDomains) MarshalFlag() (string, error) {
return fmt.Sprintf("%+v", *c), nil
}

452
internal/auth_test.go Normal file
View File

@ -0,0 +1,452 @@
package tfa
import (
"fmt"
"net/http"
"net/url"
"reflect"
"testing"
"time"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
/**
* Setup
*/
func init() {
// fw = &ForwardAuth{}
}
/**
* Tests
*/
func TestValidateCookie(t *testing.T) {
config = Config{}
r, _ := http.NewRequest("GET", "http://example.com", nil)
c := &http.Cookie{}
// Should require 3 parts
c.Value = ""
valid, _, err := ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err)
}
c.Value = "1|2"
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err)
}
c.Value = "1|2|3|4"
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err)
}
// Should catch invalid mac
c.Value = "MQ==|2|3"
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie mac" {
t.Error("Should get \"Invalid cookie mac\", got:", err)
}
// Should catch expired
config.Lifetime = time.Second * time.Duration(-1)
c = MakeCookie(r, "test@test.com")
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Cookie has expired" {
t.Error("Should get \"Cookie has expired\", got:", err)
}
// Should accept valid cookie
config.Lifetime = time.Second * time.Duration(10)
c = MakeCookie(r, "test@test.com")
valid, email, err := ValidateCookie(r, c)
if !valid {
t.Error("Valid request should return as valid")
}
if err != nil {
t.Error("Valid request should not return error, got:", err)
}
if email != "test@test.com" {
t.Error("Valid request should return user email")
}
}
func TestValidateEmail(t *testing.T) {
config = Config{}
// Should allow any
if !ValidateEmail("test@test.com") || !ValidateEmail("one@two.com") {
t.Error("Should allow any domain if email domain is not defined")
}
// Should block non matching domain
config.Domains = []string{"test.com"}
if ValidateEmail("one@two.com") {
t.Error("Should not allow user from another domain")
}
// Should allow matching domain
config.Domains = []string{"test.com"}
if !ValidateEmail("test@test.com") {
t.Error("Should allow user from allowed domain")
}
// Should block non whitelisted email address
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
if ValidateEmail("one@two.com") {
t.Error("Should not allow user not in whitelist.")
}
// Should allow matching whitelisted email address
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
if !ValidateEmail("test@test.com") {
t.Error("Should allow user in whitelist.")
}
}
func TestGetLoginURL(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
config = Config{
Path: "/_oauth",
Providers: provider.Providers{
Google: provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
},
},
}
// Check url
uri, err := url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
}
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"},
}
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
//
// With Auth URL but no matching cookie domain
// - will not use auth host
//
config = Config{
Path: "/_oauth",
AuthHost: "auth.example.com",
Providers: provider.Providers{
Google: provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
},
},
}
// Check url
uri, err = url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
}
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"nonce:http://example.com/hello"},
}
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
//
// With correct Auth URL + cookie domain
//
cookieDomain := NewCookieDomain("example.com")
config = Config{
Path: "/_oauth",
AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*cookieDomain},
Providers: provider.Providers{
Google: provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
},
},
}
// Check url
uri, err = url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
}
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://auth.example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"},
"prompt": []string{"consent select_account"},
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
//
// With Auth URL + cookie domain, but from different domain
// - will not use auth host
//
r, _ = http.NewRequest("GET", "http://another.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "another.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
// Check url
uri, err = url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
}
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://another.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://another.com/hello"},
"prompt": []string{"consent select_account"},
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
}
// TODO
// func TestExchangeCode(t *testing.T) {
// }
// TODO
// func TestGetUser(t *testing.T) {
// }
// TODO? Tested in TestValidateCookie
// func TestMakeCookie(t *testing.T) {
// }
func TestMakeCSRFCookie(t *testing.T) {
config = Config{}
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
r.Header.Add("X-Forwarded-Host", "app.example.com")
// No cookie domain or auth url
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
// With cookie domain but no auth url
cookieDomain := NewCookieDomain("example.com")
config = Config{
CookieDomains: []CookieDomain{*cookieDomain},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
// With cookie domain and auth url
config = Config{
AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*cookieDomain},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
}
func TestClearCSRFCookie(t *testing.T) {
config = Config{}
r, _ := http.NewRequest("GET", "http://example.com", nil)
c := ClearCSRFCookie(r)
if c.Value != "" {
t.Error("ClearCSRFCookie should create cookie with empty value")
}
}
func TestValidateCSRFCookie(t *testing.T) {
config = Config{}
c := &http.Cookie{}
newCsrfRequest := func(state string) *http.Request {
u := fmt.Sprintf("http://example.com?state=%s", state)
r, _ := http.NewRequest("GET", u, nil)
return r
}
// Should require 32 char string
r := newCsrfRequest("")
c.Value = ""
valid, _, err := ValidateCSRFCookie(r, c)
if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
}
c.Value = "123456789012345678901234567890123"
valid, _, err = ValidateCSRFCookie(r, c)
if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
}
// Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012"
valid, _, err = ValidateCSRFCookie(r, c)
if valid || err.Error() != "Invalid CSRF state value" {
t.Error("Should get \"Invalid CSRF state value\", got:", err)
}
// Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:99")
c.Value = "12345678901234567890123456789012"
valid, state, err := ValidateCSRFCookie(r, c)
if !valid {
t.Error("Valid request should return as valid")
}
if err != nil {
t.Error("Valid request should not return error, got:", err)
}
if state != "99" {
t.Error("Valid request should return correct state, got:", state)
}
}
func TestNonce(t *testing.T) {
err, nonce1 := Nonce()
if err != nil {
t.Error("Error generation nonce:", err)
}
err, nonce2 := Nonce()
if err != nil {
t.Error("Error generation nonce:", err)
}
if len(nonce1) != 32 || len(nonce2) != 32 {
t.Error("Nonce should be 32 chars")
}
if nonce1 == nonce2 {
t.Error("Nonce should not be equal")
}
}
func TestCookieDomainMatch(t *testing.T) {
cd := NewCookieDomain("example.com")
// Exact should match
if !cd.Match("example.com") {
t.Error("Exact domain should match")
}
// Subdomain should match
if !cd.Match("test.example.com") {
t.Error("Subdomain should match")
}
// Derived domain should not match
if cd.Match("testexample.com") {
t.Error("Derived domain should not match")
}
// Other domain should not match
if cd.Match("test.com") {
t.Error("Other domain should not match")
}
}

146
internal/config.go Normal file
View File

@ -0,0 +1,146 @@
package tfa
import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/jessevdk/go-flags"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
type Config struct {
LogLevel string `long:"log-level" default:"warn" description:"Log level: trace, debug, info, warn, error, fatal, panic"`
LogFormat string `long:"log-format" default:"text" description:"Log format: text, json, pretty"`
AuthHost string `long:"auth-host" description:"Host for central auth login"`
ConfigFile string `long:"config-file" description:"Config File"`
CookieDomains CookieDomains `long:"cookie-domains" description:"Comma separated list of cookie domains"`
CookieInsecure bool `long:"cookie-insecure" description:"Use secure cookies"`
CookieName string `long:"cookie-name" default:"_forward_auth" description:"Cookie Name"`
CSRFCookieName string `long:"csrf-cookie-name" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
DefaultAction string `long:"default-action" default:"allow" description:"Default Action"`
Domains CommaSeparatedList `long:"domains" description:"Comma separated list of email domains to allow"`
LifetimeString int `long:"lifetime" default:"43200" description:"Lifetime in seconds"`
Path string `long:"path" default:"_oauth" description:"Callback URL Path"`
SecretString string `long:"secret" description:"*Secret used for signing (required)"`
Whitelist CommaSeparatedList `long:"whitelist" description:"Comma separated list of email addresses to allow"`
Providers provider.Providers
Rules []Rule `long:"rule"`
Secret []byte
Lifetime time.Duration
Prompt string `long:"prompt" description:"DEPRECATED - Use providers.google.prompt"`
// TODO: Need to mimick the default behaviour of bool flags
CookieSecure string `long:"cookie-secure" default:"true" description:"DEPRECATED - Use \"cookie-insecure\""`
flags []string
usingToml bool
}
type CommaSeparatedList []string
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
*c = strings.Split(value, ",")
return nil
}
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
return strings.Join(*c, ","), nil
}
type Rule struct {
Action string
Rule string
}
func (r *Rule) UnmarshalFlag(value string) error {
// Format is "action:rule"
parts := strings.SplitN(value, ":", 2)
if len(parts) != 2 {
return errors.New("Invalid rule format, should be \"action:rule\"")
}
if parts[0] != "auth" && parts[0] != "allow" {
return errors.New("Invalid rule action, must be \"auth\" or \"allow\"")
}
// Parse rule
*r = Rule{
Action: parts[0],
Rule: parts[1],
}
return nil
}
func (r *Rule) MarshalFlag() (string, error) {
// TODO: format correctly
return fmt.Sprintf("%+v", *r), nil
}
var config Config
// TODO:
// - parse ini
// - parse env vars
// - parse env var file
// - support multiple config files
// - maintain backwards compat
func NewGlobalConfig() Config {
return NewGlobalConfigWithArgs(os.Args[1:])
}
func NewGlobalConfigWithArgs(args []string) Config {
config = Config{}
config.parseFlags(args)
// Struct defaults
config.Providers.Google.Build()
// Transformations
config.Path = fmt.Sprintf("/%s", config.Path)
config.Secret = []byte(config.SecretString)
config.Lifetime = time.Second * time.Duration(config.LifetimeString)
// TODO: Backwards compatability
// "secret" used to be "cookie-secret"
return config
}
func (c *Config) parseFlags(args []string) {
if _, err := flags.ParseArgs(c, args); err != nil {
flagsErr, ok := err.(*flags.Error)
if ok && flagsErr.Type == flags.ErrHelp {
os.Exit(0)
} else {
fmt.Printf("%+v", err)
os.Exit(1)
}
}
}
func (c *Config) Checks() {
// Check for show stopper errors
if len(c.Secret) == 0 {
log.Fatal("\"secret\" option must be set.")
}
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
log.Fatal("google.providers.client-id, google.providers.client-secret must be set")
}
}
func (c Config) Serialise() string {
jsonConf, _ := json.Marshal(c)
return string(jsonConf)
}

81
internal/config_test.go Normal file
View File

@ -0,0 +1,81 @@
package tfa
import (
"testing"
"time"
// "github.com/jessevdk/go-flags"
// "github.com/sirupsen/logrus"
)
/**
* Tests
*/
func TestConfigDefaults(t *testing.T) {
// Check defaults
c := NewGlobalConfigWithArgs([]string{})
if c.LogLevel != "warn" {
t.Error("LogLevel default should be warn, got", c.LogLevel)
}
if c.LogFormat != "text" {
t.Error("LogFormat default should be text, got", c.LogFormat)
}
if c.AuthHost != "" {
t.Error("AuthHost default should be empty, got", c.AuthHost)
}
if c.ConfigFile != "" {
t.Error("ConfigFile default should be empty, got", c.ConfigFile)
}
if len(c.CookieDomains) != 0 {
t.Error("CookieDomains default should be empty, got", c.CookieDomains)
}
if c.CookieInsecure != false {
t.Error("CookieInsecure default should be false, got", c.CookieInsecure)
}
if c.CookieName != "_forward_auth" {
t.Error("CookieName default should be _forward_auth, got", c.CookieName)
}
if c.CSRFCookieName != "_forward_auth_csrf" {
t.Error("CSRFCookieName default should be _forward_auth_csrf, got", c.CSRFCookieName)
}
if c.DefaultAction != "allow" {
t.Error("DefaultAction default should be allow, got", c.DefaultAction)
}
if len(c.Domains) != 0 {
t.Error("Domain default should be empty, got", c.Domains)
}
if c.Lifetime != time.Second*time.Duration(43200) {
t.Error("Lifetime default should be 43200, got", c.Lifetime)
}
if c.Path != "/_oauth" {
t.Error("Path default should be /_oauth, got", c.Path)
}
if len(c.Whitelist) != 0 {
t.Error("Whitelist default should be empty, got", c.Whitelist)
}
if c.Providers.Google.Prompt != "" {
t.Error("Providers.Google.Prompt default should be empty, got", c.Providers.Google.Prompt)
}
// Deprecated options
if c.CookieSecure != "true" {
t.Error("CookieSecure default should be true, got", c.CookieSecure)
}
}
// func TestConfigToml(t *testing.T) {
// logrus.SetLevel(logrus.DebugLevel)
// flag.CommandLine = flag.NewFlagSet("tfa-test", flag.ContinueOnError)
// flags := []string{
// "-config=../test/config.toml",
// }
// c := NewDefaultConfigWithFlags(flags)
// if c == nil {
// t.Error(c)
// }
// }

50
internal/log.go Normal file
View File

@ -0,0 +1,50 @@
package tfa
import (
"os"
"github.com/sirupsen/logrus"
)
var log logrus.FieldLogger
func NewDefaultLogger() logrus.FieldLogger {
// Setup logger
log = logrus.StandardLogger()
logrus.SetOutput(os.Stdout)
// Set logger format
switch config.LogFormat {
case "pretty":
break
case "json":
logrus.SetFormatter(&logrus.JSONFormatter{})
// "text" is the default
default:
logrus.SetFormatter(&logrus.TextFormatter{
DisableColors: true,
FullTimestamp: true,
})
}
// Set logger level
switch config.LogLevel {
case "trace":
logrus.SetLevel(logrus.TraceLevel)
case "debug":
logrus.SetLevel(logrus.DebugLevel)
case "info":
logrus.SetLevel(logrus.InfoLevel)
case "error":
logrus.SetLevel(logrus.ErrorLevel)
case "fatal":
logrus.SetLevel(logrus.FatalLevel)
case "panic":
logrus.SetLevel(logrus.PanicLevel)
// warn is the default
default:
logrus.SetLevel(logrus.WarnLevel)
}
return log
}

View File

@ -0,0 +1,96 @@
package provider
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
)
type Google struct {
ClientId string `long:"providers.google.client-id" description:"Client ID"`
ClientSecret string `long:"providers.google.client-secret" description:"Client Secret" json:"-"`
Scope string
Prompt string `long:"providers.google.prompt" description:"Space separated list of OpenID prompt options"`
LoginURL *url.URL
TokenURL *url.URL
UserURL *url.URL
}
func (g *Google) Build() {
g.LoginURL = &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
g.TokenURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
g.UserURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
}
func (g *Google) GetLoginURL(redirectUri, state string) string {
q := url.Values{}
q.Set("client_id", g.ClientId)
q.Set("response_type", "code")
q.Set("scope", g.Scope)
if g.Prompt != "" {
q.Set("prompt", g.Prompt)
}
q.Set("redirect_uri", redirectUri)
q.Set("state", state)
var u url.URL
u = *g.LoginURL
u.RawQuery = q.Encode()
return u.String()
}
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
form := url.Values{}
form.Set("client_id", g.ClientId)
form.Set("client_secret", g.ClientSecret)
form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", redirectUri)
form.Set("code", code)
res, err := http.PostForm(g.TokenURL.String(), form)
if err != nil {
return "", err
}
var token Token
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
}
func (g *Google) GetUser(token string) (User, error) {
var user User
client := &http.Client{}
req, err := http.NewRequest("GET", g.UserURL.String(), nil)
if err != nil {
return user, err
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user)
return user, err
}

View File

@ -0,0 +1,16 @@
package provider
type Providers struct {
Google Google `group:"Google Provider"`
}
type Token struct {
Token string `json:"access_token"`
}
type User struct {
Id string `json:"id"`
Email string `json:"email"`
Verified bool `json:"verified_email"`
Hd string `json:"hd"`
}

175
internal/server.go Normal file
View File

@ -0,0 +1,175 @@
package tfa
import (
"net/http"
"net/url"
"github.com/containous/traefik/pkg/rules"
"github.com/sirupsen/logrus"
)
type Server struct {
router *rules.Router
}
func NewServer() *Server {
s := &Server{}
s.buildRoutes()
return s
}
func (s *Server) buildRoutes() {
var err error
s.router, err = rules.NewRouter()
if err != nil {
log.Fatal(err)
}
// Let's build a router
for _, rule := range config.Rules {
if rule.Action == "allow" {
s.router.AddRoute(rule.Rule, 1, s.AllowHandler())
} else {
s.router.AddRoute(rule.Rule, 1, s.AuthHandler())
}
}
// Add callback handler
s.router.Handle(config.Path, s.AuthCallbackHandler())
// Add a default handler
s.router.NewRoute().Handler(s.AuthHandler())
}
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
// Modify request
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
// Pass to mux
s.router.ServeHTTP(w, r)
}
// Handler that allows requests
func (s *Server) AllowHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.logger(r, "Allowing request")
w.WriteHeader(200)
}
}
// Authenticate requests
func (s *Server) AuthHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "Authenticating request")
// Get auth cookie
c, err := r.Cookie(config.CookieName)
if err != nil {
// Error indicates no cookie, generate nonce
err, nonce := Nonce()
if err != nil {
logger.Errorf("Error generating nonce, %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Set the CSRF cookie
http.SetCookie(w, MakeCSRFCookie(r, nonce))
logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
return
}
// Validate cookie
valid, email, err := ValidateCookie(r, c)
if !valid {
logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Validate user
valid = ValidateEmail(email)
if !valid {
logger.WithFields(logrus.Fields{
"email": email,
}).Errorf("Invalid email")
http.Error(w, "Not authorized", 401)
return
}
// Valid request
logger.Debugf("Allowing valid request ")
w.Header().Set("X-Forwarded-User", email)
w.WriteHeader(200)
}
}
// Handle auth callback
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "Handling callback")
// Check for CSRF cookie
c, err := r.Cookie(config.CSRFCookieName)
if err != nil {
logger.Warn("Missing csrf cookie")
http.Error(w, "Not authorized", 401)
return
}
// Validate state
valid, redirect, err := ValidateCSRFCookie(r, c)
if !valid {
logger.Warnf("Error validating csrf cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Clear CSRF cookie
http.SetCookie(w, ClearCSRFCookie(r))
// Exchange code for token
token, err := ExchangeCode(r)
if err != nil {
logger.Errorf("Code exchange failed with: %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Get user
user, err := GetUser(token)
if err != nil {
logger.Errorf("Error getting user: %s", err)
return
}
// Generate cookie
http.SetCookie(w, MakeCookie(r, user.Email))
logger.WithFields(logrus.Fields{
"user": user.Email,
}).Infof("Generated auth cookie")
// Redirect
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}
}
func (s *Server) logger(r *http.Request, msg string) *logrus.Entry {
// Create logger
logger := log.WithFields(logrus.Fields{
"RemoteAddr": r.RemoteAddr,
})
// Log request
logger.WithFields(logrus.Fields{
"Headers": r.Header,
}).Debugf(msg)
return logger
}

269
internal/server_test.go Normal file
View File

@ -0,0 +1,269 @@
package tfa
import (
"fmt"
"time"
// "reflect"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
/**
* Setup
*/
func init() {
config.LogLevel = "panic"
log = NewDefaultLogger()
}
/**
* Utilities
*/
type TokenServerHandler struct{}
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"access_token":"123456789"}`)
}
type UserServerHandler struct{}
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{
"id":"1",
"email":"example@example.com",
"verified_email":true,
"hd":"example.com"
}`)
}
func httpRequest(s *Server, r *http.Request, c *http.Cookie) (*http.Response, string) {
w := httptest.NewRecorder()
// Set cookies on recorder
if c != nil {
http.SetCookie(w, c)
}
// Copy into request
for _, c := range w.HeaderMap["Set-Cookie"] {
r.Header.Add("Cookie", c)
}
s.RootHandler(w, r)
res := w.Result()
body, _ := ioutil.ReadAll(res.Body)
// if res.StatusCode > 300 && res.StatusCode < 400 {
// fmt.Printf("%#v", res.Header)
// }
return res, string(body)
}
func newHttpRequest(uri string) *http.Request {
r := httptest.NewRequest("", "http://example.com/", nil)
r.Header.Add("X-Forwarded-Uri", uri)
return r
}
func qsDiff(one, two url.Values) {
for k := range one {
if two.Get(k) == "" {
fmt.Printf("Key missing: %s\n", k)
}
if one.Get(k) != two.Get(k) {
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
}
}
for k := range two {
if one.Get(k) == "" {
fmt.Printf("Extra key: %s\n", k)
}
}
}
/**
* Tests
*/
func TestServerHandler(t *testing.T) {
server := NewServer()
config = Config{
Path: "/_oauth",
CookieName: "cookie_test",
Lifetime: time.Second * time.Duration(10),
Providers: provider.Providers{
Google: provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "http",
Host: "test.com",
Path: "/auth",
},
},
},
}
// Should redirect vanilla request to login url
req := newHttpRequest("/foo")
res, _ := httpRequest(server, req, nil)
if res.StatusCode != 307 {
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
t.Error("Vanilla request should be redirected to login url, got:", fwd)
}
// Should catch invalid cookie
req = newHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = httpRequest(server, req, c)
if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
}
// Should validate email
req = newHttpRequest("/foo")
c = MakeCookie(req, "test@example.com")
config.Domains = []string{"test.com"}
res, _ = httpRequest(server, req, c)
if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
}
// Should allow valid request email
req = newHttpRequest("/foo")
c = MakeCookie(req, "test@example.com")
config.Domains = []string{}
res, _ = httpRequest(server, req, c)
if res.StatusCode != 200 {
t.Error("Valid request should be allowed, got:", res.StatusCode)
}
// Should pass through user
users := res.Header["X-Forwarded-User"]
if len(users) != 1 {
t.Error("Valid request missing X-Forwarded-User header")
} else if users[0] != "test@example.com" {
t.Error("X-Forwarded-User should match user, got: ", users)
}
}
func TestServerAuthCallback(t *testing.T) {
server := NewServer()
config = Config{
Path: "/_oauth",
CookieName: "cookie_test",
Lifetime: time.Second * time.Duration(10),
CSRFCookieName: "csrf_test",
Providers: provider.Providers{
Google: provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "http",
Host: "test.com",
Path: "/auth",
},
},
},
}
// Setup token server
tokenServerHandler := &TokenServerHandler{}
tokenServer := httptest.NewServer(tokenServerHandler)
defer tokenServer.Close()
tokenUrl, _ := url.Parse(tokenServer.URL)
config.Providers.Google.TokenURL = tokenUrl
// Setup user server
userServerHandler := &UserServerHandler{}
userServer := httptest.NewServer(userServerHandler)
defer userServer.Close()
userUrl, _ := url.Parse(userServer.URL)
config.Providers.Google.UserURL = userUrl
// Should pass auth response request to callback
req := newHttpRequest("/_oauth")
res, _ := httpRequest(server, req, nil)
if res.StatusCode != 401 {
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
}
// Should catch invalid csrf cookie
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
c := MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = httpRequest(server, req, c)
if res.StatusCode != 401 {
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
}
// Should redirect valid request
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = httpRequest(server, req, c)
if res.StatusCode != 307 {
t.Error("Valid callback should be allowed, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
t.Error("Valid request should be redirected to return url, got:", fwd)
}
}
func TestServerRoutePathPrefix(t *testing.T) {
server := NewServer()
config = Config{
Path: "/_oauth",
CookieName: "cookie_test",
Lifetime: time.Second * time.Duration(10),
Providers: provider.Providers{
Google: provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "http",
Host: "test.com",
Path: "/auth",
},
},
},
Rules: []Rule{
{
Action: "allow",
Rule: "PathPrefix(`/api`)",
},
},
}
// Should allow /api request
req := newHttpRequest("/api")
c := MakeCookie(req, "test@example.com")
res, _ := httpRequest(server, req, c)
if res.StatusCode != 200 {
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
}
}