25 Commits
0.1.0 ... 0.1.1

Author SHA1 Message Date
5c800a0170 Remove old logger from tests 2019-01-22 14:19:17 +00:00
b1fdcc7f56 Fix travis build 2019-01-22 13:55:49 +00:00
db31b09a72 Add report card to README 2019-01-22 13:12:25 +00:00
e1d518db11 Minor logging + comment fix 2019-01-22 13:10:03 +00:00
67339ae79a Include logrus in docker build 2019-01-22 12:59:29 +00:00
0b2889935e Log all request headers at debug level 2019-01-22 12:58:24 +00:00
b3b31e2193 Refactor logging
Fixes #18
2019-01-22 12:46:58 +00:00
1a3a099ac1 use gofmt to simplify code 2019-01-22 10:51:41 +00:00
afd8878188 use gofmt to update styling 2019-01-22 10:50:55 +00:00
6ccd1c6dfc Add documentation for X-Forwarded-User 2018-12-10 12:48:45 +00:00
df81be1147 Pass on authenticated user via X-Forwarded-User header
Fixes #13
2018-12-10 12:44:13 +00:00
5dcf889efe Merge pull request #16 from nicoulaj/patch-1
Fix some typos in logs
2018-12-04 13:32:36 +00:00
92d72dcdd2 Fix some typos in logs 2018-12-02 19:09:49 +01:00
4c1874b786 add auth host example + update examples 2018-11-06 14:45:56 +00:00
dcf4f6574d remove direct mode + add example development compose 2018-11-06 14:17:40 +00:00
91775ff0a8 Merge branch 'lammensj-whitelist' 2018-11-06 14:04:07 +00:00
1832672f5e Modify whitelist implementation + expand docs
Closes #4
2018-11-06 14:02:18 +00:00
eaad0a9054 Allow a whitelist of email addresses 2018-11-06 14:02:18 +00:00
36fffd2382 Fix demo config
COOKIE_SECRET was renamed SECRET
2018-11-06 14:02:18 +00:00
ccbda4ec8c Merge pull request #10 from mathcantin/master
Fix demo config
2018-11-06 13:45:21 +00:00
b014c5638a Merge pull request #12 from forMetris/master
Add -prompt flag
2018-11-06 13:44:01 +00:00
c897bc8387 Add -prompt flag
Space separated list of OpenID prompt options (https://developers.google.com/identity/protocols/OpenIDConnect#prompt)
2018-11-05 16:43:30 +01:00
96f9469abd Fix demo config
COOKIE_SECRET was renamed SECRET
2018-10-29 14:21:29 -04:00
b54871391f Add example dir to dockerignore 2018-10-29 17:52:04 +00:00
d230572879 Add auth host feature
Allow central host for use as base for redirect_uri

Closes #3
2018-10-29 17:42:13 +00:00
14 changed files with 1174 additions and 720 deletions

2
.dockerignore Normal file
View File

@ -0,0 +1,2 @@
example
.travis.yml

View File

@ -4,4 +4,5 @@ go:
- "1.10" - "1.10"
install: install:
- go get github.com/namsral/flag - go get github.com/namsral/flag
- go get github.com/op/go-logging - go get github.com/sirupsen/logrus
script: go test -v ./...

View File

@ -7,7 +7,7 @@ WORKDIR /app
# Add libraries # Add libraries
RUN apk add --no-cache git && \ RUN apk add --no-cache git && \
go get "github.com/namsral/flag" && \ go get "github.com/namsral/flag" && \
go get "github.com/op/go-logging" && \ go get "github.com/sirupsen/logrus" && \
apk del git apk del git
# Copy & build # Copy & build

5
Makefile Normal file
View File

@ -0,0 +1,5 @@
format:
gofmt -w -s *.go
.PHONY: format

View File

@ -1,5 +1,5 @@
# Traefik Forward Auth [![Build Status](https://travis-ci.org/thomseddon/traefik-forward-auth.svg?branch=master)](https://travis-ci.org/thomseddon/traefik-forward-auth) # Traefik Forward Auth [![Build Status](https://travis-ci.org/thomseddon/traefik-forward-auth.svg?branch=master)](https://travis-ci.org/thomseddon/traefik-forward-auth) [![Go Report Card](https://goreportcard.com/badge/github.com/thomseddon/traefik-forward-auth)](https://goreportcard.com/badge/github.com/thomseddon/traefik-forward-auth)
A minimal forward authentication service that provides Google oauth based login and authentication for the traefik reverse proxy. A minimal forward authentication service that provides Google oauth based login and authentication for the traefik reverse proxy.
@ -24,16 +24,20 @@ The following configuration is supported:
|-----------------------|------|-----------| |-----------------------|------|-----------|
|-client-id|string|*Google Client ID (required)| |-client-id|string|*Google Client ID (required)|
|-client-secret|string|*Google Client Secret (required)| |-client-secret|string|*Google Client Secret (required)|
|-secret|string|*Secret used for signing (required)|
|-config|string|Path to config file| |-config|string|Path to config file|
|-cookie-domains|string|Comma separated list of cookie domains| |-auth-host|string|Central auth login (see below)|
|-cookie-domains|string|Comma separated list of cookie domains (see below)|
|-cookie-name|string|Cookie Name (default "_forward_auth")| |-cookie-name|string|Cookie Name (default "_forward_auth")|
|-cookie-secret|string|*Cookie secret (required)|
|-cookie-secure|bool|Use secure cookies (default true)| |-cookie-secure|bool|Use secure cookies (default true)|
|-csrf-cookie-name|string|CSRF Cookie Name (default "_forward_auth_csrf")| |-csrf-cookie-name|string|CSRF Cookie Name (default "_forward_auth_csrf")|
|-direct|bool|Run in direct mode (use own hostname as oppose to <br>X-Forwarded-Host, used for testing/development)
|-domain|string|Comma separated list of email domains to allow| |-domain|string|Comma separated list of email domains to allow|
|-whitelist|string|Comma separated list of email addresses to allow|
|-lifetime|int|Session length in seconds (default 43200)| |-lifetime|int|Session length in seconds (default 43200)|
|-url-path|string|Callback URL (default "_oauth")| |-url-path|string|Callback URL (default "_oauth")|
|-prompt|string|Space separated list of [OpenID prompt options](https://developers.google.com/identity/protocols/OpenIDConnect#prompt)|
|-log-level|string|Log level: trace, debug, info, warn, error, fatal, panic (default "warn")|
|-log-format|string|Log format: text, json, pretty (default "text")|
Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`) Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`)
@ -47,6 +51,19 @@ Create a new project then search for and select "Credentials" in the search bar.
Click, "Create Credentials" > "OAuth client ID". Select "Web Application", fill in the name of your app, skip "Authorized JavaScript origins" and fill "Authorized redirect URIs" with all the domains you will allow authentication from, appended with the `url-path` (e.g. https://app.test.com/_oauth) Click, "Create Credentials" > "OAuth client ID". Select "Web Application", fill in the name of your app, skip "Authorized JavaScript origins" and fill "Authorized redirect URIs" with all the domains you will allow authentication from, appended with the `url-path` (e.g. https://app.test.com/_oauth)
## Usage
The authenticated user is set in the `X-Forwarded-User` header, to pass this on add this to the `authResponseHeaders` as shown [here](https://github.com/thomseddon/traefik-forward-auth/blob/master/example/docker-compose-dev.yml).
## User Restriction
You can restrict who can login with the following parameters:
* `-domain` - Use this to limit logins to a specific domain, e.g. test.com only
* `-whitelist` - Use this to only allow specific users to login e.g. thom@test.com only
Note, if you pass `whitelist` then only this is checked and `domain` is effectively ignored.
## Cookie Domains ## Cookie Domains
You can supply a comma separated list of cookie domains, if the host of the original request is a subdomain of any given cookie domain, the authentication cookie will set with the given domain. You can supply a comma separated list of cookie domains, if the host of the original request is a subdomain of any given cookie domain, the authentication cookie will set with the given domain.
@ -55,6 +72,39 @@ For example, if cookie domain is `test.com` and a request comes in on `app1.test
Beware however, if using cookie domains whilst running multiple instances of traefik/traefik-forward-auth for the same domain, the cookies will clash. You can fix this by using the same `cookie-secret` in both instances, or using a different `cookie-name` on each. Beware however, if using cookie domains whilst running multiple instances of traefik/traefik-forward-auth for the same domain, the cookies will clash. You can fix this by using the same `cookie-secret` in both instances, or using a different `cookie-name` on each.
## Operation Modes
#### Overlay
Overlay is the default operation mode, in this mode the authorisation endpoint is overlayed onto any domain. By default the `/_oauth` path is used, this can be customised using the `-url-path` option.
If a request comes in for `www.myapp.com/home` then the user will be redirected to the google login, following this they will be sent back to `www.myapp.com/_oauth`, where their token will be validated (this request will not be forwarded to your application). Following successful authoristion, the user will return to their originally requested url of `www.myapp.com/home`.
As the hostname in the `redirect_uri` is dynamically generated based on the orignal request, every hostname must be permitted in the Google OAuth console (e.g. `www.myappp.com` would need to be added in the above example)
#### Auth Host
This is an optional mode of operation that is useful when dealing with a large number of subdomains, it is activated by using the `-auth-host` config option (see [this example docker-compose.yml](https://github.com/thomseddon/traefik-forward-auth/blob/master/example/docker-compose-auth-host.yml)).
For example, if you have a few applications: `app1.test.com`, `app2.test.com`, `appN.test.com`, adding every domain to Google's console can become laborious.
To utilise an auth host, permit domain level cookies by setting the cookie domain to `test.com` then set the `auth-host` to: `auth.test.com`.
The user flow will then be:
1. Request to `app10.test.com/home/page`
2. User redirected to Google login
3. After Google login, user is redirected to `auth.test.com/_oauth`
4. Token, user and CSRF cookie is validated, auth cookie is set to `test.com`
5. User is redirected to `app10.test.com/home/page`
6. Request is allowed
With this setup, only `auth.test.com` must be permitted in the Google console.
Two criteria must be met for an `auth-host` to be used:
1. Request matches given `cookie-domain`
2. `auth-host` is also subdomain of same `cookie-domain`
## Copyright ## Copyright
2018 Thom Seddon 2018 Thom Seddon

View File

@ -0,0 +1,44 @@
version: '3'
services:
traefik:
image: traefik
command: -c /traefik.toml --logLevel=DEBUG
ports:
- "8085:80"
- "8086:8080"
networks:
- traefik
volumes:
- ./traefik.toml:/traefik.toml
- /var/run/docker.sock:/var/run/docker.sock
whoami1:
image: emilevauge/whoami
networks:
- traefik
labels:
- "traefik.backend=whoami"
- "traefik.enable=true"
- "traefik.frontend.rule=Host:whoami.yourdomain.com"
traefik-forward-auth:
image: thomseddon/traefik-forward-auth
environment:
- CLIENT_ID=your-client-id
- CLIENT_SECRET=your-client-secret
- SECRET=something-random
- COOKIE_SECURE=false
- DOMAIN=yourcompany.com
- AUTH_HOST=auth.yourdomain.com
networks:
- traefik
# When using an auth host, adding it here prompts traefik to generate certs
labels:
- traefik.enable=true
- traefik.port=4181
- traefik.backend=traefik-forward-auth
- traefik.frontend.rule=Host:auth.yourdomain.com
networks:
traefik:

View File

@ -0,0 +1,48 @@
version: '3'
services:
traefik:
image: traefik
command: -c /traefik.toml
# command: -c /traefik.toml --logLevel=DEBUG
ports:
- "8085:80"
- "8086:8080"
networks:
- traefik
volumes:
- ./traefik.toml:/traefik.toml
- /var/run/docker.sock:/var/run/docker.sock
whoami1:
image: emilevauge/whoami
networks:
- traefik
labels:
- "traefik.backend=whoami1"
- "traefik.enable=true"
- "traefik.frontend.rule=Host:whoami.localhost.com"
whoami2:
image: emilevauge/whoami
networks:
- traefik
labels:
- "traefik.backend=whoami2"
- "traefik.enable=true"
- "traefik.frontend.rule=Host:whoami.localhost.org"
traefik-forward-auth:
build: ../
environment:
- CLIENT_ID=test
- CLIENT_SECRET=test
- COOKIE_SECRET=something-random
- COOKIE_SECURE=false
- COOKIE_DOMAINS=localhost.com
- AUTH_URL=http://auth.localhost.com:8085/_oauth
networks:
- traefik
networks:
traefik:

View File

@ -22,12 +22,12 @@ services:
- "traefik.enable=true" - "traefik.enable=true"
- "traefik.frontend.rule=Host:whoami.localhost.com" - "traefik.frontend.rule=Host:whoami.localhost.com"
forward-oauth: traefik-forward-auth:
image: thomseddon/traefik-forward-auth image: thomseddon/traefik-forward-auth
environment: environment:
- CLIENT_ID=your-client-id - CLIENT_ID=your-client-id
- CLIENT_SECRET=your-client-secret - CLIENT_SECRET=your-client-secret
- COOKIE_SECRET=something-random - SECRET=something-random
- COOKIE_SECURE=false - COOKIE_SECURE=false
- DOMAIN=yourcompany.com - DOMAIN=yourcompany.com
networks: networks:

View File

@ -37,7 +37,8 @@
address = ":80" address = ":80"
[entryPoints.http.auth.forward] [entryPoints.http.auth.forward]
address = "http://forward-oauth:4181" address = "http://traefik-forward-auth:4181"
authResponseHeaders = ["X-Forwarded-User"]
################################################################ ################################################################
# Traefik logs configuration # Traefik logs configuration

View File

@ -1,358 +1,390 @@
package main package main
import ( import (
"fmt" "crypto/hmac"
"time" "crypto/rand"
"errors" "crypto/sha256"
"strings" "encoding/base64"
"strconv" "encoding/json"
"net/url" "errors"
"net/http" "fmt"
"crypto/hmac" "net/http"
"crypto/rand" "net/url"
"crypto/sha256" "strconv"
"encoding/json" "strings"
"encoding/base64" "time"
) )
// Forward Auth // Forward Auth
type ForwardAuth struct { type ForwardAuth struct {
Path string Path string
Lifetime time.Duration Lifetime time.Duration
Secret []byte
ClientId string ClientId string
ClientSecret string ClientSecret string `json:"-"`
Scope string Scope string
LoginURL *url.URL LoginURL *url.URL
TokenURL *url.URL TokenURL *url.URL
UserURL *url.URL UserURL *url.URL
CookieName string AuthHost string
CookieDomains []CookieDomain
CSRFCookieName string
CookieSecret []byte
CookieSecure bool
Domain []string CookieName string
CookieDomains []CookieDomain
CSRFCookieName string
CookieSecure bool
Direct bool Domain []string
Whitelist []string
Prompt string
} }
// Request Validation // Request Validation
// Cookie = hash(secret, cookie domain, email, expires)|expires|email // Cookie = hash(secret, cookie domain, email, expires)|expires|email
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) { func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
parts := strings.Split(c.Value, "|") parts := strings.Split(c.Value, "|")
if len(parts) != 3 { if len(parts) != 3 {
return false, "", errors.New("Invalid cookie format") return false, "", errors.New("Invalid cookie format")
} }
mac, err := base64.URLEncoding.DecodeString(parts[0]) mac, err := base64.URLEncoding.DecodeString(parts[0])
if err != nil { if err != nil {
return false, "", errors.New("Unable to decode cookie mac") return false, "", errors.New("Unable to decode cookie mac")
} }
expectedSignature := f.cookieSignature(r, parts[2], parts[1]) expectedSignature := f.cookieSignature(r, parts[2], parts[1])
expected, err := base64.URLEncoding.DecodeString(expectedSignature) expected, err := base64.URLEncoding.DecodeString(expectedSignature)
if err != nil { if err != nil {
return false, "", errors.New("Unable to generate mac") return false, "", errors.New("Unable to generate mac")
} }
// Valid token? // Valid token?
if !hmac.Equal(mac, expected) { if !hmac.Equal(mac, expected) {
return false, "", errors.New("Invalid cookie mac") return false, "", errors.New("Invalid cookie mac")
} }
expires, err := strconv.ParseInt(parts[1], 10, 64) expires, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil { if err != nil {
return false, "", errors.New("Unable to parse cookie expiry") return false, "", errors.New("Unable to parse cookie expiry")
} }
// Has it expired? // Has it expired?
if time.Unix(expires, 0).Before(time.Now()) { if time.Unix(expires, 0).Before(time.Now()) {
return false, "", errors.New("Cookie has expired") return false, "", errors.New("Cookie has expired")
} }
// Looks valid // Looks valid
return true, parts[2], nil return true, parts[2], nil
} }
// Validate email // Validate email
func (f *ForwardAuth) ValidateEmail(email string) bool { func (f *ForwardAuth) ValidateEmail(email string) bool {
if len(f.Domain) > 0 { found := false
parts := strings.Split(email, "@") if len(f.Whitelist) > 0 {
if len(parts) < 2 { for _, whitelist := range f.Whitelist {
return false if email == whitelist {
} found = true
found := false }
for _, domain := range f.Domain { }
if domain == parts[1] { } else if len(f.Domain) > 0 {
found = true parts := strings.Split(email, "@")
} if len(parts) < 2 {
} return false
if !found { }
return false for _, domain := range f.Domain {
} if domain == parts[1] {
} found = true
}
}
} else {
return true
}
return true return found
} }
// OAuth Methods // OAuth Methods
// Get login url // Get login url
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string { func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r)) state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
q := url.Values{} q := url.Values{}
q.Set("client_id", fw.ClientId) q.Set("client_id", fw.ClientId)
q.Set("response_type", "code") q.Set("response_type", "code")
q.Set("scope", fw.Scope) q.Set("scope", fw.Scope)
// q.Set("approval_prompt", fw.ClientId) if fw.Prompt != "" {
q.Set("redirect_uri", f.redirectUri(r)) q.Set("prompt", fw.Prompt)
q.Set("state", state) }
q.Set("redirect_uri", f.redirectUri(r))
q.Set("state", state)
var u url.URL var u url.URL
u = *fw.LoginURL u = *fw.LoginURL
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
return u.String() return u.String()
} }
// Exchange code for token // Exchange code for token
type Token struct { type Token struct {
Token string `json:"access_token"` Token string `json:"access_token"`
} }
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) { func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
form := url.Values{} form := url.Values{}
form.Set("client_id", fw.ClientId) form.Set("client_id", fw.ClientId)
form.Set("client_secret", fw.ClientSecret) form.Set("client_secret", fw.ClientSecret)
form.Set("grant_type", "authorization_code") form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", f.redirectUri(r)) form.Set("redirect_uri", f.redirectUri(r))
form.Set("code", code) form.Set("code", code)
res, err := http.PostForm(fw.TokenURL.String(), form)
if err != nil {
return "", err
}
res, err := http.PostForm(fw.TokenURL.String(), form) var token Token
if err != nil { defer res.Body.Close()
return "", err err = json.NewDecoder(res.Body).Decode(&token)
}
var token Token return token.Token, err
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
} }
// Get user with token // Get user with token
type User struct { type User struct {
Id string `json:"id"` Id string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Verified bool `json:"verified_email"` Verified bool `json:"verified_email"`
Hd string `json:"hd"` Hd string `json:"hd"`
} }
func (f *ForwardAuth) GetUser(token string) (User, error) { func (f *ForwardAuth) GetUser(token string) (User, error) {
var user User var user User
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("GET", fw.UserURL.String(), nil) req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
if err != nil { if err != nil {
return user, err return user, err
} }
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return user, err return user, err
} }
defer res.Body.Close() defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&user) err = json.NewDecoder(res.Body).Decode(&user)
return user, err return user, err
} }
// Utility methods // Utility methods
// Get the redirect base // Get the redirect base
func (f *ForwardAuth) redirectBase(r *http.Request) string { func (f *ForwardAuth) redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto") proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host") host := r.Header.Get("X-Forwarded-Host")
// Direct mode return fmt.Sprintf("%s://%s", proto, host)
if f.Direct {
proto = "http"
host = r.Host
}
return fmt.Sprintf("%s://%s", proto, host)
} }
// Return url // Return url
func (f *ForwardAuth) returnUrl(r *http.Request) string { func (f *ForwardAuth) returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri") path := r.Header.Get("X-Forwarded-Uri")
// Testing return fmt.Sprintf("%s%s", f.redirectBase(r), path)
if f.Direct {
path = r.URL.String()
}
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
} }
// Get oauth redirect uri // Get oauth redirect uri
func (f *ForwardAuth) redirectUri(r *http.Request) string { func (f *ForwardAuth) redirectUri(r *http.Request) string {
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path) if use, _ := f.useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
}
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
}
// Should we use auth host + what it is
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
if f.AuthHost == "" {
return false, ""
}
// Does the request match a given cookie domain?
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
// Do any of the auth hosts match a cookie domain?
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
// We need both to match the same domain
return reqMatch && authMatch && reqHost == authHost, reqHost
} }
// Cookie methods // Cookie methods
// Create an auth cookie // Create an auth cookie
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie { func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
expires := f.cookieExpiry() expires := f.cookieExpiry()
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix())) mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email) value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
return &http.Cookie{ return &http.Cookie{
Name: f.CookieName, Name: f.CookieName,
Value: value, Value: value,
Path: "/", Path: "/",
Domain: f.cookieDomain(r), Domain: f.cookieDomain(r),
HttpOnly: true, HttpOnly: true,
Secure: f.CookieSecure, Secure: f.CookieSecure,
Expires: expires, Expires: expires,
} }
} }
// Make a CSRF cookie (used during login only) // Make a CSRF cookie (used during login only)
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: f.CSRFCookieName, Name: f.CSRFCookieName,
Value: nonce, Value: nonce,
Path: "/", Path: "/",
Domain: f.cookieDomain(r), Domain: f.csrfCookieDomain(r),
HttpOnly: true, HttpOnly: true,
Secure: f.CookieSecure, Secure: f.CookieSecure,
Expires: f.cookieExpiry(), Expires: f.cookieExpiry(),
} }
} }
// Create a cookie to clear csrf cookie // Create a cookie to clear csrf cookie
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie { func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: f.CSRFCookieName, Name: f.CSRFCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: f.cookieDomain(r), Domain: f.csrfCookieDomain(r),
HttpOnly: true, HttpOnly: true,
Secure: f.CookieSecure, Secure: f.CookieSecure,
Expires: time.Now().Local().Add(time.Hour * -1), Expires: time.Now().Local().Add(time.Hour * -1),
} }
} }
// Validate the csrf cookie against state // Validate the csrf cookie against state
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) { func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
if len(c.Value) != 32 { if len(c.Value) != 32 {
return false, "", errors.New("Invalid CSRF cookie value") return false, "", errors.New("Invalid CSRF cookie value")
} }
if len(state) < 34 { if len(state) < 34 {
return false, "", errors.New("Invalid CSRF state value") return false, "", errors.New("Invalid CSRF state value")
} }
// Check nonce match // Check nonce match
if c.Value != state[:32] { if c.Value != state[:32] {
return false, "", errors.New("CSRF cookie does not match state") return false, "", errors.New("CSRF cookie does not match state")
} }
// Valid, return redirect // Valid, return redirect
return true, state[33:], nil return true, state[33:], nil
} }
func (f *ForwardAuth) Nonce() (error, string) { func (f *ForwardAuth) Nonce() (error, string) {
// Make nonce // Make nonce
nonce := make([]byte, 16) nonce := make([]byte, 16)
_, err := rand.Read(nonce) _, err := rand.Read(nonce)
if err != nil { if err != nil {
return err, "" return err, ""
} }
return nil, fmt.Sprintf("%x", nonce) return nil, fmt.Sprintf("%x", nonce)
} }
// Cookie domain // Cookie domain
func (f *ForwardAuth) cookieDomain(r *http.Request) string { func (f *ForwardAuth) cookieDomain(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host") host := r.Header.Get("X-Forwarded-Host")
// Direct mode // Check if any of the given cookie domains matches
if f.Direct { _, domain := f.matchCookieDomains(host)
host = r.Host return domain
} }
// Remove port for matching // Cookie domain
p := strings.Split(host, ":") func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string {
var host string
if use, domain := f.useAuthDomain(r); use {
host = domain
} else {
host = r.Header.Get("X-Forwarded-Host")
}
// Check if any of the given cookie domains matches // Remove port
for _, domain := range f.CookieDomains { p := strings.Split(host, ":")
if domain.Match(p[0]) { return p[0]
return domain.Domain }
}
}
return p[0] // Return matching cookie domain if exists
func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
// Remove port
p := strings.Split(domain, ":")
for _, d := range f.CookieDomains {
if d.Match(p[0]) {
return true, d.Domain
}
}
return false, p[0]
} }
// Create cookie hmac // Create cookie hmac
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string { func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
hash := hmac.New(sha256.New, f.CookieSecret) hash := hmac.New(sha256.New, f.Secret)
hash.Write([]byte(f.cookieDomain(r))) hash.Write([]byte(f.cookieDomain(r)))
hash.Write([]byte(email)) hash.Write([]byte(email))
hash.Write([]byte(expires)) hash.Write([]byte(expires))
return base64.URLEncoding.EncodeToString(hash.Sum(nil)) return base64.URLEncoding.EncodeToString(hash.Sum(nil))
} }
// Get cookie expirary // Get cookie expirary
func (f *ForwardAuth) cookieExpiry() time.Time { func (f *ForwardAuth) cookieExpiry() time.Time {
return time.Now().Local().Add(f.Lifetime) return time.Now().Local().Add(f.Lifetime)
} }
// Cookie Domain // Cookie Domain
// Cookie Domain // Cookie Domain
type CookieDomain struct { type CookieDomain struct {
Domain string Domain string
DomainLen int DomainLen int
SubDomain string SubDomain string
SubDomainLen int SubDomainLen int
} }
func NewCookieDomain(domain string) *CookieDomain { func NewCookieDomain(domain string) *CookieDomain {
return &CookieDomain{ return &CookieDomain{
Domain: domain, Domain: domain,
DomainLen: len(domain), DomainLen: len(domain),
SubDomain: fmt.Sprintf(".%s", domain), SubDomain: fmt.Sprintf(".%s", domain),
SubDomainLen: len(domain) + 1, SubDomainLen: len(domain) + 1,
} }
} }
func (c *CookieDomain) Match(host string) bool { func (c *CookieDomain) Match(host string) bool {
// Exact domain match? // Exact domain match?
if host == c.Domain { if host == c.Domain {
return true return true
} }
// Subdomain match? // Subdomain match?
if len(host) >= c.SubDomainLen && host[len(host) - c.SubDomainLen:] == c.SubDomain { if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
return true return true
} }
return false return false
} }

View File

@ -1,138 +1,284 @@
package main package main
import ( import (
// "fmt" // "fmt"
"time" "net/http"
"reflect" "net/url"
"testing" "reflect"
"net/url" "testing"
"net/http" "time"
) )
func TestValidateCookie(t *testing.T) { func TestValidateCookie(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
r, _ := http.NewRequest("GET", "http://example.com", nil) r, _ := http.NewRequest("GET", "http://example.com", nil)
c := &http.Cookie{} c := &http.Cookie{}
// Should require 3 parts // Should require 3 parts
c.Value = "" c.Value = ""
valid, _, err := fw.ValidateCookie(r, c) valid, _, err := fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" { if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err) t.Error("Should get \"Invalid cookie format\", got:", err)
} }
c.Value = "1|2" c.Value = "1|2"
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" { if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err) t.Error("Should get \"Invalid cookie format\", got:", err)
} }
c.Value = "1|2|3|4" c.Value = "1|2|3|4"
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie format" { if valid || err.Error() != "Invalid cookie format" {
t.Error("Should get \"Invalid cookie format\", got:", err) t.Error("Should get \"Invalid cookie format\", got:", err)
} }
// Should catch invalid mac // Should catch invalid mac
c.Value = "MQ==|2|3" c.Value = "MQ==|2|3"
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Invalid cookie mac" { if valid || err.Error() != "Invalid cookie mac" {
t.Error("Should get \"Invalid cookie mac\", got:", err) t.Error("Should get \"Invalid cookie mac\", got:", err)
} }
// Should catch expired // Should catch expired
fw.Lifetime = time.Second * time.Duration(-1) fw.Lifetime = time.Second * time.Duration(-1)
c = fw.MakeCookie(r, "test@test.com") c = fw.MakeCookie(r, "test@test.com")
valid, _, err = fw.ValidateCookie(r, c) valid, _, err = fw.ValidateCookie(r, c)
if valid || err.Error() != "Cookie has expired" { if valid || err.Error() != "Cookie has expired" {
t.Error("Should get \"Cookie has expired\", got:", err) t.Error("Should get \"Cookie has expired\", got:", err)
} }
// Should accept valid cookie // Should accept valid cookie
fw.Lifetime = time.Second * time.Duration(10) fw.Lifetime = time.Second * time.Duration(10)
c = fw.MakeCookie(r, "test@test.com") c = fw.MakeCookie(r, "test@test.com")
valid, email, err := fw.ValidateCookie(r, c) valid, email, err := fw.ValidateCookie(r, c)
if !valid { if !valid {
t.Error("Valid request should return as valid") t.Error("Valid request should return as valid")
} }
if err != nil { if err != nil {
t.Error("Valid request should not return error, got:", err) t.Error("Valid request should not return error, got:", err)
} }
if email != "test@test.com" { if email != "test@test.com" {
t.Error("Valid request should return user email") t.Error("Valid request should return user email")
} }
} }
func TestValidateEmail(t *testing.T) { func TestValidateEmail(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
// Should allow any // Should allow any
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") { if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
t.Error("Should allow any domain if email domain is not defined") t.Error("Should allow any domain if email domain is not defined")
} }
// Should block non matching domain // Should block non matching domain
fw.Domain = []string{"test.com"} fw.Domain = []string{"test.com"}
if fw.ValidateEmail("one@two.com") { if fw.ValidateEmail("one@two.com") {
t.Error("Should not allow user from another domain") t.Error("Should not allow user from another domain")
} }
// Should allow matching domain // Should allow matching domain
fw.Domain = []string{"test.com"} fw.Domain = []string{"test.com"}
if !fw.ValidateEmail("test@test.com") { if !fw.ValidateEmail("test@test.com") {
t.Error("Should allow user from allowed domain") t.Error("Should allow user from allowed domain")
} }
// Should block non whitelisted email address
fw.Domain = []string{}
fw.Whitelist = []string{"test@test.com"}
if fw.ValidateEmail("one@two.com") {
t.Error("Should not allow user not in whitelist.")
}
// Should allow matching whitelisted email address
fw.Domain = []string{}
fw.Whitelist = []string{"test@test.com"}
if !fw.ValidateEmail("test@test.com") {
t.Error("Should allow user in whitelist.")
}
} }
func TestGetLoginURL(t *testing.T) { func TestGetLoginURL(t *testing.T) {
fw = &ForwardAuth{ r, _ := http.NewRequest("GET", "http://example.com", nil)
Path: "/_oauth", r.Header.Add("X-Forwarded-Proto", "http")
ClientId: "idtest", r.Header.Add("X-Forwarded-Host", "example.com")
ClientSecret: "sectest", r.Header.Add("X-Forwarded-Uri", "/hello")
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
}
r, _ := http.NewRequest("GET", "http://example.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
// Check url fw = &ForwardAuth{
uri, err := url.Parse(fw.GetLoginURL(r, "nonce")) Path: "/_oauth",
if err != nil { ClientId: "idtest",
t.Error("Error parsing login url:", err) ClientSecret: "sectest",
} Scope: "scopetest",
if uri.Scheme != "https" { LoginURL: &url.URL{
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme) Scheme: "https",
} Host: "test.com",
if uri.Host != "test.com" { Path: "/auth",
t.Error("Expected login Host to be \"test.com\", got:", uri.Host) },
} }
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string // Check url
qs := uri.Query() uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
expectedQs := url.Values{ if err != nil {
"client_id": []string{"idtest"}, t.Error("Error parsing login url:", err)
"redirect_uri": []string{"http://example.com/_oauth"}, }
"response_type": []string{"code"}, if uri.Scheme != "https" {
"scope": []string{"scopetest"}, t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
"state": []string{"nonce:http://example.com/hello"}, }
} if uri.Host != "test.com" {
if !reflect.DeepEqual(qs, expectedQs) { t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
t.Error("Incorrect login query string, expected:") }
t.Error(expectedQs) if uri.Path != "/auth" {
t.Error("Got:") t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
t.Error(qs) }
}
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"},
}
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
//
// With Auth URL but no matching cookie domain
// - will not use auth host
//
fw = &ForwardAuth{
Path: "/_oauth",
AuthHost: "auth.example.com",
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
Prompt: "consent select_account",
}
// Check url
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
}
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"nonce:http://example.com/hello"},
}
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
//
// With correct Auth URL + cookie domain
//
cookieDomain := NewCookieDomain("example.com")
fw = &ForwardAuth{
Path: "/_oauth",
AuthHost: "auth.example.com",
ClientId: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
LoginURL: &url.URL{
Scheme: "https",
Host: "test.com",
Path: "/auth",
},
CookieDomains: []CookieDomain{*cookieDomain},
}
// Check url
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
}
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://auth.example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://example.com/hello"},
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
//
// With Auth URL + cookie domain, but from different domain
// - will not use auth host
//
r, _ = http.NewRequest("GET", "http://another.com", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "another.com")
r.Header.Add("X-Forwarded-Uri", "/hello")
// Check url
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
if err != nil {
t.Error("Error parsing login url:", err)
}
if uri.Scheme != "https" {
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
}
if uri.Host != "test.com" {
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
}
if uri.Path != "/auth" {
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
}
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://another.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"state": []string{"nonce:http://another.com/hello"},
}
qsDiff(expectedQs, qs)
if !reflect.DeepEqual(qs, expectedQs) {
t.Error("Incorrect login query string:")
qsDiff(expectedQs, qs)
}
} }
// TODO // TODO
// func TestExchangeCode(t *testing.T) { // func TestExchangeCode(t *testing.T) {
// } // }
@ -145,98 +291,124 @@ func TestGetLoginURL(t *testing.T) {
// func TestMakeCookie(t *testing.T) { // func TestMakeCookie(t *testing.T) {
// } // }
// func TestMakeCSRFCookie(t *testing.T) { func TestMakeCSRFCookie(t *testing.T) {
// t.Log("TODO") r, _ := http.NewRequest("GET", "http://app.example.com", nil)
// } r.Header.Add("X-Forwarded-Host", "app.example.com")
// No cookie domain or auth url
fw = &ForwardAuth{}
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
// With cookie domain but no auth url
cookieDomain := NewCookieDomain("example.com")
fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain}}
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "app.example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
// With cookie domain and auth url
fw = &ForwardAuth{
AuthHost: "auth.example.com",
CookieDomains: []CookieDomain{*cookieDomain},
}
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
if c.Domain != "example.com" {
t.Error("Cookie Domain should match request domain, got:", c.Domain)
}
}
func TestClearCSRFCookie(t *testing.T) { func TestClearCSRFCookie(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
r, _ := http.NewRequest("GET", "http://example.com", nil) r, _ := http.NewRequest("GET", "http://example.com", nil)
c := fw.ClearCSRFCookie(r) c := fw.ClearCSRFCookie(r)
if c.Value != "" { if c.Value != "" {
t.Error("ClearCSRFCookie should create cookie with empty value") t.Error("ClearCSRFCookie should create cookie with empty value")
} }
} }
func TestValidateCSRFCookie(t *testing.T) { func TestValidateCSRFCookie(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
c := &http.Cookie{} c := &http.Cookie{}
// Should require 32 char string // Should require 32 char string
c.Value = "" c.Value = ""
valid, _, err := fw.ValidateCSRFCookie(c, "") valid, _, err := fw.ValidateCSRFCookie(c, "")
if valid || err.Error() != "Invalid CSRF cookie value" { if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err) t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
} }
c.Value = "123456789012345678901234567890123" c.Value = "123456789012345678901234567890123"
valid, _, err = fw.ValidateCSRFCookie(c, "") valid, _, err = fw.ValidateCSRFCookie(c, "")
if valid || err.Error() != "Invalid CSRF cookie value" { if valid || err.Error() != "Invalid CSRF cookie value" {
t.Error("Should get \"Invalid CSRF cookie value\", got:", err) t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
} }
// Should require valid state // Should require valid state
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:") valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
if valid || err.Error() != "Invalid CSRF state value" { if valid || err.Error() != "Invalid CSRF state value" {
t.Error("Should get \"Invalid CSRF state value\", got:", err) t.Error("Should get \"Invalid CSRF state value\", got:", err)
} }
// Should allow valid state // Should allow valid state
c.Value = "12345678901234567890123456789012" c.Value = "12345678901234567890123456789012"
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99") valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
if !valid { if !valid {
t.Error("Valid request should return as valid") t.Error("Valid request should return as valid")
} }
if err != nil { if err != nil {
t.Error("Valid request should not return error, got:", err) t.Error("Valid request should not return error, got:", err)
} }
if state != "99" { if state != "99" {
t.Error("Valid request should return correct state, got:", state) t.Error("Valid request should return correct state, got:", state)
} }
} }
func TestNonce(t *testing.T) { func TestNonce(t *testing.T) {
fw = &ForwardAuth{} fw = &ForwardAuth{}
err, nonce1 := fw.Nonce() err, nonce1 := fw.Nonce()
if err != nil { if err != nil {
t.Error("Error generation nonce:", err) t.Error("Error generation nonce:", err)
} }
err, nonce2 := fw.Nonce() err, nonce2 := fw.Nonce()
if err != nil { if err != nil {
t.Error("Error generation nonce:", err) t.Error("Error generation nonce:", err)
} }
if len(nonce1) != 32 || len(nonce2) != 32 { if len(nonce1) != 32 || len(nonce2) != 32 {
t.Error("Nonce should be 32 chars") t.Error("Nonce should be 32 chars")
} }
if nonce1 == nonce2 { if nonce1 == nonce2 {
t.Error("Nonce should not be equal") t.Error("Nonce should not be equal")
} }
} }
func TestCookieDomainMatch(t *testing.T) { func TestCookieDomainMatch(t *testing.T) {
cd := NewCookieDomain("example.com") cd := NewCookieDomain("example.com")
// Exact should match // Exact should match
if !cd.Match("example.com") { if !cd.Match("example.com") {
t.Error("Exact domain should match") t.Error("Exact domain should match")
} }
// Subdomain should match // Subdomain should match
if !cd.Match("test.example.com") { if !cd.Match("test.example.com") {
t.Error("Subdomain should match") t.Error("Subdomain should match")
} }
// Derived domain should not match // Derived domain should not match
if cd.Match("testexample.com") { if cd.Match("testexample.com") {
t.Error("Derived domain should not match") t.Error("Derived domain should not match")
} }
// Other domain should not match // Other domain should not match
if cd.Match("test.com") { if cd.Match("test.com") {
t.Error("Other domain should not match") t.Error("Other domain should not match")
} }
} }

48
log.go Normal file
View File

@ -0,0 +1,48 @@
package main
import (
"os"
"github.com/sirupsen/logrus"
)
func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
// Setup logger
log := logrus.StandardLogger()
logrus.SetOutput(os.Stdout)
// Set logger format
switch logFormat {
case "pretty":
break
case "json":
logrus.SetFormatter(&logrus.JSONFormatter{})
// "text" is the default
default:
logrus.SetFormatter(&logrus.TextFormatter{
DisableColors: true,
FullTimestamp: true,
})
}
// Set logger level
switch logLevel {
case "trace":
logrus.SetLevel(logrus.TraceLevel)
case "debug":
logrus.SetLevel(logrus.DebugLevel)
case "info":
logrus.SetLevel(logrus.InfoLevel)
case "error":
logrus.SetLevel(logrus.ErrorLevel)
case "fatal":
logrus.SetLevel(logrus.FatalLevel)
case "panic":
logrus.SetLevel(logrus.PanicLevel)
// warn is the default
default:
logrus.SetLevel(logrus.WarnLevel)
}
return log
}

368
main.go
View File

@ -1,217 +1,239 @@
package main package main
import ( import (
"fmt" "encoding/json"
"time" "fmt"
"strings" "net/http"
"net/url" "net/url"
"net/http" "strings"
"time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/op/go-logging" "github.com/sirupsen/logrus"
) )
// Vars // Vars
var fw *ForwardAuth; var fw *ForwardAuth
var log = logging.MustGetLogger("traefik-forward-auth") var log logrus.FieldLogger
// Primary handler // Primary handler
func handler(w http.ResponseWriter, r *http.Request) { func handler(w http.ResponseWriter, r *http.Request) {
// Parse uri // Logging setup
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri")) logger := log.WithFields(logrus.Fields{
if err != nil { "RemoteAddr": r.RemoteAddr,
log.Error("Error parsing url") })
http.Error(w, "Service unavailable", 503) logger.WithFields(logrus.Fields{
return "Headers": r.Header,
} }).Debugf("Handling request")
// Direct mode // Parse uri
if fw.Direct { uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
uri = r.URL if err != nil {
} logger.Errorf("Error parsing X-Forwarded-Uri, %v", err)
http.Error(w, "Service unavailable", 503)
return
}
// Handle callback // Handle callback
if uri.Path == fw.Path { if uri.Path == fw.Path {
handleCallback(w, r, uri.Query()) logger.Debugf("Passing request to auth callback")
return handleCallback(w, r, uri.Query(), logger)
} return
}
c, err := r.Cookie(fw.CookieName) // Get auth cookie
if err != nil { c, err := r.Cookie(fw.CookieName)
// Error indicates no cookie, generate nonce if err != nil {
err, nonce := fw.Nonce() // Error indicates no cookie, generate nonce
if err != nil { err, nonce := fw.Nonce()
log.Error("Error generating nonce") if err != nil {
http.Error(w, "Service unavailable", 503) logger.Errorf("Error generating nonce, %v", err)
return http.Error(w, "Service unavailable", 503)
} return
}
// Set the CSRF cookie // Set the CSRF cookie
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce)) http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
log.Debug("Set CSRF cookie and redirecting to google login") logger.Debug("Set CSRF cookie and redirecting to google login")
// Forward them on // Forward them on
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect) http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
return return
} }
// Validate cookie // Validate cookie
valid, email, err := fw.ValidateCookie(r, c) valid, email, err := fw.ValidateCookie(r, c)
if !valid { if !valid {
log.Debugf("Invlaid cookie: %s", err) logger.Errorf("Invalid cookie: %v", err)
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return return
} }
// Validate user // Validate user
valid = fw.ValidateEmail(email) valid = fw.ValidateEmail(email)
if !valid { if !valid {
log.Debugf("Invalid email: %s", email) logger.WithFields(logrus.Fields{
http.Error(w, "Not authorized", 401) "email": email,
return }).Errorf("Invalid email")
} http.Error(w, "Not authorized", 401)
return
}
// Valid request // Valid request
w.WriteHeader(200) logger.Debugf("Allowing valid request ")
w.Header().Set("X-Forwarded-User", email)
w.WriteHeader(200)
} }
// Authenticate user after they have come back from google // Authenticate user after they have come back from google
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) { func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values,
// Check for CSRF cookie logger logrus.FieldLogger) {
csrfCookie, err := r.Cookie(fw.CSRFCookieName) // Check for CSRF cookie
if err != nil { csrfCookie, err := r.Cookie(fw.CSRFCookieName)
log.Debug("Missing csrf cookie") if err != nil {
http.Error(w, "Not authorized", 401) logger.Warn("Missing csrf cookie")
return http.Error(w, "Not authorized", 401)
} return
}
// Validate state // Validate state
state := qs.Get("state") state := qs.Get("state")
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state) valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
if !valid { if !valid {
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state) logger.WithFields(logrus.Fields{
http.Error(w, "Not authorized", 401) "csrf": csrfCookie.Value,
return "state": state,
} }).Warnf("Error validating csrf cookie: %v", err)
http.Error(w, "Not authorized", 401)
return
}
// Clear CSRF cookie // Clear CSRF cookie
http.SetCookie(w, fw.ClearCSRFCookie(r)) http.SetCookie(w, fw.ClearCSRFCookie(r))
// Exchange code for token // Exchange code for token
token, err := fw.ExchangeCode(r, qs.Get("code")) token, err := fw.ExchangeCode(r, qs.Get("code"))
if err != nil { if err != nil {
log.Debugf("Code exchange failed with: %s\n", err) logger.Errorf("Code exchange failed with: %v", err)
http.Error(w, "Service unavailable", 503) http.Error(w, "Service unavailable", 503)
return return
} }
// Get user // Get user
user, err := fw.GetUser(token) user, err := fw.GetUser(token)
if err != nil { if err != nil {
log.Debugf("Error getting user: %s\n", err) logger.Errorf("Error getting user: %s", err)
return return
} }
// Generate cookie // Generate cookie
http.SetCookie(w, fw.MakeCookie(r, user.Email)) http.SetCookie(w, fw.MakeCookie(r, user.Email))
log.Debugf("Generated auth cookie for %s\n", user.Email) logger.WithFields(logrus.Fields{
"user": user.Email,
}).Infof("Generated auth cookie")
// Redirect // Redirect
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
} }
// Main // Main
func main() { func main() {
// Parse options // Parse options
flag.String(flag.DefaultConfigFlagname, "", "Path to config file") flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
path := flag.String("url-path", "_oauth", "Callback URL") path := flag.String("url-path", "_oauth", "Callback URL")
lifetime := flag.Int("lifetime", 43200, "Session length in seconds") lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
clientId := flag.String("client-id", "", "*Google Client ID (required)") secret := flag.String("secret", "", "*Secret used for signing (required)")
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)") authHost := flag.String("auth-host", "", "Central auth login")
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name") clientId := flag.String("client-id", "", "*Google Client ID (required)")
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name") clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
cookieSecret := flag.String("cookie-secret", "", "*Cookie secret (required)") cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies") cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
domainList := flag.String("domain", "", "Comma separated list of email domains to allow") cookieSecret := flag.String("cookie-secret", "", "Deprecated")
direct := flag.Bool("direct", false, "Run in direct mode (use own hostname as oppose to X-Forwarded-Host, used for testing/development)") cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
logLevel := flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
logFormat := flag.String("log-format", "text", "Log format: text, json, pretty")
flag.Parse() flag.Parse()
// Check for show stopper errors // Setup logger
err := false log = CreateLogger(*logLevel, *logFormat)
if *clientId == "" {
err = true
log.Critical("client-id must be set")
}
if *clientSecret == "" {
err = true
log.Critical("client-secret must be set")
}
if *cookieSecret == "" {
err = true
log.Critical("cookie-secret must be set")
}
if err {
return
}
// Parse lists // Backwards compatibility
var cookieDomains []CookieDomain if *secret == "" && *cookieSecret != "" {
if *cookieDomainList != "" { *secret = *cookieSecret
for _, d := range strings.Split(*cookieDomainList, ",") { }
cookieDomain := NewCookieDomain(d)
cookieDomains = append(cookieDomains, *cookieDomain)
}
}
var domain []string // Check for show stopper errors
if *domainList != "" { if *clientId == "" || *clientSecret == "" || *secret == "" {
domain = strings.Split(*domainList, ",") log.Fatal("client-id, client-secret and secret must all be set")
} }
// Setup // Parse lists
fw = &ForwardAuth{ var cookieDomains []CookieDomain
Path: fmt.Sprintf("/%s", *path), if *cookieDomainList != "" {
Lifetime: time.Second * time.Duration(*lifetime), for _, d := range strings.Split(*cookieDomainList, ",") {
cookieDomain := NewCookieDomain(d)
cookieDomains = append(cookieDomains, *cookieDomain)
}
}
ClientId: *clientId, var domain []string
ClientSecret: *clientSecret, if *domainList != "" {
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", domain = strings.Split(*domainList, ",")
LoginURL: &url.URL{ }
Scheme: "https", var whitelist []string
Host: "accounts.google.com", if *emailWhitelist != "" {
Path: "/o/oauth2/auth", whitelist = strings.Split(*emailWhitelist, ",")
}, }
TokenURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
},
UserURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
},
CookieName: *cookieName, // Setup
CSRFCookieName: *cSRFCookieName, fw = &ForwardAuth{
CookieDomains: cookieDomains, Path: fmt.Sprintf("/%s", *path),
CookieSecret: []byte(*cookieSecret), Lifetime: time.Second * time.Duration(*lifetime),
CookieSecure: *cookieSecure, Secret: []byte(*secret),
AuthHost: *authHost,
Domain: domain, ClientId: *clientId,
ClientSecret: *clientSecret,
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
LoginURL: &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
},
TokenURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
},
UserURL: &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
},
Direct: *direct, CookieName: *cookieName,
} CSRFCookieName: *cSRFCookieName,
CookieDomains: cookieDomains,
CookieSecure: *cookieSecure,
// Attach handler Domain: domain,
http.HandleFunc("/", handler) Whitelist: whitelist,
log.Notice("Litening on :4181") Prompt: *prompt,
log.Notice(http.ListenAndServe(":4181", nil)) }
// Attach handler
http.HandleFunc("/", handler)
// Start
jsonConf, _ := json.Marshal(fw)
log.Debugf("Starting with options: %s", string(jsonConf))
log.Info("Listening on :4181")
log.Info(http.ListenAndServe(":4181", nil))
} }

View File

@ -1,29 +1,31 @@
package main package main
import ( import (
"fmt" "fmt"
"time" "time"
// "reflect" // "reflect"
"strings" "io/ioutil"
"testing" "net/http"
"net/url" "net/http/httptest"
"net/http" "net/url"
"io/ioutil" "strings"
"net/http/httptest" "testing"
"github.com/op/go-logging"
) )
/**
* Utilities
*/
type TokenServerHandler struct{}
type TokenServerHandler struct {}
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{"access_token":"123456789"}`) fmt.Fprint(w, `{"access_token":"123456789"}`)
} }
type UserServerHandler struct {} type UserServerHandler struct{}
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, `{ fmt.Fprint(w, `{
"id":"1", "id":"1",
"email":"example@example.com", "email":"example@example.com",
"verified_email":true, "verified_email":true,
@ -32,150 +34,177 @@ func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func init() { func init() {
// Remove for debugging log = CreateLogger("panic", "")
logging.SetLevel(logging.INFO, "traefik-forward-auth")
} }
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) { func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Set cookies on recorder // Set cookies on recorder
if c != nil { if c != nil {
http.SetCookie(w, c) http.SetCookie(w, c)
} }
// Copy into request // Copy into request
for _, c := range w.HeaderMap["Set-Cookie"] { for _, c := range w.HeaderMap["Set-Cookie"] {
r.Header.Add("Cookie", c) r.Header.Add("Cookie", c)
} }
handler(w, r) handler(w, r)
res := w.Result() res := w.Result()
body, _ := ioutil.ReadAll(res.Body) body, _ := ioutil.ReadAll(res.Body)
return res, string(body) return res, string(body)
} }
func newHttpRequest(uri string) *http.Request { func newHttpRequest(uri string) *http.Request {
r := httptest.NewRequest("", "http://example.com", nil) r := httptest.NewRequest("", "http://example.com", nil)
r.Header.Add("X-Forwarded-Uri", uri) r.Header.Add("X-Forwarded-Uri", uri)
return r return r
} }
func qsDiff(one, two url.Values) {
for k := range one {
if two.Get(k) == "" {
fmt.Printf("Key missing: %s\n", k)
}
if one.Get(k) != two.Get(k) {
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
}
}
for k := range two {
if one.Get(k) == "" {
fmt.Printf("Extra key: %s\n", k)
}
}
}
/**
* Tests
*/
func TestHandler(t *testing.T) { func TestHandler(t *testing.T) {
fw = &ForwardAuth{ fw = &ForwardAuth{
Path: "_oauth", Path: "_oauth",
ClientId: "idtest", ClientId: "idtest",
ClientSecret: "sectest", ClientSecret: "sectest",
Scope: "scopetest", Scope: "scopetest",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: "test.com", Host: "test.com",
Path: "/auth", Path: "/auth",
}, },
CookieName: "cookie_test", CookieName: "cookie_test",
Lifetime: time.Second * time.Duration(10), Lifetime: time.Second * time.Duration(10),
} }
// Should redirect vanilla request to login url // Should redirect vanilla request to login url
req := newHttpRequest("foo") req := newHttpRequest("foo")
res, _ := httpRequest(req, nil) res, _ := httpRequest(req, nil)
if res.StatusCode != 307 { if res.StatusCode != 307 {
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode) t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
} }
fwd, _ := res.Location() fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" { if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
t.Error("Vanilla request should be redirected to login url, got:", fwd) t.Error("Vanilla request should be redirected to login url, got:", fwd)
} }
// Should catch invalid cookie // Should catch invalid cookie
req = newHttpRequest("foo") req = newHttpRequest("foo")
c := fw.MakeCookie(req, "test@example.com") c := fw.MakeCookie(req, "test@example.com")
parts := strings.Split(c.Value, "|") parts := strings.Split(c.Value, "|")
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2]) c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode) t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
} }
// Should validate email // Should validate email
req = newHttpRequest("foo") req = newHttpRequest("foo")
c = fw.MakeCookie(req, "test@example.com") c = fw.MakeCookie(req, "test@example.com")
fw.Domain = []string{"test.com"} fw.Domain = []string{"test.com"}
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode) t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
} }
// Should allow valid request email // Should allow valid request email
req = newHttpRequest("foo") req = newHttpRequest("foo")
c = fw.MakeCookie(req, "test@example.com") c = fw.MakeCookie(req, "test@example.com")
fw.Domain = []string{} fw.Domain = []string{}
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 200 { if res.StatusCode != 200 {
t.Error("Valid request should be allowed, got:", res.StatusCode) t.Error("Valid request should be allowed, got:", res.StatusCode)
} }
// Should pass through user
users := res.Header["X-Forwarded-User"]
if len(users) != 1 {
t.Error("Valid request missing X-Forwarded-User header")
} else if users[0] != "test@example.com" {
t.Error("X-Forwarded-User should match user, got: ", users)
}
} }
func TestCallback(t *testing.T) { func TestCallback(t *testing.T) {
fw = &ForwardAuth{ fw = &ForwardAuth{
Path: "_oauth", Path: "_oauth",
ClientId: "idtest", ClientId: "idtest",
ClientSecret: "sectest", ClientSecret: "sectest",
Scope: "scopetest", Scope: "scopetest",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: "test.com", Host: "test.com",
Path: "/auth", Path: "/auth",
}, },
CSRFCookieName: "csrf_test", CSRFCookieName: "csrf_test",
} }
// Setup token server // Setup token server
tokenServerHandler := &TokenServerHandler{} tokenServerHandler := &TokenServerHandler{}
tokenServer := httptest.NewServer(tokenServerHandler) tokenServer := httptest.NewServer(tokenServerHandler)
defer tokenServer.Close() defer tokenServer.Close()
tokenUrl, _ := url.Parse(tokenServer.URL) tokenUrl, _ := url.Parse(tokenServer.URL)
fw.TokenURL = tokenUrl fw.TokenURL = tokenUrl
// Setup user server // Setup user server
userServerHandler := &UserServerHandler{} userServerHandler := &UserServerHandler{}
userServer := httptest.NewServer(userServerHandler) userServer := httptest.NewServer(userServerHandler)
defer userServer.Close() defer userServer.Close()
userUrl, _ := url.Parse(userServer.URL) userUrl, _ := url.Parse(userServer.URL)
fw.UserURL = userUrl fw.UserURL = userUrl
// Should pass auth response request to callback // Should pass auth response request to callback
req := newHttpRequest("_oauth") req := newHttpRequest("_oauth")
res, _ := httpRequest(req, nil) res, _ := httpRequest(req, nil)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode) t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
} }
// Should catch invalid csrf cookie // Should catch invalid csrf cookie
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect") req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
c := fw.MakeCSRFCookie(req, "nononononononononononononononono") c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 401 { if res.StatusCode != 401 {
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode) t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
} }
// Should redirect valid request // Should redirect valid request
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect") req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012") c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
res, _ = httpRequest(req, c) res, _ = httpRequest(req, c)
if res.StatusCode != 307 { if res.StatusCode != 307 {
t.Error("Valid callback should be allowed, got:", res.StatusCode) t.Error("Valid callback should be allowed, got:", res.StatusCode)
} }
fwd, _ := res.Location() fwd, _ := res.Location()
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" { if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
t.Error("Valid request should be redirected to return url, got:", fwd) t.Error("Valid request should be redirected to return url, got:", fwd)
} }
} }