Add more v2 tests + fixes + improve legacy config parsing
This commit is contained in:
parent
43775591fa
commit
d890a4aad6
11
Dockerfile
11
Dockerfile
@ -1,20 +1,15 @@
|
|||||||
FROM golang:1.10-alpine as builder
|
FROM golang:1.12-alpine as builder
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
RUN mkdir -p /go/src/github.com/thomseddon/traefik-forward-auth
|
RUN mkdir -p /go/src/github.com/thomseddon/traefik-forward-auth
|
||||||
WORKDIR /go/src/github.com/thomseddon/traefik-forward-auth
|
WORKDIR /go/src/github.com/thomseddon/traefik-forward-auth
|
||||||
|
|
||||||
# Add libraries
|
# Add libraries
|
||||||
RUN apk add --no-cache git && \
|
RUN apk add --no-cache git
|
||||||
go get "github.com/BurntSushi/toml" && \
|
|
||||||
go get "github.com/gorilla/mux" && \
|
|
||||||
go get "github.com/namsral/flag" && \
|
|
||||||
go get "github.com/sirupsen/logrus" && \
|
|
||||||
apk del git
|
|
||||||
|
|
||||||
# Copy & build
|
# Copy & build
|
||||||
ADD . /go/src/github.com/thomseddon/traefik-forward-auth/
|
ADD . /go/src/github.com/thomseddon/traefik-forward-auth/
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix nocgo -o /traefik-forward-auth .
|
RUN CGO_ENABLED=0 GOOS=linux GO111MODULE=on go build -a -installsuffix nocgo -o /traefik-forward-auth github.com/thomseddon/traefik-forward-auth/cmd
|
||||||
|
|
||||||
# Copy into scratch container
|
# Copy into scratch container
|
||||||
FROM scratch
|
FROM scratch
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// import (
|
|
||||||
// "testing"
|
|
||||||
// )
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tests
|
|
||||||
*/
|
|
||||||
|
|
||||||
// func TestMain(t *testing.T) {
|
|
||||||
|
|
||||||
// }
|
|
@ -164,7 +164,7 @@ func MakeCookie(r *http.Request, email string) *http.Cookie {
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: cookieDomain(r),
|
Domain: cookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: !config.CookieInsecure,
|
Secure: !config.InsecureCookie,
|
||||||
Expires: expires,
|
Expires: expires,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,7 +177,7 @@ func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: csrfCookieDomain(r),
|
Domain: csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: !config.CookieInsecure,
|
Secure: !config.InsecureCookie,
|
||||||
Expires: cookieExpiry(),
|
Expires: cookieExpiry(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,7 +190,7 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: csrfCookieDomain(r),
|
Domain: csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: !config.CookieInsecure,
|
Secure: !config.InsecureCookie,
|
||||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -325,5 +325,9 @@ func (c *CookieDomains) UnmarshalFlag(value string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *CookieDomains) MarshalFlag() (string, error) {
|
func (c *CookieDomains) MarshalFlag() (string, error) {
|
||||||
return fmt.Sprintf("%+v", *c), nil
|
var domains []string
|
||||||
|
for _, d := range *c {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
}
|
||||||
|
return strings.Join(domains, ","), nil
|
||||||
}
|
}
|
||||||
|
@ -5,26 +5,19 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
|
||||||
* Setup
|
|
||||||
*/
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// fw = &ForwardAuth{}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests
|
* Tests
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func TestValidateCookie(t *testing.T) {
|
func TestAuthValidateCookie(t *testing.T) {
|
||||||
config = Config{}
|
config, _ = NewConfig([]string{})
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
@ -75,8 +68,8 @@ func TestValidateCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateEmail(t *testing.T) {
|
func TestAuthValidateEmail(t *testing.T) {
|
||||||
config = Config{}
|
config, _ = NewConfig([]string{})
|
||||||
|
|
||||||
// Should allow any
|
// Should allow any
|
||||||
if !ValidateEmail("test@test.com") || !ValidateEmail("one@two.com") {
|
if !ValidateEmail("test@test.com") || !ValidateEmail("one@two.com") {
|
||||||
@ -110,28 +103,28 @@ func TestValidateEmail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetLoginURL(t *testing.T) {
|
// TODO: Split google tests out
|
||||||
|
func TestAuthGetLoginURL(t *testing.T) {
|
||||||
|
google := provider.Google{
|
||||||
|
ClientId: "idtest",
|
||||||
|
ClientSecret: "sectest",
|
||||||
|
Scope: "scopetest",
|
||||||
|
Prompt: "consent select_account",
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "test.com",
|
||||||
|
Path: "/auth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
config.Providers.Google = google
|
||||||
|
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
r.Header.Add("X-Forwarded-Proto", "http")
|
r.Header.Add("X-Forwarded-Proto", "http")
|
||||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||||
|
|
||||||
config = Config{
|
|
||||||
Path: "/_oauth",
|
|
||||||
Providers: provider.Providers{
|
|
||||||
Google: provider.Google{
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
uri, err := url.Parse(GetLoginURL(r, "nonce"))
|
uri, err := url.Parse(GetLoginURL(r, "nonce"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -154,6 +147,7 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||||
"response_type": []string{"code"},
|
"response_type": []string{"code"},
|
||||||
"scope": []string{"scopetest"},
|
"scope": []string{"scopetest"},
|
||||||
|
"prompt": []string{"consent select_account"},
|
||||||
"state": []string{"nonce:http://example.com/hello"},
|
"state": []string{"nonce:http://example.com/hello"},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(qs, expectedQs) {
|
if !reflect.DeepEqual(qs, expectedQs) {
|
||||||
@ -166,23 +160,9 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// With Auth URL but no matching cookie domain
|
// With Auth URL but no matching cookie domain
|
||||||
// - will not use auth host
|
// - will not use auth host
|
||||||
//
|
//
|
||||||
config = Config{
|
config, _ = NewConfig([]string{})
|
||||||
Path: "/_oauth",
|
config.AuthHost = "auth.example.com"
|
||||||
AuthHost: "auth.example.com",
|
config.Providers.Google = google
|
||||||
Providers: provider.Providers{
|
|
||||||
Google: provider.Google{
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
Prompt: "consent select_account",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||||
@ -218,25 +198,10 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
//
|
//
|
||||||
// With correct Auth URL + cookie domain
|
// With correct Auth URL + cookie domain
|
||||||
//
|
//
|
||||||
cookieDomain := NewCookieDomain("example.com")
|
config, _ = NewConfig([]string{})
|
||||||
config = Config{
|
config.AuthHost = "auth.example.com"
|
||||||
Path: "/_oauth",
|
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||||
AuthHost: "auth.example.com",
|
config.Providers.Google = google
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
|
||||||
Providers: provider.Providers{
|
|
||||||
Google: provider.Google{
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
Prompt: "consent select_account",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check url
|
// Check url
|
||||||
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
uri, err = url.Parse(GetLoginURL(r, "nonce"))
|
||||||
@ -317,19 +282,59 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
// func TestExchangeCode(t *testing.T) {
|
// func TestAuthExchangeCode(t *testing.T) {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
// func TestGetUser(t *testing.T) {
|
// func TestAuthGetUser(t *testing.T) {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// TODO? Tested in TestValidateCookie
|
func TestAuthMakeCookie(t *testing.T) {
|
||||||
// func TestMakeCookie(t *testing.T) {
|
config, _ = NewConfig([]string{})
|
||||||
// }
|
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||||
|
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||||
|
|
||||||
func TestMakeCSRFCookie(t *testing.T) {
|
c := MakeCookie(r, "test@example.com")
|
||||||
config = Config{}
|
if c.Name != "_forward_auth" {
|
||||||
|
t.Error("Cookie name should be \"_forward_auth\", got:", c.Name)
|
||||||
|
}
|
||||||
|
parts := strings.Split(c.Value, "|")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Error("Cookie should be in 3 parts, got:", c.Value)
|
||||||
|
}
|
||||||
|
valid, _, _ := ValidateCookie(r, c)
|
||||||
|
if !valid {
|
||||||
|
t.Error("Should generate valid cookie:", c.Value)
|
||||||
|
}
|
||||||
|
if c.Path != "/" {
|
||||||
|
t.Error("Cookie path should be \"/\", got:", c.Path)
|
||||||
|
}
|
||||||
|
if c.Domain != "app.example.com" {
|
||||||
|
t.Error("Cookie domain should be \"app.example.com\", got:", c.Domain)
|
||||||
|
}
|
||||||
|
if c.Secure != true {
|
||||||
|
t.Error("Cookie domain should be true, got:", c.Secure)
|
||||||
|
}
|
||||||
|
if !c.Expires.After(time.Now().Local()) {
|
||||||
|
t.Error("Expires should be after now, got:", c.Expires)
|
||||||
|
}
|
||||||
|
if !c.Expires.Before(time.Now().Local().Add(config.Lifetime).Add(10 * time.Second)) {
|
||||||
|
t.Error("Expires should be before lifetime + 10 seconds, got:", c.Expires)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.CookieName = "testname"
|
||||||
|
config.InsecureCookie = true
|
||||||
|
c = MakeCookie(r, "test@example.com")
|
||||||
|
if c.Name != "testname" {
|
||||||
|
t.Error("Cookie name should be \"testname\", got:", c.Name)
|
||||||
|
}
|
||||||
|
if c.Secure != false {
|
||||||
|
t.Error("Cookie domain should be false, got:", c.Secure)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthMakeCSRFCookie(t *testing.T) {
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||||
|
|
||||||
@ -340,9 +345,8 @@ func TestMakeCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// With cookie domain but no auth url
|
// With cookie domain but no auth url
|
||||||
cookieDomain := NewCookieDomain("example.com")
|
|
||||||
config = Config{
|
config = Config{
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||||
}
|
}
|
||||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
if c.Domain != "app.example.com" {
|
if c.Domain != "app.example.com" {
|
||||||
@ -352,7 +356,7 @@ func TestMakeCSRFCookie(t *testing.T) {
|
|||||||
// With cookie domain and auth url
|
// With cookie domain and auth url
|
||||||
config = Config{
|
config = Config{
|
||||||
AuthHost: "auth.example.com",
|
AuthHost: "auth.example.com",
|
||||||
CookieDomains: []CookieDomain{*cookieDomain},
|
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||||
}
|
}
|
||||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
if c.Domain != "example.com" {
|
if c.Domain != "example.com" {
|
||||||
@ -360,8 +364,8 @@ func TestMakeCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClearCSRFCookie(t *testing.T) {
|
func TestAuthClearCSRFCookie(t *testing.T) {
|
||||||
config = Config{}
|
config, _ = NewConfig([]string{})
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
|
||||||
c := ClearCSRFCookie(r)
|
c := ClearCSRFCookie(r)
|
||||||
@ -370,8 +374,8 @@ func TestClearCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateCSRFCookie(t *testing.T) {
|
func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||||
config = Config{}
|
config, _ = NewConfig([]string{})
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
newCsrfRequest := func(state string) *http.Request {
|
newCsrfRequest := func(state string) *http.Request {
|
||||||
@ -416,7 +420,7 @@ func TestValidateCSRFCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNonce(t *testing.T) {
|
func TestAuthNonce(t *testing.T) {
|
||||||
err, nonce1 := Nonce()
|
err, nonce1 := Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error generation nonce:", err)
|
t.Error("Error generation nonce:", err)
|
||||||
@ -435,7 +439,7 @@ func TestNonce(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCookieDomainMatch(t *testing.T) {
|
func TestAuthCookieDomainMatch(t *testing.T) {
|
||||||
cd := NewCookieDomain("example.com")
|
cd := NewCookieDomain("example.com")
|
||||||
|
|
||||||
// Exact should match
|
// Exact should match
|
||||||
@ -458,3 +462,29 @@ func TestCookieDomainMatch(t *testing.T) {
|
|||||||
t.Error("Other domain should not match")
|
t.Error("Other domain should not match")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthCookieDomains(t *testing.T) {
|
||||||
|
cds := CookieDomains{}
|
||||||
|
|
||||||
|
err := cds.UnmarshalFlag("one.com,two.org")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if len(cds) != 2 {
|
||||||
|
t.Error("Expected UnmarshalFlag to provide 2 CookieDomains, got", cds)
|
||||||
|
}
|
||||||
|
if cds[0].Domain != "one.com" || cds[0].SubDomain != ".one.com" {
|
||||||
|
t.Error("Expected UnmarshalFlag to provide one.com, got", cds[0])
|
||||||
|
}
|
||||||
|
if cds[1].Domain != "two.org" || cds[1].SubDomain != ".two.org" {
|
||||||
|
t.Error("Expected UnmarshalFlag to provide two.org, got", cds[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
marshal, err := cds.MarshalFlag()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if marshal != "one.com,two.org" {
|
||||||
|
t.Error("Expected MarshalFlag to provide \"one.com,two.org\", got", cds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
package tfa
|
package tfa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -16,48 +20,41 @@ import (
|
|||||||
var config Config
|
var config Config
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LogLevel string `long:"log-level" default:"warn" description:"Log level: trace, debug, info, warn, error, fatal, panic"`
|
LogLevel string `long:"log-level" env:"LOG_LEVEL" default:"warn" choice:"trace" choice:"debug" choice:"info" choice:"warn" choice:"error" choice:"fatal" choice:"panic" description:"Log level"`
|
||||||
LogFormat string `long:"log-format" default:"text" description:"Log format: text, json, pretty"`
|
LogFormat string `long:"log-format" env:"LOG_FORMAT" default:"text" choice:"text" choice:"json" choice:"pretty" description:"Log format"`
|
||||||
|
|
||||||
AuthHost string `long:"auth-host" description:"Host for central auth login"`
|
AuthHost string `long:"auth-host" env:"AUTH_HOST" description:"Host for central auth login"`
|
||||||
ConfigFile string `long:"config-file" description:"Config File"`
|
Config func(s string) error `long:"config" env:"CONFIG" description:"Config file"`
|
||||||
CookieDomains CookieDomains `long:"cookie-domains" description:"Comma separated list of cookie domains"`
|
CookieDomains CookieDomains `long:"cookie-domains" env:"COOKIE_DOMAINS" description:"Comma separated list of cookie domains"`
|
||||||
CookieInsecure bool `long:"cookie-insecure" description:"Use secure cookies"`
|
InsecureCookie bool `long:"insecure-cookie" env:"INSECURE_COOKIE" description:"Use insecure cookies"`
|
||||||
CookieName string `long:"cookie-name" default:"_forward_auth" description:"Cookie Name"`
|
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
|
||||||
CSRFCookieName string `long:"csrf-cookie-name" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
||||||
DefaultAction string `long:"default-action" default:"allow" description:"Default Action"`
|
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default Action"`
|
||||||
Domains CommaSeparatedList `long:"domains" description:"Comma separated list of email domains to allow"`
|
Domains CommaSeparatedList `long:"domains" env:"DOMAINS" description:"Comma separated list of email domains to allow"`
|
||||||
LifetimeString int `long:"lifetime" default:"43200" description:"Lifetime in seconds"`
|
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
|
||||||
Path string `long:"path" default:"_oauth" description:"Callback URL Path"`
|
Path string `long:"url-path" env:"URL_PATH" default:"_oauth" description:"Callback URL Path"`
|
||||||
SecretString string `long:"secret" description:"*Secret used for signing (required)"`
|
SecretString string `long:"secret" env:"SECRET" description:"*Secret used for signing (required)"`
|
||||||
Whitelist CommaSeparatedList `long:"whitelist" description:"Comma separated list of email addresses to allow"`
|
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" description:"Comma separated list of email addresses to allow"`
|
||||||
|
|
||||||
Providers provider.Providers
|
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
|
||||||
Rules map[string]*Rule `long:"rule"`
|
Rules map[string]*Rule `long:"rules.<name>.<param>" description:"Rule definitions, see docs, param can be: \"action\", \"rule\""`
|
||||||
|
|
||||||
|
// Filled during transformations
|
||||||
Secret []byte
|
Secret []byte
|
||||||
Lifetime time.Duration
|
Lifetime time.Duration
|
||||||
|
|
||||||
Prompt string `long:"prompt" description:"DEPRECATED - Use providers.google.prompt"`
|
// Legacy
|
||||||
// TODO: Need to mimick the default behaviour of bool flags
|
ClientIdLegacy string `long:"client-id" env:"CLIENT_ID" group:"DEPs" description:"DEPRECATED - Use \"providers.google.client-id\""`
|
||||||
CookieSecure string `long:"cookie-secure" default:"true" description:"DEPRECATED - Use \"cookie-insecure\""`
|
ClientSecretLegacy string `long:"client-secret" env:"CLIENT_SECRET" description:"DEPRECATED - Use \"providers.google.client-id\""`
|
||||||
|
PromptLegacy string `long:"prompt" env:"PROMPT" description:"DEPRECATED - Use \"providers.google.prompt\""`
|
||||||
flags []string
|
CookieSecureLegacy string `long:"cookie-secure" env:"COOKIE_SECURE" namespace:"DERPS" description:"DEPRECATED - Use \"insecure-cookie\""`
|
||||||
usingToml bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
|
||||||
// - parse ini
|
|
||||||
// - parse env vars
|
|
||||||
// - parse env var file
|
|
||||||
// - support multiple config files
|
|
||||||
// - maintain backwards compat
|
|
||||||
|
|
||||||
func NewGlobalConfig() Config {
|
func NewGlobalConfig() Config {
|
||||||
var err error
|
var err error
|
||||||
config, err = NewConfig(os.Args[1:])
|
config, err = NewConfig(os.Args[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("startup error: %+v", err)
|
fmt.Printf("%+v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,7 +71,28 @@ func NewConfig(args []string) (Config, error) {
|
|||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Struct defaults
|
// TODO: as log flags have now been parsed maybe we should return here so
|
||||||
|
// any further errors can be logged via logrus instead of printed?
|
||||||
|
|
||||||
|
// Backwards compatability
|
||||||
|
if c.ClientIdLegacy != "" {
|
||||||
|
c.Providers.Google.ClientId = c.ClientIdLegacy
|
||||||
|
}
|
||||||
|
if c.ClientSecretLegacy != "" {
|
||||||
|
c.Providers.Google.ClientSecret = c.ClientSecretLegacy
|
||||||
|
}
|
||||||
|
if c.PromptLegacy != "" {
|
||||||
|
c.Providers.Google.Prompt = c.PromptLegacy
|
||||||
|
}
|
||||||
|
if c.CookieSecureLegacy != "" {
|
||||||
|
secure, err := strconv.ParseBool(c.CookieSecureLegacy)
|
||||||
|
if err != nil {
|
||||||
|
return c, err
|
||||||
|
}
|
||||||
|
c.InsecureCookie = !secure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider defaults
|
||||||
c.Providers.Google.Build()
|
c.Providers.Google.Build()
|
||||||
|
|
||||||
// Transformations
|
// Transformations
|
||||||
@ -82,25 +100,35 @@ func NewConfig(args []string) (Config, error) {
|
|||||||
c.Secret = []byte(c.SecretString)
|
c.Secret = []byte(c.SecretString)
|
||||||
c.Lifetime = time.Second * time.Duration(c.LifetimeString)
|
c.Lifetime = time.Second * time.Duration(c.LifetimeString)
|
||||||
|
|
||||||
// TODO: Backwards compatability
|
|
||||||
// "secret" used to be "cookie-secret"
|
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) parseFlags(args []string) error {
|
func (c *Config) parseFlags(args []string) error {
|
||||||
parser := flags.NewParser(c, flags.Default)
|
p := flags.NewParser(c, flags.Default)
|
||||||
parser.UnknownOptionHandler = c.parseUnknownFlag
|
p.UnknownOptionHandler = c.parseUnknownFlag
|
||||||
|
|
||||||
_, err := parser.ParseArgs(args)
|
i := flags.NewIniParser(p)
|
||||||
if err != nil {
|
c.Config = func(s string) error {
|
||||||
flagsErr, ok := err.(*flags.Error)
|
// Try parsing at as an ini
|
||||||
if ok && flagsErr.Type == flags.ErrHelp {
|
err := i.ParseFile(s)
|
||||||
// Library has just printed cli help
|
|
||||||
os.Exit(0)
|
// If it fails with a syntax error, try converting legacy to ini
|
||||||
} else {
|
if err != nil && strings.Contains(err.Error(), "malformed key=value") {
|
||||||
return err
|
converted, convertErr := convertLegacyToIni(s)
|
||||||
|
if convertErr != nil {
|
||||||
|
// If conversion fails, return the original error
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return i.Parse(converted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := p.ParseArgs(args)
|
||||||
|
if err != nil {
|
||||||
|
return handlFlagError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -139,7 +167,7 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add param value to rule
|
// Add param value to rule
|
||||||
switch(parts[2]) {
|
switch parts[2] {
|
||||||
case "action":
|
case "action":
|
||||||
rule.Action = val
|
rule.Action = val
|
||||||
case "rule":
|
case "rule":
|
||||||
@ -156,6 +184,27 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [
|
|||||||
return args, nil
|
return args, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handlFlagError(err error) error {
|
||||||
|
flagsErr, ok := err.(*flags.Error)
|
||||||
|
if ok && flagsErr.Type == flags.ErrHelp {
|
||||||
|
// Library has just printed cli help
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var legacyFileFormat = regexp.MustCompile(`^([a-z-]+) ([\w\W]+)$`)
|
||||||
|
|
||||||
|
func convertLegacyToIni(name string) (io.Reader, error) {
|
||||||
|
b, err := ioutil.ReadFile(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.NewReader(legacyFileFormat.ReplaceAll(b, []byte("$1=$2"))), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() {
|
func (c *Config) Validate() {
|
||||||
// Check for show stopper errors
|
// Check for show stopper errors
|
||||||
if len(c.Secret) == 0 {
|
if len(c.Secret) == 0 {
|
||||||
@ -185,7 +234,7 @@ type Rule struct {
|
|||||||
|
|
||||||
func NewRule() *Rule {
|
func NewRule() *Rule {
|
||||||
return &Rule{
|
return &Rule{
|
||||||
Action: "auth",
|
Action: "auth",
|
||||||
Provider: "google", // TODO: Use default provider
|
Provider: "google", // TODO: Use default provider
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -201,32 +250,6 @@ func (r *Rule) Validate() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Rule) UnmarshalFlag(value string) error {
|
|
||||||
// Format is "action:rule"
|
|
||||||
parts := strings.SplitN(value, ":", 2)
|
|
||||||
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return errors.New("invalid rule format, should be \"action:rule\"")
|
|
||||||
}
|
|
||||||
|
|
||||||
if parts[0] != "auth" && parts[0] != "allow" {
|
|
||||||
return errors.New("invalid rule action, must be \"auth\" or \"allow\"")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse rule
|
|
||||||
*r = Rule{
|
|
||||||
Action: parts[0],
|
|
||||||
Rule: parts[1],
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Rule) MarshalFlag() (string, error) {
|
|
||||||
// TODO: format correctly
|
|
||||||
return fmt.Sprintf("%+v", *r), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type CommaSeparatedList []string
|
type CommaSeparatedList []string
|
||||||
|
|
||||||
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
func (c *CommaSeparatedList) UnmarshalFlag(value string) error {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package tfa
|
package tfa
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -26,14 +28,11 @@ func TestConfigDefaults(t *testing.T) {
|
|||||||
if c.AuthHost != "" {
|
if c.AuthHost != "" {
|
||||||
t.Error("AuthHost default should be empty, got", c.AuthHost)
|
t.Error("AuthHost default should be empty, got", c.AuthHost)
|
||||||
}
|
}
|
||||||
if c.ConfigFile != "" {
|
|
||||||
t.Error("ConfigFile default should be empty, got", c.ConfigFile)
|
|
||||||
}
|
|
||||||
if len(c.CookieDomains) != 0 {
|
if len(c.CookieDomains) != 0 {
|
||||||
t.Error("CookieDomains default should be empty, got", c.CookieDomains)
|
t.Error("CookieDomains default should be empty, got", c.CookieDomains)
|
||||||
}
|
}
|
||||||
if c.CookieInsecure != false {
|
if c.InsecureCookie != false {
|
||||||
t.Error("CookieInsecure default should be false, got", c.CookieInsecure)
|
t.Error("InsecureCookie default should be false, got", c.InsecureCookie)
|
||||||
}
|
}
|
||||||
if c.CookieName != "_forward_auth" {
|
if c.CookieName != "_forward_auth" {
|
||||||
t.Error("CookieName default should be _forward_auth, got", c.CookieName)
|
t.Error("CookieName default should be _forward_auth, got", c.CookieName)
|
||||||
@ -41,13 +40,13 @@ func TestConfigDefaults(t *testing.T) {
|
|||||||
if c.CSRFCookieName != "_forward_auth_csrf" {
|
if c.CSRFCookieName != "_forward_auth_csrf" {
|
||||||
t.Error("CSRFCookieName default should be _forward_auth_csrf, got", c.CSRFCookieName)
|
t.Error("CSRFCookieName default should be _forward_auth_csrf, got", c.CSRFCookieName)
|
||||||
}
|
}
|
||||||
if c.DefaultAction != "allow" {
|
if c.DefaultAction != "auth" {
|
||||||
t.Error("DefaultAction default should be allow, got", c.DefaultAction)
|
t.Error("DefaultAction default should be auth, got", c.DefaultAction)
|
||||||
}
|
}
|
||||||
if len(c.Domains) != 0 {
|
if len(c.Domains) != 0 {
|
||||||
t.Error("Domain default should be empty, got", c.Domains)
|
t.Error("Domain default should be empty, got", c.Domains)
|
||||||
}
|
}
|
||||||
if c.Lifetime != time.Second * time.Duration(43200) {
|
if c.Lifetime != time.Second*time.Duration(43200) {
|
||||||
t.Error("Lifetime default should be 43200, got", c.Lifetime)
|
t.Error("Lifetime default should be 43200, got", c.Lifetime)
|
||||||
}
|
}
|
||||||
if c.Path != "/_oauth" {
|
if c.Path != "/_oauth" {
|
||||||
@ -60,17 +59,12 @@ func TestConfigDefaults(t *testing.T) {
|
|||||||
if c.Providers.Google.Prompt != "" {
|
if c.Providers.Google.Prompt != "" {
|
||||||
t.Error("Providers.Google.Prompt default should be empty, got", c.Providers.Google.Prompt)
|
t.Error("Providers.Google.Prompt default should be empty, got", c.Providers.Google.Prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated options
|
|
||||||
if c.CookieSecure != "true" {
|
|
||||||
t.Error("CookieSecure default should be true, got", c.CookieSecure)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigParseFlags(t *testing.T) {
|
func TestConfigParseArgs(t *testing.T) {
|
||||||
c, err := NewConfig([]string{
|
c, err := NewConfig([]string{
|
||||||
"--path=_oauthpath",
|
"--cookie-name=cookiename",
|
||||||
"--cookie-name", "\"cookiename\"",
|
"--csrf-cookie-name", "\"csrfcookiename\"",
|
||||||
"--rule.1.action=allow",
|
"--rule.1.action=allow",
|
||||||
"--rule.1.rule=PathPrefix(`/one`)",
|
"--rule.1.rule=PathPrefix(`/one`)",
|
||||||
"--rule.two.action=auth",
|
"--rule.two.action=auth",
|
||||||
@ -81,12 +75,12 @@ func TestConfigParseFlags(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check normal flags
|
// Check normal flags
|
||||||
if c.Path != "/_oauthpath" {
|
|
||||||
t.Error("Path default should be /_oauthpath, got", c.Path)
|
|
||||||
}
|
|
||||||
if c.CookieName != "cookiename" {
|
if c.CookieName != "cookiename" {
|
||||||
t.Error("CookieName default should be cookiename, got", c.CookieName)
|
t.Error("CookieName default should be cookiename, got", c.CookieName)
|
||||||
}
|
}
|
||||||
|
if c.CSRFCookieName != "csrfcookiename" {
|
||||||
|
t.Error("CSRFCookieName default should be csrfcookiename, got", c.CSRFCookieName)
|
||||||
|
}
|
||||||
|
|
||||||
// Check rules
|
// Check rules
|
||||||
if len(c.Rules) != 2 {
|
if len(c.Rules) != 2 {
|
||||||
@ -124,23 +118,152 @@ func TestConfigParseFlags(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// func TestConfigParseUnknownFlags(t *testing.T) {
|
func TestConfigParseUnknownFlags(t *testing.T) {
|
||||||
// c := NewConfig([]string{
|
_, err := NewConfig([]string{
|
||||||
// "--unknown=_oauthpath",
|
"--unknown=_oauthpath2",
|
||||||
// })
|
})
|
||||||
|
if err.Error() != "unknown flag: unknown" {
|
||||||
|
t.Error("Error should be \"unknown flag: unknown\", got:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// }
|
func TestConfigFlagBackwardsCompatability(t *testing.T) {
|
||||||
|
c, err := NewConfig([]string{
|
||||||
|
"--client-id=clientid",
|
||||||
|
"--client-secret=verysecret",
|
||||||
|
"--prompt=prompt",
|
||||||
|
"--lifetime=200",
|
||||||
|
"--cookie-secure=false",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
// func TestConfigToml(t *testing.T) {
|
if c.ClientIdLegacy != "clientid" {
|
||||||
// logrus.SetLevel(logrus.DebugLevel)
|
t.Error("ClientIdLegacy should be clientid, got", c.ClientIdLegacy)
|
||||||
// flag.CommandLine = flag.NewFlagSet("tfa-test", flag.ContinueOnError)
|
}
|
||||||
|
if c.Providers.Google.ClientId != "clientid" {
|
||||||
|
t.Error("Providers.Google.ClientId should be clientid, got", c.Providers.Google.ClientId)
|
||||||
|
}
|
||||||
|
if c.ClientSecretLegacy != "verysecret" {
|
||||||
|
t.Error("ClientSecretLegacy should be verysecret, got", c.ClientSecretLegacy)
|
||||||
|
}
|
||||||
|
if c.Providers.Google.ClientSecret != "verysecret" {
|
||||||
|
t.Error("Providers.Google.ClientSecret should be verysecret, got", c.Providers.Google.ClientSecret)
|
||||||
|
}
|
||||||
|
if c.PromptLegacy != "prompt" {
|
||||||
|
t.Error("PromptLegacy should be prompt, got", c.PromptLegacy)
|
||||||
|
}
|
||||||
|
if c.Providers.Google.Prompt != "prompt" {
|
||||||
|
t.Error("Providers.Google.Prompt should be prompt, got", c.Providers.Google.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
// flags := []string{
|
// "cookie-secure" used to be a standard go bool flag that could take
|
||||||
// "-config=../test/config.toml",
|
// true, TRUE, 1, false, FALSE, 0 etc. values.
|
||||||
// }
|
// Here we're checking that format is still suppoted
|
||||||
// c := NewDefaultConfigWithFlags(flags)
|
if c.CookieSecureLegacy != "false" || c.InsecureCookie != true {
|
||||||
|
t.Error("Setting cookie-secure=false should set InsecureCookie true, got", c.InsecureCookie)
|
||||||
|
}
|
||||||
|
c, err = NewConfig([]string{"--cookie-secure=TRUE"})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if c.CookieSecureLegacy != "TRUE" || c.InsecureCookie != false {
|
||||||
|
t.Error("Setting cookie-secure=TRUE should set InsecureCookie false, got", c.InsecureCookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// if c == nil {
|
func TestConfigParseIni(t *testing.T) {
|
||||||
// t.Error(c)
|
c, err := NewConfig([]string{
|
||||||
// }
|
"--config=../test/config0",
|
||||||
// }
|
"--config=../test/config1",
|
||||||
|
"--csrf-cookie-name=csrfcookiename",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.CookieName != "inicookiename" {
|
||||||
|
t.Error("CookieName should be read as inicookiename from ini file, got", c.CookieName)
|
||||||
|
}
|
||||||
|
if c.CSRFCookieName != "csrfcookiename" {
|
||||||
|
t.Error("CSRFCookieName argument should override ini file, got", c.CSRFCookieName)
|
||||||
|
}
|
||||||
|
if c.Path != "/two" {
|
||||||
|
t.Error("Path in second ini file should override first ini file, got", c.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigFileBackwardsCompatability(t *testing.T) {
|
||||||
|
c, err := NewConfig([]string{
|
||||||
|
"--config=../test/config-legacy",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Path != "/two" {
|
||||||
|
t.Error("Path in legacy config file should be read, got", c.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigParseEnvironment(t *testing.T) {
|
||||||
|
os.Setenv("COOKIE_NAME", "env_cookie_name")
|
||||||
|
c, err := NewConfig([]string{})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.CookieName != "env_cookie_name" {
|
||||||
|
t.Error("CookieName should be read as env_cookie_name from environment, got", c.CookieName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigTransformation(t *testing.T) {
|
||||||
|
c, err := NewConfig([]string{
|
||||||
|
"--url-path=_oauthpath",
|
||||||
|
"--secret=verysecret",
|
||||||
|
"--lifetime=200",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Path != "/_oauthpath" {
|
||||||
|
t.Error("Path should add slash to front to get /_oauthpath, got:", c.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.SecretString != "verysecret" {
|
||||||
|
t.Error("SecretString should be verysecret, got:", c.SecretString)
|
||||||
|
}
|
||||||
|
if bytes.Compare(c.Secret, []byte("verysecret")) != 0 {
|
||||||
|
t.Error("Secret should be []byte(verysecret), got:", string(c.Secret))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.LifetimeString != 200 {
|
||||||
|
t.Error("LifetimeString should be 200, got:", c.LifetimeString)
|
||||||
|
}
|
||||||
|
if c.Lifetime != time.Second*time.Duration(200) {
|
||||||
|
t.Error("Lifetime default should be 200, got", c.Lifetime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigCommaSeparatedList(t *testing.T) {
|
||||||
|
list := CommaSeparatedList{}
|
||||||
|
|
||||||
|
err := list.UnmarshalFlag("one,two")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if len(list) != 2 || list[0] != "one" || list[1] != "two" {
|
||||||
|
t.Error("Expected UnmarshalFlag to provide CommaSeparatedList{one,two}, got", list)
|
||||||
|
}
|
||||||
|
|
||||||
|
marshal, err := list.MarshalFlag()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if marshal != "one,two" {
|
||||||
|
t.Error("Expected MarshalFlag to provide \"one,two\", got", list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -8,10 +8,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Google struct {
|
type Google struct {
|
||||||
ClientId string `long:"providers.google.client-id" description:"Client ID"`
|
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||||
ClientSecret string `long:"providers.google.client-secret" description:"Client Secret" json:"-"`
|
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
|
||||||
Scope string
|
Scope string
|
||||||
Prompt string `long:"providers.google.prompt" description:"Space separated list of OpenID prompt options"`
|
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
|
||||||
|
|
||||||
LoginURL *url.URL
|
LoginURL *url.URL
|
||||||
TokenURL *url.URL
|
TokenURL *url.URL
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
type Providers struct {
|
type Providers struct {
|
||||||
Google Google `group:"Google Provider"`
|
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
||||||
|
@ -38,7 +38,11 @@ func (s *Server) buildRoutes() {
|
|||||||
s.router.Handle(config.Path, s.AuthCallbackHandler())
|
s.router.Handle(config.Path, s.AuthCallbackHandler())
|
||||||
|
|
||||||
// Add a default handler
|
// Add a default handler
|
||||||
s.router.NewRoute().Handler(s.AuthHandler())
|
if config.DefaultAction == "allow" {
|
||||||
|
s.router.NewRoute().Handler(s.AllowHandler())
|
||||||
|
} else {
|
||||||
|
s.router.NewRoute().Handler(s.AuthHandler())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -2,16 +2,12 @@ package tfa
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
// "reflect"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -23,6 +19,152 @@ func init() {
|
|||||||
log = NewDefaultLogger()
|
log = NewDefaultLogger()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
func TestServerAuthHandler(t *testing.T) {
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
|
||||||
|
// Should redirect vanilla request to login url
|
||||||
|
req := newHttpRequest("/foo")
|
||||||
|
|
||||||
|
res, _ := httpRequest(req, nil)
|
||||||
|
if res.StatusCode != 307 {
|
||||||
|
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
fwd, _ := res.Location()
|
||||||
|
if fwd.Scheme != "https" || fwd.Host != "accounts.google.com" || fwd.Path != "/o/oauth2/auth" {
|
||||||
|
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should catch invalid cookie
|
||||||
|
req = newHttpRequest("/foo")
|
||||||
|
c := MakeCookie(req, "test@example.com")
|
||||||
|
parts := strings.Split(c.Value, "|")
|
||||||
|
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||||
|
|
||||||
|
res, _ = httpRequest(req, c)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should validate email
|
||||||
|
req = newHttpRequest("/foo")
|
||||||
|
c = MakeCookie(req, "test@example.com")
|
||||||
|
config.Domains = []string{"test.com"}
|
||||||
|
|
||||||
|
res, _ = httpRequest(req, c)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Request with invalid email shound't be authorised", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should allow valid request email
|
||||||
|
req = newHttpRequest("/foo")
|
||||||
|
|
||||||
|
c = MakeCookie(req, "test@example.com")
|
||||||
|
config.Domains = []string{}
|
||||||
|
|
||||||
|
res, _ = httpRequest(req, c)
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should pass through user
|
||||||
|
users := res.Header["X-Forwarded-User"]
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Error("Valid request missing X-Forwarded-User header")
|
||||||
|
} else if users[0] != "test@example.com" {
|
||||||
|
t.Error("X-Forwarded-User should match user, got: ", users)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerAuthCallback(t *testing.T) {
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
|
||||||
|
// Setup token server
|
||||||
|
tokenServerHandler := &TokenServerHandler{}
|
||||||
|
tokenServer := httptest.NewServer(tokenServerHandler)
|
||||||
|
defer tokenServer.Close()
|
||||||
|
tokenUrl, _ := url.Parse(tokenServer.URL)
|
||||||
|
config.Providers.Google.TokenURL = tokenUrl
|
||||||
|
|
||||||
|
// Setup user server
|
||||||
|
userServerHandler := &UserServerHandler{}
|
||||||
|
userServer := httptest.NewServer(userServerHandler)
|
||||||
|
defer userServer.Close()
|
||||||
|
userUrl, _ := url.Parse(userServer.URL)
|
||||||
|
config.Providers.Google.UserURL = userUrl
|
||||||
|
|
||||||
|
// Should pass auth response request to callback
|
||||||
|
req := newHttpRequest("/_oauth")
|
||||||
|
res, _ := httpRequest(req, nil)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should catch invalid csrf cookie
|
||||||
|
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
|
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
||||||
|
res, _ = httpRequest(req, c)
|
||||||
|
if res.StatusCode != 401 {
|
||||||
|
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should redirect valid request
|
||||||
|
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
|
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||||
|
res, _ = httpRequest(req, c)
|
||||||
|
if res.StatusCode != 307 {
|
||||||
|
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
fwd, _ := res.Location()
|
||||||
|
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||||
|
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerDefaultAction(t *testing.T) {
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
|
||||||
|
req := newHttpRequest("/random")
|
||||||
|
res, _ := httpRequest(req, nil)
|
||||||
|
if res.StatusCode != 307 {
|
||||||
|
t.Error("Request should require auth with auth default handler, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.DefaultAction = "allow"
|
||||||
|
req = newHttpRequest("/random")
|
||||||
|
res, _ = httpRequest(req, nil)
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
t.Error("Request should be allowed with allow default handler, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRoutePathPrefix(t *testing.T) {
|
||||||
|
config, _ = NewConfig([]string{})
|
||||||
|
config.Rules = map[string]*Rule{
|
||||||
|
"web1": {
|
||||||
|
Action: "allow",
|
||||||
|
Rule: "PathPrefix(`/api`)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should block any request
|
||||||
|
req := newHttpRequest("/random")
|
||||||
|
res, _ := httpRequest(req, nil)
|
||||||
|
if res.StatusCode != 307 {
|
||||||
|
t.Error("Request not matching any rule should require auth, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should allow /api request
|
||||||
|
req = newHttpRequest("/api")
|
||||||
|
res, _ = httpRequest(req, nil)
|
||||||
|
if res.StatusCode != 200 {
|
||||||
|
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utilities
|
* Utilities
|
||||||
*/
|
*/
|
||||||
@ -44,7 +186,7 @@ func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}`)
|
}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpRequest(s *Server, r *http.Request, c *http.Cookie) (*http.Response, string) {
|
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Set cookies on recorder
|
// Set cookies on recorder
|
||||||
@ -57,7 +199,8 @@ func httpRequest(s *Server, r *http.Request, c *http.Cookie) (*http.Response, st
|
|||||||
r.Header.Add("Cookie", c)
|
r.Header.Add("Cookie", c)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.RootHandler(w, r)
|
|
||||||
|
NewServer().RootHandler(w, r)
|
||||||
|
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
body, _ := ioutil.ReadAll(res.Body)
|
body, _ := ioutil.ReadAll(res.Body)
|
||||||
@ -92,186 +235,3 @@ func qsDiff(t *testing.T, one, two url.Values) []string {
|
|||||||
}
|
}
|
||||||
return errs
|
return errs
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Tests
|
|
||||||
*/
|
|
||||||
|
|
||||||
func TestServerHandler(t *testing.T) {
|
|
||||||
server := NewServer()
|
|
||||||
|
|
||||||
config = Config{
|
|
||||||
Path: "/_oauth",
|
|
||||||
CookieName: "cookie_test",
|
|
||||||
Lifetime: time.Second * time.Duration(10),
|
|
||||||
Providers: provider.Providers{
|
|
||||||
Google: provider.Google{
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should redirect vanilla request to login url
|
|
||||||
req := newHttpRequest("/foo")
|
|
||||||
res, _ := httpRequest(server, req, nil)
|
|
||||||
if res.StatusCode != 307 {
|
|
||||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
fwd, _ := res.Location()
|
|
||||||
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
|
||||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should catch invalid cookie
|
|
||||||
req = newHttpRequest("/foo")
|
|
||||||
|
|
||||||
c := MakeCookie(req, "test@example.com")
|
|
||||||
parts := strings.Split(c.Value, "|")
|
|
||||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
|
||||||
|
|
||||||
res, _ = httpRequest(server, req, c)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should validate email
|
|
||||||
req = newHttpRequest("/foo")
|
|
||||||
|
|
||||||
c = MakeCookie(req, "test@example.com")
|
|
||||||
config.Domains = []string{"test.com"}
|
|
||||||
|
|
||||||
res, _ = httpRequest(server, req, c)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should allow valid request email
|
|
||||||
req = newHttpRequest("/foo")
|
|
||||||
|
|
||||||
c = MakeCookie(req, "test@example.com")
|
|
||||||
config.Domains = []string{}
|
|
||||||
|
|
||||||
res, _ = httpRequest(server, req, c)
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should pass through user
|
|
||||||
users := res.Header["X-Forwarded-User"]
|
|
||||||
if len(users) != 1 {
|
|
||||||
t.Error("Valid request missing X-Forwarded-User header")
|
|
||||||
} else if users[0] != "test@example.com" {
|
|
||||||
t.Error("X-Forwarded-User should match user, got: ", users)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerAuthCallback(t *testing.T) {
|
|
||||||
server := NewServer()
|
|
||||||
config = Config{
|
|
||||||
Path: "/_oauth",
|
|
||||||
CookieName: "cookie_test",
|
|
||||||
Lifetime: time.Second * time.Duration(10),
|
|
||||||
CSRFCookieName: "csrf_test",
|
|
||||||
Providers: provider.Providers{
|
|
||||||
Google: provider.Google{
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup token server
|
|
||||||
tokenServerHandler := &TokenServerHandler{}
|
|
||||||
tokenServer := httptest.NewServer(tokenServerHandler)
|
|
||||||
defer tokenServer.Close()
|
|
||||||
tokenUrl, _ := url.Parse(tokenServer.URL)
|
|
||||||
config.Providers.Google.TokenURL = tokenUrl
|
|
||||||
|
|
||||||
// Setup user server
|
|
||||||
userServerHandler := &UserServerHandler{}
|
|
||||||
userServer := httptest.NewServer(userServerHandler)
|
|
||||||
defer userServer.Close()
|
|
||||||
userUrl, _ := url.Parse(userServer.URL)
|
|
||||||
config.Providers.Google.UserURL = userUrl
|
|
||||||
|
|
||||||
// Should pass auth response request to callback
|
|
||||||
req := newHttpRequest("/_oauth")
|
|
||||||
res, _ := httpRequest(server, req, nil)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should catch invalid csrf cookie
|
|
||||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
|
||||||
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
|
||||||
res, _ = httpRequest(server, req, c)
|
|
||||||
if res.StatusCode != 401 {
|
|
||||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should redirect valid request
|
|
||||||
req = newHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
|
||||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
|
||||||
res, _ = httpRequest(server, req, c)
|
|
||||||
if res.StatusCode != 307 {
|
|
||||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
fwd, _ := res.Location()
|
|
||||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
|
||||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerRoutePathPrefix(t *testing.T) {
|
|
||||||
config = Config{
|
|
||||||
Path: "/_oauth",
|
|
||||||
CookieName: "cookie_test",
|
|
||||||
Lifetime: time.Second * time.Duration(10),
|
|
||||||
Providers: provider.Providers{
|
|
||||||
Google: provider.Google{
|
|
||||||
ClientId: "idtest",
|
|
||||||
ClientSecret: "sectest",
|
|
||||||
Scope: "scopetest",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Rules: map[string]*Rule{
|
|
||||||
"web1": &Rule{
|
|
||||||
Action: "allow",
|
|
||||||
Rule: "PathPrefix(`/api`)",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
server := NewServer()
|
|
||||||
|
|
||||||
// Should block any request
|
|
||||||
req := newHttpRequest("/random")
|
|
||||||
res, _ := httpRequest(server, req, nil)
|
|
||||||
if res.StatusCode != 307 {
|
|
||||||
t.Error("Request not matching any rule should require auth, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should allow /api request
|
|
||||||
req = newHttpRequest("/api")
|
|
||||||
res, _ = httpRequest(server, req, nil)
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
t.Error("Request matching allowed rule should be allowed, got:", res.StatusCode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
1
test/config-legacy
Normal file
1
test/config-legacy
Normal file
@ -0,0 +1 @@
|
|||||||
|
url-path two
|
3
test/config0
Normal file
3
test/config0
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
cookie-name=inicookiename
|
||||||
|
csrf-cookie-name=inicsrfcookiename
|
||||||
|
url-path=one
|
1
test/config1
Normal file
1
test/config1
Normal file
@ -0,0 +1 @@
|
|||||||
|
url-path=two
|
Loading…
x
Reference in New Issue
Block a user