Compare commits
34 Commits
v0.1.0
...
v2.0.0-bet
Author | SHA1 | Date | |
---|---|---|---|
d890a4aad6 | |||
43775591fa | |||
daec9f591a | |||
091590d391 | |||
8ca16a88d2 | |||
814892a88b | |||
19c249a6d1 | |||
0f278d516b | |||
ae95e8b2e5 | |||
5c800a0170 | |||
b1fdcc7f56 | |||
db31b09a72 | |||
e1d518db11 | |||
67339ae79a | |||
0b2889935e | |||
b3b31e2193 | |||
1a3a099ac1 | |||
afd8878188 | |||
6ccd1c6dfc | |||
df81be1147 | |||
5dcf889efe | |||
92d72dcdd2 | |||
4c1874b786 | |||
dcf4f6574d | |||
91775ff0a8 | |||
1832672f5e | |||
eaad0a9054 | |||
36fffd2382 | |||
ccbda4ec8c | |||
b014c5638a | |||
c897bc8387 | |||
96f9469abd | |||
b54871391f | |||
d230572879 |
2
.dockerignore
Normal file
2
.dockerignore
Normal file
@ -0,0 +1,2 @@
|
||||
example
|
||||
.travis.yml
|
@ -1,7 +1,5 @@
|
||||
language: go
|
||||
sudo: false
|
||||
go:
|
||||
- "1.10"
|
||||
install:
|
||||
- go get github.com/namsral/flag
|
||||
- go get github.com/op/go-logging
|
||||
- "1.12"
|
||||
script: env GO111MODULE=on go test -v ./...
|
||||
|
15
Dockerfile
15
Dockerfile
@ -1,18 +1,15 @@
|
||||
FROM golang:1.10-alpine as builder
|
||||
FROM golang:1.12-alpine as builder
|
||||
|
||||
# Setup
|
||||
RUN mkdir /app
|
||||
WORKDIR /app
|
||||
RUN mkdir -p /go/src/github.com/thomseddon/traefik-forward-auth
|
||||
WORKDIR /go/src/github.com/thomseddon/traefik-forward-auth
|
||||
|
||||
# Add libraries
|
||||
RUN apk add --no-cache git && \
|
||||
go get "github.com/namsral/flag" && \
|
||||
go get "github.com/op/go-logging" && \
|
||||
apk del git
|
||||
RUN apk add --no-cache git
|
||||
|
||||
# Copy & build
|
||||
ADD . /app/
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /traefik-forward-auth .
|
||||
ADD . /go/src/github.com/thomseddon/traefik-forward-auth/
|
||||
RUN CGO_ENABLED=0 GOOS=linux GO111MODULE=on go build -a -installsuffix nocgo -o /traefik-forward-auth github.com/thomseddon/traefik-forward-auth/cmd
|
||||
|
||||
# Copy into scratch container
|
||||
FROM scratch
|
||||
|
5
Makefile
Normal file
5
Makefile
Normal file
@ -0,0 +1,5 @@
|
||||
|
||||
format:
|
||||
gofmt -w -s internal/*.go cmd/*.go
|
||||
|
||||
.PHONY: format
|
58
README.md
58
README.md
@ -1,5 +1,5 @@
|
||||
|
||||
# Traefik Forward Auth [](https://travis-ci.org/thomseddon/traefik-forward-auth)
|
||||
# Traefik Forward Auth [](https://travis-ci.org/thomseddon/traefik-forward-auth) [](https://goreportcard.com/badge/github.com/thomseddon/traefik-forward-auth)
|
||||
|
||||
A minimal forward authentication service that provides Google oauth based login and authentication for the traefik reverse proxy.
|
||||
|
||||
@ -24,16 +24,20 @@ The following configuration is supported:
|
||||
|-----------------------|------|-----------|
|
||||
|-client-id|string|*Google Client ID (required)|
|
||||
|-client-secret|string|*Google Client Secret (required)|
|
||||
|-secret|string|*Secret used for signing (required)|
|
||||
|-config|string|Path to config file|
|
||||
|-cookie-domains|string|Comma separated list of cookie domains|
|
||||
|-auth-host|string|Central auth login (see below)|
|
||||
|-cookie-domains|string|Comma separated list of cookie domains (see below)|
|
||||
|-cookie-name|string|Cookie Name (default "_forward_auth")|
|
||||
|-cookie-secret|string|*Cookie secret (required)|
|
||||
|-cookie-secure|bool|Use secure cookies (default true)|
|
||||
|-csrf-cookie-name|string|CSRF Cookie Name (default "_forward_auth_csrf")|
|
||||
|-direct|bool|Run in direct mode (use own hostname as oppose to <br>X-Forwarded-Host, used for testing/development)
|
||||
|-domain|string|Comma separated list of email domains to allow|
|
||||
|-whitelist|string|Comma separated list of email addresses to allow|
|
||||
|-lifetime|int|Session length in seconds (default 43200)|
|
||||
|-url-path|string|Callback URL (default "_oauth")|
|
||||
|-prompt|string|Space separated list of [OpenID prompt options](https://developers.google.com/identity/protocols/OpenIDConnect#prompt)|
|
||||
|-log-level|string|Log level: trace, debug, info, warn, error, fatal, panic (default "warn")|
|
||||
|-log-format|string|Log format: text, json, pretty (default "text")|
|
||||
|
||||
Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`)
|
||||
|
||||
@ -47,6 +51,19 @@ Create a new project then search for and select "Credentials" in the search bar.
|
||||
|
||||
Click, "Create Credentials" > "OAuth client ID". Select "Web Application", fill in the name of your app, skip "Authorized JavaScript origins" and fill "Authorized redirect URIs" with all the domains you will allow authentication from, appended with the `url-path` (e.g. https://app.test.com/_oauth)
|
||||
|
||||
## Usage
|
||||
|
||||
The authenticated user is set in the `X-Forwarded-User` header, to pass this on add this to the `authResponseHeaders` as shown [here](https://github.com/thomseddon/traefik-forward-auth/blob/master/example/docker-compose-dev.yml).
|
||||
|
||||
## User Restriction
|
||||
|
||||
You can restrict who can login with the following parameters:
|
||||
|
||||
* `-domain` - Use this to limit logins to a specific domain, e.g. test.com only
|
||||
* `-whitelist` - Use this to only allow specific users to login e.g. thom@test.com only
|
||||
|
||||
Note, if you pass `whitelist` then only this is checked and `domain` is effectively ignored.
|
||||
|
||||
## Cookie Domains
|
||||
|
||||
You can supply a comma separated list of cookie domains, if the host of the original request is a subdomain of any given cookie domain, the authentication cookie will set with the given domain.
|
||||
@ -55,6 +72,39 @@ For example, if cookie domain is `test.com` and a request comes in on `app1.test
|
||||
|
||||
Beware however, if using cookie domains whilst running multiple instances of traefik/traefik-forward-auth for the same domain, the cookies will clash. You can fix this by using the same `cookie-secret` in both instances, or using a different `cookie-name` on each.
|
||||
|
||||
## Operation Modes
|
||||
|
||||
#### Overlay
|
||||
|
||||
Overlay is the default operation mode, in this mode the authorisation endpoint is overlayed onto any domain. By default the `/_oauth` path is used, this can be customised using the `-url-path` option.
|
||||
|
||||
If a request comes in for `www.myapp.com/home` then the user will be redirected to the google login, following this they will be sent back to `www.myapp.com/_oauth`, where their token will be validated (this request will not be forwarded to your application). Following successful authoristion, the user will return to their originally requested url of `www.myapp.com/home`.
|
||||
|
||||
As the hostname in the `redirect_uri` is dynamically generated based on the orignal request, every hostname must be permitted in the Google OAuth console (e.g. `www.myappp.com` would need to be added in the above example)
|
||||
|
||||
#### Auth Host
|
||||
|
||||
This is an optional mode of operation that is useful when dealing with a large number of subdomains, it is activated by using the `-auth-host` config option (see [this example docker-compose.yml](https://github.com/thomseddon/traefik-forward-auth/blob/master/example/docker-compose-auth-host.yml)).
|
||||
|
||||
For example, if you have a few applications: `app1.test.com`, `app2.test.com`, `appN.test.com`, adding every domain to Google's console can become laborious.
|
||||
To utilise an auth host, permit domain level cookies by setting the cookie domain to `test.com` then set the `auth-host` to: `auth.test.com`.
|
||||
|
||||
The user flow will then be:
|
||||
|
||||
1. Request to `app10.test.com/home/page`
|
||||
2. User redirected to Google login
|
||||
3. After Google login, user is redirected to `auth.test.com/_oauth`
|
||||
4. Token, user and CSRF cookie is validated, auth cookie is set to `test.com`
|
||||
5. User is redirected to `app10.test.com/home/page`
|
||||
6. Request is allowed
|
||||
|
||||
With this setup, only `auth.test.com` must be permitted in the Google console.
|
||||
|
||||
Two criteria must be met for an `auth-host` to be used:
|
||||
|
||||
1. Request matches given `cookie-domain`
|
||||
2. `auth-host` is also subdomain of same `cookie-domain`
|
||||
|
||||
## Copyright
|
||||
|
||||
2018 Thom Seddon
|
||||
|
30
cmd/main.go
Normal file
30
cmd/main.go
Normal file
@ -0,0 +1,30 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
internal "github.com/thomseddon/traefik-forward-auth/internal"
|
||||
)
|
||||
|
||||
// Main
|
||||
func main() {
|
||||
// Parse options
|
||||
config := internal.NewGlobalConfig()
|
||||
|
||||
// Setup logger
|
||||
log := internal.NewDefaultLogger()
|
||||
|
||||
// Perform config validation
|
||||
config.Validate()
|
||||
|
||||
// Build server
|
||||
server := internal.NewServer()
|
||||
|
||||
// Attach router to default server
|
||||
http.HandleFunc("/", server.RootHandler)
|
||||
|
||||
// Start
|
||||
log.Debugf("Starting with options: %s", config)
|
||||
log.Info("Listening on :4181")
|
||||
log.Info(http.ListenAndServe(":4181", nil))
|
||||
}
|
44
example/docker-compose-auth-host.yml
Normal file
44
example/docker-compose-auth-host.yml
Normal file
@ -0,0 +1,44 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
traefik:
|
||||
image: traefik
|
||||
command: -c /traefik.toml --logLevel=DEBUG
|
||||
ports:
|
||||
- "8085:80"
|
||||
- "8086:8080"
|
||||
networks:
|
||||
- traefik
|
||||
volumes:
|
||||
- ./traefik.toml:/traefik.toml
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
|
||||
whoami1:
|
||||
image: emilevauge/whoami
|
||||
networks:
|
||||
- traefik
|
||||
labels:
|
||||
- "traefik.backend=whoami"
|
||||
- "traefik.enable=true"
|
||||
- "traefik.frontend.rule=Host:whoami.yourdomain.com"
|
||||
|
||||
traefik-forward-auth:
|
||||
image: thomseddon/traefik-forward-auth
|
||||
environment:
|
||||
- CLIENT_ID=your-client-id
|
||||
- CLIENT_SECRET=your-client-secret
|
||||
- SECRET=something-random
|
||||
- COOKIE_SECURE=false
|
||||
- DOMAIN=yourcompany.com
|
||||
- AUTH_HOST=auth.yourdomain.com
|
||||
networks:
|
||||
- traefik
|
||||
# When using an auth host, adding it here prompts traefik to generate certs
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.port=4181
|
||||
- traefik.backend=traefik-forward-auth
|
||||
- traefik.frontend.rule=Host:auth.yourdomain.com
|
||||
|
||||
networks:
|
||||
traefik:
|
48
example/docker-compose-dev.yml
Normal file
48
example/docker-compose-dev.yml
Normal file
@ -0,0 +1,48 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
traefik:
|
||||
image: traefik
|
||||
command: -c /traefik.toml
|
||||
# command: -c /traefik.toml --logLevel=DEBUG
|
||||
ports:
|
||||
- "8085:80"
|
||||
- "8086:8080"
|
||||
networks:
|
||||
- traefik
|
||||
volumes:
|
||||
- ./traefik.toml:/traefik.toml
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
|
||||
whoami1:
|
||||
image: emilevauge/whoami
|
||||
networks:
|
||||
- traefik
|
||||
labels:
|
||||
- "traefik.backend=whoami1"
|
||||
- "traefik.enable=true"
|
||||
- "traefik.frontend.rule=Host:whoami.localhost.com"
|
||||
|
||||
whoami2:
|
||||
image: emilevauge/whoami
|
||||
networks:
|
||||
- traefik
|
||||
labels:
|
||||
- "traefik.backend=whoami2"
|
||||
- "traefik.enable=true"
|
||||
- "traefik.frontend.rule=Host:whoami.localhost.org"
|
||||
|
||||
traefik-forward-auth:
|
||||
build: ../
|
||||
environment:
|
||||
- CLIENT_ID=test
|
||||
- CLIENT_SECRET=test
|
||||
- COOKIE_SECRET=something-random
|
||||
- COOKIE_SECURE=false
|
||||
- COOKIE_DOMAINS=localhost.com
|
||||
- AUTH_URL=http://auth.localhost.com:8085/_oauth
|
||||
networks:
|
||||
- traefik
|
||||
|
||||
networks:
|
||||
traefik:
|
@ -22,12 +22,12 @@ services:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.frontend.rule=Host:whoami.localhost.com"
|
||||
|
||||
forward-oauth:
|
||||
traefik-forward-auth:
|
||||
image: thomseddon/traefik-forward-auth
|
||||
environment:
|
||||
- CLIENT_ID=your-client-id
|
||||
- CLIENT_SECRET=your-client-secret
|
||||
- COOKIE_SECRET=something-random
|
||||
- SECRET=something-random
|
||||
- COOKIE_SECURE=false
|
||||
- DOMAIN=yourcompany.com
|
||||
networks:
|
||||
|
@ -37,7 +37,8 @@
|
||||
address = ":80"
|
||||
|
||||
[entryPoints.http.auth.forward]
|
||||
address = "http://forward-oauth:4181"
|
||||
address = "http://traefik-forward-auth:4181"
|
||||
authResponseHeaders = ["X-Forwarded-User"]
|
||||
|
||||
################################################################
|
||||
# Traefik logs configuration
|
||||
|
358
forwardauth.go
358
forwardauth.go
@ -1,358 +0,0 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"errors"
|
||||
"strings"
|
||||
"strconv"
|
||||
"net/url"
|
||||
"net/http"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// Forward Auth
|
||||
type ForwardAuth struct {
|
||||
Path string
|
||||
Lifetime time.Duration
|
||||
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
Scope string
|
||||
|
||||
LoginURL *url.URL
|
||||
TokenURL *url.URL
|
||||
UserURL *url.URL
|
||||
|
||||
CookieName string
|
||||
CookieDomains []CookieDomain
|
||||
CSRFCookieName string
|
||||
CookieSecret []byte
|
||||
CookieSecure bool
|
||||
|
||||
Domain []string
|
||||
|
||||
Direct bool
|
||||
}
|
||||
|
||||
// Request Validation
|
||||
|
||||
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
||||
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||
parts := strings.Split(c.Value, "|")
|
||||
|
||||
if len(parts) != 3 {
|
||||
return false, "", errors.New("Invalid cookie format")
|
||||
}
|
||||
|
||||
mac, err := base64.URLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to decode cookie mac")
|
||||
}
|
||||
|
||||
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
|
||||
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to generate mac")
|
||||
}
|
||||
|
||||
// Valid token?
|
||||
if !hmac.Equal(mac, expected) {
|
||||
return false, "", errors.New("Invalid cookie mac")
|
||||
}
|
||||
|
||||
expires, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to parse cookie expiry")
|
||||
}
|
||||
|
||||
// Has it expired?
|
||||
if time.Unix(expires, 0).Before(time.Now()) {
|
||||
return false, "", errors.New("Cookie has expired")
|
||||
}
|
||||
|
||||
// Looks valid
|
||||
return true, parts[2], nil
|
||||
}
|
||||
|
||||
// Validate email
|
||||
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
||||
if len(f.Domain) > 0 {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
found := false
|
||||
for _, domain := range f.Domain {
|
||||
if domain == parts[1] {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
// OAuth Methods
|
||||
|
||||
// Get login url
|
||||
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
||||
|
||||
q := url.Values{}
|
||||
q.Set("client_id", fw.ClientId)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", fw.Scope)
|
||||
// q.Set("approval_prompt", fw.ClientId)
|
||||
q.Set("redirect_uri", f.redirectUri(r))
|
||||
q.Set("state", state)
|
||||
|
||||
var u url.URL
|
||||
u = *fw.LoginURL
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
|
||||
type Token struct {
|
||||
Token string `json:"access_token"`
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", fw.ClientId)
|
||||
form.Set("client_secret", fw.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", f.redirectUri(r))
|
||||
form.Set("code", code)
|
||||
|
||||
|
||||
res, err := http.PostForm(fw.TokenURL.String(), form)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var token Token
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
|
||||
return token.Token, err
|
||||
}
|
||||
|
||||
// Get user with token
|
||||
|
||||
type User struct {
|
||||
Id string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified_email"`
|
||||
Hd string `json:"hd"`
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&user)
|
||||
|
||||
return user, err
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
// Get the redirect base
|
||||
func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
// Direct mode
|
||||
if f.Direct {
|
||||
proto = "http"
|
||||
host = r.Host
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s", proto, host)
|
||||
}
|
||||
|
||||
// Return url
|
||||
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
|
||||
// Testing
|
||||
if f.Direct {
|
||||
path = r.URL.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
|
||||
}
|
||||
|
||||
// Get oauth redirect uri
|
||||
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
||||
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
||||
}
|
||||
|
||||
// Cookie methods
|
||||
|
||||
// Create an auth cookie
|
||||
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||
expires := f.cookieExpiry()
|
||||
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||
|
||||
return &http.Cookie{
|
||||
Name: f.CookieName,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: f.cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: expires,
|
||||
}
|
||||
}
|
||||
|
||||
// Make a CSRF cookie (used during login only)
|
||||
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
Domain: f.cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: f.cookieExpiry(),
|
||||
}
|
||||
}
|
||||
|
||||
// Create a cookie to clear csrf cookie
|
||||
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: f.CSRFCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: f.cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: f.CookieSecure,
|
||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the csrf cookie against state
|
||||
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
|
||||
if len(state) < 34 {
|
||||
return false, "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, "", errors.New("CSRF cookie does not match state")
|
||||
}
|
||||
|
||||
// Valid, return redirect
|
||||
return true, state[33:], nil
|
||||
}
|
||||
|
||||
func (f *ForwardAuth) Nonce() (error, string) {
|
||||
// Make nonce
|
||||
nonce := make([]byte, 16)
|
||||
_, err := rand.Read(nonce)
|
||||
if err != nil {
|
||||
return err, ""
|
||||
}
|
||||
|
||||
return nil, fmt.Sprintf("%x", nonce)
|
||||
}
|
||||
|
||||
// Cookie domain
|
||||
func (f *ForwardAuth) cookieDomain(r *http.Request) string {
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
// Direct mode
|
||||
if f.Direct {
|
||||
host = r.Host
|
||||
}
|
||||
|
||||
// Remove port for matching
|
||||
p := strings.Split(host, ":")
|
||||
|
||||
// Check if any of the given cookie domains matches
|
||||
for _, domain := range f.CookieDomains {
|
||||
if domain.Match(p[0]) {
|
||||
return domain.Domain
|
||||
}
|
||||
}
|
||||
|
||||
return p[0]
|
||||
}
|
||||
|
||||
// Create cookie hmac
|
||||
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
||||
hash := hmac.New(sha256.New, f.CookieSecret)
|
||||
hash.Write([]byte(f.cookieDomain(r)))
|
||||
hash.Write([]byte(email))
|
||||
hash.Write([]byte(expires))
|
||||
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// Get cookie expirary
|
||||
func (f *ForwardAuth) cookieExpiry() time.Time {
|
||||
return time.Now().Local().Add(f.Lifetime)
|
||||
}
|
||||
|
||||
// Cookie Domain
|
||||
|
||||
// Cookie Domain
|
||||
type CookieDomain struct {
|
||||
Domain string
|
||||
DomainLen int
|
||||
SubDomain string
|
||||
SubDomainLen int
|
||||
}
|
||||
|
||||
func NewCookieDomain(domain string) *CookieDomain {
|
||||
return &CookieDomain{
|
||||
Domain: domain,
|
||||
DomainLen: len(domain),
|
||||
SubDomain: fmt.Sprintf(".%s", domain),
|
||||
SubDomainLen: len(domain) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CookieDomain) Match(host string) bool {
|
||||
// Exact domain match?
|
||||
if host == c.Domain {
|
||||
return true
|
||||
}
|
||||
|
||||
// Subdomain match?
|
||||
if len(host) >= c.SubDomainLen && host[len(host) - c.SubDomainLen:] == c.SubDomain {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
@ -1,242 +0,0 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"time"
|
||||
"reflect"
|
||||
"testing"
|
||||
"net/url"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func TestValidateCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
c := &http.Cookie{}
|
||||
|
||||
// Should require 3 parts
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2|3|4"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch invalid mac
|
||||
c.Value = "MQ==|2|3"
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie mac" {
|
||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch expired
|
||||
fw.Lifetime = time.Second * time.Duration(-1)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, _, err = fw.ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Cookie has expired" {
|
||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||
}
|
||||
|
||||
// Should accept valid cookie
|
||||
fw.Lifetime = time.Second * time.Duration(10)
|
||||
c = fw.MakeCookie(r, "test@test.com")
|
||||
valid, email, err := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if email != "test@test.com" {
|
||||
t.Error("Valid request should return user email")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEmail(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
|
||||
// Should allow any
|
||||
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should allow any domain if email domain is not defined")
|
||||
}
|
||||
|
||||
// Should block non matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if fw.ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user from another domain")
|
||||
}
|
||||
|
||||
// Should allow matching domain
|
||||
fw.Domain = []string{"test.com"}
|
||||
if !fw.ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user from allowed domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLoginURL(t *testing.T) {
|
||||
fw = &ForwardAuth{
|
||||
Path: "/_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
t.Error("Incorrect login query string, expected:")
|
||||
t.Error(expectedQs)
|
||||
t.Error("Got:")
|
||||
t.Error(qs)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO
|
||||
// func TestExchangeCode(t *testing.T) {
|
||||
// }
|
||||
|
||||
// TODO
|
||||
// func TestGetUser(t *testing.T) {
|
||||
// }
|
||||
|
||||
// TODO? Tested in TestValidateCookie
|
||||
// func TestMakeCookie(t *testing.T) {
|
||||
// }
|
||||
|
||||
// func TestMakeCSRFCookie(t *testing.T) {
|
||||
// t.Log("TODO")
|
||||
// }
|
||||
|
||||
func TestClearCSRFCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
c := fw.ClearCSRFCookie(r)
|
||||
if c.Value != "" {
|
||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCSRFCookie(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
c := &http.Cookie{}
|
||||
|
||||
// Should require 32 char string
|
||||
c.Value = ""
|
||||
valid, _, err := fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "")
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
|
||||
// Should require valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
|
||||
if valid || err.Error() != "Invalid CSRF state value" {
|
||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if state != "99" {
|
||||
t.Error("Valid request should return correct state, got:", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonce(t *testing.T) {
|
||||
fw = &ForwardAuth{}
|
||||
|
||||
err, nonce1 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
err, nonce2 := fw.Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
if len(nonce1) != 32 || len(nonce2) != 32 {
|
||||
t.Error("Nonce should be 32 chars")
|
||||
}
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("Nonce should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieDomainMatch(t *testing.T) {
|
||||
cd := NewCookieDomain("example.com")
|
||||
|
||||
// Exact should match
|
||||
if !cd.Match("example.com") {
|
||||
t.Error("Exact domain should match")
|
||||
}
|
||||
|
||||
// Subdomain should match
|
||||
if !cd.Match("test.example.com") {
|
||||
t.Error("Subdomain should match")
|
||||
}
|
||||
|
||||
// Derived domain should not match
|
||||
if cd.Match("testexample.com") {
|
||||
t.Error("Derived domain should not match")
|
||||
}
|
||||
|
||||
// Other domain should not match
|
||||
if cd.Match("test.com") {
|
||||
t.Error("Other domain should not match")
|
||||
}
|
||||
}
|
31
go.mod
Normal file
31
go.mod
Normal file
@ -0,0 +1,31 @@
|
||||
module github.com/thomseddon/traefik-forward-auth
|
||||
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/VividCortex/gohistogram v1.0.0 // indirect
|
||||
github.com/cenkalti/backoff v2.1.1+incompatible // indirect
|
||||
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // indirect
|
||||
github.com/containous/flaeg v1.4.1 // indirect
|
||||
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 // indirect
|
||||
github.com/containous/traefik v2.0.0-alpha2+incompatible
|
||||
github.com/go-acme/lego v2.4.0+incompatible // indirect
|
||||
github.com/go-kit/kit v0.8.0 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff // indirect
|
||||
github.com/jessevdk/go-flags v1.4.0
|
||||
github.com/jonboulle/clockwork v0.1.0 // indirect
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/miekg/dns v1.1.8 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pkg/errors v0.8.1 // indirect
|
||||
github.com/ryanuber/go-glob v1.0.0 // indirect
|
||||
github.com/sirupsen/logrus v1.4.1
|
||||
github.com/stretchr/testify v1.3.0 // indirect
|
||||
github.com/vulcand/predicate v1.1.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a // indirect
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 // indirect
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 // indirect
|
||||
)
|
70
go.sum
Normal file
70
go.sum
Normal file
@ -0,0 +1,70 @@
|
||||
github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE=
|
||||
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
|
||||
github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY=
|
||||
github.com/cenkalti/backoff v2.1.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
|
||||
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd h1:0n+lFLh5zU0l6KSk3KpnDwfbPGAR44aRLgTbCnhRBHU=
|
||||
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd/go.mod h1:BbQgeDS5i0tNvypwEoF1oNjOJw8knRAE1DnVvjDstcQ=
|
||||
github.com/containous/flaeg v1.4.1 h1:VTouP7EF2JeowNvknpP3fJAJLUDsQ1lDHq/QQTQc1xc=
|
||||
github.com/containous/flaeg v1.4.1/go.mod h1:wgw6PDtRURXHKFFV6HOqQxWhUc3k3Hmq22jw+n2qDro=
|
||||
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898 h1:1srn9voikJGofblBhWy3WuZWqo14Ou7NaswNG/I2yWc=
|
||||
github.com/containous/mux v0.0.0-20181024131434-c33f32e26898/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg=
|
||||
github.com/containous/traefik v2.0.0-alpha2+incompatible h1:5RS6mUAOPQCy1jAmcmxLj2nChIcs3fKuxZxH9AF6ih8=
|
||||
github.com/containous/traefik v2.0.0-alpha2+incompatible/go.mod h1:epDRqge3JzKOhlSWzOpNYEEKXmM6yfN5tPzDGKk3ljo=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-acme/lego v2.4.0+incompatible h1:+BTLUfLtDc5qQauyiTCXH6lupEUOCvXyGlEjdeU0YQI=
|
||||
github.com/go-acme/lego v2.4.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
|
||||
github.com/go-kit/kit v0.8.0 h1:Wz+5lgoB0kkuqLEc6NVmwRknTKP6dTGbSqvhZtBI/j0=
|
||||
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff h1:xL/fJdlTJL6R/6Qk2tPu3EP1NsXgap9hXLvxKH0Ytko=
|
||||
github.com/gravitational/trace v0.0.0-20190409171327-f30095ced5ff/go.mod h1:RvdOUHE4SHqR3oXlFFKnGzms8a5dugHygGw1bqDstYI=
|
||||
github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
|
||||
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/miekg/dns v1.1.8 h1:1QYRAKU3lN5cRfLCkPU08hwvLJFhvjP6MqNMmQz6ZVI=
|
||||
github.com/miekg/dns v1.1.8/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
|
||||
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
|
||||
github.com/sirupsen/logrus v1.4.1 h1:GL2rEmy6nsikmW0r8opw9JIRScdMF5hA8cOYLH7In1k=
|
||||
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/vulcand/predicate v1.1.0 h1:Gq/uWopa4rx/tnZu2opOSBqHK63Yqlou/SzrbwdJiNg=
|
||||
github.com/vulcand/predicate v1.1.0/go.mod h1:mlccC5IRBoc2cIFmCB8ZM62I3VDb6p2GXESMHa3CnZg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4=
|
||||
gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
333
internal/auth.go
Normal file
333
internal/auth.go
Normal file
@ -0,0 +1,333 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
)
|
||||
|
||||
// Request Validation
|
||||
|
||||
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
||||
func ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||
parts := strings.Split(c.Value, "|")
|
||||
|
||||
if len(parts) != 3 {
|
||||
return false, "", errors.New("Invalid cookie format")
|
||||
}
|
||||
|
||||
mac, err := base64.URLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to decode cookie mac")
|
||||
}
|
||||
|
||||
expectedSignature := cookieSignature(r, parts[2], parts[1])
|
||||
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to generate mac")
|
||||
}
|
||||
|
||||
// Valid token?
|
||||
if !hmac.Equal(mac, expected) {
|
||||
return false, "", errors.New("Invalid cookie mac")
|
||||
}
|
||||
|
||||
expires, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return false, "", errors.New("Unable to parse cookie expiry")
|
||||
}
|
||||
|
||||
// Has it expired?
|
||||
if time.Unix(expires, 0).Before(time.Now()) {
|
||||
return false, "", errors.New("Cookie has expired")
|
||||
}
|
||||
|
||||
// Looks valid
|
||||
return true, parts[2], nil
|
||||
}
|
||||
|
||||
// Validate email
|
||||
func ValidateEmail(email string) bool {
|
||||
found := false
|
||||
if len(config.Whitelist) > 0 {
|
||||
for _, whitelist := range config.Whitelist {
|
||||
if email == whitelist {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
} else if len(config.Domains) > 0 {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
for _, domain := range config.Domains {
|
||||
if domain == parts[1] {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
// OAuth Methods
|
||||
|
||||
// Get login url
|
||||
func GetLoginURL(r *http.Request, nonce string) string {
|
||||
state := fmt.Sprintf("%s:%s", nonce, returnUrl(r))
|
||||
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.GetLoginURL(redirectUri(r), state)
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
|
||||
func ExchangeCode(r *http.Request) (string, error) {
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.ExchangeCode(redirectUri(r), code)
|
||||
}
|
||||
|
||||
// Get user with token
|
||||
|
||||
func GetUser(token string) (provider.User, error) {
|
||||
// TODO: Support multiple providers
|
||||
return config.Providers.Google.GetUser(token)
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
// Get the redirect base
|
||||
func redirectBase(r *http.Request) string {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
return fmt.Sprintf("%s://%s", proto, host)
|
||||
}
|
||||
|
||||
// // Return url
|
||||
func returnUrl(r *http.Request) string {
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
|
||||
return fmt.Sprintf("%s%s", redirectBase(r), path)
|
||||
}
|
||||
|
||||
// Get oauth redirect uri
|
||||
func redirectUri(r *http.Request) string {
|
||||
if use, _ := useAuthDomain(r); use {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
|
||||
}
|
||||
|
||||
// Should we use auth host + what it is
|
||||
func useAuthDomain(r *http.Request) (bool, string) {
|
||||
if config.AuthHost == "" {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// Does the request match a given cookie domain?
|
||||
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||
|
||||
// Do any of the auth hosts match a cookie domain?
|
||||
authMatch, authHost := matchCookieDomains(config.AuthHost)
|
||||
|
||||
// We need both to match the same domain
|
||||
return reqMatch && authMatch && reqHost == authHost, reqHost
|
||||
}
|
||||
|
||||
// Cookie methods
|
||||
|
||||
// Create an auth cookie
|
||||
func MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||
expires := cookieExpiry()
|
||||
mac := cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||
|
||||
return &http.Cookie{
|
||||
Name: config.CookieName,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: !config.InsecureCookie,
|
||||
Expires: expires,
|
||||
}
|
||||
}
|
||||
|
||||
// Make a CSRF cookie (used during login only)
|
||||
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: config.CSRFCookieName,
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
Domain: csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: !config.InsecureCookie,
|
||||
Expires: cookieExpiry(),
|
||||
}
|
||||
}
|
||||
|
||||
// Create a cookie to clear csrf cookie
|
||||
func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: config.CSRFCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: !config.InsecureCookie,
|
||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the csrf cookie against state
|
||||
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
|
||||
if len(state) < 34 {
|
||||
return false, "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, "", errors.New("CSRF cookie does not match state")
|
||||
}
|
||||
|
||||
// Valid, return redirect
|
||||
return true, state[33:], nil
|
||||
}
|
||||
|
||||
func Nonce() (error, string) {
|
||||
// Make nonce
|
||||
nonce := make([]byte, 16)
|
||||
_, err := rand.Read(nonce)
|
||||
if err != nil {
|
||||
return err, ""
|
||||
}
|
||||
|
||||
return nil, fmt.Sprintf("%x", nonce)
|
||||
}
|
||||
|
||||
// Cookie domain
|
||||
func cookieDomain(r *http.Request) string {
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
// Check if any of the given cookie domains matches
|
||||
_, domain := matchCookieDomains(host)
|
||||
return domain
|
||||
}
|
||||
|
||||
// Cookie domain
|
||||
func csrfCookieDomain(r *http.Request) string {
|
||||
var host string
|
||||
if use, domain := useAuthDomain(r); use {
|
||||
host = domain
|
||||
} else {
|
||||
host = r.Header.Get("X-Forwarded-Host")
|
||||
}
|
||||
|
||||
// Remove port
|
||||
p := strings.Split(host, ":")
|
||||
return p[0]
|
||||
}
|
||||
|
||||
// Return matching cookie domain if exists
|
||||
func matchCookieDomains(domain string) (bool, string) {
|
||||
// Remove port
|
||||
p := strings.Split(domain, ":")
|
||||
|
||||
for _, d := range config.CookieDomains {
|
||||
if d.Match(p[0]) {
|
||||
return true, d.Domain
|
||||
}
|
||||
}
|
||||
|
||||
return false, p[0]
|
||||
}
|
||||
|
||||
// Create cookie hmac
|
||||
func cookieSignature(r *http.Request, email, expires string) string {
|
||||
hash := hmac.New(sha256.New, config.Secret)
|
||||
hash.Write([]byte(cookieDomain(r)))
|
||||
hash.Write([]byte(email))
|
||||
hash.Write([]byte(expires))
|
||||
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// Get cookie expirary
|
||||
func cookieExpiry() time.Time {
|
||||
return time.Now().Local().Add(config.Lifetime)
|
||||
}
|
||||
|
||||
// Cookie Domain
|
||||
|
||||
// Cookie Domain
|
||||
type CookieDomain struct {
|
||||
Domain string `description:"TEST1"`
|
||||
DomainLen int `description:"TEST2"`
|
||||
SubDomain string `description:"TEST3"`
|
||||
SubDomainLen int `description:"TEST4"`
|
||||
}
|
||||
|
||||
func NewCookieDomain(domain string) *CookieDomain {
|
||||
return &CookieDomain{
|
||||
Domain: domain,
|
||||
DomainLen: len(domain),
|
||||
SubDomain: fmt.Sprintf(".%s", domain),
|
||||
SubDomainLen: len(domain) + 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CookieDomain) Match(host string) bool {
|
||||
// Exact domain match?
|
||||
if host == c.Domain {
|
||||
return true
|
||||
}
|
||||
|
||||
// Subdomain match?
|
||||
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type CookieDomains []CookieDomain
|
||||
|
||||
func (c *CookieDomains) UnmarshalFlag(value string) error {
|
||||
// TODO: test
|
||||
if len(value) > 0 {
|
||||
for _, d := range strings.Split(value, ",") {
|
||||
cookieDomain := NewCookieDomain(d)
|
||||
*c = append(*c, *cookieDomain)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CookieDomains) MarshalFlag() (string, error) {
|
||||
var domains []string
|
||||
for _, d := range *c {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
return strings.Join(domains, ","), nil
|
||||
}
|
490
internal/auth_test.go
Normal file
490
internal/auth_test.go
Normal file
@ -0,0 +1,490 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
)
|
||||
|
||||
/**
|
||||
* Tests
|
||||
*/
|
||||
|
||||
func TestAuthValidateCookie(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
c := &http.Cookie{}
|
||||
|
||||
// Should require 3 parts
|
||||
c.Value = ""
|
||||
valid, _, err := ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2"
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
c.Value = "1|2|3|4"
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie format" {
|
||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch invalid mac
|
||||
c.Value = "MQ==|2|3"
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Invalid cookie mac" {
|
||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||
}
|
||||
|
||||
// Should catch expired
|
||||
config.Lifetime = time.Second * time.Duration(-1)
|
||||
c = MakeCookie(r, "test@test.com")
|
||||
valid, _, err = ValidateCookie(r, c)
|
||||
if valid || err.Error() != "Cookie has expired" {
|
||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||
}
|
||||
|
||||
// Should accept valid cookie
|
||||
config.Lifetime = time.Second * time.Duration(10)
|
||||
c = MakeCookie(r, "test@test.com")
|
||||
valid, email, err := ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if email != "test@test.com" {
|
||||
t.Error("Valid request should return user email")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthValidateEmail(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Should allow any
|
||||
if !ValidateEmail("test@test.com") || !ValidateEmail("one@two.com") {
|
||||
t.Error("Should allow any domain if email domain is not defined")
|
||||
}
|
||||
|
||||
// Should block non matching domain
|
||||
config.Domains = []string{"test.com"}
|
||||
if ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user from another domain")
|
||||
}
|
||||
|
||||
// Should allow matching domain
|
||||
config.Domains = []string{"test.com"}
|
||||
if !ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user from allowed domain")
|
||||
}
|
||||
|
||||
// Should block non whitelisted email address
|
||||
config.Domains = []string{}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
if ValidateEmail("one@two.com") {
|
||||
t.Error("Should not allow user not in whitelist.")
|
||||
}
|
||||
|
||||
// Should allow matching whitelisted email address
|
||||
config.Domains = []string{}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
if !ValidateEmail("test@test.com") {
|
||||
t.Error("Should allow user in whitelist.")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Split google tests out
|
||||
func TestAuthGetLoginURL(t *testing.T) {
|
||||
google := provider.Google{
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
|
||||
config, _ = NewConfig([]string{})
|
||||
config.Providers.Google = google
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// With Auth URL but no matching cookie domain
|
||||
// - will not use auth host
|
||||
//
|
||||
config, _ = NewConfig([]string{})
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.Providers.Google = google
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// With correct Auth URL + cookie domain
|
||||
//
|
||||
config, _ = NewConfig([]string{})
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
config.Providers.Google = google
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://auth.example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://example.com/hello"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
}
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// With Auth URL + cookie domain, but from different domain
|
||||
// - will not use auth host
|
||||
//
|
||||
r, _ = http.NewRequest("GET", "http://another.com", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "another.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
// Check url
|
||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||
if err != nil {
|
||||
t.Error("Error parsing login url:", err)
|
||||
}
|
||||
if uri.Scheme != "https" {
|
||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||
}
|
||||
if uri.Host != "test.com" {
|
||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||
}
|
||||
if uri.Path != "/auth" {
|
||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||
}
|
||||
|
||||
// Check query string
|
||||
qs = uri.Query()
|
||||
expectedQs = url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://another.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"nonce:http://another.com/hello"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
}
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
if !reflect.DeepEqual(qs, expectedQs) {
|
||||
for _, err := range qsDiff(t, expectedQs, qs) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
// func TestAuthExchangeCode(t *testing.T) {
|
||||
// }
|
||||
|
||||
// TODO
|
||||
// func TestAuthGetUser(t *testing.T) {
|
||||
// }
|
||||
|
||||
func TestAuthMakeCookie(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
|
||||
c := MakeCookie(r, "test@example.com")
|
||||
if c.Name != "_forward_auth" {
|
||||
t.Error("Cookie name should be \"_forward_auth\", got:", c.Name)
|
||||
}
|
||||
parts := strings.Split(c.Value, "|")
|
||||
if len(parts) != 3 {
|
||||
t.Error("Cookie should be in 3 parts, got:", c.Value)
|
||||
}
|
||||
valid, _, _ := ValidateCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Should generate valid cookie:", c.Value)
|
||||
}
|
||||
if c.Path != "/" {
|
||||
t.Error("Cookie path should be \"/\", got:", c.Path)
|
||||
}
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie domain should be \"app.example.com\", got:", c.Domain)
|
||||
}
|
||||
if c.Secure != true {
|
||||
t.Error("Cookie domain should be true, got:", c.Secure)
|
||||
}
|
||||
if !c.Expires.After(time.Now().Local()) {
|
||||
t.Error("Expires should be after now, got:", c.Expires)
|
||||
}
|
||||
if !c.Expires.Before(time.Now().Local().Add(config.Lifetime).Add(10 * time.Second)) {
|
||||
t.Error("Expires should be before lifetime + 10 seconds, got:", c.Expires)
|
||||
}
|
||||
|
||||
config.CookieName = "testname"
|
||||
config.InsecureCookie = true
|
||||
c = MakeCookie(r, "test@example.com")
|
||||
if c.Name != "testname" {
|
||||
t.Error("Cookie name should be \"testname\", got:", c.Name)
|
||||
}
|
||||
if c.Secure != false {
|
||||
t.Error("Cookie domain should be false, got:", c.Secure)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMakeCSRFCookie(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
|
||||
// No cookie domain or auth url
|
||||
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
|
||||
// With cookie domain but no auth url
|
||||
config = Config{
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "app.example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
|
||||
// With cookie domain and auth url
|
||||
config = Config{
|
||||
AuthHost: "auth.example.com",
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
if c.Domain != "example.com" {
|
||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthClearCSRFCookie(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
c := ClearCSRFCookie(r)
|
||||
if c.Value != "" {
|
||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
c := &http.Cookie{}
|
||||
|
||||
newCsrfRequest := func(state string) *http.Request {
|
||||
u := fmt.Sprintf("http://example.com?state=%s", state)
|
||||
r, _ := http.NewRequest("GET", u, nil)
|
||||
return r
|
||||
}
|
||||
|
||||
// Should require 32 char string
|
||||
r := newCsrfRequest("")
|
||||
c.Value = ""
|
||||
valid, _, err := ValidateCSRFCookie(r, c)
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, err = ValidateCSRFCookie(r, c)
|
||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||
}
|
||||
|
||||
// Should require valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, err = ValidateCSRFCookie(r, c)
|
||||
if valid || err.Error() != "Invalid CSRF state value" {
|
||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:99")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, state, err := ValidateCSRFCookie(r, c)
|
||||
if !valid {
|
||||
t.Error("Valid request should return as valid")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Valid request should not return error, got:", err)
|
||||
}
|
||||
if state != "99" {
|
||||
t.Error("Valid request should return correct state, got:", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthNonce(t *testing.T) {
|
||||
err, nonce1 := Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
err, nonce2 := Nonce()
|
||||
if err != nil {
|
||||
t.Error("Error generation nonce:", err)
|
||||
}
|
||||
|
||||
if len(nonce1) != 32 || len(nonce2) != 32 {
|
||||
t.Error("Nonce should be 32 chars")
|
||||
}
|
||||
if nonce1 == nonce2 {
|
||||
t.Error("Nonce should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthCookieDomainMatch(t *testing.T) {
|
||||
cd := NewCookieDomain("example.com")
|
||||
|
||||
// Exact should match
|
||||
if !cd.Match("example.com") {
|
||||
t.Error("Exact domain should match")
|
||||
}
|
||||
|
||||
// Subdomain should match
|
||||
if !cd.Match("test.example.com") {
|
||||
t.Error("Subdomain should match")
|
||||
}
|
||||
|
||||
// Derived domain should not match
|
||||
if cd.Match("testexample.com") {
|
||||
t.Error("Derived domain should not match")
|
||||
}
|
||||
|
||||
// Other domain should not match
|
||||
if cd.Match("test.com") {
|
||||
t.Error("Other domain should not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthCookieDomains(t *testing.T) {
|
||||
cds := CookieDomains{}
|
||||
|
||||
err := cds.UnmarshalFlag("one.com,two.org")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(cds) != 2 {
|
||||
t.Error("Expected UnmarshalFlag to provide 2 CookieDomains, got", cds)
|
||||
}
|
||||
if cds[0].Domain != "one.com" || cds[0].SubDomain != ".one.com" {
|
||||
t.Error("Expected UnmarshalFlag to provide one.com, got", cds[0])
|
||||
}
|
||||
if cds[1].Domain != "two.org" || cds[1].SubDomain != ".two.org" {
|
||||
t.Error("Expected UnmarshalFlag to provide two.org, got", cds[1])
|
||||
}
|
||||
|
||||
marshal, err := cds.MarshalFlag()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if marshal != "one.com,two.org" {
|
||||
t.Error("Expected MarshalFlag to provide \"one.com,two.org\", got", cds)
|
||||
}
|
||||
}
|
262
internal/config.go
Normal file
262
internal/config.go
Normal file
@ -0,0 +1,262 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jessevdk/go-flags"
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
)
|
||||
|
||||
var config Config
|
||||
|
||||
type Config struct {
|
||||
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
|
||||
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
|
||||
|
||||
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Host for central auth login"`
|
||||
Config func(s string) error `long:"config" env:"CONFIG" description:"Config file"`
|
||||
CookieDomains CookieDomains `long:"cookie-domains" env:"COOKIE_DOMAINS" description:"Comma separated list of cookie domains"`
|
||||
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
|
||||
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
|
||||
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
||||
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default Action"`
|
||||
Domains CommaSeparatedList `long:"domains" env:"DOMAINS" description:"Comma separated list of email domains to allow"`
|
||||
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
|
||||
Path string `long:"url-path" env:"URL_PATH" default:"_oauth" description:"Callback URL Path"`
|
||||
SecretString string `long:"secret" env:"SECRET" description:"*Secret used for signing (required)"`
|
||||
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Comma separated list of email addresses to allow"`
|
||||
|
||||
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
|
||||
Rules map[string]*Rule `long:"rules.<name>.<param>" description:"Rule definitions, see docs, param can be: \"action\", \"rule\""`
|
||||
|
||||
// Filled during transformations
|
||||
Secret []byte
|
||||
Lifetime time.Duration
|
||||
|
||||
// Legacy
|
||||
ClientIdLegacy string `long:"client-id" env:"CLIENT_ID" group:"DEPs" description:"DEPRECATED - Use \"providers.google.client-id\""`
|
||||
ClientSecretLegacy string `long:"client-secret" env:"CLIENT_SECRET" description:"DEPRECATED - Use \"providers.google.client-id\""`
|
||||
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
|
||||
CookieSecureLegacy string `long:"cookie-secure" env:"COOKIE_SECURE" namespace:"DERPS" description:"DEPRECATED - Use \"insecure-cookie\""`
|
||||
}
|
||||
|
||||
func NewGlobalConfig() Config {
|
||||
var err error
|
||||
config, err = NewConfig(os.Args[1:])
|
||||
if err != nil {
|
||||
fmt.Printf("%+v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func NewConfig(args []string) (Config, error) {
|
||||
c := Config{
|
||||
Rules: map[string]*Rule{},
|
||||
}
|
||||
|
||||
err := c.parseFlags(args)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
|
||||
// TODO: as log flags have now been parsed maybe we should return here so
|
||||
// any further errors can be logged via logrus instead of printed?
|
||||
|
||||
// Backwards compatability
|
||||
if c.ClientIdLegacy != "" {
|
||||
c.Providers.Google.ClientId = c.ClientIdLegacy
|
||||
}
|
||||
if c.ClientSecretLegacy != "" {
|
||||
c.Providers.Google.ClientSecret = c.ClientSecretLegacy
|
||||
}
|
||||
if c.PromptLegacy != "" {
|
||||
c.Providers.Google.Prompt = c.PromptLegacy
|
||||
}
|
||||
if c.CookieSecureLegacy != "" {
|
||||
secure, err := strconv.ParseBool(c.CookieSecureLegacy)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
c.InsecureCookie = !secure
|
||||
}
|
||||
|
||||
// Provider defaults
|
||||
c.Providers.Google.Build()
|
||||
|
||||
// Transformations
|
||||
c.Path = fmt.Sprintf("/%s", c.Path)
|
||||
c.Secret = []byte(c.SecretString)
|
||||
c.Lifetime = time.Second * time.Duration(c.LifetimeString)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) parseFlags(args []string) error {
|
||||
p := flags.NewParser(c, flags.Default)
|
||||
p.UnknownOptionHandler = c.parseUnknownFlag
|
||||
|
||||
i := flags.NewIniParser(p)
|
||||
c.Config = func(s string) error {
|
||||
// Try parsing at as an ini
|
||||
err := i.ParseFile(s)
|
||||
|
||||
// If it fails with a syntax error, try converting legacy to ini
|
||||
if err != nil && strings.Contains(err.Error(), "malformed key=value") {
|
||||
converted, convertErr := convertLegacyToIni(s)
|
||||
if convertErr != nil {
|
||||
// If conversion fails, return the original error
|
||||
return err
|
||||
}
|
||||
|
||||
return i.Parse(converted)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := p.ParseArgs(args)
|
||||
if err != nil {
|
||||
return handlFlagError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args []string) ([]string, error) {
|
||||
// Parse rules in the format "rule.<name>.<param>"
|
||||
parts := strings.Split(option, ".")
|
||||
if len(parts) == 3 && parts[0] == "rule" {
|
||||
// Get or create rule
|
||||
rule, ok := c.Rules[parts[1]]
|
||||
if !ok {
|
||||
rule = NewRule()
|
||||
c.Rules[parts[1]] = rule
|
||||
}
|
||||
|
||||
// Get value, or pop the next arg
|
||||
val, ok := arg.Value()
|
||||
if !ok {
|
||||
val = args[0]
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
// Check value
|
||||
if len(val) == 0 {
|
||||
return args, errors.New("route param value is required")
|
||||
}
|
||||
|
||||
// Unquote if required
|
||||
if val[0] == '"' {
|
||||
var err error
|
||||
val, err = strconv.Unquote(val)
|
||||
if err != nil {
|
||||
return args, err
|
||||
}
|
||||
}
|
||||
|
||||
// Add param value to rule
|
||||
switch parts[2] {
|
||||
case "action":
|
||||
rule.Action = val
|
||||
case "rule":
|
||||
rule.Rule = val
|
||||
case "provider":
|
||||
rule.Provider = val
|
||||
default:
|
||||
return args, fmt.Errorf("inavlid route param: %v", option)
|
||||
}
|
||||
} else {
|
||||
return args, fmt.Errorf("unknown flag: %v", option)
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func handlFlagError(err error) error {
|
||||
flagsErr, ok := err.(*flags.Error)
|
||||
if ok && flagsErr.Type == flags.ErrHelp {
|
||||
// Library has just printed cli help
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
var legacyFileFormat = regexp.MustCompile(`^([a-z-]+) ([\w\W]+)$`)
|
||||
|
||||
func convertLegacyToIni(name string) (io.Reader, error) {
|
||||
b, err := ioutil.ReadFile(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil
|
||||
}
|
||||
|
||||
func (c *Config) Validate() {
|
||||
// Check for show stopper errors
|
||||
if len(c.Secret) == 0 {
|
||||
log.Fatal("\"secret\" option must be set.")
|
||||
}
|
||||
|
||||
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" {
|
||||
log.Fatal("google.providers.client-id, google.providers.client-secret must be set")
|
||||
}
|
||||
|
||||
// Check rules
|
||||
for _, rule := range c.Rules {
|
||||
rule.Validate()
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String() string {
|
||||
jsonConf, _ := json.Marshal(c)
|
||||
return string(jsonConf)
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
Action string
|
||||
Rule string
|
||||
Provider string
|
||||
}
|
||||
|
||||
func NewRule() *Rule {
|
||||
return &Rule{
|
||||
Action: "auth",
|
||||
Provider: "google", // TODO: Use default provider
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Rule) Validate() {
|
||||
if r.Action != "auth" && r.Action != "allow" {
|
||||
log.Fatal("invalid rule action, must be \"auth\" or \"allow\"")
|
||||
}
|
||||
|
||||
// TODO: Update with more provider support
|
||||
if r.Provider != "google" {
|
||||
log.Fatal("invalid rule provider, must be \"google\"")
|
||||
}
|
||||
}
|
||||
|
||||
type CommaSeparatedList []string
|
||||
|
||||
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
||||
*c = strings.Split(value, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommaSeparatedList) MarshalFlag() (string, error) {
|
||||
return strings.Join(*c, ","), nil
|
||||
}
|
269
internal/config_test.go
Normal file
269
internal/config_test.go
Normal file
@ -0,0 +1,269 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
/**
|
||||
* Tests
|
||||
*/
|
||||
|
||||
func TestConfigDefaults(t *testing.T) {
|
||||
// Check defaults
|
||||
c, err := NewConfig([]string{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.LogLevel != "warn" {
|
||||
t.Error("LogLevel default should be warn, got", c.LogLevel)
|
||||
}
|
||||
if c.LogFormat != "text" {
|
||||
t.Error("LogFormat default should be text, got", c.LogFormat)
|
||||
}
|
||||
|
||||
if c.AuthHost != "" {
|
||||
t.Error("AuthHost default should be empty, got", c.AuthHost)
|
||||
}
|
||||
if len(c.CookieDomains) != 0 {
|
||||
t.Error("CookieDomains default should be empty, got", c.CookieDomains)
|
||||
}
|
||||
if c.InsecureCookie != false {
|
||||
t.Error("InsecureCookie default should be false, got", c.InsecureCookie)
|
||||
}
|
||||
if c.CookieName != "_forward_auth" {
|
||||
t.Error("CookieName default should be _forward_auth, got", c.CookieName)
|
||||
}
|
||||
if c.CSRFCookieName != "_forward_auth_csrf" {
|
||||
t.Error("CSRFCookieName default should be _forward_auth_csrf, got", c.CSRFCookieName)
|
||||
}
|
||||
if c.DefaultAction != "auth" {
|
||||
t.Error("DefaultAction default should be auth, got", c.DefaultAction)
|
||||
}
|
||||
if len(c.Domains) != 0 {
|
||||
t.Error("Domain default should be empty, got", c.Domains)
|
||||
}
|
||||
if c.Lifetime != time.Second*time.Duration(43200) {
|
||||
t.Error("Lifetime default should be 43200, got", c.Lifetime)
|
||||
}
|
||||
if c.Path != "/_oauth" {
|
||||
t.Error("Path default should be /_oauth, got", c.Path)
|
||||
}
|
||||
if len(c.Whitelist) != 0 {
|
||||
t.Error("Whitelist default should be empty, got", c.Whitelist)
|
||||
}
|
||||
|
||||
if c.Providers.Google.Prompt != "" {
|
||||
t.Error("Providers.Google.Prompt default should be empty, got", c.Providers.Google.Prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigParseArgs(t *testing.T) {
|
||||
c, err := NewConfig([]string{
|
||||
"--cookie-name=cookiename",
|
||||
"--csrf-cookie-name", "\"csrfcookiename\"",
|
||||
"--rule.1.action=allow",
|
||||
"--rule.1.rule=PathPrefix(`/one`)",
|
||||
"--rule.two.action=auth",
|
||||
"--rule.two.rule=\"Host(`two.com`) && Path(`/two`)\"",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Check normal flags
|
||||
if c.CookieName != "cookiename" {
|
||||
t.Error("CookieName default should be cookiename, got", c.CookieName)
|
||||
}
|
||||
if c.CSRFCookieName != "csrfcookiename" {
|
||||
t.Error("CSRFCookieName default should be csrfcookiename, got", c.CSRFCookieName)
|
||||
}
|
||||
|
||||
// Check rules
|
||||
if len(c.Rules) != 2 {
|
||||
t.Error("Should create 2 rules, got:", len(c.Rules))
|
||||
}
|
||||
|
||||
// First rule
|
||||
if rule, ok := c.Rules["1"]; !ok {
|
||||
t.Error("Could not find rule key '1'")
|
||||
} else {
|
||||
if rule.Action != "allow" {
|
||||
t.Error("First rule action should be allow, got:", rule.Action)
|
||||
}
|
||||
if rule.Rule != "PathPrefix(`/one`)" {
|
||||
t.Error("First rule rule should be PathPrefix(`/one`), got:", rule.Rule)
|
||||
}
|
||||
if rule.Provider != "google" {
|
||||
t.Error("First rule provider should be google, got:", rule.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// Second rule
|
||||
if rule, ok := c.Rules["two"]; !ok {
|
||||
t.Error("Could not find rule key '1'")
|
||||
} else {
|
||||
if rule.Action != "auth" {
|
||||
t.Error("Second rule action should be auth, got:", rule.Action)
|
||||
}
|
||||
if rule.Rule != "Host(`two.com`) && Path(`/two`)" {
|
||||
t.Error("Second rule rule should be Host(`two.com`) && Path(`/two`), got:", rule.Rule)
|
||||
}
|
||||
if rule.Provider != "google" {
|
||||
t.Error("Second rule provider should be google, got:", rule.Provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigParseUnknownFlags(t *testing.T) {
|
||||
_, err := NewConfig([]string{
|
||||
"--unknown=_oauthpath2",
|
||||
})
|
||||
if err.Error() != "unknown flag: unknown" {
|
||||
t.Error("Error should be \"unknown flag: unknown\", got:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFlagBackwardsCompatability(t *testing.T) {
|
||||
c, err := NewConfig([]string{
|
||||
"--client-id=clientid",
|
||||
"--client-secret=verysecret",
|
||||
"--prompt=prompt",
|
||||
"--lifetime=200",
|
||||
"--cookie-secure=false",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.ClientIdLegacy != "clientid" {
|
||||
t.Error("ClientIdLegacy should be clientid, got", c.ClientIdLegacy)
|
||||
}
|
||||
if c.Providers.Google.ClientId != "clientid" {
|
||||
t.Error("Providers.Google.ClientId should be clientid, got", c.Providers.Google.ClientId)
|
||||
}
|
||||
if c.ClientSecretLegacy != "verysecret" {
|
||||
t.Error("ClientSecretLegacy should be verysecret, got", c.ClientSecretLegacy)
|
||||
}
|
||||
if c.Providers.Google.ClientSecret != "verysecret" {
|
||||
t.Error("Providers.Google.ClientSecret should be verysecret, got", c.Providers.Google.ClientSecret)
|
||||
}
|
||||
if c.PromptLegacy != "prompt" {
|
||||
t.Error("PromptLegacy should be prompt, got", c.PromptLegacy)
|
||||
}
|
||||
if c.Providers.Google.Prompt != "prompt" {
|
||||
t.Error("Providers.Google.Prompt should be prompt, got", c.Providers.Google.Prompt)
|
||||
}
|
||||
|
||||
// "cookie-secure" used to be a standard go bool flag that could take
|
||||
// true, TRUE, 1, false, FALSE, 0 etc. values.
|
||||
// Here we're checking that format is still suppoted
|
||||
if c.CookieSecureLegacy != "false" || c.InsecureCookie != true {
|
||||
t.Error("Setting cookie-secure=false should set InsecureCookie true, got", c.InsecureCookie)
|
||||
}
|
||||
c, err = NewConfig([]string{"--cookie-secure=TRUE"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if c.CookieSecureLegacy != "TRUE" || c.InsecureCookie != false {
|
||||
t.Error("Setting cookie-secure=TRUE should set InsecureCookie false, got", c.InsecureCookie)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigParseIni(t *testing.T) {
|
||||
c, err := NewConfig([]string{
|
||||
"--config=../test/config0",
|
||||
"--config=../test/config1",
|
||||
"--csrf-cookie-name=csrfcookiename",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.CookieName != "inicookiename" {
|
||||
t.Error("CookieName should be read as inicookiename from ini file, got", c.CookieName)
|
||||
}
|
||||
if c.CSRFCookieName != "csrfcookiename" {
|
||||
t.Error("CSRFCookieName argument should override ini file, got", c.CSRFCookieName)
|
||||
}
|
||||
if c.Path != "/two" {
|
||||
t.Error("Path in second ini file should override first ini file, got", c.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigFileBackwardsCompatability(t *testing.T) {
|
||||
c, err := NewConfig([]string{
|
||||
"--config=../test/config-legacy",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.Path != "/two" {
|
||||
t.Error("Path in legacy config file should be read, got", c.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigParseEnvironment(t *testing.T) {
|
||||
os.Setenv("COOKIE_NAME", "env_cookie_name")
|
||||
c, err := NewConfig([]string{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.CookieName != "env_cookie_name" {
|
||||
t.Error("CookieName should be read as env_cookie_name from environment, got", c.CookieName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigTransformation(t *testing.T) {
|
||||
c, err := NewConfig([]string{
|
||||
"--url-path=_oauthpath",
|
||||
"--secret=verysecret",
|
||||
"--lifetime=200",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.Path != "/_oauthpath" {
|
||||
t.Error("Path should add slash to front to get /_oauthpath, got:", c.Path)
|
||||
}
|
||||
|
||||
if c.SecretString != "verysecret" {
|
||||
t.Error("SecretString should be verysecret, got:", c.SecretString)
|
||||
}
|
||||
if bytes.Compare(c.Secret, []byte("verysecret")) != 0 {
|
||||
t.Error("Secret should be []byte(verysecret), got:", string(c.Secret))
|
||||
}
|
||||
|
||||
if c.LifetimeString != 200 {
|
||||
t.Error("LifetimeString should be 200, got:", c.LifetimeString)
|
||||
}
|
||||
if c.Lifetime != time.Second*time.Duration(200) {
|
||||
t.Error("Lifetime default should be 200, got", c.Lifetime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCommaSeparatedList(t *testing.T) {
|
||||
list := CommaSeparatedList{}
|
||||
|
||||
err := list.UnmarshalFlag("one,two")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(list) != 2 || list[0] != "one" || list[1] != "two" {
|
||||
t.Error("Expected UnmarshalFlag to provide CommaSeparatedList{one,two}, got", list)
|
||||
}
|
||||
|
||||
marshal, err := list.MarshalFlag()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if marshal != "one,two" {
|
||||
t.Error("Expected MarshalFlag to provide \"one,two\", got", list)
|
||||
}
|
||||
}
|
50
internal/log.go
Normal file
50
internal/log.go
Normal file
@ -0,0 +1,50 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var log logrus.FieldLogger
|
||||
|
||||
func NewDefaultLogger() logrus.FieldLogger {
|
||||
// Setup logger
|
||||
log = logrus.StandardLogger()
|
||||
logrus.SetOutput(os.Stdout)
|
||||
|
||||
// Set logger format
|
||||
switch config.LogFormat {
|
||||
case "pretty":
|
||||
break
|
||||
case "json":
|
||||
logrus.SetFormatter(&logrus.JSONFormatter{})
|
||||
// "text" is the default
|
||||
default:
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
DisableColors: true,
|
||||
FullTimestamp: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Set logger level
|
||||
switch config.LogLevel {
|
||||
case "trace":
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
case "debug":
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
case "info":
|
||||
logrus.SetLevel(logrus.InfoLevel)
|
||||
case "error":
|
||||
logrus.SetLevel(logrus.ErrorLevel)
|
||||
case "fatal":
|
||||
logrus.SetLevel(logrus.FatalLevel)
|
||||
case "panic":
|
||||
logrus.SetLevel(logrus.PanicLevel)
|
||||
// warn is the default
|
||||
default:
|
||||
logrus.SetLevel(logrus.WarnLevel)
|
||||
}
|
||||
|
||||
return log
|
||||
}
|
96
internal/provider/google.go
Normal file
96
internal/provider/google.go
Normal file
@ -0,0 +1,96 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type Google struct {
|
||||
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
|
||||
Scope string
|
||||
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
|
||||
|
||||
LoginURL *url.URL
|
||||
TokenURL *url.URL
|
||||
UserURL *url.URL
|
||||
}
|
||||
|
||||
func (g *Google) Build() {
|
||||
g.LoginURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
}
|
||||
g.TokenURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
}
|
||||
g.UserURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Google) GetLoginURL(redirectUri, state string) string {
|
||||
q := url.Values{}
|
||||
q.Set("client_id", g.ClientId)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", g.Scope)
|
||||
if g.Prompt != "" {
|
||||
q.Set("prompt", g.Prompt)
|
||||
}
|
||||
q.Set("redirect_uri", redirectUri)
|
||||
q.Set("state", state)
|
||||
|
||||
var u url.URL
|
||||
u = *g.LoginURL
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", g.ClientId)
|
||||
form.Set("client_secret", g.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", redirectUri)
|
||||
form.Set("code", code)
|
||||
|
||||
res, err := http.PostForm(g.TokenURL.String(), form)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var token Token
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
|
||||
return token.Token, err
|
||||
}
|
||||
|
||||
func (g *Google) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", g.UserURL.String(), nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&user)
|
||||
|
||||
return user, err
|
||||
}
|
16
internal/provider/providers.go
Normal file
16
internal/provider/providers.go
Normal file
@ -0,0 +1,16 @@
|
||||
package provider
|
||||
|
||||
type Providers struct {
|
||||
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
Token string `json:"access_token"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Id string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified_email"`
|
||||
Hd string `json:"hd"`
|
||||
}
|
179
internal/server.go
Normal file
179
internal/server.go
Normal file
@ -0,0 +1,179 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/containous/traefik/pkg/rules"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *rules.Router
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
s := &Server{}
|
||||
s.buildRoutes()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) buildRoutes() {
|
||||
var err error
|
||||
s.router, err = rules.NewRouter()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Let's build a router
|
||||
for _, rule := range config.Rules {
|
||||
if rule.Action == "allow" {
|
||||
s.router.AddRoute(rule.Rule, 1, s.AllowHandler())
|
||||
} else {
|
||||
s.router.AddRoute(rule.Rule, 1, s.AuthHandler())
|
||||
}
|
||||
}
|
||||
|
||||
// Add callback handler
|
||||
s.router.Handle(config.Path, s.AuthCallbackHandler())
|
||||
|
||||
// Add a default handler
|
||||
if config.DefaultAction == "allow" {
|
||||
s.router.NewRoute().Handler(s.AllowHandler())
|
||||
} else {
|
||||
s.router.NewRoute().Handler(s.AuthHandler())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Modify request
|
||||
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||
|
||||
// Pass to mux
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Handler that allows requests
|
||||
func (s *Server) AllowHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s.logger(r, "Allowing request")
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate requests
|
||||
func (s *Server) AuthHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Logging setup
|
||||
logger := s.logger(r, "Authenticating request")
|
||||
|
||||
// Get auth cookie
|
||||
c, err := r.Cookie(config.CookieName)
|
||||
if err != nil {
|
||||
// Error indicates no cookie, generate nonce
|
||||
err, nonce := Nonce()
|
||||
if err != nil {
|
||||
logger.Errorf("Error generating nonce, %v", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the CSRF cookie
|
||||
http.SetCookie(w, MakeCSRFCookie(r, nonce))
|
||||
logger.Debug("Set CSRF cookie and redirecting to google login")
|
||||
|
||||
// Forward them on
|
||||
http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Validate cookie
|
||||
valid, email, err := ValidateCookie(r, c)
|
||||
if !valid {
|
||||
logger.Errorf("Invalid cookie: %v", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate user
|
||||
valid = ValidateEmail(email)
|
||||
if !valid {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"email": email,
|
||||
}).Errorf("Invalid email")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Valid request
|
||||
logger.Debugf("Allowing valid request ")
|
||||
w.Header().Set("X-Forwarded-User", email)
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle auth callback
|
||||
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Logging setup
|
||||
logger := s.logger(r, "Handling callback")
|
||||
|
||||
// Check for CSRF cookie
|
||||
c, err := r.Cookie(config.CSRFCookieName)
|
||||
if err != nil {
|
||||
logger.Warn("Missing csrf cookie")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state
|
||||
valid, redirect, err := ValidateCSRFCookie(r, c)
|
||||
if !valid {
|
||||
logger.Warnf("Error validating csrf cookie: %v", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, ClearCSRFCookie(r))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := ExchangeCode(r)
|
||||
if err != nil {
|
||||
logger.Errorf("Code exchange failed with: %v", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := GetUser(token)
|
||||
if err != nil {
|
||||
logger.Errorf("Error getting user: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate cookie
|
||||
http.SetCookie(w, MakeCookie(r, user.Email))
|
||||
logger.WithFields(logrus.Fields{
|
||||
"user": user.Email,
|
||||
}).Infof("Generated auth cookie")
|
||||
|
||||
// Redirect
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) logger(r *http.Request, msg string) *logrus.Entry {
|
||||
// Create logger
|
||||
logger := log.WithFields(logrus.Fields{
|
||||
"RemoteAddr": r.RemoteAddr,
|
||||
})
|
||||
|
||||
// Log request
|
||||
logger.WithFields(logrus.Fields{
|
||||
"Headers": r.Header,
|
||||
}).Debugf(msg)
|
||||
|
||||
return logger
|
||||
}
|
237
internal/server_test.go
Normal file
237
internal/server_test.go
Normal file
@ -0,0 +1,237 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
/**
|
||||
* Setup
|
||||
*/
|
||||
|
||||
func init() {
|
||||
config.LogLevel = "panic"
|
||||
log = NewDefaultLogger()
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests
|
||||
*/
|
||||
|
||||
func TestServerAuthHandler(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Should redirect vanilla request to login url
|
||||
req := newHttpRequest("/foo")
|
||||
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "https" || fwd.Host != "accounts.google.com" || fwd.Path != "/o/oauth2/auth" {
|
||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||
}
|
||||
|
||||
// Should catch invalid cookie
|
||||
req = newHttpRequest("/foo")
|
||||
c := MakeCookie(req, "test@example.com")
|
||||
parts := strings.Split(c.Value, "|")
|
||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should validate email
|
||||
req = newHttpRequest("/foo")
|
||||
c = MakeCookie(req, "test@example.com")
|
||||
config.Domains = []string{"test.com"}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid email shound't be authorised", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should allow valid request email
|
||||
req = newHttpRequest("/foo")
|
||||
|
||||
c = MakeCookie(req, "test@example.com")
|
||||
config.Domains = []string{}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should pass through user
|
||||
users := res.Header["X-Forwarded-User"]
|
||||
if len(users) != 1 {
|
||||
t.Error("Valid request missing X-Forwarded-User header")
|
||||
} else if users[0] != "test@example.com" {
|
||||
t.Error("X-Forwarded-User should match user, got: ", users)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerAuthCallback(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Setup token server
|
||||
tokenServerHandler := &TokenServerHandler{}
|
||||
tokenServer := httptest.NewServer(tokenServerHandler)
|
||||
defer tokenServer.Close()
|
||||
tokenUrl, _ := url.Parse(tokenServer.URL)
|
||||
config.Providers.Google.TokenURL = tokenUrl
|
||||
|
||||
// Setup user server
|
||||
userServerHandler := &UserServerHandler{}
|
||||
userServer := httptest.NewServer(userServerHandler)
|
||||
defer userServer.Close()
|
||||
userUrl, _ := url.Parse(userServer.URL)
|
||||
config.Providers.Google.UserURL = userUrl
|
||||
|
||||
// Should pass auth response request to callback
|
||||
req := newHttpRequest("/_oauth")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should catch invalid csrf cookie
|
||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should redirect valid request
|
||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerDefaultAction(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
req := newHttpRequest("/random")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Request should require auth with auth default handler, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
config.DefaultAction = "allow"
|
||||
req = newHttpRequest("/random")
|
||||
res, _ = httpRequest(req, nil)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Request should be allowed with allow default handler, got:", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerRoutePathPrefix(t *testing.T) {
|
||||
config, _ = NewConfig([]string{})
|
||||
config.Rules = map[string]*Rule{
|
||||
"web1": {
|
||||
Action: "allow",
|
||||
Rule: "PathPrefix(`/api`)",
|
||||
},
|
||||
}
|
||||
|
||||
// Should block any request
|
||||
req := newHttpRequest("/random")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Request not matching any rule should require auth, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should allow /api request
|
||||
req = newHttpRequest("/api")
|
||||
res, _ = httpRequest(req, nil)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utilities
|
||||
*/
|
||||
|
||||
type TokenServerHandler struct{}
|
||||
|
||||
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
||||
}
|
||||
|
||||
type UserServerHandler struct{}
|
||||
|
||||
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, `{
|
||||
"id":"1",
|
||||
"email":"example@example.com",
|
||||
"verified_email":true,
|
||||
"hd":"example.com"
|
||||
}`)
|
||||
}
|
||||
|
||||
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Set cookies on recorder
|
||||
if c != nil {
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
|
||||
// Copy into request
|
||||
for _, c := range w.HeaderMap["Set-Cookie"] {
|
||||
r.Header.Add("Cookie", c)
|
||||
}
|
||||
|
||||
|
||||
NewServer().RootHandler(w, r)
|
||||
|
||||
res := w.Result()
|
||||
body, _ := ioutil.ReadAll(res.Body)
|
||||
|
||||
// if res.StatusCode > 300 && res.StatusCode < 400 {
|
||||
// fmt.Printf("%#v", res.Header)
|
||||
// }
|
||||
|
||||
return res, string(body)
|
||||
}
|
||||
|
||||
func newHttpRequest(uri string) *http.Request {
|
||||
r := httptest.NewRequest("", "http://example.com/", nil)
|
||||
r.Header.Add("X-Forwarded-Uri", uri)
|
||||
return r
|
||||
}
|
||||
|
||||
func qsDiff(t *testing.T, one, two url.Values) []string {
|
||||
errs := make([]string, 0)
|
||||
for k := range one {
|
||||
if two.Get(k) == "" {
|
||||
errs = append(errs, fmt.Sprintf("Key missing: %s", k))
|
||||
}
|
||||
if one.Get(k) != two.Get(k) {
|
||||
errs = append(errs, fmt.Sprintf("Value different for %s: expected: '%s' got: '%s'", k, one.Get(k), two.Get(k)))
|
||||
}
|
||||
}
|
||||
for k := range two {
|
||||
if one.Get(k) == "" {
|
||||
errs = append(errs, fmt.Sprintf("Extra key: %s", k))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
217
main.go
217
main.go
@ -1,217 +0,0 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"strings"
|
||||
"net/url"
|
||||
"net/http"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
"github.com/op/go-logging"
|
||||
)
|
||||
|
||||
// Vars
|
||||
var fw *ForwardAuth;
|
||||
var log = logging.MustGetLogger("traefik-forward-auth")
|
||||
|
||||
// Primary handler
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse uri
|
||||
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||
if err != nil {
|
||||
log.Error("Error parsing url")
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Direct mode
|
||||
if fw.Direct {
|
||||
uri = r.URL
|
||||
}
|
||||
|
||||
// Handle callback
|
||||
if uri.Path == fw.Path {
|
||||
handleCallback(w, r, uri.Query())
|
||||
return
|
||||
}
|
||||
|
||||
c, err := r.Cookie(fw.CookieName)
|
||||
if err != nil {
|
||||
// Error indicates no cookie, generate nonce
|
||||
err, nonce := fw.Nonce()
|
||||
if err != nil {
|
||||
log.Error("Error generating nonce")
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the CSRF cookie
|
||||
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
||||
log.Debug("Set CSRF cookie and redirecting to google login")
|
||||
|
||||
// Forward them on
|
||||
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Validate cookie
|
||||
valid, email, err := fw.ValidateCookie(r, c)
|
||||
if !valid {
|
||||
log.Debugf("Invlaid cookie: %s", err)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate user
|
||||
valid = fw.ValidateEmail(email)
|
||||
if !valid {
|
||||
log.Debugf("Invalid email: %s", email)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Valid request
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
|
||||
// Authenticate user after they have come back from google
|
||||
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
||||
// Check for CSRF cookie
|
||||
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
||||
if err != nil {
|
||||
log.Debug("Missing csrf cookie")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state
|
||||
state := qs.Get("state")
|
||||
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
||||
if !valid {
|
||||
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
||||
if err != nil {
|
||||
log.Debugf("Code exchange failed with: %s\n", err)
|
||||
http.Error(w, "Service unavailable", 503)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := fw.GetUser(token)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting user: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate cookie
|
||||
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
||||
log.Debugf("Generated auth cookie for %s\n", user.Email)
|
||||
|
||||
// Redirect
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
|
||||
// Main
|
||||
func main() {
|
||||
// Parse options
|
||||
flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
|
||||
path := flag.String("url-path", "_oauth", "Callback URL")
|
||||
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
||||
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
||||
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
||||
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
||||
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
||||
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
||||
cookieSecret := flag.String("cookie-secret", "", "*Cookie secret (required)")
|
||||
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
||||
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
||||
direct := flag.Bool("direct", false, "Run in direct mode (use own hostname as oppose to X-Forwarded-Host, used for testing/development)")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// Check for show stopper errors
|
||||
err := false
|
||||
if *clientId == "" {
|
||||
err = true
|
||||
log.Critical("client-id must be set")
|
||||
}
|
||||
if *clientSecret == "" {
|
||||
err = true
|
||||
log.Critical("client-secret must be set")
|
||||
}
|
||||
if *cookieSecret == "" {
|
||||
err = true
|
||||
log.Critical("cookie-secret must be set")
|
||||
}
|
||||
if err {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse lists
|
||||
var cookieDomains []CookieDomain
|
||||
if *cookieDomainList != "" {
|
||||
for _, d := range strings.Split(*cookieDomainList, ",") {
|
||||
cookieDomain := NewCookieDomain(d)
|
||||
cookieDomains = append(cookieDomains, *cookieDomain)
|
||||
}
|
||||
}
|
||||
|
||||
var domain []string
|
||||
if *domainList != "" {
|
||||
domain = strings.Split(*domainList, ",")
|
||||
}
|
||||
|
||||
// Setup
|
||||
fw = &ForwardAuth{
|
||||
Path: fmt.Sprintf("/%s", *path),
|
||||
Lifetime: time.Second * time.Duration(*lifetime),
|
||||
|
||||
ClientId: *clientId,
|
||||
ClientSecret: *clientSecret,
|
||||
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
},
|
||||
TokenURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
},
|
||||
UserURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
},
|
||||
|
||||
CookieName: *cookieName,
|
||||
CSRFCookieName: *cSRFCookieName,
|
||||
CookieDomains: cookieDomains,
|
||||
CookieSecret: []byte(*cookieSecret),
|
||||
CookieSecure: *cookieSecure,
|
||||
|
||||
Domain: domain,
|
||||
|
||||
Direct: *direct,
|
||||
}
|
||||
|
||||
// Attach handler
|
||||
http.HandleFunc("/", handler)
|
||||
|
||||
log.Notice("Litening on :4181")
|
||||
log.Notice(http.ListenAndServe(":4181", nil))
|
||||
}
|
181
main_test.go
181
main_test.go
@ -1,181 +0,0 @@
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
// "reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"net/url"
|
||||
"net/http"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/op/go-logging"
|
||||
)
|
||||
|
||||
|
||||
type TokenServerHandler struct {}
|
||||
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
||||
}
|
||||
|
||||
type UserServerHandler struct {}
|
||||
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, `{
|
||||
"id":"1",
|
||||
"email":"example@example.com",
|
||||
"verified_email":true,
|
||||
"hd":"example.com"
|
||||
}`)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Remove for debugging
|
||||
logging.SetLevel(logging.INFO, "traefik-forward-auth")
|
||||
}
|
||||
|
||||
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Set cookies on recorder
|
||||
if c != nil {
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
|
||||
// Copy into request
|
||||
for _, c := range w.HeaderMap["Set-Cookie"] {
|
||||
r.Header.Add("Cookie", c)
|
||||
}
|
||||
|
||||
handler(w, r)
|
||||
|
||||
res := w.Result()
|
||||
body, _ := ioutil.ReadAll(res.Body)
|
||||
|
||||
return res, string(body)
|
||||
}
|
||||
|
||||
func newHttpRequest(uri string) *http.Request {
|
||||
r := httptest.NewRequest("", "http://example.com", nil)
|
||||
r.Header.Add("X-Forwarded-Uri", uri)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
fw = &ForwardAuth{
|
||||
Path: "_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CookieName: "cookie_test",
|
||||
Lifetime: time.Second * time.Duration(10),
|
||||
}
|
||||
|
||||
// Should redirect vanilla request to login url
|
||||
req := newHttpRequest("foo")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||
}
|
||||
|
||||
// Should catch invalid cookie
|
||||
req = newHttpRequest("foo")
|
||||
|
||||
c := fw.MakeCookie(req, "test@example.com")
|
||||
parts := strings.Split(c.Value, "|")
|
||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should validate email
|
||||
req = newHttpRequest("foo")
|
||||
|
||||
c = fw.MakeCookie(req, "test@example.com")
|
||||
fw.Domain = []string{"test.com"}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should allow valid request email
|
||||
req = newHttpRequest("foo")
|
||||
|
||||
c = fw.MakeCookie(req, "test@example.com")
|
||||
fw.Domain = []string{}
|
||||
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 200 {
|
||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
fw = &ForwardAuth{
|
||||
Path: "_oauth",
|
||||
ClientId: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "test.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
CSRFCookieName: "csrf_test",
|
||||
}
|
||||
|
||||
// Setup token server
|
||||
tokenServerHandler := &TokenServerHandler{}
|
||||
tokenServer := httptest.NewServer(tokenServerHandler)
|
||||
defer tokenServer.Close()
|
||||
tokenUrl, _ := url.Parse(tokenServer.URL)
|
||||
fw.TokenURL = tokenUrl
|
||||
|
||||
// Setup user server
|
||||
userServerHandler := &UserServerHandler{}
|
||||
userServer := httptest.NewServer(userServerHandler)
|
||||
defer userServer.Close()
|
||||
userUrl, _ := url.Parse(userServer.URL)
|
||||
fw.UserURL = userUrl
|
||||
|
||||
// Should pass auth response request to callback
|
||||
req := newHttpRequest("_oauth")
|
||||
res, _ := httpRequest(req, nil)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should catch invalid csrf cookie
|
||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 401 {
|
||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||
}
|
||||
|
||||
// Should redirect valid request
|
||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
res, _ = httpRequest(req, c)
|
||||
if res.StatusCode != 307 {
|
||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||
}
|
||||
fwd, _ := res.Location()
|
||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||
}
|
||||
}
|
1
test/config-legacy
Normal file
1
test/config-legacy
Normal file
@ -0,0 +1 @@
|
||||
url-path two
|
3
test/config0
Normal file
3
test/config0
Normal file
@ -0,0 +1,3 @@
|
||||
cookie-name=inicookiename
|
||||
csrf-cookie-name=inicsrfcookiename
|
||||
url-path=one
|
1
test/config1
Normal file
1
test/config1
Normal file
@ -0,0 +1 @@
|
||||
url-path=two
|
Reference in New Issue
Block a user