9 Commits

Author SHA1 Message Date
d890a4aad6 Add more v2 tests + fixes + improve legacy config parsing 2019-04-18 15:27:41 +01:00
43775591fa Use new rule config + tidy ups 2019-04-17 11:29:35 +01:00
daec9f591a Improve qsdiff error reporting 2019-04-15 12:03:31 +01:00
091590d391 another attempt at dev travis 2019-04-12 16:47:36 +01:00
8ca16a88d2 fix dev travis 2019-04-12 16:34:23 +01:00
814892a88b Refactor progress
- move directory structure
- string based rule definition
- use traefik rule parsing
- drop toml config
- new flag library
- implement go dep
2019-04-12 16:12:13 +01:00
19c249a6d1 fix dockerfile 2019-01-30 18:06:46 +00:00
0f278d516b fix tests 2019-01-30 17:11:41 +00:00
ae95e8b2e5 Begin refactor + selective auth 2019-01-30 16:52:47 +00:00
20 changed files with 1657 additions and 896 deletions

View File

@ -1,8 +1,5 @@
language: go
sudo: false
go:
- "1.10"
install:
- go get github.com/namsral/flag
- go get github.com/sirupsen/logrus
script: go test -v ./...
- "1.12"
script: env GO111MODULE=on go test -v ./...

View File

@ -1,18 +1,15 @@
FROM golang:1.10-alpine as builder
FROM golang:1.12-alpine as builder
# Setup
RUN mkdir /app
WORKDIR /app
RUN mkdir -p /go/src/github.com/thomseddon/traefik-forward-auth
WORKDIR /go/src/github.com/thomseddon/traefik-forward-auth
# Add libraries
RUN apk add --no-cache git && \
go get "github.com/namsral/flag" && \
go get "github.com/sirupsen/logrus" && \
apk del git
RUN apk add --no-cache git
# Copy & build
ADD . /app/
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /traefik-forward-auth .
ADD . /go/src/github.com/thomseddon/traefik-forward-auth/
RUN CGO_ENABLED=0 GOOS=linux GO111MODULE=on go build -a -installsuffix nocgo -o /traefik-forward-auth github.com/thomseddon/traefik-forward-auth/cmd
# Copy into scratch container
FROM scratch

View File

