Begin refactor + selective auth
This commit is contained in:
parent
f1ba9b5ac4
commit
d51b93d4b0
@ -3,6 +3,8 @@ sudo: false
|
|||||||
go:
|
go:
|
||||||
- "1.10"
|
- "1.10"
|
||||||
install:
|
install:
|
||||||
|
- go get github.com/BurntSushi/toml
|
||||||
|
- go get github.com/gorilla/mux
|
||||||
- go get github.com/namsral/flag
|
- go get github.com/namsral/flag
|
||||||
- go get github.com/sirupsen/logrus
|
- go get github.com/sirupsen/logrus
|
||||||
script: go test -v ./...
|
script: go test -v ./...
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
FROM golang:1.10-alpine as builder
|
FROM golang:1.10-alpine as builder
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
RUN mkdir /app
|
RUN mkdir -p /go/src/github.com/thomseddon/traefik-forward-auth
|
||||||
WORKDIR /app
|
WORKDIR /go/src/github.com/thomseddon/traefik-forward-auth
|
||||||
|
|
||||||
# Add libraries
|
# Add libraries
|
||||||
RUN apk add --no-cache git && \
|
RUN apk add --no-cache git && \
|
||||||
|
go get "github.com/BurntSushi/toml" && \
|
||||||
|
go get "github.com/gorilla/mux" && \
|
||||||
go get "github.com/namsral/flag" && \
|
go get "github.com/namsral/flag" && \
|
||||||
go get "github.com/sirupsen/logrus" && \
|
go get "github.com/sirupsen/logrus" && \
|
||||||
apk del git
|
apk del git
|
||||||
|
|
||||||
# Copy & build
|
# Copy & build
|
||||||
ADD . /app/
|
ADD . /go/src/github.com/thomseddon/traefik-forward-auth/
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /traefik-forward-auth .
|
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /traefik-forward-auth .
|
||||||
|
|
||||||
# Copy into scratch container
|
# Copy into scratch container
|
||||||
|
2
Makefile
2
Makefile
@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
format:
|
format:
|
||||||
gofmt -w -s *.go
|
gofmt -w -s *.go provider/*.go
|
||||||
|
|
||||||
.PHONY: format
|
.PHONY: format
|
||||||
|
204
config.go
Normal file
204
config.go
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/BurntSushi/toml"
|
||||||
|
"github.com/namsral/flag"
|
||||||
|
"github.com/thomseddon/traefik-forward-auth/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
DefaultAction string
|
||||||
|
Path string
|
||||||
|
Lifetime time.Duration
|
||||||
|
Secret string
|
||||||
|
SecretBytes []byte
|
||||||
|
AuthHost string
|
||||||
|
|
||||||
|
LogLevel *string
|
||||||
|
LogFormat *string
|
||||||
|
TomlConfig *string // temp
|
||||||
|
|
||||||
|
CookieName string
|
||||||
|
CookieDomains []CookieDomain
|
||||||
|
CSRFCookieName string
|
||||||
|
CookieSecure bool
|
||||||
|
|
||||||
|
Domain []string
|
||||||
|
Whitelist []string
|
||||||
|
|
||||||
|
Providers provider.Providers
|
||||||
|
Rules map[string]Rules
|
||||||
|
}
|
||||||
|
|
||||||
|
type Rules struct {
|
||||||
|
Action string
|
||||||
|
Match []Match
|
||||||
|
}
|
||||||
|
|
||||||
|
type Match struct {
|
||||||
|
Host []string
|
||||||
|
PathPrefix []string
|
||||||
|
Header [][]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfig() *Config {
|
||||||
|
c := &Config{}
|
||||||
|
c.parseFlags()
|
||||||
|
c.applyDefaults()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Fix
|
||||||
|
// At the moment any flag value will overwrite the toml config
|
||||||
|
// Need to put the flag default values in applyDefaults & empty the flag
|
||||||
|
// defaults so we can check if they're being passed and set accordingly
|
||||||
|
// Ideally we also need to remove the two calls to parseFlags
|
||||||
|
//
|
||||||
|
// We also need to check the default -config flag for toml suffix and
|
||||||
|
// parse that as needed
|
||||||
|
//
|
||||||
|
// Ideally we'd also support multiple config files
|
||||||
|
|
||||||
|
func NewParsedConfig() *Config {
|
||||||
|
c := &Config{}
|
||||||
|
|
||||||
|
// Temp
|
||||||
|
c.parseFlags()
|
||||||
|
|
||||||
|
// Parse toml
|
||||||
|
if *c.TomlConfig != "" {
|
||||||
|
if _, err := toml.DecodeFile(*c.TomlConfig, &c); err != nil {
|
||||||
|
panic(err)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.applyDefaults()
|
||||||
|
|
||||||
|
// Conversions
|
||||||
|
c.SecretBytes = []byte(c.Secret)
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Checks() {
|
||||||
|
// Check for show stopper errors
|
||||||
|
if c.Providers.Google.ClientId == "" || c.Providers.Google.ClientSecret == "" || len(c.Secret) == 0 {
|
||||||
|
log.Fatal("client-id, client-secret and secret must all be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) applyDefaults() {
|
||||||
|
// Providers
|
||||||
|
// Google
|
||||||
|
if c.Providers.Google.Scope == "" {
|
||||||
|
c.Providers.Google.Scope = "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email"
|
||||||
|
}
|
||||||
|
if c.Providers.Google.LoginURL == nil {
|
||||||
|
c.Providers.Google.LoginURL = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "accounts.google.com",
|
||||||
|
Path: "/o/oauth2/auth",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Providers.Google.TokenURL == nil {
|
||||||
|
c.Providers.Google.TokenURL = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "www.googleapis.com",
|
||||||
|
Path: "/oauth2/v3/token",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Providers.Google.UserURL == nil {
|
||||||
|
c.Providers.Google.UserURL = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "www.googleapis.com",
|
||||||
|
Path: "/oauth2/v2/userinfo",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) parseFlags() {
|
||||||
|
c.LogLevel = flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
|
||||||
|
c.LogFormat = flag.String("log-format", "text", "Log format: text, json, pretty")
|
||||||
|
c.TomlConfig = flag.String("toml-config", "", "TEMP")
|
||||||
|
|
||||||
|
// Legacy?
|
||||||
|
path := flag.String("url-path", "_oauth", "Callback URL")
|
||||||
|
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
||||||
|
secret := flag.String("secret", "", "*Secret used for signing (required)")
|
||||||
|
authHost := flag.String("auth-host", "", "Central auth login")
|
||||||
|
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
||||||
|
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
||||||
|
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
||||||
|
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
||||||
|
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
||||||
|
cookieSecret := flag.String("cookie-secret", "", "Deprecated")
|
||||||
|
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
||||||
|
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
||||||
|
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
||||||
|
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Add to config
|
||||||
|
c.Path = fmt.Sprintf("/%s", *path)
|
||||||
|
c.Lifetime = time.Second * time.Duration(*lifetime)
|
||||||
|
c.AuthHost = *authHost
|
||||||
|
c.Providers.Google.ClientId = *clientId
|
||||||
|
c.Providers.Google.ClientSecret = *clientSecret
|
||||||
|
c.Providers.Google.Prompt = *prompt
|
||||||
|
c.CookieName = *cookieName
|
||||||
|
c.CSRFCookieName = *cSRFCookieName
|
||||||
|
c.CookieSecure = *cookieSecure
|
||||||
|
|
||||||
|
// Backwards compatibility
|
||||||
|
if *secret == "" && *cookieSecret != "" {
|
||||||
|
*secret = *cookieSecret
|
||||||
|
}
|
||||||
|
c.Secret = *secret
|
||||||
|
|
||||||
|
// Parse lists
|
||||||
|
if *cookieDomainList != "" {
|
||||||
|
for _, d := range strings.Split(*cookieDomainList, ",") {
|
||||||
|
cookieDomain := NewCookieDomain(d)
|
||||||
|
c.CookieDomains = append(c.CookieDomains, *cookieDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if *domainList != "" {
|
||||||
|
c.Domain = strings.Split(*domainList, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
if *emailWhitelist != "" {
|
||||||
|
c.Whitelist = strings.Split(*emailWhitelist, ",")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Temp
|
||||||
|
func (c Config) Walk() {
|
||||||
|
for name, rule := range c.Rules {
|
||||||
|
fmt.Printf("Rule: %s\n", name)
|
||||||
|
for _, match := range rule.Match {
|
||||||
|
if len(match.Host) > 0 {
|
||||||
|
for _, val := range match.Host {
|
||||||
|
fmt.Printf(" - Host: %s\n", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(match.PathPrefix) > 0 {
|
||||||
|
for _, val := range match.PathPrefix {
|
||||||
|
fmt.Printf(" - PathPrefix: %s\n", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(match.Header) > 0 {
|
||||||
|
for _, val := range match.Header {
|
||||||
|
fmt.Printf(" - Header: %s: %s\n", val[0], val[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
13
config_test.go
Normal file
13
config_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// import (
|
||||||
|
// "testing"
|
||||||
|
// )
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
// func TestMain(t *testing.T) {
|
||||||
|
|
||||||
|
// }
|
151
forwardauth.go
151
forwardauth.go
@ -5,41 +5,31 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
// "encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
// "net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/thomseddon/traefik-forward-auth/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ForwardAuthContext int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Nonce ForwardAuthContext = iota
|
||||||
|
Request
|
||||||
)
|
)
|
||||||
|
|
||||||
// Forward Auth
|
// Forward Auth
|
||||||
type ForwardAuth struct {
|
type ForwardAuth struct {
|
||||||
Path string
|
}
|
||||||
Lifetime time.Duration
|
|
||||||
Secret []byte
|
|
||||||
|
|
||||||
ClientId string
|
func NewForwardAuth() *ForwardAuth {
|
||||||
ClientSecret string `json:"-"`
|
return &ForwardAuth{}
|
||||||
Scope string
|
|
||||||
|
|
||||||
LoginURL *url.URL
|
|
||||||
TokenURL *url.URL
|
|
||||||
UserURL *url.URL
|
|
||||||
|
|
||||||
AuthHost string
|
|
||||||
|
|
||||||
CookieName string
|
|
||||||
CookieDomains []CookieDomain
|
|
||||||
CSRFCookieName string
|
|
||||||
CookieSecure bool
|
|
||||||
|
|
||||||
Domain []string
|
|
||||||
Whitelist []string
|
|
||||||
|
|
||||||
Prompt string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request Validation
|
// Request Validation
|
||||||
@ -85,18 +75,18 @@ func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, str
|
|||||||
// Validate email
|
// Validate email
|
||||||
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
||||||
found := false
|
found := false
|
||||||
if len(f.Whitelist) > 0 {
|
if len(config.Whitelist) > 0 {
|
||||||
for _, whitelist := range f.Whitelist {
|
for _, whitelist := range config.Whitelist {
|
||||||
if email == whitelist {
|
if email == whitelist {
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if len(f.Domain) > 0 {
|
} else if len(config.Domain) > 0 {
|
||||||
parts := strings.Split(email, "@")
|
parts := strings.Split(email, "@")
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for _, domain := range f.Domain {
|
for _, domain := range config.Domain {
|
||||||
if domain == parts[1] {
|
if domain == parts[1] {
|
||||||
found = true
|
found = true
|
||||||
}
|
}
|
||||||
@ -114,77 +104,24 @@ func (f *ForwardAuth) ValidateEmail(email string) bool {
|
|||||||
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
||||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
||||||
|
|
||||||
q := url.Values{}
|
// TODO: Support multiple providers
|
||||||
q.Set("client_id", fw.ClientId)
|
return config.Providers.Google.GetLoginURL(f.redirectUri(r), state)
|
||||||
q.Set("response_type", "code")
|
|
||||||
q.Set("scope", fw.Scope)
|
|
||||||
if fw.Prompt != "" {
|
|
||||||
q.Set("prompt", fw.Prompt)
|
|
||||||
}
|
|
||||||
q.Set("redirect_uri", f.redirectUri(r))
|
|
||||||
q.Set("state", state)
|
|
||||||
|
|
||||||
var u url.URL
|
|
||||||
u = *fw.LoginURL
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
return u.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
|
|
||||||
type Token struct {
|
func (f *ForwardAuth) ExchangeCode(r *http.Request) (string, error) {
|
||||||
Token string `json:"access_token"`
|
code := r.URL.Query().Get("code")
|
||||||
}
|
|
||||||
|
|
||||||
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
|
// TODO: Support multiple providers
|
||||||
form := url.Values{}
|
return config.Providers.Google.ExchangeCode(f.redirectUri(r), code)
|
||||||
form.Set("client_id", fw.ClientId)
|
|
||||||
form.Set("client_secret", fw.ClientSecret)
|
|
||||||
form.Set("grant_type", "authorization_code")
|
|
||||||
form.Set("redirect_uri", f.redirectUri(r))
|
|
||||||
form.Set("code", code)
|
|
||||||
|
|
||||||
res, err := http.PostForm(fw.TokenURL.String(), form)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var token Token
|
|
||||||
defer res.Body.Close()
|
|
||||||
err = json.NewDecoder(res.Body).Decode(&token)
|
|
||||||
|
|
||||||
return token.Token, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user with token
|
// Get user with token
|
||||||
|
|
||||||
type User struct {
|
func (f *ForwardAuth) GetUser(token string) (provider.User, error) {
|
||||||
Id string `json:"id"`
|
// TODO: Support multiple providers
|
||||||
Email string `json:"email"`
|
return config.Providers.Google.GetUser(token)
|
||||||
Verified bool `json:"verified_email"`
|
|
||||||
Hd string `json:"hd"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *ForwardAuth) GetUser(token string) (User, error) {
|
|
||||||
var user User
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return user, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
||||||
res, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return user, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer res.Body.Close()
|
|
||||||
err = json.NewDecoder(res.Body).Decode(&user)
|
|
||||||
|
|
||||||
return user, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Utility methods
|
// Utility methods
|
||||||
@ -197,7 +134,7 @@ func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
|||||||
return fmt.Sprintf("%s://%s", proto, host)
|
return fmt.Sprintf("%s://%s", proto, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return url
|
// // Return url
|
||||||
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
||||||
path := r.Header.Get("X-Forwarded-Uri")
|
path := r.Header.Get("X-Forwarded-Uri")
|
||||||
|
|
||||||
@ -208,15 +145,15 @@ func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
|||||||
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
||||||
if use, _ := f.useAuthDomain(r); use {
|
if use, _ := f.useAuthDomain(r); use {
|
||||||
proto := r.Header.Get("X-Forwarded-Proto")
|
proto := r.Header.Get("X-Forwarded-Proto")
|
||||||
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
|
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
return fmt.Sprintf("%s%s", f.redirectBase(r), config.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should we use auth host + what it is
|
// Should we use auth host + what it is
|
||||||
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
||||||
if f.AuthHost == "" {
|
if config.AuthHost == "" {
|
||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,7 +161,7 @@ func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
|||||||
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||||
|
|
||||||
// Do any of the auth hosts match a cookie domain?
|
// Do any of the auth hosts match a cookie domain?
|
||||||
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
|
authMatch, authHost := f.matchCookieDomains(config.AuthHost)
|
||||||
|
|
||||||
// We need both to match the same domain
|
// We need both to match the same domain
|
||||||
return reqMatch && authMatch && reqHost == authHost, reqHost
|
return reqMatch && authMatch && reqHost == authHost, reqHost
|
||||||
@ -239,12 +176,12 @@ func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
|||||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: f.CookieName,
|
Name: config.CookieName,
|
||||||
Value: value,
|
Value: value,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.cookieDomain(r),
|
Domain: f.cookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: f.CookieSecure,
|
Secure: config.CookieSecure,
|
||||||
Expires: expires,
|
Expires: expires,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -252,12 +189,12 @@ func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
|||||||
// Make a CSRF cookie (used during login only)
|
// Make a CSRF cookie (used during login only)
|
||||||
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: f.CSRFCookieName,
|
Name: config.CSRFCookieName,
|
||||||
Value: nonce,
|
Value: nonce,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.csrfCookieDomain(r),
|
Domain: f.csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: f.CookieSecure,
|
Secure: config.CookieSecure,
|
||||||
Expires: f.cookieExpiry(),
|
Expires: f.cookieExpiry(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -265,18 +202,20 @@ func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie
|
|||||||
// Create a cookie to clear csrf cookie
|
// Create a cookie to clear csrf cookie
|
||||||
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: f.CSRFCookieName,
|
Name: config.CSRFCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.csrfCookieDomain(r),
|
Domain: f.csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: f.CookieSecure,
|
Secure: config.CookieSecure,
|
||||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the csrf cookie against state
|
// // Validate the csrf cookie against state
|
||||||
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
|
func (f *ForwardAuth) ValidateCSRFCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
if len(c.Value) != 32 {
|
if len(c.Value) != 32 {
|
||||||
return false, "", errors.New("Invalid CSRF cookie value")
|
return false, "", errors.New("Invalid CSRF cookie value")
|
||||||
}
|
}
|
||||||
@ -333,7 +272,7 @@ func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
|||||||
// Remove port
|
// Remove port
|
||||||
p := strings.Split(domain, ":")
|
p := strings.Split(domain, ":")
|
||||||
|
|
||||||
for _, d := range f.CookieDomains {
|
for _, d := range config.CookieDomains {
|
||||||
if d.Match(p[0]) {
|
if d.Match(p[0]) {
|
||||||
return true, d.Domain
|
return true, d.Domain
|
||||||
}
|
}
|
||||||
@ -344,7 +283,7 @@ func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
|||||||
|
|
||||||
// Create cookie hmac
|
// Create cookie hmac
|
||||||
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
||||||
hash := hmac.New(sha256.New, f.Secret)
|
hash := hmac.New(sha256.New, config.SecretBytes)
|
||||||
hash.Write([]byte(f.cookieDomain(r)))
|
hash.Write([]byte(f.cookieDomain(r)))
|
||||||
hash.Write([]byte(email))
|
hash.Write([]byte(email))
|
||||||
hash.Write([]byte(expires))
|
hash.Write([]byte(expires))
|
||||||
@ -353,7 +292,7 @@ func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) st
|
|||||||
|
|
||||||
// Get cookie expirary
|
// Get cookie expirary
|
||||||
func (f *ForwardAuth) cookieExpiry() time.Time {
|
func (f *ForwardAuth) cookieExpiry() time.Time {
|
||||||
return time.Now().Local().Add(f.Lifetime)
|
return time.Now().Local().Add(config.Lifetime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookie Domain
|
// Cookie Domain
|
||||||
|
@ -1,16 +1,30 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/thomseddon/traefik-forward-auth/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestValidateCookie(t *testing.T) {
|
/**
|
||||||
|
* Setup
|
||||||
|
*/
|
||||||
|
|
||||||
|
func init() {
|
||||||
fw = &ForwardAuth{}
|
fw = &ForwardAuth{}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
func TestValidateCookie(t *testing.T) {
|
||||||
|
config = &Config{}
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
@ -39,7 +53,7 @@ func TestValidateCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should catch expired
|
// Should catch expired
|
||||||
fw.Lifetime = time.Second * time.Duration(-1)
|
config.Lifetime = time.Second * time.Duration(-1)
|
||||||
c = fw.MakeCookie(r, "test@test.com")
|
c = fw.MakeCookie(r, "test@test.com")
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = fw.ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Cookie has expired" {
|
if valid || err.Error() != "Cookie has expired" {
|
||||||
@ -47,7 +61,7 @@ func TestValidateCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should accept valid cookie
|
// Should accept valid cookie
|
||||||
fw.Lifetime = time.Second * time.Duration(10)
|
config.Lifetime = time.Second * time.Duration(10)
|
||||||
c = fw.MakeCookie(r, "test@test.com")
|
c = fw.MakeCookie(r, "test@test.com")
|
||||||
valid, email, err := fw.ValidateCookie(r, c)
|
valid, email, err := fw.ValidateCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
@ -62,7 +76,7 @@ func TestValidateCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateEmail(t *testing.T) {
|
func TestValidateEmail(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
config = &Config{}
|
||||||
|
|
||||||
// Should allow any
|
// Should allow any
|
||||||
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
||||||
@ -70,27 +84,27 @@ func TestValidateEmail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Should block non matching domain
|
// Should block non matching domain
|
||||||
fw.Domain = []string{"test.com"}
|
config.Domain = []string{"test.com"}
|
||||||
if fw.ValidateEmail("one@two.com") {
|
if fw.ValidateEmail("one@two.com") {
|
||||||
t.Error("Should not allow user from another domain")
|
t.Error("Should not allow user from another domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow matching domain
|
// Should allow matching domain
|
||||||
fw.Domain = []string{"test.com"}
|
config.Domain = []string{"test.com"}
|
||||||
if !fw.ValidateEmail("test@test.com") {
|
if !fw.ValidateEmail("test@test.com") {
|
||||||
t.Error("Should allow user from allowed domain")
|
t.Error("Should allow user from allowed domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should block non whitelisted email address
|
// Should block non whitelisted email address
|
||||||
fw.Domain = []string{}
|
config.Domain = []string{}
|
||||||
fw.Whitelist = []string{"test@test.com"}
|
config.Whitelist = []string{"test@test.com"}
|
||||||
if fw.ValidateEmail("one@two.com") {
|
if fw.ValidateEmail("one@two.com") {
|
||||||
t.Error("Should not allow user not in whitelist.")
|
t.Error("Should not allow user not in whitelist.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow matching whitelisted email address
|
// Should allow matching whitelisted email address
|
||||||
fw.Domain = []string{}
|
config.Domain = []string{}
|
||||||
fw.Whitelist = []string{"test@test.com"}
|
config.Whitelist = []string{"test@test.com"}
|
||||||
if !fw.ValidateEmail("test@test.com") {
|
if !fw.ValidateEmail("test@test.com") {
|
||||||
t.Error("Should allow user in whitelist.")
|
t.Error("Should allow user in whitelist.")
|
||||||
}
|
}
|
||||||
@ -102,8 +116,10 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||||
|
|
||||||
fw = &ForwardAuth{
|
config = &Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
|
Providers: provider.Providers{
|
||||||
|
Google: provider.Google{
|
||||||
ClientId: "idtest",
|
ClientId: "idtest",
|
||||||
ClientSecret: "sectest",
|
ClientSecret: "sectest",
|
||||||
Scope: "scopetest",
|
Scope: "scopetest",
|
||||||
@ -112,6 +128,8 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
Host: "test.com",
|
Host: "test.com",
|
||||||
Path: "/auth",
|
Path: "/auth",
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
@ -147,18 +165,22 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// With Auth URL but no matching cookie domain
|
// With Auth URL but no matching cookie domain
|
||||||
// - will not use auth host
|
// - will not use auth host
|
||||||
//
|
//
|
||||||
fw = &ForwardAuth{
|
config = &Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
AuthHost: "auth.example.com",
|
AuthHost: "auth.example.com",
|
||||||
|
Providers: provider.Providers{
|
||||||
|
Google: provider.Google{
|
||||||
ClientId: "idtest",
|
ClientId: "idtest",
|
||||||
ClientSecret: "sectest",
|
ClientSecret: "sectest",
|
||||||
Scope: "scopetest",
|
Scope: "scopetest",
|
||||||
|
Prompt: "consent select_account",
|
||||||
LoginURL: &url.URL{
|
LoginURL: &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: "test.com",
|
Host: "test.com",
|
||||||
Path: "/auth",
|
Path: "/auth",
|
||||||
},
|
},
|
||||||
Prompt: "consent select_account",
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
@ -195,18 +217,23 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// With correct Auth URL + cookie domain
|
// With correct Auth URL + cookie domain
|
||||||
//
|
//
|
||||||
cookieDomain := NewCookieDomain("example.com")
|
cookieDomain := NewCookieDomain("example.com")
|
||||||
fw = &ForwardAuth{
|
config = &Config{
|
||||||
Path: "/_oauth",
|
Path: "/_oauth",
|
||||||
AuthHost: "auth.example.com",
|
AuthHost: "auth.example.com",
|
||||||
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
|
Providers: provider.Providers{
|
||||||
|
Google: provider.Google{
|
||||||
ClientId: "idtest",
|
ClientId: "idtest",
|
||||||
ClientSecret: "sectest",
|
ClientSecret: "sectest",
|
||||||
Scope: "scopetest",
|
Scope: "scopetest",
|
||||||
|
Prompt: "consent select_account",
|
||||||
LoginURL: &url.URL{
|
LoginURL: &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: "test.com",
|
Host: "test.com",
|
||||||
Path: "/auth",
|
Path: "/auth",
|
||||||
},
|
},
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
@ -232,6 +259,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"scope": []string{"scopetest"},
|
"scope": []string{"scopetest"},
|
||||||
"state": []string{"nonce:http://example.com/hello"},
|
"state": []string{"nonce:http://example.com/hello"},
|
||||||
|
"prompt": []string{"consent select_account"},
|
||||||
}
|
}
|
||||||
qsDiff(expectedQs, qs)
|
qsDiff(expectedQs, qs)
|
||||||
if !reflect.DeepEqual(qs, expectedQs) {
|
if !reflect.DeepEqual(qs, expectedQs) {
|
||||||
@ -271,6 +299,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"scope": []string{"scopetest"},
|
"scope": []string{"scopetest"},
|
||||||
"state": []string{"nonce:http://another.com/hello"},
|
"state": []string{"nonce:http://another.com/hello"},
|
||||||
|
"prompt": []string{"consent select_account"},
|
||||||
}
|
}
|
||||||
qsDiff(expectedQs, qs)
|
qsDiff(expectedQs, qs)
|
||||||
if !reflect.DeepEqual(qs, expectedQs) {
|
if !reflect.DeepEqual(qs, expectedQs) {
|
||||||
@ -292,11 +321,11 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
func TestMakeCSRFCookie(t *testing.T) {
|
func TestMakeCSRFCookie(t *testing.T) {
|
||||||
|
config = &Config{}
|
||||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||||
|
|
||||||
// No cookie domain or auth url
|
// No cookie domain or auth url
|
||||||
fw = &ForwardAuth{}
|
|
||||||
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
if c.Domain != "app.example.com" {
|
if c.Domain != "app.example.com" {
|
||||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
@ -304,14 +333,16 @@ func TestMakeCSRFCookie(t *testing.T) {
|
|||||||
|
|
||||||
// With cookie domain but no auth url
|
// With cookie domain but no auth url
|
||||||
cookieDomain := NewCookieDomain("example.com")
|
cookieDomain := NewCookieDomain("example.com")
|
||||||
fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain}}
|
config = &Config{
|
||||||
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
|
}
|
||||||
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
if c.Domain != "app.example.com" {
|
if c.Domain != "app.example.com" {
|
||||||
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// With cookie domain and auth url
|
// With cookie domain and auth url
|
||||||
fw = &ForwardAuth{
|
config = &Config{
|
||||||
AuthHost: "auth.example.com",
|
AuthHost: "auth.example.com",
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
}
|
}
|
||||||
@ -322,7 +353,7 @@ func TestMakeCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClearCSRFCookie(t *testing.T) {
|
func TestClearCSRFCookie(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
config = &Config{}
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
|
||||||
c := fw.ClearCSRFCookie(r)
|
c := fw.ClearCSRFCookie(r)
|
||||||
@ -332,31 +363,40 @@ func TestClearCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateCSRFCookie(t *testing.T) {
|
func TestValidateCSRFCookie(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
config = &Config{}
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
|
newCsrfRequest := func(state string) *http.Request {
|
||||||
|
u := fmt.Sprintf("http://example.com?state=%s", state)
|
||||||
|
r, _ := http.NewRequest("GET", u, nil)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
// Should require 32 char string
|
// Should require 32 char string
|
||||||
|
r := newCsrfRequest("")
|
||||||
c.Value = ""
|
c.Value = ""
|
||||||
valid, _, err := fw.ValidateCSRFCookie(c, "")
|
valid, _, err := fw.ValidateCSRFCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||||
}
|
}
|
||||||
c.Value = "123456789012345678901234567890123"
|
c.Value = "123456789012345678901234567890123"
|
||||||
valid, _, err = fw.ValidateCSRFCookie(c, "")
|
valid, _, err = fw.ValidateCSRFCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should require valid state
|
// Should require valid state
|
||||||
|
r = newCsrfRequest("12345678901234567890123456789012:")
|
||||||
c.Value = "12345678901234567890123456789012"
|
c.Value = "12345678901234567890123456789012"
|
||||||
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
|
valid, _, err = fw.ValidateCSRFCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid CSRF state value" {
|
if valid || err.Error() != "Invalid CSRF state value" {
|
||||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow valid state
|
// Should allow valid state
|
||||||
|
r = newCsrfRequest("12345678901234567890123456789012:99")
|
||||||
c.Value = "12345678901234567890123456789012"
|
c.Value = "12345678901234567890123456789012"
|
||||||
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
|
valid, state, err := fw.ValidateCSRFCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("Valid request should return as valid")
|
t.Error("Valid request should return as valid")
|
||||||
}
|
}
|
||||||
@ -369,8 +409,6 @@ func TestValidateCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNonce(t *testing.T) {
|
func TestNonce(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
|
||||||
|
|
||||||
err, nonce1 := fw.Nonce()
|
err, nonce1 := fw.Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error generation nonce:", err)
|
t.Error("Error generation nonce:", err)
|
||||||
|
6
log.go
6
log.go
@ -6,13 +6,13 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
|
func NewLogger() logrus.FieldLogger {
|
||||||
// Setup logger
|
// Setup logger
|
||||||
log := logrus.StandardLogger()
|
log := logrus.StandardLogger()
|
||||||
logrus.SetOutput(os.Stdout)
|
logrus.SetOutput(os.Stdout)
|
||||||
|
|
||||||
// Set logger format
|
// Set logger format
|
||||||
switch logFormat {
|
switch *config.LogFormat {
|
||||||
case "pretty":
|
case "pretty":
|
||||||
break
|
break
|
||||||
case "json":
|
case "json":
|
||||||
@ -26,7 +26,7 @@ func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set logger level
|
// Set logger level
|
||||||
switch logLevel {
|
switch *config.LogLevel {
|
||||||
case "trace":
|
case "trace":
|
||||||
logrus.SetLevel(logrus.TraceLevel)
|
logrus.SetLevel(logrus.TraceLevel)
|
||||||
case "debug":
|
case "debug":
|
||||||
|
225
main.go
225
main.go
@ -2,237 +2,38 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/namsral/flag"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Vars
|
// Vars
|
||||||
var fw *ForwardAuth
|
var fw *ForwardAuth
|
||||||
var log logrus.FieldLogger
|
var log logrus.FieldLogger
|
||||||
|
var config *Config
|
||||||
// Primary handler
|
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Logging setup
|
|
||||||
logger := log.WithFields(logrus.Fields{
|
|
||||||
"SourceIP": r.Header.Get("X-Forwarded-For"),
|
|
||||||
})
|
|
||||||
logger.WithFields(logrus.Fields{
|
|
||||||
"Headers": r.Header,
|
|
||||||
}).Debugf("Handling request")
|
|
||||||
|
|
||||||
// Parse uri
|
|
||||||
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error parsing X-Forwarded-Uri, %v", err)
|
|
||||||
http.Error(w, "Service unavailable", 503)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle callback
|
|
||||||
if uri.Path == fw.Path {
|
|
||||||
logger.Debugf("Passing request to auth callback")
|
|
||||||
handleCallback(w, r, uri.Query(), logger)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get auth cookie
|
|
||||||
c, err := r.Cookie(fw.CookieName)
|
|
||||||
if err != nil {
|
|
||||||
// Error indicates no cookie, generate nonce
|
|
||||||
err, nonce := fw.Nonce()
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error generating nonce, %v", err)
|
|
||||||
http.Error(w, "Service unavailable", 503)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the CSRF cookie
|
|
||||||
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
|
||||||
logger.Debug("Set CSRF cookie and redirecting to google login")
|
|
||||||
|
|
||||||
// Forward them on
|
|
||||||
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate cookie
|
|
||||||
valid, email, err := fw.ValidateCookie(r, c)
|
|
||||||
if !valid {
|
|
||||||
logger.Errorf("Invalid cookie: %v", err)
|
|
||||||
http.Error(w, "Not authorized", 401)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate user
|
|
||||||
valid = fw.ValidateEmail(email)
|
|
||||||
if !valid {
|
|
||||||
logger.WithFields(logrus.Fields{
|
|
||||||
"email": email,
|
|
||||||
}).Errorf("Invalid email")
|
|
||||||
http.Error(w, "Not authorized", 401)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid request
|
|
||||||
logger.Debugf("Allowing valid request ")
|
|
||||||
w.Header().Set("X-Forwarded-User", email)
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate user after they have come back from google
|
|
||||||
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values,
|
|
||||||
logger logrus.FieldLogger) {
|
|
||||||
// Check for CSRF cookie
|
|
||||||
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Missing csrf cookie")
|
|
||||||
http.Error(w, "Not authorized", 401)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate state
|
|
||||||
state := qs.Get("state")
|
|
||||||
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
|
||||||
if !valid {
|
|
||||||
logger.WithFields(logrus.Fields{
|
|
||||||
"csrf": csrfCookie.Value,
|
|
||||||
"state": state,
|
|
||||||
}).Warnf("Error validating csrf cookie: %v", err)
|
|
||||||
http.Error(w, "Not authorized", 401)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear CSRF cookie
|
|
||||||
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
|
||||||
|
|
||||||
// Exchange code for token
|
|
||||||
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Code exchange failed with: %v", err)
|
|
||||||
http.Error(w, "Service unavailable", 503)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get user
|
|
||||||
user, err := fw.GetUser(token)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error getting user: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate cookie
|
|
||||||
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
|
||||||
logger.WithFields(logrus.Fields{
|
|
||||||
"user": user.Email,
|
|
||||||
}).Infof("Generated auth cookie")
|
|
||||||
|
|
||||||
// Redirect
|
|
||||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Main
|
// Main
|
||||||
func main() {
|
func main() {
|
||||||
// Parse options
|
// Parse config
|
||||||
flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
|
config = NewParsedConfig()
|
||||||
path := flag.String("url-path", "_oauth", "Callback URL")
|
|
||||||
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
|
||||||
secret := flag.String("secret", "", "*Secret used for signing (required)")
|
|
||||||
authHost := flag.String("auth-host", "", "Central auth login")
|
|
||||||
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
|
||||||
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
|
||||||
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
|
||||||
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
|
||||||
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
|
||||||
cookieSecret := flag.String("cookie-secret", "", "Deprecated")
|
|
||||||
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
|
||||||
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
|
||||||
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
|
||||||
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
|
||||||
logLevel := flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
|
|
||||||
logFormat := flag.String("log-format", "text", "Log format: text, json, pretty")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
// Setup logger
|
// Setup logger
|
||||||
log = CreateLogger(*logLevel, *logFormat)
|
log = NewLogger()
|
||||||
|
|
||||||
// Backwards compatibility
|
// Perform config checks
|
||||||
if *secret == "" && *cookieSecret != "" {
|
config.Checks()
|
||||||
*secret = *cookieSecret
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for show stopper errors
|
// Build forward auth handler
|
||||||
if *clientId == "" || *clientSecret == "" || *secret == "" {
|
fw = NewForwardAuth()
|
||||||
log.Fatal("client-id, client-secret and secret must all be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse lists
|
// Build server
|
||||||
var cookieDomains []CookieDomain
|
server := NewServer()
|
||||||
if *cookieDomainList != "" {
|
|
||||||
for _, d := range strings.Split(*cookieDomainList, ",") {
|
|
||||||
cookieDomain := NewCookieDomain(d)
|
|
||||||
cookieDomains = append(cookieDomains, *cookieDomain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var domain []string
|
// Attach router to default server
|
||||||
if *domainList != "" {
|
http.HandleFunc("/", server.RootHandler)
|
||||||
domain = strings.Split(*domainList, ",")
|
|
||||||
}
|
|
||||||
var whitelist []string
|
|
||||||
if *emailWhitelist != "" {
|
|
||||||
whitelist = strings.Split(*emailWhitelist, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup
|
|
||||||
fw = &ForwardAuth{
|
|
||||||
Path: fmt.Sprintf("/%s", *path),
|
|
||||||
Lifetime: time.Second * time.Duration(*lifetime),
|
|
||||||
Secret: []byte(*secret),
|
|
||||||
AuthHost: *authHost,
|
|
||||||
|
|
||||||
ClientId: *clientId,
|
|
||||||
ClientSecret: *clientSecret,
|
|
||||||
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "accounts.google.com",
|
|
||||||
Path: "/o/oauth2/auth",
|
|
||||||
},
|
|
||||||
TokenURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "www.googleapis.com",
|
|
||||||
Path: "/oauth2/v3/token",
|
|
||||||
},
|
|
||||||
UserURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "www.googleapis.com",
|
|
||||||
Path: "/oauth2/v2/userinfo",
|
|
||||||
},
|
|
||||||
|
|
||||||
CookieName: *cookieName,
|
|
||||||
CSRFCookieName: *cSRFCookieName,
|
|
||||||
CookieDomains: cookieDomains,
|
|
||||||
CookieSecure: *cookieSecure,
|
|
||||||
|
|
||||||
Domain: domain,
|
|
||||||
Whitelist: whitelist,
|
|
||||||
|
|
||||||
Prompt: *prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attach handler
|
|
||||||
http.HandleFunc("/", handler)
|
|
||||||
|
|
||||||
// Start
|
// Start
|
||||||
jsonConf, _ := json.Marshal(fw)
|
jsonConf, _ := json.Marshal(config)
|
||||||
log.Debugf("Starting with options: %s", string(jsonConf))
|
log.Debugf("Starting with options: %s", string(jsonConf))
|
||||||
log.Info("Listening on :4181")
|
log.Info("Listening on :4181")
|
||||||
log.Info(http.ListenAndServe(":4181", nil))
|
log.Info(http.ListenAndServe(":4181", nil))
|
||||||
|
207
main_test.go
207
main_test.go
@ -1,210 +1,13 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
// import (
|
||||||
"fmt"
|
// "testing"
|
||||||
"time"
|
// )
|
||||||
// "reflect"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utilities
|
|
||||||
*/
|
|
||||||
|
|
||||||
type TokenServerHandler struct{}
|
|
||||||
|
|
||||||
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserServerHandler struct{}
|
|
||||||
|
|
||||||
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fmt.Fprint(w, `{
|
|
||||||
"id":"1",
|
|
||||||
"email":"example@example.com",
|
|
||||||
"verified_email":true,
|
|
||||||
"hd":"example.com"
|
|
||||||
}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
log = CreateLogger("panic", "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Set cookies on recorder
|
|
||||||
if c != nil {
|
|
||||||
http.SetCookie(w, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy into request
|
|
||||||
for _, c := range w.HeaderMap["Set-Cookie"] {
|
|
||||||
r.Header.Add("Cookie", c)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler(w, r)
|
|
||||||
|
|
||||||
res := w.Result()
|
|
||||||
body, _ := ioutil.ReadAll(res.Body)
|
|
||||||
|
|
||||||
return res, string(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHttpRequest(uri string) *http.Request {
|
|
||||||
r := httptest.NewRequest("", "http://example.com", nil)
|
|
||||||
r.Header.Add("X-Forwarded-Uri", uri)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func qsDiff(one, two url.Values) {
|
|
||||||
for k := range one {
|
|
||||||
if two.Get(k) == "" {
|
|
||||||
fmt.Printf("Key missing: %s\n", k)
|
|
||||||
}
|
|
||||||
if one.Get(k) != two.Get(k) {
|
|
||||||
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range two {
|
|
||||||
if one.Get(k) == "" {
|
|
||||||
fmt.Printf("Extra key: %s\n", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests
|
* Tests
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func TestHandler(t *testing.T) {
|
// func TestMain(t *testing.T) {
|
||||||
fw = &ForwardAuth{
|
|
||||||
Path: "_oauth",
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
CookieName: "cookie_test",
|
|
||||||
Lifetime: time.Second * time.Duration(10),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should redirect vanilla request to login url
|
// }
|
||||||
req := newHttpRequest("foo")
|
|
||||||
res, _ := httpRequest(req, nil)
|
|
||||||
if res.StatusCode != 307 {
|
|
||||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
fwd, _ := res.Location()
|
|
||||||
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
|
||||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should catch invalid cookie
|
|
||||||
req = newHttpRequest("foo")
|
|
||||||
|
|
||||||
c := fw.MakeCookie(req, "test@example.com")
|
|
||||||
parts := strings.Split(c.Value, "|")
|
|
||||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
|
||||||
|
|
||||||
res, _ = httpRequest(req, c)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should validate email
|
|
||||||
req = newHttpRequest("foo")
|
|
||||||
|
|
||||||
c = fw.MakeCookie(req, "test@example.com")
|
|
||||||
fw.Domain = []string{"test.com"}
|
|
||||||
|
|
||||||
res, _ = httpRequest(req, c)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should allow valid request email
|
|
||||||
req = newHttpRequest("foo")
|
|
||||||
|
|
||||||
c = fw.MakeCookie(req, "test@example.com")
|
|
||||||
fw.Domain = []string{}
|
|
||||||
|
|
||||||
res, _ = httpRequest(req, c)
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should pass through user
|
|
||||||
users := res.Header["X-Forwarded-User"]
|
|
||||||
if len(users) != 1 {
|
|
||||||
t.Error("Valid request missing X-Forwarded-User header")
|
|
||||||
} else if users[0] != "test@example.com" {
|
|
||||||
t.Error("X-Forwarded-User should match user, got: ", users)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCallback(t *testing.T) {
|
|
||||||
fw = &ForwardAuth{
|
|
||||||
Path: "_oauth",
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
CSRFCookieName: "csrf_test",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup token server
|
|
||||||
tokenServerHandler := &TokenServerHandler{}
|
|
||||||
tokenServer := httptest.NewServer(tokenServerHandler)
|
|
||||||
defer tokenServer.Close()
|
|
||||||
tokenUrl, _ := url.Parse(tokenServer.URL)
|
|
||||||
fw.TokenURL = tokenUrl
|
|
||||||
|
|
||||||
// Setup user server
|
|
||||||
userServerHandler := &UserServerHandler{}
|
|
||||||
userServer := httptest.NewServer(userServerHandler)
|
|
||||||
defer userServer.Close()
|
|
||||||
userUrl, _ := url.Parse(userServer.URL)
|
|
||||||
fw.UserURL = userUrl
|
|
||||||
|
|
||||||
// Should pass auth response request to callback
|
|
||||||
req := newHttpRequest("_oauth")
|
|
||||||
res, _ := httpRequest(req, nil)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should catch invalid csrf cookie
|
|
||||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
|
||||||
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
|
||||||
res, _ = httpRequest(req, c)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should redirect valid request
|
|
||||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
|
||||||
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
|
||||||
res, _ = httpRequest(req, c)
|
|
||||||
if res.StatusCode != 307 {
|
|
||||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
fwd, _ := res.Location()
|
|
||||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
|
||||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
78
provider/google.go
Normal file
78
provider/google.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Google struct {
|
||||||
|
ClientId string
|
||||||
|
ClientSecret string `json:"-"`
|
||||||
|
Scope string
|
||||||
|
Prompt string
|
||||||
|
|
||||||
|
LoginURL *url.URL
|
||||||
|
TokenURL *url.URL
|
||||||
|
UserURL *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Google) GetLoginURL(redirectUri, state string) string {
|
||||||
|
q := url.Values{}
|
||||||
|
q.Set("client_id", g.ClientId)
|
||||||
|
q.Set("response_type", "code")
|
||||||
|
q.Set("scope", g.Scope)
|
||||||
|
if g.Prompt != "" {
|
||||||
|
q.Set("prompt", g.Prompt)
|
||||||
|
}
|
||||||
|
q.Set("redirect_uri", redirectUri)
|
||||||
|
q.Set("state", state)
|
||||||
|
|
||||||
|
var u url.URL
|
||||||
|
u = *g.LoginURL
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
|
||||||
|
form := url.Values{}
|
||||||
|
form.Set("client_id", g.ClientId)
|
||||||
|
form.Set("client_secret", g.ClientSecret)
|
||||||
|
form.Set("grant_type", "authorization_code")
|
||||||
|
form.Set("redirect_uri", redirectUri)
|
||||||
|
form.Set("code", code)
|
||||||
|
|
||||||
|
res, err := http.PostForm(g.TokenURL.String(), form)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var token Token
|
||||||
|
defer res.Body.Close()
|
||||||
|
err = json.NewDecoder(res.Body).Decode(&token)
|
||||||
|
|
||||||
|
return token.Token, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Google) GetUser(token string) (User, error) {
|
||||||
|
var user User
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
req, err := http.NewRequest("GET", g.UserURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
err = json.NewDecoder(res.Body).Decode(&user)
|
||||||
|
|
||||||
|
return user, err
|
||||||
|
}
|
16
provider/providers.go
Normal file
16
provider/providers.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
type Providers struct {
|
||||||
|
Google Google
|
||||||
|
}
|
||||||
|
|
||||||
|
type Token struct {
|
||||||
|
Token string `json:"access_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Verified bool `json:"verified_email"`
|
||||||
|
Hd string `json:"hd"`
|
||||||
|
}
|
200
server.go
Normal file
200
server.go
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
// "fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
mux *mux.Router
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer() *Server {
|
||||||
|
s := &Server{}
|
||||||
|
s.buildRoutes()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) buildRoutes() {
|
||||||
|
s.mux = mux.NewRouter()
|
||||||
|
|
||||||
|
// Let's build a server
|
||||||
|
for _, rules := range config.Rules {
|
||||||
|
// fmt.Printf("Rule: %s\n", name)
|
||||||
|
for _, match := range rules.Match {
|
||||||
|
s.attachHandler(&match, rules.Action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add callback handler
|
||||||
|
s.mux.Handle(config.Path, s.AuthCallbackHandler())
|
||||||
|
|
||||||
|
// Add a default handler
|
||||||
|
s.mux.NewRoute().Handler(s.AuthHandler())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Modify request
|
||||||
|
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||||
|
|
||||||
|
// Pass to mux
|
||||||
|
s.mux.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler that allows requests
|
||||||
|
func (s *Server) AllowHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.logger(r, "Allowing request")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate requests
|
||||||
|
func (s *Server) AuthHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Logging setup
|
||||||
|
logger := s.logger(r, "Authenticating request")
|
||||||
|
|
||||||
|
// Get auth cookie
|
||||||
|
c, err := r.Cookie(config.CookieName)
|
||||||
|
if err != nil {
|
||||||
|
// Error indicates no cookie, generate nonce
|
||||||
|
err, nonce := fw.Nonce()
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error generating nonce, %v", err)
|
||||||
|
http.Error(w, "Service unavailable", 503)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the CSRF cookie
|
||||||
|
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
||||||
|
logger.Debug("Set CSRF cookie and redirecting to google login")
|
||||||
|
|
||||||
|
// Forward them on
|
||||||
|
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate cookie
|
||||||
|
valid, email, err := fw.ValidateCookie(r, c)
|
||||||
|
if !valid {
|
||||||
|
logger.Errorf("Invalid cookie: %v", err)
|
||||||
|
http.Error(w, "Not authorized", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate user
|
||||||
|
valid = fw.ValidateEmail(email)
|
||||||
|
if !valid {
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"email": email,
|
||||||
|
}).Errorf("Invalid email")
|
||||||
|
http.Error(w, "Not authorized", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid request
|
||||||
|
logger.Debugf("Allowing valid request ")
|
||||||
|
w.Header().Set("X-Forwarded-User", email)
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle auth callback
|
||||||
|
func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Logging setup
|
||||||
|
logger := s.logger(r, "Handling callback")
|
||||||
|
|
||||||
|
// Check for CSRF cookie
|
||||||
|
c, err := r.Cookie(config.CSRFCookieName)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Missing csrf cookie")
|
||||||
|
http.Error(w, "Not authorized", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate state
|
||||||
|
valid, redirect, err := fw.ValidateCSRFCookie(r, c)
|
||||||
|
if !valid {
|
||||||
|
logger.Warnf("Error validating csrf cookie: %v", err)
|
||||||
|
http.Error(w, "Not authorized", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear CSRF cookie
|
||||||
|
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
||||||
|
|
||||||
|
// Exchange code for token
|
||||||
|
token, err := fw.ExchangeCode(r)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Code exchange failed with: %v", err)
|
||||||
|
http.Error(w, "Service unavailable", 503)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user
|
||||||
|
user, err := fw.GetUser(token)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error getting user: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate cookie
|
||||||
|
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"user": user.Email,
|
||||||
|
}).Infof("Generated auth cookie")
|
||||||
|
|
||||||
|
// Redirect
|
||||||
|
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a handler for a given matcher
|
||||||
|
func (s *Server) attachHandler(m *Match, action string) {
|
||||||
|
// Build a new route matcher
|
||||||
|
route := s.mux.NewRoute()
|
||||||
|
|
||||||
|
for _, host := range m.Host {
|
||||||
|
route.Host(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pathPrefix := range m.PathPrefix {
|
||||||
|
route.PathPrefix(pathPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, header := range m.Header {
|
||||||
|
if len(header) != 2 {
|
||||||
|
panic("todo")
|
||||||
|
}
|
||||||
|
|
||||||
|
route.Headers(header[0], header[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add handler to new route
|
||||||
|
if action == "allow" {
|
||||||
|
route.Handler(s.AllowHandler())
|
||||||
|
} else {
|
||||||
|
route.Handler(s.AuthHandler())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) logger(r *http.Request, msg string) *logrus.Entry {
|
||||||
|
// Create logger
|
||||||
|
logger := log.WithFields(logrus.Fields{
|
||||||
|
"RemoteAddr": r.RemoteAddr,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Log request
|
||||||
|
logger.WithFields(logrus.Fields{
|
||||||
|
"Headers": r.Header,
|
||||||
|
}).Debugf(msg)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
}
|
277
server_test.go
Normal file
277
server_test.go
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
// "reflect"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/thomseddon/traefik-forward-auth/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setup
|
||||||
|
*/
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
fw = &ForwardAuth{}
|
||||||
|
config = NewConfig()
|
||||||
|
|
||||||
|
logLevel := "panic"
|
||||||
|
config.LogLevel = &logLevel
|
||||||
|
log = NewLogger()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utilities
|
||||||
|
*/
|
||||||
|
|
||||||
|
type TokenServerHandler struct{}
|
||||||
|
|
||||||
|
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserServerHandler struct{}
|
||||||
|
|
||||||
|
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, `{
|
||||||
|
"id":"1",
|
||||||
|
"email":"example@example.com",
|
||||||
|
"verified_email":true,
|
||||||
|
"hd":"example.com"
|
||||||
|
}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpRequest(s *Server, r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Set cookies on recorder
|
||||||
|
if c != nil {
|
||||||
|
http.SetCookie(w, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy into request
|
||||||
|
for _, c := range w.HeaderMap["Set-Cookie"] {
|
||||||
|
r.Header.Add("Cookie", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.RootHandler(w, r)
|
||||||
|
|
||||||
|
res := w.Result()
|
||||||
|
body, _ := ioutil.ReadAll(res.Body)
|
||||||
|
|
||||||
|
// if res.StatusCode > 300 && res.StatusCode < 400 {
|
||||||
|
// fmt.Printf("%#v", res.Header)
|
||||||
|
// }
|
||||||
|
|
||||||
|
return res, string(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHttpRequest(uri string) *http.Request {
|
||||||
|
r := httptest.NewRequest("", "http://example.com/", nil)
|
||||||
|
r.Header.Add("X-Forwarded-Uri", uri)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func qsDiff(one, two url.Values) {
|
||||||
|
for k := range one {
|
||||||
|
if two.Get(k) == "" {
|
||||||
|
fmt.Printf("Key missing: %s\n", k)
|
||||||
|
}
|
||||||
|
if one.Get(k) != two.Get(k) {
|
||||||
|
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range two {
|
||||||
|
if one.Get(k) == "" {
|
||||||
|
fmt.Printf("Extra key: %s\n", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
func TestServerHandler(t *testing.T) {
|
||||||
|
server := NewServer()
|
||||||
|
|
||||||
|
config = &Config{
|
||||||
|
Path: "/_oauth",
|
||||||
|
CookieName: "cookie_test",
|
||||||
|
Lifetime: time.Second * time.Duration(10),
|
||||||
|
Providers: provider.Providers{
|
||||||
|
Google: provider.Google{
|
||||||
|
ClientId: "idtest",
|
||||||
|
ClientSecret: "sectest",
|
||||||
|
Scope: "scopetest",
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "test.com",
|
||||||
|
Path: "/auth",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should redirect vanilla request to login url
|
||||||
|
req := newHttpRequest("/foo")
|
||||||
|
res, _ := httpRequest(server, req, nil)
|
||||||
|
if res.StatusCode != 307 {
|
||||||
|
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
fwd, _ := res.Location()
|
||||||
|
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
||||||
|
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should catch invalid cookie
|
||||||
|
req = newHttpRequest("/foo")
|
||||||
|
|
||||||
|
c := fw.MakeCookie(req, "test@example.com")
|
||||||
|
parts := strings.Split(c.Value, "|")
|
||||||
|
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||||
|
|
||||||
|
res, _ = httpRequest(server, req, c)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should validate email
|
||||||
|
req = newHttpRequest("/foo")
|
||||||
|
|
||||||
|
c = fw.MakeCookie(req, "test@example.com")
|
||||||
|
config.Domain = []string{"test.com"}
|
||||||
|
|
||||||
|
res, _ = httpRequest(server, req, c)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should allow valid request email
|
||||||
|
req = newHttpRequest("/foo")
|
||||||
|
|
||||||
|
c = fw.MakeCookie(req, "test@example.com")
|
||||||
|
config.Domain = []string{}
|
||||||
|
|
||||||
|
res, _ = httpRequest(server, req, c)
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should pass through user
|
||||||
|
users := res.Header["X-Forwarded-User"]
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Error("Valid request missing X-Forwarded-User header")
|
||||||
|
} else if users[0] != "test@example.com" {
|
||||||
|
t.Error("X-Forwarded-User should match user, got: ", users)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerAuthCallback(t *testing.T) {
|
||||||
|
server := NewServer()
|
||||||
|
config = &Config{
|
||||||
|
Path: "/_oauth",
|
||||||
|
CookieName: "cookie_test",
|
||||||
|
Lifetime: time.Second * time.Duration(10),
|
||||||
|
CSRFCookieName: "csrf_test",
|
||||||
|
Providers: provider.Providers{
|
||||||
|
Google: provider.Google{
|
||||||
|
ClientId: "idtest",
|
||||||
|
ClientSecret: "sectest",
|
||||||
|
Scope: "scopetest",
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "test.com",
|
||||||
|
Path: "/auth",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup token server
|
||||||
|
tokenServerHandler := &TokenServerHandler{}
|
||||||
|
tokenServer := httptest.NewServer(tokenServerHandler)
|
||||||
|
defer tokenServer.Close()
|
||||||
|
tokenUrl, _ := url.Parse(tokenServer.URL)
|
||||||
|
config.Providers.Google.TokenURL = tokenUrl
|
||||||
|
|
||||||
|
// Setup user server
|
||||||
|
userServerHandler := &UserServerHandler{}
|
||||||
|
userServer := httptest.NewServer(userServerHandler)
|
||||||
|
defer userServer.Close()
|
||||||
|
userUrl, _ := url.Parse(userServer.URL)
|
||||||
|
config.Providers.Google.UserURL = userUrl
|
||||||
|
|
||||||
|
// Should pass auth response request to callback
|
||||||
|
req := newHttpRequest("/_oauth")
|
||||||
|
res, _ := httpRequest(server, req, nil)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should catch invalid csrf cookie
|
||||||
|
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
|
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
||||||
|
res, _ = httpRequest(server, req, c)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should redirect valid request
|
||||||
|
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
|
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||||
|
res, _ = httpRequest(server, req, c)
|
||||||
|
if res.StatusCode != 307 {
|
||||||
|
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
fwd, _ := res.Location()
|
||||||
|
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||||
|
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerMatcherPathPrefix(t *testing.T) {
|
||||||
|
server := NewServer()
|
||||||
|
config = &Config{
|
||||||
|
Path: "/_oauth",
|
||||||
|
CookieName: "cookie_test",
|
||||||
|
Lifetime: time.Second * time.Duration(10),
|
||||||
|
Providers: provider.Providers{
|
||||||
|
Google: provider.Google{
|
||||||
|
ClientId: "idtest",
|
||||||
|
ClientSecret: "sectest",
|
||||||
|
Scope: "scopetest",
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "test.com",
|
||||||
|
Path: "/auth",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Rules: map[string]Rules{
|
||||||
|
"rule1": {
|
||||||
|
Action: "allow",
|
||||||
|
Match: []Match{
|
||||||
|
{
|
||||||
|
PathPrefix: []string{"/api"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should allow /api request
|
||||||
|
req := newHttpRequest("/api")
|
||||||
|
c := fw.MakeCookie(req, "test@example.com")
|
||||||
|
res, _ := httpRequest(server, req, c)
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user