@ -1,5 +1,5 @@
format:
gofmt -w -s *.go
gofmt -w -s internal/*.go cmd/*.go
.PHONY: format

30
cmd/main.go Normal file
View File

@ -0,0 +1,30 @@
package main
import (
"net/http"
internal "github.com/thomseddon/traefik-forward-auth/internal"
)
// Main
func main() {
// Parse options
config := internal.NewGlobalConfig()
// Setup logger
log := internal.NewDefaultLogger()
// Perform config validation
config.Validate()
// Build server
server := internal.NewServer()
// Attach router to default server
http.HandleFunc("/", server.RootHandler)
// Start
log.Debugf("Starting with options: %s", config)
log.Info("Listening on :4181")
log.Info(http.ListenAndServe(":4181", nil))
}

View File

@ -1,390 +0,0 @@
package main
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
// Forward Auth
type ForwardAuth struct {
Path string
Lifetime time.Duration
Secret []byte
ClientId string
ClientSecret string `json:"-"`
Scope string
LoginURL *url.URL
TokenURL *url.URL
UserURL *url.URL
AuthHost string
CookieName string
CookieDomains []CookieDomain
CSRFCookieName string
CookieSecure bool
Domain []string
Whitelist []string
Prompt string
}
// Request Validation
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
parts := strings.Split(c.Value, "|")
if len(parts) != 3 {
return false, "", errors.New("Invalid cookie format")
}
mac, err := base64.URLEncoding.DecodeString(parts[0])
if err != nil {
return false, "", errors.New("Unable to decode cookie mac")
}
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil {
return false, "", errors.New("Unable to generate mac")
}
// Valid token?
if !hmac.Equal(mac, expected) {
return false, "", errors.New("Invalid cookie mac")
}
expires, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return false, "", errors.New("Unable to parse cookie expiry")
}
// Has it expired?
if time.Unix(expires, 0).Before(time.Now()) {
return false, "", errors.New("Cookie has expired")
}
// Looks valid
return true, parts[2], nil
}
// Validate email
func (f *ForwardAuth) ValidateEmail(email string) bool {
found := false
if len(f.Whitelist) > 0 {
for _, whitelist := range f.Whitelist {
if email == whitelist {
found = true
}
}
} else if len(f.Domain) > 0 {
parts := strings.Split(email, "@")
if len(parts) < 2 {
return false
}
for _, domain := range f.Domain {
if domain == parts[1] {
found = true
}
}
} else {
return true
}
return found
}
// OAuth Methods
// Get login url
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
q := url.Values{}
q.Set("client_id", fw.ClientId)
q.Set("response_type", "code")
q.Set("scope", fw.Scope)
if fw.Prompt != "" {
q.Set("prompt", fw.Prompt)
}
q.Set("redirect_uri", f.redirectUri(r))
q.Set("state", state)
var u url.URL
u = *fw.LoginURL
u.RawQuery = q.Encode()
return u.String()
}
// Exchange code for token
type Token struct {
Token string `json:"access_token"`
}
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
form := url.Values{}
form.Set("client_id", fw.ClientId)
form.Set("client_secret", fw.ClientSecret)
form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", f.redirectUri(r))
form.Set("code", code)
res, err := http.PostForm(fw.TokenURL.String(), form)
if err != nil {
return "", err
}
var token Token
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
}
// Get user with token
type User struct {
Id string `json:"id"`
Email string `json:"email"`
Verified bool `json:"verified_email"`
Hd string `json:"hd"`
}
func (f *ForwardAuth) GetUser(token string) (User, error) {
var user User
client := &http.Client{}
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
if err != nil {
return user, err
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user)
return user, err
}
// Utility methods
// Get the redirect base
func (f *ForwardAuth) redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")
return fmt.Sprintf("%s://%s", proto, host)
}
// Return url
func (f *ForwardAuth) returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri")
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
}
// Get oauth redirect uri
func (f *ForwardAuth) redirectUri(r *http.Request) string {
if use, _ := f.useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
}
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
}
// Should we use auth host + what it is
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
if f.AuthHost == "" {
return false, ""
}
// Does the request match a given cookie domain?
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
// Do any of the auth hosts match a cookie domain?
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
// We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost
}
// Cookie methods
// Create an auth cookie
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
expires := f.cookieExpiry()
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
return &http.Cookie{
Name: f.CookieName,
Value: value,
Path: "/",
Domain: f.cookieDomain(r),
HttpOnly: true,
Secure: f.CookieSecure,
Expires: expires,
}
}
// Make a CSRF cookie (used during login only)
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: f.CSRFCookieName,
Value: nonce,
Path: "/",
Domain: f.csrfCookieDomain(r),
HttpOnly: true,
Secure: f.CookieSecure,
Expires: f.cookieExpiry(),
}
}
// Create a cookie to clear csrf cookie
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{
Name: f.CSRFCookieName,
Value: "",
Path: "/",
Domain: f.csrfCookieDomain(r),
HttpOnly: true,
Secure: f.CookieSecure,
Expires: time.Now().Local().Add(time.Hour * -1),
}
}
// Validate the csrf cookie against state
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value")
}
if len(state) < 34 {
return false, "", errors.New("Invalid CSRF state value")
}
// Check nonce match
if c.Value != state[:32] {
return false, "", errors.New("CSRF cookie does not match state")
}
// Valid, return redirect
return true, state[33:], nil
}
func (f *ForwardAuth) Nonce() (error, string) {
// Make nonce
nonce := make([]byte, 16)
_, err := rand.Read(nonce)
if err != nil {
return err, ""
}
return nil, fmt.Sprintf("%x", nonce)
}
// Cookie domain
func (f *ForwardAuth) cookieDomain(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")
// Check if any of the given cookie domains matches
_, domain := f.matchCookieDomains(host)
return domain
}
// Cookie domain
func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string {
var host string
if use, domain := f.useAuthDomain(r); use {
host = domain
} else {
host = r.Header.Get("X-Forwarded-Host")
}
// Remove port
p := strings.Split(host, ":")
return p[0]
}
// Return matching cookie domain if exists
func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
// Remove port
p := strings.Split(domain, ":")
for _, d := range f.CookieDomains {
if d.Match(p[0]) {
return true, d.Domain
}
}
return false, p[0]
}
// Create cookie hmac
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, f.Secret)
hash.Write([]byte(f.cookieDomain(r)))
hash.Write([]byte(email))
hash.Write([]byte(expires))
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
}
// Get cookie expirary
func (f *ForwardAuth) cookieExpiry() time.Time {
return time.Now().Local().Add(f.Lifetime)
}
// Cookie Domain
// Cookie Domain
type CookieDomain struct {
Domain string
DomainLen int
SubDomain string
SubDomainLen int
}
func NewCookieDomain(domain string) *CookieDomain {
return &CookieDomain{
Domain: domain,
DomainLen: len(domain),
SubDomain: fmt.Sprintf(".%s", domain),
SubDomainLen: len(domain) + 1,
}
}
func (c *CookieDomain) Match(host string) bool {
// Exact domain match?
if host == c.Domain {
return true
}
// Subdomain match?
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
return true
}
return false
}

31
go.mod Normal file
View File

@ -0,0 +1,31 @@
module github.com/thomseddon/traefik-forward-auth
go 1.12
require (
github.com/VividCortex/gohistogram v1.0.0 // indirect
github.com/cenkalti/backoff v2.1.1+incompatible // indirect
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // indirect
github.com/containous/flaeg v1.4.1 // indirect
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect
github.com/containous/traefik v2.0.0-alpha2+incompatible
github.com/go-acme/lego v2.4.0+incompatible // indirect
github.com/go-kit/kit v0.8.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff // indirect
github.com/jessevdk/go-flags v1.4.0
github.com/jonboulle/clockwork v0.1.0 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/miekg/dns v1.1.8 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pkg/errors v0.8.1 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/sirupsen/logrus v1.4.1
github.com/stretchr/testify v1.3.0 // indirect
github.com/vulcand/predicate v1.1.0 // indirect
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a // indirect
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 // indirect
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/square/go-jose.v2 v2.3.1 // indirect
)

70
go.sum Normal file
View File

@ -0,0 +1,70 @@
github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE=
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY=
github.com/cenkalti/backoff v2.1.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd h1:0n+lFLh5zU0l6KSk3KpnDwfbPGAR44aRLgTbCnhRBHU=
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd/go.mod h1:BbQgeDS5i0tNvypwEoF1oNjOJw8knRAE1DnVvjDstcQ=
github.com/containous/flaeg v1.4.1 h1:VTouP7EF2JeowNvknpP3fJAJLUDsQ1lDHq/QQTQc1xc=
github.com/containous/flaeg v1.4.1/go.mod h1:wgw6PDtRURXHKFFV6HOqQxWhUc3k3Hmq22jw+n2qDro=
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 h1:1srn9voikJGofblBhWy3WuZWqo14Ou7NaswNG/I2yWc=
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg=
github.com/containous/traefik v2.0.0-alpha2+incompatible h1:5RS6mUAOPQCy1jAmcmxLj2nChIcs3fKuxZxH9AF6ih8=
github.com/containous/traefik v2.0.0-alpha2+incompatible/go.mod h1:epDRqge3JzKOhlSWzOpNYEEKXmM6yfN5tPzDGKk3ljo=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI=
github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff h1:xL/fJdlTJL6R/6Qk2tPu3EP1NsXgap9hXLvxKH0Ytko=
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff/go.mod h1:RvdOUHE4SHqR3oXlFFKnGzms8a5dugHygGw1bqDstYI=
github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/miekg/dns v1.1.8 h1:1QYRAKU3lN5cRfLCkPU08hwvLJFhvjP6MqNMmQz6ZVI=
github.com/miekg/dns v1.1.8/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/vulcand/predicate v1.1.0 h1:Gq/uWopa4rx/tnZu2opOSBqHK63Yqlou/SzrbwdJiNg=
github.com/vulcand/predicate v1.1.0/go.mod h1:mlccC5IRBoc2cIFmCB8ZM62I3VDb6p2GXESMHa3CnZg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=

333
internal/auth.go Normal file
View File

@ -0,0 +1,333 @@
package tfa
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
// Request Validation
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
func ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
parts := strings.Split(c.Value, "|")
if len(parts) != 3 {
return false, "", errors.New("Invalid cookie format")
}
mac, err := base64.URLEncoding.DecodeString(parts[0])
if err != nil {
return false, "", errors.New("Unable to decode cookie mac")
}
expectedSignature := cookieSignature(r, parts[2], parts[1])
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil {
return false, "", errors.New("Unable to generate mac")
}
// Valid token?
if !hmac.Equal(mac, expected) {
return false, "", errors.New("Invalid cookie mac")
}
expires, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return false, "", errors.New("Unable to parse cookie expiry")
}
// Has it expired?
if time.Unix(expires, 0).Before(time.Now()) {
return false, "", errors.New("Cookie has expired")
}
// Looks valid
return true, parts[2], nil
}
// Validate email
func ValidateEmail(email string) bool {
found := false
if len(config.Whitelist) > 0 {
for _, whitelist := range config.Whitelist {
if email == whitelist {
found = true
}
}
} else if len(config.Domains) > 0 {
parts := strings.Split(email, "@")
if len(parts) < 2 {
return false
}
for _, domain := range config.Domains {
if domain == parts[1] {
found = true
}
}
} else {
return true
}
return found
}
// OAuth Methods
// Get login url
func GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, returnUrl(r))
// TODO: Support multiple providers
return config.Providers.Google.GetLoginURL(redirectUri(r), state)
}
// Exchange code for token
func ExchangeCode(r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
// TODO: Support multiple providers
return config.Providers.Google.ExchangeCode(redirectUri(r), code)
}
// Get user with token
func GetUser(token string) (provider.User, error) {
// TODO: Support multiple providers
return config.Providers.Google.GetUser(token)
}
// Utility methods
// Get the redirect base
func redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")
return fmt.Sprintf("%s://%s", proto, host)
}
// // Return url
func returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri")
return fmt.Sprintf("%s%s", redirectBase(r), path)
}
// Get oauth redirect uri
func redirectUri(r *http.Request) string {
if use, _ := useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
}
return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
}
// Should we use auth host + what it is
func useAuthDomain(r *http.Request) (bool, string) {
if config.AuthHost == "" {
return false, ""
}
// Does the request match a given cookie domain?
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
// Do any of the auth hosts match a cookie domain?
authMatch, authHost := matchCookieDomains(config.AuthHost)
// We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost
}
// Cookie methods
// Create an auth cookie
func MakeCookie(r *http.Request, email string) *http.Cookie {
expires := cookieExpiry()
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
return &http.Cookie{
Name: config.CookieName,
Value: value,
Path: "/",
Domain: cookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: expires,
}
}
// Make a CSRF cookie (used during login only)
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Value: nonce,
Path: "/",
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: cookieExpiry(),
}
}
// Create a cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Value: "",
Path: "/",
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: time.Now().Local().Add(time.Hour * -1),
}
}
// Validate the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
state := r.URL.Query().Get("state")
if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value")
}
if len(state) < 34 {
return false, "", errors.New("Invalid CSRF state value")
}
// Check nonce match
if c.Value != state[:32] {
return false, "", errors.New("CSRF cookie does not match state")
}
// Valid, return redirect
return true, state[33:], nil
}
func Nonce() (error, string) {
// Make nonce
nonce := make([]byte, 16)
_, err := rand.Read(nonce)
if err != nil {
return err, ""
}
return nil, fmt.Sprintf("%x", nonce)
}
// Cookie domain
func cookieDomain(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")
// Check if any of the given cookie domains matches
_, domain := matchCookieDomains(host)
return domain
}
// Cookie domain
func csrfCookieDomain(r *http.Request) string {
var host string
if use, domain := useAuthDomain(r); use {
host = domain
} else {
host = r.Header.Get("X-Forwarded-Host")
}
// Remove port
p := strings.Split(host, ":")
return p[0]
}
// Return matching cookie domain if exists
func matchCookieDomains(domain string) (bool, string) {
// Remove port
p := strings.Split(domain, ":")
for _, d := range config.CookieDomains {
if d.Match(p[0]) {
return true, d.Domain
}
}
return false, p[0]
}
// Create cookie hmac
func cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, config.Secret)
hash.Write([]byte(cookieDomain(r)))
hash.Write([]byte(email))
hash.Write([]byte(expires))
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
}
// Get cookie expirary
func cookieExpiry() time.Time {
return time.Now().Local().Add(config.Lifetime)
}
// Cookie Domain
// Cookie Domain
type CookieDomain struct {
Domain string `description:"TEST1"`
DomainLen int `description:"TEST2"`
SubDomain string `description:"TEST3"`
SubDomainLen int `description:"TEST4"`
}
func NewCookieDomain(domain string) *CookieDomain {
return &CookieDomain{
Domain: domain,
DomainLen: len(domain),
SubDomain: fmt.Sprintf(".%s", domain),
SubDomainLen: len(domain) + 1,
}
}
func (c *CookieDomain) Match(host string) bool {
// Exact domain match?
if host == c.Domain {
return true
}
// Subdomain match?
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
return true
}
return false
}
type CookieDomains []CookieDomain
func (c *CookieDomains) UnmarshalFlag(value string) error {
// TODO: test
if len(value) > 0 {
for _, d := range strings.Split(value, ",") {
cookieDomain := NewCookieDomain(d)
*c = append(*c, *cookieDomain)
}
}
return nil
}
func (c *CookieDomains) MarshalFlag() (string, error) {
var domains []string
for _, d := range *c {
domains = append(domains, d.Domain)
}
return strings.Join(domains, ","), nil
}

View File

@ -1,55 +1,62 @@
package main
package tfa
import (
// "fmt"
"fmt"
"net/http"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
func TestValidateCookie(t *testing.T) {
fw = &ForwardAuth{}
/**
* Tests
*/
func TestAuthValidateCookie(t *testing.T) {
config, _ = NewConfig([]string{})
r, _ := http.NewRequest("GET", "http://example.com", nil)
c := &http.Cookie{}
// Should require 3 parts
c.Value = ""
valid, _, err := fw.ValidateCookie(r, c)
valid, _, err := ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err)
}
c.Value = "1|2"
valid, _, err = fw.ValidateCookie(r, c)
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err)
}
c.Value = "1|2|3|4"
valid, _, err = fw.ValidateCookie(r, c)
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err)
}
// Should catch invalid mac
c.Value = "MQ==|2|3"
valid, _, err = fw.ValidateCookie(r, c)
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie mac" {
t.Error("Should get \"Invalid cookie mac\", got:", err)
}
// Should catch expired
fw.Lifetime = time.Second * time.Duration(-1)
c = fw.MakeCookie(r, "test@test.com")
valid, _, err = fw.ValidateCookie(r, c)
config.Lifetime = time.Second * time.Duration(-1)
c = MakeCookie(r, "test@test.com")
valid, _, err = ValidateCookie(r, c)
if valid || err.Error() != "Cookie has expired" {
t.Error("Should get \"Cookie has expired\", got:", err)
}
// Should accept valid cookie
fw.Lifetime = time.Second * time.Duration(10)
c = fw.MakeCookie(r, "test@test.com")
valid, email, err := fw.ValidateCookie(r, c)
config.Lifetime = time.Second * time.Duration(10)
c = MakeCookie(r, "test@test.com")
valid, email, err := ValidateCookie(r, c)
if !valid {
t.Error("Valid request should return as valid")
}
@ -61,52 +68,48 @@ func TestValidateCookie(t *testing.T) {
}
}
func TestValidateEmail(t *testing.T) {
fw = &ForwardAuth{}
func TestAuthValidateEmail(t *testing.T) {
config, _ = NewConfig([]string{})
// Should allow any
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
if !ValidateEmail("test@test.com") || !ValidateEmail("one@two.com") {
t.Error("Should allow any domain if email domain is not defined")
}
// Should block non matching domain
fw.Domain = []string{"test.com"}
if fw.ValidateEmail("one@two.com") {
config.Domains = []string{"test.com"}
if ValidateEmail("one@two.com") {
t.Error("Should not allow user from another domain")
}
// Should allow matching domain
fw.Domain = []string{"test.com"}
if !fw.ValidateEmail("test@test.com") {
config.Domains = []string{"test.com"}
if !ValidateEmail("test@test.com") {
t.Error("Should allow user from allowed domain")
}
// Should block non whitelisted email address
fw.Domain = []string{}
fw.Whitelist = []string{"test@test.com"}
if fw.ValidateEmail("one@two.com") {
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
if ValidateEmail("one@two.com") {
t.Error("Should not allow user not in whitelist.")
}
// Should allow matching whitelisted email address
fw.Domain = []string{}
fw.Whitelist = []string{"test@test.com"}
if !fw.ValidateEmail("test@test.com") {
config.Domains = []string{}
config.Whitelist = []string{"test@test.com"}
if !ValidateEmail("test@test.com") {
t.Error("Should allow user in whitelist.")
}
}
func TestGetLoginURL(t *testing.T) {
r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
fw = &ForwardAuth{
Path: "/_oauth",
// TODO: Split google tests out
func TestAuthGetLoginURL(t *testing.T) {
google := provider.Google{
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
@ -114,8 +117,16 @@ func TestGetLoginURL(t *testing.T) {
},
}
config, _ = NewConfig([]string{})
config.Providers.Google = google
r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
// Check url
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
uri, err := url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
@ -136,33 +147,25 @@ func TestGetLoginURL(t *testing.T) {
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"nonce:http://example.com/hello"},
}
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
for _, err := range qsDiff(t, expectedQs, qs) {
t.Error(err)
}
}
//
// With Auth URL but no matching cookie domain
// - will not use auth host
//
fw = &ForwardAuth{
Path: "/_oauth",
AuthHost: "auth.example.com",
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
Prompt: "consent select_account",
}
config, _ = NewConfig([]string{})
config.AuthHost = "auth.example.com"
config.Providers.Google = google
// Check url
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
uri, err = url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
@ -187,30 +190,21 @@ func TestGetLoginURL(t *testing.T) {
"state": []string{"nonce:http://example.com/hello"},
}
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
for _, err := range qsDiff(t, expectedQs, qs) {
t.Error(err)
}
}
//
// With correct Auth URL + cookie domain
//
cookieDomain := NewCookieDomain("example.com")
fw = &ForwardAuth{
Path: "/_oauth",
AuthHost: "auth.example.com",
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
CookieDomains: []CookieDomain{*cookieDomain},
}
config, _ = NewConfig([]string{})
config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
config.Providers.Google = google
// Check url
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
uri, err = url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
@ -232,11 +226,15 @@ func TestGetLoginURL(t *testing.T) {
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"},
"prompt": []string{"consent select_account"},
}
for _, err := range qsDiff(t, expectedQs, qs) {
t.Error(err)
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
for _, err := range qsDiff(t, expectedQs, qs) {
t.Error(err)
}
}
//
@ -249,7 +247,7 @@ func TestGetLoginURL(t *testing.T) {
r.Header.Add("X-Forwarded-Uri", "/hello")
// Check url
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
uri, err = url.Parse(GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
@ -271,92 +269,146 @@ func TestGetLoginURL(t *testing.T) {
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://another.com/hello"},
"prompt": []string{"consent select_account"},
}
for _, err := range qsDiff(t, expectedQs, qs) {
t.Error(err)
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
for _, err := range qsDiff(t, expectedQs, qs) {
t.Error(err)
}
}
}
// TODO
// func TestExchangeCode(t *testing.T) {
// func TestAuthExchangeCode(t *testing.T) {
// }
// TODO
// func TestGetUser(t *testing.T) {
// func TestAuthGetUser(t *testing.T) {
// }
// TODO? Tested in TestValidateCookie
// func TestMakeCookie(t *testing.T) {
// }
func TestAuthMakeCookie(t *testing.T) {
config, _ = NewConfig([]string{})
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
r.Header.Add("X-Forwarded-Host", "app.example.com")
func TestMakeCSRFCookie(t *testing.T) {
c := MakeCookie(r, "test@example.com")
if c.Name != "_forward_auth" {
t.Error("Cookie name should be \"_forward_auth\", got:", c.Name)
}
parts := strings.Split(c.Value, "|")
if len(parts) != 3 {
t.Error("Cookie should be in 3 parts, got:", c.Value)
}
valid, _, _ := ValidateCookie(r, c)
if !valid {
t.Error("Should generate valid cookie:", c.Value)
}
if c.Path != "/" {
t.Error("Cookie path should be \"/\", got:", c.Path)
}
if c.Domain != "app.example.com" {
t.Error("Cookie domain should be \"app.example.com\", got:", c.Domain)
}
if c.Secure != true {
t.Error("Cookie domain should be true, got:", c.Secure)
}
if !c.Expires.After(time.Now().Local()) {
t.Error("Expires should be after now, got:", c.Expires)
}
if !c.Expires.Before(time.Now().Local().Add(config.Lifetime).Add(10 * time.Second)) {
t.Error("Expires should be before lifetime + 10 seconds, got:", c.Expires)
}
config.CookieName = "testname"
config.InsecureCookie = true
c = MakeCookie(r, "test@example.com")
if c.Name != "testname" {
t.Error("Cookie name should be \"testname\", got:", c.Name)
}
if c.Secure != false {
t.Error("Cookie domain should be false, got:", c.Secure)
}
}
func TestAuthMakeCSRFCookie(t *testing.T) {
config, _ = NewConfig([]string{})
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
r.Header.Add("X-Forwarded-Host", "app.example.com")
// No cookie domain or auth url
fw = &ForwardAuth{}
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
// With cookie domain but no auth url
cookieDomain := NewCookieDomain("example.com")
fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain}}
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
config = Config{
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
// With cookie domain and auth url
fw = &ForwardAuth{
config = Config{
AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*cookieDomain},
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
}
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
}
func TestClearCSRFCookie(t *testing.T) {
fw = &ForwardAuth{}
func TestAuthClearCSRFCookie(t *testing.T) {
config, _ = NewConfig([]string{})
r, _ := http.NewRequest("GET", "http://example.com", nil)
c := fw.ClearCSRFCookie(r)
c := ClearCSRFCookie(r)
if c.Value != "" {
t.Error("ClearCSRFCookie should create cookie with empty value")
}
}
func TestValidateCSRFCookie(t *testing.T) {
fw = &ForwardAuth{}
func TestAuthValidateCSRFCookie(t *testing.T) {
config, _ = NewConfig([]string{})
c := &http.Cookie{}
newCsrfRequest := func(state string) *http.Request {
u := fmt.Sprintf("http://example.com?state=%s", state)
r, _ := http.NewRequest("GET", u, nil)
return r
}
// Should require 32 char string
r := newCsrfRequest("")
c.Value = ""
valid, _, err := fw.ValidateCSRFCookie(c, "")
valid, _, err := ValidateCSRFCookie(r, c)
if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
}
c.Value = "123456789012345678901234567890123"
valid, _, err = fw.ValidateCSRFCookie(c, "")
valid, _, err = ValidateCSRFCookie(r, c)
if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
}
// Should require valid state
r = newCsrfRequest("12345678901234567890123456789012:")
c.Value = "12345678901234567890123456789012"
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
valid, _, err = ValidateCSRFCookie(r, c)
if valid || err.Error() != "Invalid CSRF state value" {
t.Error("Should get \"Invalid CSRF state value\", got:", err)
}
// Should allow valid state
r = newCsrfRequest("12345678901234567890123456789012:99")
c.Value = "12345678901234567890123456789012"
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
valid, state, err := ValidateCSRFCookie(r, c)
if !valid {
t.Error("Valid request should return as valid")
}
@ -368,15 +420,13 @@ func TestValidateCSRFCookie(t *testing.T) {
}
}
func TestNonce(t *testing.T) {
fw = &ForwardAuth{}
err, nonce1 := fw.Nonce()
func TestAuthNonce(t *testing.T) {
err, nonce1 := Nonce()
if err != nil {
t.Error("Error generation nonce:", err)
}
err, nonce2 := fw.Nonce()
err, nonce2 := Nonce()
if err != nil {
t.Error("Error generation nonce:", err)
}
@ -389,7 +439,7 @@ func TestNonce(t *testing.T) {
}
}
func TestCookieDomainMatch(t *testing.T) {
func TestAuthCookieDomainMatch(t *testing.T) {
cd := NewCookieDomain("example.com")
// Exact should match
@ -412,3 +462,29 @@ func TestCookieDomainMatch(t *testing.T) {
t.Error("Other domain should not match")
}
}
func TestAuthCookieDomains(t *testing.T) {
cds := CookieDomains{}
err := cds.UnmarshalFlag("one.com,two.org")
if err != nil {
t.Error(err)
}
if len(cds) != 2 {
t.Error("Expected UnmarshalFlag to provide 2 CookieDomains, got", cds)
}
if cds[0].Domain != "one.com" || cds[0].SubDomain != ".one.com" {
t.Error("Expected UnmarshalFlag to provide one.com, got", cds[0])
}
if cds[1].Domain != "two.org" || cds[1].SubDomain != ".two.org" {
t.Error("Expected UnmarshalFlag to provide two.org, got", cds[1])
}
marshal, err := cds.MarshalFlag()
if err != nil {
t.Error(err)
}
if marshal != "one.com,two.org" {
t.Error("Expected MarshalFlag to provide \"one.com,two.org\", got", cds)
}
}

262
internal/config.go Normal file
View File

@ -0,0 +1,262 @@
package tfa
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"regexp"
"strconv"
"strings"
"time"
"github.com/jessevdk/go-flags"
"github.com/thomseddon/traefik-forward-auth/internal/provider"
)
var config Config
type Config struct {
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Host for central auth login"`
Config func(s string) error `long:"config" env:"CONFIG" description:"Config file"`
CookieDomains CookieDomains `long:"cookie-domains" env:"COOKIE_DOMAINS" description:"Comma separated list of cookie domains"`
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default Action"`
Domains CommaSeparatedList `long:"domains" env:"DOMAINS" description:"Comma separated list of email domains to allow"`
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
Path string `long:"url-path" env:"URL_PATH" default:"_oauth" description:"Callback URL Path"`
SecretString string `long:"secret" env:"SECRET" description:"*Secret used for signing (required)"`
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Comma separated list of email addresses to allow"`
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
Rules map[string]*Rule `long:"rules.<name>.<param>" description:"Rule definitions, see docs, param can be: \"action\", \"rule\""`
// Filled during transformations
Secret []byte
Lifetime time.Duration
// Legacy
ClientIdLegacy string `long:"client-id" env:"CLIENT_ID" group:"DEPs" description:"DEPRECATED - Use \"providers.google.client-id\""`
ClientSecretLegacy string `long:"client-secret" env:"CLIENT_SECRET" description:"DEPRECATED - Use \"providers.google.client-id\""`
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
CookieSecureLegacy string `long:"cookie-secure" env:"COOKIE_SECURE" namespace:"DERPS" description:"DEPRECATED - Use \"insecure-cookie\""`
}
func NewGlobalConfig() Config {
var err error
config, err = NewConfig(os.Args[1:])
if err != nil {
fmt.Printf("%+v\n", err)
os.Exit(1)
}
return config
}
func NewConfig(args []string) (Config, error) {
c := Config{
Rules: map[string]*Rule{},
}
err := c.parseFlags(args)
if err != nil {
return c, err
}
// TODO: as log flags have now been parsed maybe we should return here so
// any further errors can be logged via logrus instead of printed?
// Backwards compatability
if c.ClientIdLegacy != "" {
c.Providers.Google.ClientId = c.ClientIdLegacy
}
if c.ClientSecretLegacy != "" {
c.Providers.Google.ClientSecret = c.ClientSecretLegacy
}
if c.PromptLegacy != "" {
c.Providers.Google.Prompt = c.PromptLegacy
}
if c.CookieSecureLegacy != "" {
secure, err := strconv.ParseBool(c.CookieSecureLegacy)
if err != nil {
return c, err
}
c.InsecureCookie = !secure
}
// Provider defaults
c.Providers.Google.Build()
// Transformations
c.Path = fmt.Sprintf("/%s", c.Path)
c.Secret = []byte(c.SecretString)
c.Lifetime = time.Second * time.Duration(c.LifetimeString)
return c, nil
}
func (c *Config) parseFlags(args []string) error {
p := flags.NewParser(c, flags.Default)
p.UnknownOptionHandler = c.parseUnknownFlag
i := flags.NewIniParser(p)
c.Config = func(s string) error {
// Try parsing at as an ini
err := i.ParseFile(s)
// If it fails with a syntax error, try converting legacy to ini
if err != nil && strings.Contains(err.Error(), "malformed key=value") {
converted, convertErr := convertLegacyToIni(s)
if convertErr != nil {
// If conversion fails, return the original error
return err
}
return i.Parse(converted)
}
return err
}
_, err := p.ParseArgs(args)
if err != nil {
return handlFlagError(err)
}
return nil
}
func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args []string) ([]string, error) {
// Parse rules in the format "rule.<name>.<param>"
parts := strings.Split(option, ".")
if len(parts) == 3 && parts[0] == "rule" {
// Get or create rule
rule, ok := c.Rules[parts[1]]
if !ok {
rule = NewRule()
c.Rules[parts[1]] = rule
}
// Get value, or pop the next arg
val, ok := arg.Value()
if !ok {
val = args[0]
args = args[1:]
}
// Check value
if len(val) == 0 {
return args, errors.New("route param value is required")
}
// Unquote if required
if val[0] == '"' {
var err error
val, err = strconv.Unquote(val)
if err != nil {
return args, err
}
}
// Add param value to rule
switch parts[2] {
case "action":
rule.Action = val
case "rule":
rule.Rule = val
case "provider":
rule.Provider = val
default:
return args, fmt.Errorf("inavlid route param: %v", option)
}
} else {
return args, fmt.Errorf("unknown flag: %v", option)
}
return args, nil
}
func handlFlagError(err error) error {
flagsErr, ok := err.(*flags.Error)
if ok && flagsErr.Type == flags.ErrHelp {
// Library has just printed cli help
os.Exit(0)
}
return err
}
var legacyFileFormat = regexp.MustCompile(`^([a-z-]+) ([\w\W]+)$`)
func convertLegacyToIni(name string) (io.Reader, error) {
b, err := ioutil.ReadFile(name)
if err != nil {
return nil, err
}
return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil
}
func (c *Config) Validate() {
// Check for show stopper errors
if len(c.Secret) == 0 {
log.Fatal("\"secret\" option must be set.")
}
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
log.Fatal("google.providers.client-id, google.providers.client-secret must be set")
}
// Check rules
for _, rule := range c.Rules {
rule.Validate()
}
}
func (c Config) String() string {
jsonConf, _ := json.Marshal(c)
return string(jsonConf)
}
type Rule struct {
Action string
Rule string
Provider string
}
func NewRule() *Rule {
return &Rule{
Action: "auth",
Provider: "google", // TODO: Use default provider
}
}
func (r *Rule) Validate() {
if r.Action != "auth" && r.Action != "allow" {
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"")
}
// TODO: Update with more provider support
if r.Provider != "google" {
log.Fatal("invalid rule provider, must be \"google\"")
}
}
type CommaSeparatedList []string
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
*c = strings.Split(value, ",")
return nil
}
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
return strings.Join(*c, ","), nil
}

269
internal/config_test.go Normal file
View File

@ -0,0 +1,269 @@
package tfa
import (
"bytes"
"os"
"testing"
"time"
)
/**
* Tests
*/
func TestConfigDefaults(t *testing.T) {
// Check defaults
c, err := NewConfig([]string{})
if err != nil {
t.Error(err)
}
if c.LogLevel != "warn" {
t.Error("LogLevel default should be warn, got", c.LogLevel)
}
if c.LogFormat != "text" {
t.Error("LogFormat default should be text, got", c.LogFormat)
}
if c.AuthHost != "" {
t.Error("AuthHost default should be empty, got", c.AuthHost)
}
if len(c.CookieDomains) != 0 {
t.Error("CookieDomains default should be empty, got", c.CookieDomains)
}
if c.InsecureCookie != false {
t.Error("InsecureCookie default should be false, got", c.InsecureCookie)
}
if c.CookieName != "_forward_auth" {
t.Error("CookieName default should be _forward_auth, got", c.CookieName)
}
if c.CSRFCookieName != "_forward_auth_csrf" {
t.Error("CSRFCookieName default should be _forward_auth_csrf, got", c.CSRFCookieName)
}
if c.DefaultAction != "auth" {
t.Error("DefaultAction default should be auth, got", c.DefaultAction)
}
if len(c.Domains) != 0 {
t.Error("Domain default should be empty, got", c.Domains)
}
if c.Lifetime != time.Second*time.Duration(43200) {
t.Error("Lifetime default should be 43200, got", c.Lifetime)
}
if c.Path != "/_oauth" {
t.Error("Path default should be /_oauth, got", c.Path)
}
if len(c.Whitelist) != 0 {
t.Error("Whitelist default should be empty, got", c.Whitelist)
}
if c.Providers.Google.Prompt != "" {
t.Error("Providers.Google.Prompt default should be empty, got", c.Providers.Google.Prompt)
}
}
func TestConfigParseArgs(t *testing.T) {
c, err := NewConfig([]string{
"--cookie-name=cookiename",
"--csrf-cookie-name", "\"csrfcookiename\"",
"--rule.1.action=allow",
"--rule.1.rule=PathPrefix(`/one`)",
"--rule.two.action=auth",
"--rule.two.rule=\"Host(`two.com`) && Path(`/two`)\"",
})
if err != nil {
t.Error(err)
}
// Check normal flags
if c.CookieName != "cookiename" {
t.Error("CookieName default should be cookiename, got", c.CookieName)
}
if c.CSRFCookieName != "csrfcookiename" {
t.Error("CSRFCookieName default should be csrfcookiename, got", c.CSRFCookieName)
}
// Check rules
if len(c.Rules) != 2 {
t.Error("Should create 2 rules, got:", len(c.Rules))
}
// First rule
if rule, ok := c.Rules["1"]; !ok {
t.Error("Could not find rule key '1'")
} else {
if rule.Action != "allow" {
t.Error("First rule action should be allow, got:", rule.Action)
}
if rule.Rule != "PathPrefix(`/one`)" {
t.Error("First rule rule should be PathPrefix(`/one`), got:", rule.Rule)
}
if rule.Provider != "google" {
t.Error("First rule provider should be google, got:", rule.Provider)
}
}
// Second rule
if rule, ok := c.Rules["two"]; !ok {
t.Error("Could not find rule key '1'")
} else {
if rule.Action != "auth" {
t.Error("Second rule action should be auth, got:", rule.Action)
}
if rule.Rule != "Host(`two.com`) && Path(`/two`)" {
t.Error("Second rule rule should be Host(`two.com`) && Path(`/two`), got:", rule.Rule)
}
if rule.Provider != "google" {
t.Error("Second rule provider should be google, got:", rule.Provider)
}
}
}
func TestConfigParseUnknownFlags(t *testing.T) {
_, err := NewConfig([]string{
"--unknown=_oauthpath2",
})
if err.Error() != "unknown flag: unknown" {
t.Error("Error should be \"unknown flag: unknown\", got:", err)
}
}
func TestConfigFlagBackwardsCompatability(t *testing.T) {
c, err := NewConfig([]string{
"--client-id=clientid",
"--client-secret=verysecret",
"--prompt=prompt",
"--lifetime=200",
"--cookie-secure=false",
})
if err != nil {
t.Error(err)
}
if c.ClientIdLegacy != "clientid" {
t.Error("ClientIdLegacy should be clientid, got", c.ClientIdLegacy)
}
if c.Providers.Google.ClientId != "clientid" {
t.Error("Providers.Google.ClientId should be clientid, got", c.Providers.Google.ClientId)
}
if c.ClientSecretLegacy != "verysecret" {
t.Error("ClientSecretLegacy should be verysecret, got", c.ClientSecretLegacy)
}
if c.Providers.Google.ClientSecret != "verysecret" {
t.Error("Providers.Google.ClientSecret should be verysecret, got", c.Providers.Google.ClientSecret)
}
if c.PromptLegacy != "prompt" {
t.Error("PromptLegacy should be prompt, got", c.PromptLegacy)
}
if c.Providers.Google.Prompt != "prompt" {
t.Error("Providers.Google.Prompt should be prompt, got", c.Providers.Google.Prompt)
}
// "cookie-secure" used to be a standard go bool flag that could take
// true, TRUE, 1, false, FALSE, 0 etc. values.
// Here we're checking that format is still suppoted
if c.CookieSecureLegacy != "false" || c.InsecureCookie != true {
t.Error("Setting cookie-secure=false should set InsecureCookie true, got", c.InsecureCookie)
}
c, err = NewConfig([]string{"--cookie-secure=TRUE"})
if err != nil {
t.Error(err)
}
if c.CookieSecureLegacy != "TRUE" || c.InsecureCookie != false {
t.Error("Setting cookie-secure=TRUE should set InsecureCookie false, got", c.InsecureCookie)
}
}
func TestConfigParseIni(t *testing.T) {
c, err := NewConfig([]string{
"--config=../test/config0",
"--config=../test/config1",
"--csrf-cookie-name=csrfcookiename",
})
if err != nil {
t.Error(err)
}
if c.CookieName != "inicookiename" {
t.Error("CookieName should be read as inicookiename from ini file, got", c.CookieName)
}
if c.CSRFCookieName != "csrfcookiename" {
t.Error("CSRFCookieName argument should override ini file, got", c.CSRFCookieName)
}
if c.Path != "/two" {
t.Error("Path in second ini file should override first ini file, got", c.Path)
}
}
func TestConfigFileBackwardsCompatability(t *testing.T) {
c, err := NewConfig([]string{
"--config=../test/config-legacy",
})
if err != nil {
t.Error(err)
}
if c.Path != "/two" {
t.Error("Path in legacy config file should be read, got", c.Path)
}
}
func TestConfigParseEnvironment(t *testing.T) {
os.Setenv("COOKIE_NAME", "env_cookie_name")
c, err := NewConfig([]string{})
if err != nil {
t.Error(err)
}
if c.CookieName != "env_cookie_name" {
t.Error("CookieName should be read as env_cookie_name from environment, got", c.CookieName)
}
}
func TestConfigTransformation(t *testing.T) {
c, err := NewConfig([]string{
"--url-path=_oauthpath",
"--secret=verysecret",
"--lifetime=200",
})
if err != nil {
t.Error(err)
}
if c.Path != "/_oauthpath" {
t.Error("Path should add slash to front to get /_oauthpath, got:", c.Path)
}
if c.SecretString != "verysecret" {
t.Error("SecretString should be verysecret, got:", c.SecretString)
}
if bytes.Compare(c.Secret, []byte("verysecret")) != 0 {
t.Error("Secret should be []byte(verysecret), got:", string(c.Secret))
}
if c.LifetimeString != 200 {
t.Error("LifetimeString should be 200, got:", c.LifetimeString)
}
if c.Lifetime != time.Second*time.Duration(200) {
t.Error("Lifetime default should be 200, got", c.Lifetime)
}
}
func TestConfigCommaSeparatedList(t *testing.T) {
list := CommaSeparatedList{}
err := list.UnmarshalFlag("one,two")
if err != nil {
t.Error(err)
}
if len(list) != 2 || list[0] != "one" || list[1] != "two" {
t.Error("Expected UnmarshalFlag to provide CommaSeparatedList{one,two}, got", list)
}
marshal, err := list.MarshalFlag()
if err != nil {
t.Error(err)
}
if marshal != "one,two" {
t.Error("Expected MarshalFlag to provide \"one,two\", got", list)
}
}

View File

@ -1,4 +1,4 @@
package main
package tfa
import (
"os"
@ -6,13 +6,15 @@ import (
"github.com/sirupsen/logrus"
)
func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
var log logrus.FieldLogger
func NewDefaultLogger() logrus.FieldLogger {
// Setup logger
log := logrus.StandardLogger()
log = logrus.StandardLogger()
logrus.SetOutput(os.Stdout)
// Set logger format
switch logFormat {
switch config.LogFormat {
case "pretty":
break
case "json":
@ -26,7 +28,7 @@ func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
}
// Set logger level
switch logLevel {
switch config.LogLevel {
case "trace":
logrus.SetLevel(logrus.TraceLevel)
case "debug":

View File

@ -0,0 +1,96 @@
package provider
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
)
type Google struct {
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
Scope string
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
LoginURL *url.URL
TokenURL *url.URL
UserURL *url.URL
}
func (g *Google) Build() {
g.LoginURL = &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
g.TokenURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
g.UserURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
}
func (g *Google) GetLoginURL(redirectUri, state string) string {
q := url.Values{}
q.Set("client_id", g.ClientId)
q.Set("response_type", "code")
q.Set("scope", g.Scope)
if g.Prompt != "" {
q.Set("prompt", g.Prompt)
}
q.Set("redirect_uri", redirectUri)
q.Set("state", state)
var u url.URL
u = *g.LoginURL
u.RawQuery = q.Encode()
return u.String()
}
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
form := url.Values{}
form.Set("client_id", g.ClientId)
form.Set("client_secret", g.ClientSecret)
form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", redirectUri)
form.Set("code", code)
res, err := http.PostForm(g.TokenURL.String(), form)
if err != nil {
return "", err
}
var token Token
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
}
func (g *Google) GetUser(token string) (User, error) {
var user User
client := &http.Client{}
req, err := http.NewRequest("GET", g.UserURL.String(), nil)
if err != nil {
return user, err
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user)
return user, err
}

View File

@ -0,0 +1,16 @@
package provider
type Providers struct {
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
}
type Token struct {
Token string `json:"access_token"`
}
type User struct {
Id string `json:"id"`
Email string `json:"email"`
Verified bool `json:"verified_email"`
Hd string `json:"hd"`
}

179
internal/server.go Normal file
View File

@ -0,0 +1,179 @@
package tfa
import (
"net/http"
"net/url"
"github.com/containous/traefik/pkg/rules"
"github.com/sirupsen/logrus"
)
type Server struct {
router *rules.Router
}
func NewServer() *Server {
s := &Server{}
s.buildRoutes()
return s
}
func (s *Server) buildRoutes() {
var err error
s.router, err = rules.NewRouter()
if err != nil {
log.Fatal(err)
}
// Let's build a router
for _, rule := range config.Rules {
if rule.Action == "allow" {
s.router.AddRoute(rule.Rule, 1, s.AllowHandler())
} else {
s.router.AddRoute(rule.Rule, 1, s.AuthHandler())
}
}
// Add callback handler
s.router.Handle(config.Path, s.AuthCallbackHandler())
// Add a default handler
if config.DefaultAction == "allow" {
s.router.NewRoute().Handler(s.AllowHandler())
} else {
s.router.NewRoute().Handler(s.AuthHandler())
}
}
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
// Modify request
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
// Pass to mux
s.router.ServeHTTP(w, r)
}
// Handler that allows requests
func (s *Server) AllowHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s.logger(r, "Allowing request")
w.WriteHeader(200)
}
}
// Authenticate requests
func (s *Server) AuthHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "Authenticating request")
// Get auth cookie
c, err := r.Cookie(config.CookieName)
if err != nil {
// Error indicates no cookie, generate nonce
err, nonce := Nonce()
if err != nil {
logger.Errorf("Error generating nonce, %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Set the CSRF cookie
http.SetCookie(w, MakeCSRFCookie(r, nonce))
logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
return
}
// Validate cookie
valid, email, err := ValidateCookie(r, c)
if !valid {
logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Validate user
valid = ValidateEmail(email)
if !valid {
logger.WithFields(logrus.Fields{
"email": email,
}).Errorf("Invalid email")
http.Error(w, "Not authorized", 401)
return
}
// Valid request
logger.Debugf("Allowing valid request ")
w.Header().Set("X-Forwarded-User", email)
w.WriteHeader(200)
}
}
// Handle auth callback
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := s.logger(r, "Handling callback")
// Check for CSRF cookie
c, err := r.Cookie(config.CSRFCookieName)
if err != nil {
logger.Warn("Missing csrf cookie")
http.Error(w, "Not authorized", 401)
return
}
// Validate state
valid, redirect, err := ValidateCSRFCookie(r, c)
if !valid {
logger.Warnf("Error validating csrf cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Clear CSRF cookie
http.SetCookie(w, ClearCSRFCookie(r))
// Exchange code for token
token, err := ExchangeCode(r)
if err != nil {
logger.Errorf("Code exchange failed with: %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Get user
user, err := GetUser(token)
if err != nil {
logger.Errorf("Error getting user: %s", err)
return
}
// Generate cookie
http.SetCookie(w, MakeCookie(r, user.Email))
logger.WithFields(logrus.Fields{
"user": user.Email,
}).Infof("Generated auth cookie")
// Redirect
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}
}
func (s *Server) logger(r *http.Request, msg string) *logrus.Entry {
// Create logger
logger := log.WithFields(logrus.Fields{
"RemoteAddr": r.RemoteAddr,
})
// Log request
logger.WithFields(logrus.Fields{
"Headers": r.Header,
}).Debugf(msg)
return logger
}

View File

@ -1,9 +1,7 @@
package main
package tfa
import (
"fmt"
"time"
// "reflect"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -12,6 +10,161 @@ import (
"testing"
)
/**
* Setup
*/
func init() {
config.LogLevel = "panic"
log = NewDefaultLogger()
}
/**
* Tests
*/
func TestServerAuthHandler(t *testing.T) {
config, _ = NewConfig([]string{})
// Should redirect vanilla request to login url
req := newHttpRequest("/foo")
res, _ := httpRequest(req, nil)
if res.StatusCode != 307 {
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "https" || fwd.Host != "accounts.google.com" || fwd.Path != "/o/oauth2/auth" {
t.Error("Vanilla request should be redirected to login url, got:", fwd)
}
// Should catch invalid cookie
req = newHttpRequest("/foo")
c := MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = httpRequest(req, c)
if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
}
// Should validate email
req = newHttpRequest("/foo")
c = MakeCookie(req, "test@example.com")
config.Domains = []string{"test.com"}
res, _ = httpRequest(req, c)
if res.StatusCode != 401 {
t.Error("Request with invalid email shound't be authorised", res.StatusCode)
}
// Should allow valid request email
req = newHttpRequest("/foo")
c = MakeCookie(req, "test@example.com")
config.Domains = []string{}
res, _ = httpRequest(req, c)
if res.StatusCode != 200 {
t.Error("Valid request should be allowed, got:", res.StatusCode)
}
// Should pass through user
users := res.Header["X-Forwarded-User"]
if len(users) != 1 {
t.Error("Valid request missing X-Forwarded-User header")
} else if users[0] != "test@example.com" {
t.Error("X-Forwarded-User should match user, got: ", users)
}
}
func TestServerAuthCallback(t *testing.T) {
config, _ = NewConfig([]string{})
// Setup token server
tokenServerHandler := &TokenServerHandler{}
tokenServer := httptest.NewServer(tokenServerHandler)
defer tokenServer.Close()
tokenUrl, _ := url.Parse(tokenServer.URL)
config.Providers.Google.TokenURL = tokenUrl
// Setup user server
userServerHandler := &UserServerHandler{}
userServer := httptest.NewServer(userServerHandler)
defer userServer.Close()
userUrl, _ := url.Parse(userServer.URL)
config.Providers.Google.UserURL = userUrl
// Should pass auth response request to callback
req := newHttpRequest("/_oauth")
res, _ := httpRequest(req, nil)
if res.StatusCode != 401 {
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
}
// Should catch invalid csrf cookie
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
c := MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = httpRequest(req, c)
if res.StatusCode != 401 {
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
}
// Should redirect valid request
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = httpRequest(req, c)
if res.StatusCode != 307 {
t.Error("Valid callback should be allowed, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
t.Error("Valid request should be redirected to return url, got:", fwd)
}
}
func TestServerDefaultAction(t *testing.T) {
config, _ = NewConfig([]string{})
req := newHttpRequest("/random")
res, _ := httpRequest(req, nil)
if res.StatusCode != 307 {
t.Error("Request should require auth with auth default handler, got:", res.StatusCode)
}
config.DefaultAction = "allow"
req = newHttpRequest("/random")
res, _ = httpRequest(req, nil)
if res.StatusCode != 200 {
t.Error("Request should be allowed with allow default handler, got:", res.StatusCode)
}
}
func TestServerRoutePathPrefix(t *testing.T) {
config, _ = NewConfig([]string{})
config.Rules = map[string]*Rule{
"web1": {
Action: "allow",
Rule: "PathPrefix(`/api`)",
},
}
// Should block any request
req := newHttpRequest("/random")
res, _ := httpRequest(req, nil)
if res.StatusCode != 307 {
t.Error("Request not matching any rule should require auth, got:", res.StatusCode)
}
// Should allow /api request
req = newHttpRequest("/api")
res, _ = httpRequest(req, nil)
if res.StatusCode != 200 {
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
}
}
/**
* Utilities
*/
@ -33,10 +186,6 @@ func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}`)
}
func init() {
log = CreateLogger("panic", "")
}
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
w := httptest.NewRecorder()
@ -50,161 +199,39 @@ func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
r.Header.Add("Cookie", c)
}
handler(w, r)
NewServer().RootHandler(w, r)
res := w.Result()
body, _ := ioutil.ReadAll(res.Body)
// if res.StatusCode > 300 && res.StatusCode < 400 {
// fmt.Printf("%#v", res.Header)
// }
return res, string(body)
}
func newHttpRequest(uri string) *http.Request {
r := httptest.NewRequest("", "http://example.com", nil)
r := httptest.NewRequest("", "http://example.com/", nil)
r.Header.Add("X-Forwarded-Uri", uri)
return r
}
func qsDiff(one, two url.Values) {
func qsDiff(t *testing.T, one, two url.Values) []string {
errs := make([]string, 0)
for k := range one {
if two.Get(k) == "" {
fmt.Printf("Key missing: %s\n", k)
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
}
if one.Get(k) != two.Get(k) {
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
}
}
for k := range two {
if one.Get(k) == "" {
fmt.Printf("Extra key: %s\n", k)
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
}
}
}
/**
* Tests
*/
func TestHandler(t *testing.T) {
fw = &ForwardAuth{
Path: "_oauth",
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "http",
Host: "test.com",
Path: "/auth",
},
CookieName: "cookie_test",
Lifetime: time.Second * time.Duration(10),
}
// Should redirect vanilla request to login url
req := newHttpRequest("foo")
res, _ := httpRequest(req, nil)
if res.StatusCode != 307 {
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
t.Error("Vanilla request should be redirected to login url, got:", fwd)
}
// Should catch invalid cookie
req = newHttpRequest("foo")
c := fw.MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = httpRequest(req, c)
if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
}
// Should validate email
req = newHttpRequest("foo")
c = fw.MakeCookie(req, "test@example.com")
fw.Domain = []string{"test.com"}
res, _ = httpRequest(req, c)
if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
}
// Should allow valid request email
req = newHttpRequest("foo")
c = fw.MakeCookie(req, "test@example.com")
fw.Domain = []string{}
res, _ = httpRequest(req, c)
if res.StatusCode != 200 {
t.Error("Valid request should be allowed, got:", res.StatusCode)
}
// Should pass through user
users := res.Header["X-Forwarded-User"]
if len(users) != 1 {
t.Error("Valid request missing X-Forwarded-User header")
} else if users[0] != "test@example.com" {
t.Error("X-Forwarded-User should match user, got: ", users)
}
}
func TestCallback(t *testing.T) {
fw = &ForwardAuth{
Path: "_oauth",
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "http",
Host: "test.com",
Path: "/auth",
},
CSRFCookieName: "csrf_test",
}
// Setup token server
tokenServerHandler := &TokenServerHandler{}
tokenServer := httptest.NewServer(tokenServerHandler)
defer tokenServer.Close()
tokenUrl, _ := url.Parse(tokenServer.URL)
fw.TokenURL = tokenUrl
// Setup user server
userServerHandler := &UserServerHandler{}
userServer := httptest.NewServer(userServerHandler)
defer userServer.Close()
userUrl, _ := url.Parse(userServer.URL)
fw.UserURL = userUrl
// Should pass auth response request to callback
req := newHttpRequest("_oauth")
res, _ := httpRequest(req, nil)
if res.StatusCode != 401 {
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
}
// Should catch invalid csrf cookie
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = httpRequest(req, c)
if res.StatusCode != 401 {
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
}
// Should redirect valid request
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = httpRequest(req, c)
if res.StatusCode != 307 {
t.Error("Valid callback should be allowed, got:", res.StatusCode)
}
fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
t.Error("Valid request should be redirected to return url, got:", fwd)
}
return errs
}

239
main.go
View File

@ -1,239 +0,0 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/namsral/flag"
"github.com/sirupsen/logrus"
)
// Vars
var fw *ForwardAuth
var log logrus.FieldLogger
// Primary handler
func handler(w http.ResponseWriter, r *http.Request) {
// Logging setup
logger := log.WithFields(logrus.Fields{
"RemoteAddr": r.RemoteAddr,
})
logger.WithFields(logrus.Fields{
"Headers": r.Header,
}).Debugf("Handling request")
// Parse uri
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
if err != nil {
logger.Errorf("Error parsing X-Forwarded-Uri, %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Handle callback
if uri.Path == fw.Path {
logger.Debugf("Passing request to auth callback")
handleCallback(w, r, uri.Query(), logger)
return
}
// Get auth cookie
c, err := r.Cookie(fw.CookieName)
if err != nil {
// Error indicates no cookie, generate nonce
err, nonce := fw.Nonce()
if err != nil {
logger.Errorf("Error generating nonce, %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Set the CSRF cookie
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
return
}
// Validate cookie
valid, email, err := fw.ValidateCookie(r, c)
if !valid {
logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Validate user
valid = fw.ValidateEmail(email)
if !valid {
logger.WithFields(logrus.Fields{
"email": email,
}).Errorf("Invalid email")
http.Error(w, "Not authorized", 401)
return
}
// Valid request
logger.Debugf("Allowing valid request ")
w.Header().Set("X-Forwarded-User", email)
w.WriteHeader(200)
}
// Authenticate user after they have come back from google
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values,
logger logrus.FieldLogger) {
// Check for CSRF cookie
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
if err != nil {
logger.Warn("Missing csrf cookie")
http.Error(w, "Not authorized", 401)
return
}
// Validate state
state := qs.Get("state")
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
if !valid {
logger.WithFields(logrus.Fields{
"csrf": csrfCookie.Value,
"state": state,
}).Warnf("Error validating csrf cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Clear CSRF cookie
http.SetCookie(w, fw.ClearCSRFCookie(r))
// Exchange code for token
token, err := fw.ExchangeCode(r, qs.Get("code"))
if err != nil {
logger.Errorf("Code exchange failed with: %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Get user
user, err := fw.GetUser(token)
if err != nil {
logger.Errorf("Error getting user: %s", err)
return
}
// Generate cookie
http.SetCookie(w, fw.MakeCookie(r, user.Email))
logger.WithFields(logrus.Fields{
"user": user.Email,
}).Infof("Generated auth cookie")
// Redirect
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}
// Main
func main() {
// Parse options
flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
path := flag.String("url-path", "_oauth", "Callback URL")
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
secret := flag.String("secret", "", "*Secret used for signing (required)")
authHost := flag.String("auth-host", "", "Central auth login")
clientId := flag.String("client-id", "", "*Google Client ID (required)")
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
cookieSecret := flag.String("cookie-secret", "", "Deprecated")
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
logLevel := flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
logFormat := flag.String("log-format", "text", "Log format: text, json, pretty")
flag.Parse()
// Setup logger
log = CreateLogger(*logLevel, *logFormat)
// Backwards compatibility
if *secret == "" && *cookieSecret != "" {
*secret = *cookieSecret
}
// Check for show stopper errors
if *clientId == "" || *clientSecret == "" || *secret == "" {
log.Fatal("client-id, client-secret and secret must all be set")
}
// Parse lists
var cookieDomains []CookieDomain
if *cookieDomainList != "" {
for _, d := range strings.Split(*cookieDomainList, ",") {
cookieDomain := NewCookieDomain(d)
cookieDomains = append(cookieDomains, *cookieDomain)
}
}
var domain []string
if *domainList != "" {
domain = strings.Split(*domainList, ",")
}
var whitelist []string
if *emailWhitelist != "" {
whitelist = strings.Split(*emailWhitelist, ",")
}
// Setup
fw = &ForwardAuth{
Path: fmt.Sprintf("/%s", *path),
Lifetime: time.Second * time.Duration(*lifetime),
Secret: []byte(*secret),
AuthHost: *authHost,
ClientId: *clientId,
ClientSecret: *clientSecret,
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
LoginURL: &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
},
TokenURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
},
UserURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
},
CookieName: *cookieName,
CSRFCookieName: *cSRFCookieName,
CookieDomains: cookieDomains,
CookieSecure: *cookieSecure,
Domain: domain,
Whitelist: whitelist,
Prompt: *prompt,
}
// Attach handler
http.HandleFunc("/", handler)
// Start
jsonConf, _ := json.Marshal(fw)
log.Debugf("Starting with options: %s", string(jsonConf))
log.Info("Listening on :4181")
log.Info(http.ListenAndServe(":4181", nil))
}

1
test/config-legacy Normal file
View File

@ -0,0 +1 @@
url-path two

3
test/config0 Normal file
View File

@ -0,0 +1,3 @@
cookie-name=inicookiename
csrf-cookie-name=inicsrfcookiename
url-path=one

1
test/config1 Normal file
View File

@ -0,0 +1 @@
url-path=two