Compare commits
25 Commits
Author | SHA1 | Date | |
---|---|---|---|
5c800a0170 | |||
b1fdcc7f56 | |||
db31b09a72 | |||
e1d518db11 | |||
67339ae79a | |||
0b2889935e | |||
b3b31e2193 | |||
1a3a099ac1 | |||
afd8878188 | |||
6ccd1c6dfc | |||
df81be1147 | |||
5dcf889efe | |||
92d72dcdd2 | |||
4c1874b786 | |||
dcf4f6574d | |||
91775ff0a8 | |||
1832672f5e | |||
eaad0a9054 | |||
36fffd2382 | |||
ccbda4ec8c | |||
b014c5638a | |||
c897bc8387 | |||
96f9469abd | |||
b54871391f | |||
d230572879 |
2
.dockerignore
Normal file
2
.dockerignore
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
example
|
||||||
|
.travis.yml
|
@ -4,4 +4,5 @@ go:
|
|||||||
- "1.10"
|
- "1.10"
|
||||||
install:
|
install:
|
||||||
- go get github.com/namsral/flag
|
- go get github.com/namsral/flag
|
||||||
- go get github.com/op/go-logging
|
- go get github.com/sirupsen/logrus
|
||||||
|
script: go test -v ./...
|
||||||
|
@ -7,7 +7,7 @@ WORKDIR /app
|
|||||||
# Add libraries
|
# Add libraries
|
||||||
RUN apk add --no-cache git && \
|
RUN apk add --no-cache git && \
|
||||||
go get "github.com/namsral/flag" && \
|
go get "github.com/namsral/flag" && \
|
||||||
go get "github.com/op/go-logging" && \
|
go get "github.com/sirupsen/logrus" && \
|
||||||
apk del git
|
apk del git
|
||||||
|
|
||||||
# Copy & build
|
# Copy & build
|
||||||
|
58
README.md
58
README.md
@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
# Traefik Forward Auth [](https://travis-ci.org/thomseddon/traefik-forward-auth)
|
# Traefik Forward Auth [](https://travis-ci.org/thomseddon/traefik-forward-auth) [](https://goreportcard.com/badge/github.com/thomseddon/traefik-forward-auth)
|
||||||
|
|
||||||
A minimal forward authentication service that provides Google oauth based login and authentication for the traefik reverse proxy.
|
A minimal forward authentication service that provides Google oauth based login and authentication for the traefik reverse proxy.
|
||||||
|
|
||||||
@ -24,16 +24,20 @@ The following configuration is supported:
|
|||||||
|-----------------------|------|-----------|
|
|-----------------------|------|-----------|
|
||||||
|-client-id|string|*Google Client ID (required)|
|
|-client-id|string|*Google Client ID (required)|
|
||||||
|-client-secret|string|*Google Client Secret (required)|
|
|-client-secret|string|*Google Client Secret (required)|
|
||||||
|
|-secret|string|*Secret used for signing (required)|
|
||||||
|-config|string|Path to config file|
|
|-config|string|Path to config file|
|
||||||
|-cookie-domains|string|Comma separated list of cookie domains|
|
|-auth-host|string|Central auth login (see below)|
|
||||||
|
|-cookie-domains|string|Comma separated list of cookie domains (see below)|
|
||||||
|-cookie-name|string|Cookie Name (default "_forward_auth")|
|
|-cookie-name|string|Cookie Name (default "_forward_auth")|
|
||||||
|-cookie-secret|string|*Cookie secret (required)|
|
|
||||||
|-cookie-secure|bool|Use secure cookies (default true)|
|
|-cookie-secure|bool|Use secure cookies (default true)|
|
||||||
|-csrf-cookie-name|string|CSRF Cookie Name (default "_forward_auth_csrf")|
|
|-csrf-cookie-name|string|CSRF Cookie Name (default "_forward_auth_csrf")|
|
||||||
|-direct|bool|Run in direct mode (use own hostname as oppose to <br>X-Forwarded-Host, used for testing/development)
|
|
||||||
|-domain|string|Comma separated list of email domains to allow|
|
|-domain|string|Comma separated list of email domains to allow|
|
||||||
|
|-whitelist|string|Comma separated list of email addresses to allow|
|
||||||
|-lifetime|int|Session length in seconds (default 43200)|
|
|-lifetime|int|Session length in seconds (default 43200)|
|
||||||
|-url-path|string|Callback URL (default "_oauth")|
|
|-url-path|string|Callback URL (default "_oauth")|
|
||||||
|
|-prompt|string|Space separated list of [OpenID prompt options](https://developers.google.com/identity/protocols/OpenIDConnect#prompt)|
|
||||||
|
|-log-level|string|Log level: trace, debug, info, warn, error, fatal, panic (default "warn")|
|
||||||
|
|-log-format|string|Log format: text, json, pretty (default "text")|
|
||||||
|
|
||||||
Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`)
|
Configuration can also be supplied as environment variables (use upper case and swap `-`'s for `_`'s e.g. `-client-id` becomes `CLIENT_ID`)
|
||||||
|
|
||||||
@ -47,6 +51,19 @@ Create a new project then search for and select "Credentials" in the search bar.
|
|||||||
|
|
||||||
Click, "Create Credentials" > "OAuth client ID". Select "Web Application", fill in the name of your app, skip "Authorized JavaScript origins" and fill "Authorized redirect URIs" with all the domains you will allow authentication from, appended with the `url-path` (e.g. https://app.test.com/_oauth)
|
Click, "Create Credentials" > "OAuth client ID". Select "Web Application", fill in the name of your app, skip "Authorized JavaScript origins" and fill "Authorized redirect URIs" with all the domains you will allow authentication from, appended with the `url-path` (e.g. https://app.test.com/_oauth)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The authenticated user is set in the `X-Forwarded-User` header, to pass this on add this to the `authResponseHeaders` as shown [here](https://github.com/thomseddon/traefik-forward-auth/blob/master/example/docker-compose-dev.yml).
|
||||||
|
|
||||||
|
## User Restriction
|
||||||
|
|
||||||
|
You can restrict who can login with the following parameters:
|
||||||
|
|
||||||
|
* `-domain` - Use this to limit logins to a specific domain, e.g. test.com only
|
||||||
|
* `-whitelist` - Use this to only allow specific users to login e.g. thom@test.com only
|
||||||
|
|
||||||
|
Note, if you pass `whitelist` then only this is checked and `domain` is effectively ignored.
|
||||||
|
|
||||||
## Cookie Domains
|
## Cookie Domains
|
||||||
|
|
||||||
You can supply a comma separated list of cookie domains, if the host of the original request is a subdomain of any given cookie domain, the authentication cookie will set with the given domain.
|
You can supply a comma separated list of cookie domains, if the host of the original request is a subdomain of any given cookie domain, the authentication cookie will set with the given domain.
|
||||||
@ -55,6 +72,39 @@ For example, if cookie domain is `test.com` and a request comes in on `app1.test
|
|||||||
|
|
||||||
Beware however, if using cookie domains whilst running multiple instances of traefik/traefik-forward-auth for the same domain, the cookies will clash. You can fix this by using the same `cookie-secret` in both instances, or using a different `cookie-name` on each.
|
Beware however, if using cookie domains whilst running multiple instances of traefik/traefik-forward-auth for the same domain, the cookies will clash. You can fix this by using the same `cookie-secret` in both instances, or using a different `cookie-name` on each.
|
||||||
|
|
||||||
|
## Operation Modes
|
||||||
|
|
||||||
|
#### Overlay
|
||||||
|
|
||||||
|
Overlay is the default operation mode, in this mode the authorisation endpoint is overlayed onto any domain. By default the `/_oauth` path is used, this can be customised using the `-url-path` option.
|
||||||
|
|
||||||
|
If a request comes in for `www.myapp.com/home` then the user will be redirected to the google login, following this they will be sent back to `www.myapp.com/_oauth`, where their token will be validated (this request will not be forwarded to your application). Following successful authoristion, the user will return to their originally requested url of `www.myapp.com/home`.
|
||||||
|
|
||||||
|
As the hostname in the `redirect_uri` is dynamically generated based on the orignal request, every hostname must be permitted in the Google OAuth console (e.g. `www.myappp.com` would need to be added in the above example)
|
||||||
|
|
||||||
|
#### Auth Host
|
||||||
|
|
||||||
|
This is an optional mode of operation that is useful when dealing with a large number of subdomains, it is activated by using the `-auth-host` config option (see [this example docker-compose.yml](https://github.com/thomseddon/traefik-forward-auth/blob/master/example/docker-compose-auth-host.yml)).
|
||||||
|
|
||||||
|
For example, if you have a few applications: `app1.test.com`, `app2.test.com`, `appN.test.com`, adding every domain to Google's console can become laborious.
|
||||||
|
To utilise an auth host, permit domain level cookies by setting the cookie domain to `test.com` then set the `auth-host` to: `auth.test.com`.
|
||||||
|
|
||||||
|
The user flow will then be:
|
||||||
|
|
||||||
|
1. Request to `app10.test.com/home/page`
|
||||||
|
2. User redirected to Google login
|
||||||
|
3. After Google login, user is redirected to `auth.test.com/_oauth`
|
||||||
|
4. Token, user and CSRF cookie is validated, auth cookie is set to `test.com`
|
||||||
|
5. User is redirected to `app10.test.com/home/page`
|
||||||
|
6. Request is allowed
|
||||||
|
|
||||||
|
With this setup, only `auth.test.com` must be permitted in the Google console.
|
||||||
|
|
||||||
|
Two criteria must be met for an `auth-host` to be used:
|
||||||
|
|
||||||
|
1. Request matches given `cookie-domain`
|
||||||
|
2. `auth-host` is also subdomain of same `cookie-domain`
|
||||||
|
|
||||||
## Copyright
|
## Copyright
|
||||||
|
|
||||||
2018 Thom Seddon
|
2018 Thom Seddon
|
||||||
|
44
example/docker-compose-auth-host.yml
Normal file
44
example/docker-compose-auth-host.yml
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
version: '3'
|
||||||
|
|
||||||
|
services:
|
||||||
|
traefik:
|
||||||
|
image: traefik
|
||||||
|
command: -c /traefik.toml --logLevel=DEBUG
|
||||||
|
ports:
|
||||||
|
- "8085:80"
|
||||||
|
- "8086:8080"
|
||||||
|
networks:
|
||||||
|
- traefik
|
||||||
|
volumes:
|
||||||
|
- ./traefik.toml:/traefik.toml
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock
|
||||||
|
|
||||||
|
whoami1:
|
||||||
|
image: emilevauge/whoami
|
||||||
|
networks:
|
||||||
|
- traefik
|
||||||
|
labels:
|
||||||
|
- "traefik.backend=whoami"
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.frontend.rule=Host:whoami.yourdomain.com"
|
||||||
|
|
||||||
|
traefik-forward-auth:
|
||||||
|
image: thomseddon/traefik-forward-auth
|
||||||
|
environment:
|
||||||
|
- CLIENT_ID=your-client-id
|
||||||
|
- CLIENT_SECRET=your-client-secret
|
||||||
|
- SECRET=something-random
|
||||||
|
- COOKIE_SECURE=false
|
||||||
|
- DOMAIN=yourcompany.com
|
||||||
|
- AUTH_HOST=auth.yourdomain.com
|
||||||
|
networks:
|
||||||
|
- traefik
|
||||||
|
# When using an auth host, adding it here prompts traefik to generate certs
|
||||||
|
labels:
|
||||||
|
- traefik.enable=true
|
||||||
|
- traefik.port=4181
|
||||||
|
- traefik.backend=traefik-forward-auth
|
||||||
|
- traefik.frontend.rule=Host:auth.yourdomain.com
|
||||||
|
|
||||||
|
networks:
|
||||||
|
traefik:
|
48
example/docker-compose-dev.yml
Normal file
48
example/docker-compose-dev.yml
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
version: '3'
|
||||||
|
|
||||||
|
services:
|
||||||
|
traefik:
|
||||||
|
image: traefik
|
||||||
|
command: -c /traefik.toml
|
||||||
|
# command: -c /traefik.toml --logLevel=DEBUG
|
||||||
|
ports:
|
||||||
|
- "8085:80"
|
||||||
|
- "8086:8080"
|
||||||
|
networks:
|
||||||
|
- traefik
|
||||||
|
volumes:
|
||||||
|
- ./traefik.toml:/traefik.toml
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock
|
||||||
|
|
||||||
|
whoami1:
|
||||||
|
image: emilevauge/whoami
|
||||||
|
networks:
|
||||||
|
- traefik
|
||||||
|
labels:
|
||||||
|
- "traefik.backend=whoami1"
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.frontend.rule=Host:whoami.localhost.com"
|
||||||
|
|
||||||
|
whoami2:
|
||||||
|
image: emilevauge/whoami
|
||||||
|
networks:
|
||||||
|
- traefik
|
||||||
|
labels:
|
||||||
|
- "traefik.backend=whoami2"
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.frontend.rule=Host:whoami.localhost.org"
|
||||||
|
|
||||||
|
traefik-forward-auth:
|
||||||
|
build: ../
|
||||||
|
environment:
|
||||||
|
- CLIENT_ID=test
|
||||||
|
- CLIENT_SECRET=test
|
||||||
|
- COOKIE_SECRET=something-random
|
||||||
|
- COOKIE_SECURE=false
|
||||||
|
- COOKIE_DOMAINS=localhost.com
|
||||||
|
- AUTH_URL=http://auth.localhost.com:8085/_oauth
|
||||||
|
networks:
|
||||||
|
- traefik
|
||||||
|
|
||||||
|
networks:
|
||||||
|
traefik:
|
@ -22,12 +22,12 @@ services:
|
|||||||
- "traefik.enable=true"
|
- "traefik.enable=true"
|
||||||
- "traefik.frontend.rule=Host:whoami.localhost.com"
|
- "traefik.frontend.rule=Host:whoami.localhost.com"
|
||||||
|
|
||||||
forward-oauth:
|
traefik-forward-auth:
|
||||||
image: thomseddon/traefik-forward-auth
|
image: thomseddon/traefik-forward-auth
|
||||||
environment:
|
environment:
|
||||||
- CLIENT_ID=your-client-id
|
- CLIENT_ID=your-client-id
|
||||||
- CLIENT_SECRET=your-client-secret
|
- CLIENT_SECRET=your-client-secret
|
||||||
- COOKIE_SECRET=something-random
|
- SECRET=something-random
|
||||||
- COOKIE_SECURE=false
|
- COOKIE_SECURE=false
|
||||||
- DOMAIN=yourcompany.com
|
- DOMAIN=yourcompany.com
|
||||||
networks:
|
networks:
|
||||||
|
@ -37,7 +37,8 @@
|
|||||||
address = ":80"
|
address = ":80"
|
||||||
|
|
||||||
[entryPoints.http.auth.forward]
|
[entryPoints.http.auth.forward]
|
||||||
address = "http://forward-oauth:4181"
|
address = "http://traefik-forward-auth:4181"
|
||||||
|
authResponseHeaders = ["X-Forwarded-User"]
|
||||||
|
|
||||||
################################################################
|
################################################################
|
||||||
# Traefik logs configuration
|
# Traefik logs configuration
|
||||||
|
478
forwardauth.go
478
forwardauth.go
@ -1,358 +1,390 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"crypto/hmac"
|
||||||
"time"
|
"crypto/rand"
|
||||||
"errors"
|
"crypto/sha256"
|
||||||
"strings"
|
"encoding/base64"
|
||||||
"strconv"
|
"encoding/json"
|
||||||
"net/url"
|
"errors"
|
||||||
"net/http"
|
"fmt"
|
||||||
"crypto/hmac"
|
"net/http"
|
||||||
"crypto/rand"
|
"net/url"
|
||||||
"crypto/sha256"
|
"strconv"
|
||||||
"encoding/json"
|
"strings"
|
||||||
"encoding/base64"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Forward Auth
|
// Forward Auth
|
||||||
type ForwardAuth struct {
|
type ForwardAuth struct {
|
||||||
Path string
|
Path string
|
||||||
Lifetime time.Duration
|
Lifetime time.Duration
|
||||||
|
Secret []byte
|
||||||
|
|
||||||
ClientId string
|
ClientId string
|
||||||
ClientSecret string
|
ClientSecret string `json:"-"`
|
||||||
Scope string
|
Scope string
|
||||||
|
|
||||||
LoginURL *url.URL
|
LoginURL *url.URL
|
||||||
TokenURL *url.URL
|
TokenURL *url.URL
|
||||||
UserURL *url.URL
|
UserURL *url.URL
|
||||||
|
|
||||||
CookieName string
|
AuthHost string
|
||||||
CookieDomains []CookieDomain
|
|
||||||
CSRFCookieName string
|
|
||||||
CookieSecret []byte
|
|
||||||
CookieSecure bool
|
|
||||||
|
|
||||||
Domain []string
|
CookieName string
|
||||||
|
CookieDomains []CookieDomain
|
||||||
|
CSRFCookieName string
|
||||||
|
CookieSecure bool
|
||||||
|
|
||||||
Direct bool
|
Domain []string
|
||||||
|
Whitelist []string
|
||||||
|
|
||||||
|
Prompt string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request Validation
|
// Request Validation
|
||||||
|
|
||||||
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
// Cookie = hash(secret, cookie domain, email, expires)|expires|email
|
||||||
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
func (f *ForwardAuth) ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) {
|
||||||
parts := strings.Split(c.Value, "|")
|
parts := strings.Split(c.Value, "|")
|
||||||
|
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return false, "", errors.New("Invalid cookie format")
|
return false, "", errors.New("Invalid cookie format")
|
||||||
}
|
}
|
||||||
|
|
||||||
mac, err := base64.URLEncoding.DecodeString(parts[0])
|
mac, err := base64.URLEncoding.DecodeString(parts[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", errors.New("Unable to decode cookie mac")
|
return false, "", errors.New("Unable to decode cookie mac")
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
|
expectedSignature := f.cookieSignature(r, parts[2], parts[1])
|
||||||
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
expected, err := base64.URLEncoding.DecodeString(expectedSignature)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", errors.New("Unable to generate mac")
|
return false, "", errors.New("Unable to generate mac")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid token?
|
// Valid token?
|
||||||
if !hmac.Equal(mac, expected) {
|
if !hmac.Equal(mac, expected) {
|
||||||
return false, "", errors.New("Invalid cookie mac")
|
return false, "", errors.New("Invalid cookie mac")
|
||||||
}
|
}
|
||||||
|
|
||||||
expires, err := strconv.ParseInt(parts[1], 10, 64)
|
expires, err := strconv.ParseInt(parts[1], 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", errors.New("Unable to parse cookie expiry")
|
return false, "", errors.New("Unable to parse cookie expiry")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Has it expired?
|
// Has it expired?
|
||||||
if time.Unix(expires, 0).Before(time.Now()) {
|
if time.Unix(expires, 0).Before(time.Now()) {
|
||||||
return false, "", errors.New("Cookie has expired")
|
return false, "", errors.New("Cookie has expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Looks valid
|
// Looks valid
|
||||||
return true, parts[2], nil
|
return true, parts[2], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate email
|
// Validate email
|
||||||
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
func (f *ForwardAuth) ValidateEmail(email string) bool {
|
||||||
if len(f.Domain) > 0 {
|
found := false
|
||||||
parts := strings.Split(email, "@")
|
if len(f.Whitelist) > 0 {
|
||||||
if len(parts) < 2 {
|
for _, whitelist := range f.Whitelist {
|
||||||
return false
|
if email == whitelist {
|
||||||
}
|
found = true
|
||||||
found := false
|
}
|
||||||
for _, domain := range f.Domain {
|
}
|
||||||
if domain == parts[1] {
|
} else if len(f.Domain) > 0 {
|
||||||
found = true
|
parts := strings.Split(email, "@")
|
||||||
}
|
if len(parts) < 2 {
|
||||||
}
|
return false
|
||||||
if !found {
|
}
|
||||||
return false
|
for _, domain := range f.Domain {
|
||||||
}
|
if domain == parts[1] {
|
||||||
}
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// OAuth Methods
|
// OAuth Methods
|
||||||
|
|
||||||
// Get login url
|
// Get login url
|
||||||
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
func (f *ForwardAuth) GetLoginURL(r *http.Request, nonce string) string {
|
||||||
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
state := fmt.Sprintf("%s:%s", nonce, f.returnUrl(r))
|
||||||
|
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Set("client_id", fw.ClientId)
|
q.Set("client_id", fw.ClientId)
|
||||||
q.Set("response_type", "code")
|
q.Set("response_type", "code")
|
||||||
q.Set("scope", fw.Scope)
|
q.Set("scope", fw.Scope)
|
||||||
// q.Set("approval_prompt", fw.ClientId)
|
if fw.Prompt != "" {
|
||||||
q.Set("redirect_uri", f.redirectUri(r))
|
q.Set("prompt", fw.Prompt)
|
||||||
q.Set("state", state)
|
}
|
||||||
|
q.Set("redirect_uri", f.redirectUri(r))
|
||||||
|
q.Set("state", state)
|
||||||
|
|
||||||
var u url.URL
|
var u url.URL
|
||||||
u = *fw.LoginURL
|
u = *fw.LoginURL
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
return u.String()
|
return u.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
||||||
Token string `json:"access_token"`
|
Token string `json:"access_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
|
func (f *ForwardAuth) ExchangeCode(r *http.Request, code string) (string, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
form.Set("client_id", fw.ClientId)
|
form.Set("client_id", fw.ClientId)
|
||||||
form.Set("client_secret", fw.ClientSecret)
|
form.Set("client_secret", fw.ClientSecret)
|
||||||
form.Set("grant_type", "authorization_code")
|
form.Set("grant_type", "authorization_code")
|
||||||
form.Set("redirect_uri", f.redirectUri(r))
|
form.Set("redirect_uri", f.redirectUri(r))
|
||||||
form.Set("code", code)
|
form.Set("code", code)
|
||||||
|
|
||||||
|
res, err := http.PostForm(fw.TokenURL.String(), form)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
res, err := http.PostForm(fw.TokenURL.String(), form)
|
var token Token
|
||||||
if err != nil {
|
defer res.Body.Close()
|
||||||
return "", err
|
err = json.NewDecoder(res.Body).Decode(&token)
|
||||||
}
|
|
||||||
|
|
||||||
var token Token
|
return token.Token, err
|
||||||
defer res.Body.Close()
|
|
||||||
err = json.NewDecoder(res.Body).Decode(&token)
|
|
||||||
|
|
||||||
return token.Token, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user with token
|
// Get user with token
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Verified bool `json:"verified_email"`
|
Verified bool `json:"verified_email"`
|
||||||
Hd string `json:"hd"`
|
Hd string `json:"hd"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ForwardAuth) GetUser(token string) (User, error) {
|
func (f *ForwardAuth) GetUser(token string) (User, error) {
|
||||||
var user User
|
var user User
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
|
req, err := http.NewRequest("GET", fw.UserURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
err = json.NewDecoder(res.Body).Decode(&user)
|
err = json.NewDecoder(res.Body).Decode(&user)
|
||||||
|
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Utility methods
|
// Utility methods
|
||||||
|
|
||||||
// Get the redirect base
|
// Get the redirect base
|
||||||
func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
func (f *ForwardAuth) redirectBase(r *http.Request) string {
|
||||||
proto := r.Header.Get("X-Forwarded-Proto")
|
proto := r.Header.Get("X-Forwarded-Proto")
|
||||||
host := r.Header.Get("X-Forwarded-Host")
|
host := r.Header.Get("X-Forwarded-Host")
|
||||||
|
|
||||||
// Direct mode
|
return fmt.Sprintf("%s://%s", proto, host)
|
||||||
if f.Direct {
|
|
||||||
proto = "http"
|
|
||||||
host = r.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%s://%s", proto, host)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return url
|
// Return url
|
||||||
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
func (f *ForwardAuth) returnUrl(r *http.Request) string {
|
||||||
path := r.Header.Get("X-Forwarded-Uri")
|
path := r.Header.Get("X-Forwarded-Uri")
|
||||||
|
|
||||||
// Testing
|
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
|
||||||
if f.Direct {
|
|
||||||
path = r.URL.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s", f.redirectBase(r), path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get oauth redirect uri
|
// Get oauth redirect uri
|
||||||
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
func (f *ForwardAuth) redirectUri(r *http.Request) string {
|
||||||
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
if use, _ := f.useAuthDomain(r); use {
|
||||||
|
proto := r.Header.Get("X-Forwarded-Proto")
|
||||||
|
return fmt.Sprintf("%s://%s%s", proto, f.AuthHost, f.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s", f.redirectBase(r), f.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should we use auth host + what it is
|
||||||
|
func (f *ForwardAuth) useAuthDomain(r *http.Request) (bool, string) {
|
||||||
|
if f.AuthHost == "" {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Does the request match a given cookie domain?
|
||||||
|
reqMatch, reqHost := f.matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||||
|
|
||||||
|
// Do any of the auth hosts match a cookie domain?
|
||||||
|
authMatch, authHost := f.matchCookieDomains(f.AuthHost)
|
||||||
|
|
||||||
|
// We need both to match the same domain
|
||||||
|
return reqMatch && authMatch && reqHost == authHost, reqHost
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookie methods
|
// Cookie methods
|
||||||
|
|
||||||
// Create an auth cookie
|
// Create an auth cookie
|
||||||
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
func (f *ForwardAuth) MakeCookie(r *http.Request, email string) *http.Cookie {
|
||||||
expires := f.cookieExpiry()
|
expires := f.cookieExpiry()
|
||||||
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
mac := f.cookieSignature(r, email, fmt.Sprintf("%d", expires.Unix()))
|
||||||
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
value := fmt.Sprintf("%s|%d|%s", mac, expires.Unix(), email)
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: f.CookieName,
|
Name: f.CookieName,
|
||||||
Value: value,
|
Value: value,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.cookieDomain(r),
|
Domain: f.cookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: f.CookieSecure,
|
Secure: f.CookieSecure,
|
||||||
Expires: expires,
|
Expires: expires,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a CSRF cookie (used during login only)
|
// Make a CSRF cookie (used during login only)
|
||||||
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
func (f *ForwardAuth) MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: f.CSRFCookieName,
|
Name: f.CSRFCookieName,
|
||||||
Value: nonce,
|
Value: nonce,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.cookieDomain(r),
|
Domain: f.csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: f.CookieSecure,
|
Secure: f.CookieSecure,
|
||||||
Expires: f.cookieExpiry(),
|
Expires: f.cookieExpiry(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a cookie to clear csrf cookie
|
// Create a cookie to clear csrf cookie
|
||||||
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
func (f *ForwardAuth) ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: f.CSRFCookieName,
|
Name: f.CSRFCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: f.cookieDomain(r),
|
Domain: f.csrfCookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: f.CookieSecure,
|
Secure: f.CookieSecure,
|
||||||
Expires: time.Now().Local().Add(time.Hour * -1),
|
Expires: time.Now().Local().Add(time.Hour * -1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the csrf cookie against state
|
// Validate the csrf cookie against state
|
||||||
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
|
func (f *ForwardAuth) ValidateCSRFCookie(c *http.Cookie, state string) (bool, string, error) {
|
||||||
if len(c.Value) != 32 {
|
if len(c.Value) != 32 {
|
||||||
return false, "", errors.New("Invalid CSRF cookie value")
|
return false, "", errors.New("Invalid CSRF cookie value")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(state) < 34 {
|
if len(state) < 34 {
|
||||||
return false, "", errors.New("Invalid CSRF state value")
|
return false, "", errors.New("Invalid CSRF state value")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check nonce match
|
// Check nonce match
|
||||||
if c.Value != state[:32] {
|
if c.Value != state[:32] {
|
||||||
return false, "", errors.New("CSRF cookie does not match state")
|
return false, "", errors.New("CSRF cookie does not match state")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid, return redirect
|
// Valid, return redirect
|
||||||
return true, state[33:], nil
|
return true, state[33:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ForwardAuth) Nonce() (error, string) {
|
func (f *ForwardAuth) Nonce() (error, string) {
|
||||||
// Make nonce
|
// Make nonce
|
||||||
nonce := make([]byte, 16)
|
nonce := make([]byte, 16)
|
||||||
_, err := rand.Read(nonce)
|
_, err := rand.Read(nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, ""
|
return err, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Sprintf("%x", nonce)
|
return nil, fmt.Sprintf("%x", nonce)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookie domain
|
// Cookie domain
|
||||||
func (f *ForwardAuth) cookieDomain(r *http.Request) string {
|
func (f *ForwardAuth) cookieDomain(r *http.Request) string {
|
||||||
host := r.Header.Get("X-Forwarded-Host")
|
host := r.Header.Get("X-Forwarded-Host")
|
||||||
|
|
||||||
// Direct mode
|
// Check if any of the given cookie domains matches
|
||||||
if f.Direct {
|
_, domain := f.matchCookieDomains(host)
|
||||||
host = r.Host
|
return domain
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove port for matching
|
// Cookie domain
|
||||||
p := strings.Split(host, ":")
|
func (f *ForwardAuth) csrfCookieDomain(r *http.Request) string {
|
||||||
|
var host string
|
||||||
|
if use, domain := f.useAuthDomain(r); use {
|
||||||
|
host = domain
|
||||||
|
} else {
|
||||||
|
host = r.Header.Get("X-Forwarded-Host")
|
||||||
|
}
|
||||||
|
|
||||||
// Check if any of the given cookie domains matches
|
// Remove port
|
||||||
for _, domain := range f.CookieDomains {
|
p := strings.Split(host, ":")
|
||||||
if domain.Match(p[0]) {
|
return p[0]
|
||||||
return domain.Domain
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return p[0]
|
// Return matching cookie domain if exists
|
||||||
|
func (f *ForwardAuth) matchCookieDomains(domain string) (bool, string) {
|
||||||
|
// Remove port
|
||||||
|
p := strings.Split(domain, ":")
|
||||||
|
|
||||||
|
for _, d := range f.CookieDomains {
|
||||||
|
if d.Match(p[0]) {
|
||||||
|
return true, d.Domain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, p[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create cookie hmac
|
// Create cookie hmac
|
||||||
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
func (f *ForwardAuth) cookieSignature(r *http.Request, email, expires string) string {
|
||||||
hash := hmac.New(sha256.New, f.CookieSecret)
|
hash := hmac.New(sha256.New, f.Secret)
|
||||||
hash.Write([]byte(f.cookieDomain(r)))
|
hash.Write([]byte(f.cookieDomain(r)))
|
||||||
hash.Write([]byte(email))
|
hash.Write([]byte(email))
|
||||||
hash.Write([]byte(expires))
|
hash.Write([]byte(expires))
|
||||||
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
return base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get cookie expirary
|
// Get cookie expirary
|
||||||
func (f *ForwardAuth) cookieExpiry() time.Time {
|
func (f *ForwardAuth) cookieExpiry() time.Time {
|
||||||
return time.Now().Local().Add(f.Lifetime)
|
return time.Now().Local().Add(f.Lifetime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookie Domain
|
// Cookie Domain
|
||||||
|
|
||||||
// Cookie Domain
|
// Cookie Domain
|
||||||
type CookieDomain struct {
|
type CookieDomain struct {
|
||||||
Domain string
|
Domain string
|
||||||
DomainLen int
|
DomainLen int
|
||||||
SubDomain string
|
SubDomain string
|
||||||
SubDomainLen int
|
SubDomainLen int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCookieDomain(domain string) *CookieDomain {
|
func NewCookieDomain(domain string) *CookieDomain {
|
||||||
return &CookieDomain{
|
return &CookieDomain{
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
DomainLen: len(domain),
|
DomainLen: len(domain),
|
||||||
SubDomain: fmt.Sprintf(".%s", domain),
|
SubDomain: fmt.Sprintf(".%s", domain),
|
||||||
SubDomainLen: len(domain) + 1,
|
SubDomainLen: len(domain) + 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CookieDomain) Match(host string) bool {
|
func (c *CookieDomain) Match(host string) bool {
|
||||||
// Exact domain match?
|
// Exact domain match?
|
||||||
if host == c.Domain {
|
if host == c.Domain {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subdomain match?
|
// Subdomain match?
|
||||||
if len(host) >= c.SubDomainLen && host[len(host) - c.SubDomainLen:] == c.SubDomain {
|
if len(host) >= c.SubDomainLen && host[len(host)-c.SubDomainLen:] == c.SubDomain {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -1,138 +1,284 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "fmt"
|
// "fmt"
|
||||||
"time"
|
"net/http"
|
||||||
"reflect"
|
"net/url"
|
||||||
"testing"
|
"reflect"
|
||||||
"net/url"
|
"testing"
|
||||||
"net/http"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestValidateCookie(t *testing.T) {
|
func TestValidateCookie(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
fw = &ForwardAuth{}
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
// Should require 3 parts
|
// Should require 3 parts
|
||||||
c.Value = ""
|
c.Value = ""
|
||||||
valid, _, err := fw.ValidateCookie(r, c)
|
valid, _, err := fw.ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie format" {
|
if valid || err.Error() != "Invalid cookie format" {
|
||||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||||
}
|
}
|
||||||
c.Value = "1|2"
|
c.Value = "1|2"
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = fw.ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie format" {
|
if valid || err.Error() != "Invalid cookie format" {
|
||||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||||
}
|
}
|
||||||
c.Value = "1|2|3|4"
|
c.Value = "1|2|3|4"
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = fw.ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie format" {
|
if valid || err.Error() != "Invalid cookie format" {
|
||||||
t.Error("Should get \"Invalid cookie format\", got:", err)
|
t.Error("Should get \"Invalid cookie format\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should catch invalid mac
|
// Should catch invalid mac
|
||||||
c.Value = "MQ==|2|3"
|
c.Value = "MQ==|2|3"
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = fw.ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Invalid cookie mac" {
|
if valid || err.Error() != "Invalid cookie mac" {
|
||||||
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
t.Error("Should get \"Invalid cookie mac\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should catch expired
|
// Should catch expired
|
||||||
fw.Lifetime = time.Second * time.Duration(-1)
|
fw.Lifetime = time.Second * time.Duration(-1)
|
||||||
c = fw.MakeCookie(r, "test@test.com")
|
c = fw.MakeCookie(r, "test@test.com")
|
||||||
valid, _, err = fw.ValidateCookie(r, c)
|
valid, _, err = fw.ValidateCookie(r, c)
|
||||||
if valid || err.Error() != "Cookie has expired" {
|
if valid || err.Error() != "Cookie has expired" {
|
||||||
t.Error("Should get \"Cookie has expired\", got:", err)
|
t.Error("Should get \"Cookie has expired\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should accept valid cookie
|
// Should accept valid cookie
|
||||||
fw.Lifetime = time.Second * time.Duration(10)
|
fw.Lifetime = time.Second * time.Duration(10)
|
||||||
c = fw.MakeCookie(r, "test@test.com")
|
c = fw.MakeCookie(r, "test@test.com")
|
||||||
valid, email, err := fw.ValidateCookie(r, c)
|
valid, email, err := fw.ValidateCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("Valid request should return as valid")
|
t.Error("Valid request should return as valid")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Valid request should not return error, got:", err)
|
t.Error("Valid request should not return error, got:", err)
|
||||||
}
|
}
|
||||||
if email != "test@test.com" {
|
if email != "test@test.com" {
|
||||||
t.Error("Valid request should return user email")
|
t.Error("Valid request should return user email")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateEmail(t *testing.T) {
|
func TestValidateEmail(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
fw = &ForwardAuth{}
|
||||||
|
|
||||||
// Should allow any
|
// Should allow any
|
||||||
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
if !fw.ValidateEmail("test@test.com") || !fw.ValidateEmail("one@two.com") {
|
||||||
t.Error("Should allow any domain if email domain is not defined")
|
t.Error("Should allow any domain if email domain is not defined")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should block non matching domain
|
// Should block non matching domain
|
||||||
fw.Domain = []string{"test.com"}
|
fw.Domain = []string{"test.com"}
|
||||||
if fw.ValidateEmail("one@two.com") {
|
if fw.ValidateEmail("one@two.com") {
|
||||||
t.Error("Should not allow user from another domain")
|
t.Error("Should not allow user from another domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow matching domain
|
// Should allow matching domain
|
||||||
fw.Domain = []string{"test.com"}
|
fw.Domain = []string{"test.com"}
|
||||||
if !fw.ValidateEmail("test@test.com") {
|
if !fw.ValidateEmail("test@test.com") {
|
||||||
t.Error("Should allow user from allowed domain")
|
t.Error("Should allow user from allowed domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Should block non whitelisted email address
|
||||||
|
fw.Domain = []string{}
|
||||||
|
fw.Whitelist = []string{"test@test.com"}
|
||||||
|
if fw.ValidateEmail("one@two.com") {
|
||||||
|
t.Error("Should not allow user not in whitelist.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should allow matching whitelisted email address
|
||||||
|
fw.Domain = []string{}
|
||||||
|
fw.Whitelist = []string{"test@test.com"}
|
||||||
|
if !fw.ValidateEmail("test@test.com") {
|
||||||
|
t.Error("Should allow user in whitelist.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetLoginURL(t *testing.T) {
|
func TestGetLoginURL(t *testing.T) {
|
||||||
fw = &ForwardAuth{
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
Path: "/_oauth",
|
r.Header.Add("X-Forwarded-Proto", "http")
|
||||||
ClientId: "idtest",
|
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||||
ClientSecret: "sectest",
|
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||||
Scope: "scopetest",
|
|
||||||
LoginURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "test.com",
|
|
||||||
Path: "/auth",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
|
||||||
r.Header.Add("X-Forwarded-Proto", "http")
|
|
||||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
|
||||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
|
||||||
|
|
||||||
// Check url
|
fw = &ForwardAuth{
|
||||||
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
|
Path: "/_oauth",
|
||||||
if err != nil {
|
ClientId: "idtest",
|
||||||
t.Error("Error parsing login url:", err)
|
ClientSecret: "sectest",
|
||||||
}
|
Scope: "scopetest",
|
||||||
if uri.Scheme != "https" {
|
LoginURL: &url.URL{
|
||||||
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
Scheme: "https",
|
||||||
}
|
Host: "test.com",
|
||||||
if uri.Host != "test.com" {
|
Path: "/auth",
|
||||||
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
},
|
||||||
}
|
}
|
||||||
if uri.Path != "/auth" {
|
|
||||||
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check query string
|
// Check url
|
||||||
qs := uri.Query()
|
uri, err := url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||||
expectedQs := url.Values{
|
if err != nil {
|
||||||
"client_id": []string{"idtest"},
|
t.Error("Error parsing login url:", err)
|
||||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
}
|
||||||
"response_type": []string{"code"},
|
if uri.Scheme != "https" {
|
||||||
"scope": []string{"scopetest"},
|
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||||
"state": []string{"nonce:http://example.com/hello"},
|
}
|
||||||
}
|
if uri.Host != "test.com" {
|
||||||
if !reflect.DeepEqual(qs, expectedQs) {
|
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||||
t.Error("Incorrect login query string, expected:")
|
}
|
||||||
t.Error(expectedQs)
|
if uri.Path != "/auth" {
|
||||||
t.Error("Got:")
|
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||||
t.Error(qs)
|
}
|
||||||
}
|
|
||||||
|
// Check query string
|
||||||
|
qs := uri.Query()
|
||||||
|
expectedQs := url.Values{
|
||||||
|
"client_id": []string{"idtest"},
|
||||||
|
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"scope": []string{"scopetest"},
|
||||||
|
"state": []string{"nonce:http://example.com/hello"},
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(qs, expectedQs) {
|
||||||
|
t.Error("Incorrect login query string:")
|
||||||
|
qsDiff(expectedQs, qs)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// With Auth URL but no matching cookie domain
|
||||||
|
// - will not use auth host
|
||||||
|
//
|
||||||
|
fw = &ForwardAuth{
|
||||||
|
Path: "/_oauth",
|
||||||
|
AuthHost: "auth.example.com",
|
||||||
|
ClientId: "idtest",
|
||||||
|
ClientSecret: "sectest",
|
||||||
|
Scope: "scopetest",
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "test.com",
|
||||||
|
Path: "/auth",
|
||||||
|
},
|
||||||
|
Prompt: "consent select_account",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check url
|
||||||
|
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error parsing login url:", err)
|
||||||
|
}
|
||||||
|
if uri.Scheme != "https" {
|
||||||
|
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||||
|
}
|
||||||
|
if uri.Host != "test.com" {
|
||||||
|
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||||
|
}
|
||||||
|
if uri.Path != "/auth" {
|
||||||
|
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check query string
|
||||||
|
qs = uri.Query()
|
||||||
|
expectedQs = url.Values{
|
||||||
|
"client_id": []string{"idtest"},
|
||||||
|
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"scope": []string{"scopetest"},
|
||||||
|
"prompt": []string{"consent select_account"},
|
||||||
|
"state": []string{"nonce:http://example.com/hello"},
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(qs, expectedQs) {
|
||||||
|
t.Error("Incorrect login query string:")
|
||||||
|
qsDiff(expectedQs, qs)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// With correct Auth URL + cookie domain
|
||||||
|
//
|
||||||
|
cookieDomain := NewCookieDomain("example.com")
|
||||||
|
fw = &ForwardAuth{
|
||||||
|
Path: "/_oauth",
|
||||||
|
AuthHost: "auth.example.com",
|
||||||
|
ClientId: "idtest",
|
||||||
|
ClientSecret: "sectest",
|
||||||
|
Scope: "scopetest",
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "test.com",
|
||||||
|
Path: "/auth",
|
||||||
|
},
|
||||||
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check url
|
||||||
|
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error parsing login url:", err)
|
||||||
|
}
|
||||||
|
if uri.Scheme != "https" {
|
||||||
|
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||||
|
}
|
||||||
|
if uri.Host != "test.com" {
|
||||||
|
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||||
|
}
|
||||||
|
if uri.Path != "/auth" {
|
||||||
|
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check query string
|
||||||
|
qs = uri.Query()
|
||||||
|
expectedQs = url.Values{
|
||||||
|
"client_id": []string{"idtest"},
|
||||||
|
"redirect_uri": []string{"http://auth.example.com/_oauth"},
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"scope": []string{"scopetest"},
|
||||||
|
"state": []string{"nonce:http://example.com/hello"},
|
||||||
|
}
|
||||||
|
qsDiff(expectedQs, qs)
|
||||||
|
if !reflect.DeepEqual(qs, expectedQs) {
|
||||||
|
t.Error("Incorrect login query string:")
|
||||||
|
qsDiff(expectedQs, qs)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// With Auth URL + cookie domain, but from different domain
|
||||||
|
// - will not use auth host
|
||||||
|
//
|
||||||
|
r, _ = http.NewRequest("GET", "http://another.com", nil)
|
||||||
|
r.Header.Add("X-Forwarded-Proto", "http")
|
||||||
|
r.Header.Add("X-Forwarded-Host", "another.com")
|
||||||
|
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||||
|
|
||||||
|
// Check url
|
||||||
|
uri, err = url.Parse(fw.GetLoginURL(r, "nonce"))
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Error parsing login url:", err)
|
||||||
|
}
|
||||||
|
if uri.Scheme != "https" {
|
||||||
|
t.Error("Expected login Scheme to be \"https\", got:", uri.Scheme)
|
||||||
|
}
|
||||||
|
if uri.Host != "test.com" {
|
||||||
|
t.Error("Expected login Host to be \"test.com\", got:", uri.Host)
|
||||||
|
}
|
||||||
|
if uri.Path != "/auth" {
|
||||||
|
t.Error("Expected login Path to be \"/auth\", got:", uri.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check query string
|
||||||
|
qs = uri.Query()
|
||||||
|
expectedQs = url.Values{
|
||||||
|
"client_id": []string{"idtest"},
|
||||||
|
"redirect_uri": []string{"http://another.com/_oauth"},
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"scope": []string{"scopetest"},
|
||||||
|
"state": []string{"nonce:http://another.com/hello"},
|
||||||
|
}
|
||||||
|
qsDiff(expectedQs, qs)
|
||||||
|
if !reflect.DeepEqual(qs, expectedQs) {
|
||||||
|
t.Error("Incorrect login query string:")
|
||||||
|
qsDiff(expectedQs, qs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
// func TestExchangeCode(t *testing.T) {
|
// func TestExchangeCode(t *testing.T) {
|
||||||
// }
|
// }
|
||||||
@ -145,98 +291,124 @@ func TestGetLoginURL(t *testing.T) {
|
|||||||
// func TestMakeCookie(t *testing.T) {
|
// func TestMakeCookie(t *testing.T) {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// func TestMakeCSRFCookie(t *testing.T) {
|
func TestMakeCSRFCookie(t *testing.T) {
|
||||||
// t.Log("TODO")
|
r, _ := http.NewRequest("GET", "http://app.example.com", nil)
|
||||||
// }
|
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||||
|
|
||||||
|
// No cookie domain or auth url
|
||||||
|
fw = &ForwardAuth{}
|
||||||
|
c := fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
|
if c.Domain != "app.example.com" {
|
||||||
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With cookie domain but no auth url
|
||||||
|
cookieDomain := NewCookieDomain("example.com")
|
||||||
|
fw = &ForwardAuth{CookieDomains: []CookieDomain{*cookieDomain}}
|
||||||
|
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
|
if c.Domain != "app.example.com" {
|
||||||
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With cookie domain and auth url
|
||||||
|
fw = &ForwardAuth{
|
||||||
|
AuthHost: "auth.example.com",
|
||||||
|
CookieDomains: []CookieDomain{*cookieDomain},
|
||||||
|
}
|
||||||
|
c = fw.MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||||
|
if c.Domain != "example.com" {
|
||||||
|
t.Error("Cookie Domain should match request domain, got:", c.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClearCSRFCookie(t *testing.T) {
|
func TestClearCSRFCookie(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
fw = &ForwardAuth{}
|
||||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
|
||||||
c := fw.ClearCSRFCookie(r)
|
c := fw.ClearCSRFCookie(r)
|
||||||
if c.Value != "" {
|
if c.Value != "" {
|
||||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateCSRFCookie(t *testing.T) {
|
func TestValidateCSRFCookie(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
fw = &ForwardAuth{}
|
||||||
c := &http.Cookie{}
|
c := &http.Cookie{}
|
||||||
|
|
||||||
// Should require 32 char string
|
// Should require 32 char string
|
||||||
c.Value = ""
|
c.Value = ""
|
||||||
valid, _, err := fw.ValidateCSRFCookie(c, "")
|
valid, _, err := fw.ValidateCSRFCookie(c, "")
|
||||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||||
}
|
}
|
||||||
c.Value = "123456789012345678901234567890123"
|
c.Value = "123456789012345678901234567890123"
|
||||||
valid, _, err = fw.ValidateCSRFCookie(c, "")
|
valid, _, err = fw.ValidateCSRFCookie(c, "")
|
||||||
if valid || err.Error() != "Invalid CSRF cookie value" {
|
if valid || err.Error() != "Invalid CSRF cookie value" {
|
||||||
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
t.Error("Should get \"Invalid CSRF cookie value\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should require valid state
|
// Should require valid state
|
||||||
c.Value = "12345678901234567890123456789012"
|
c.Value = "12345678901234567890123456789012"
|
||||||
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
|
valid, _, err = fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:")
|
||||||
if valid || err.Error() != "Invalid CSRF state value" {
|
if valid || err.Error() != "Invalid CSRF state value" {
|
||||||
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
t.Error("Should get \"Invalid CSRF state value\", got:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow valid state
|
// Should allow valid state
|
||||||
c.Value = "12345678901234567890123456789012"
|
c.Value = "12345678901234567890123456789012"
|
||||||
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
|
valid, state, err := fw.ValidateCSRFCookie(c, "12345678901234567890123456789012:99")
|
||||||
if !valid {
|
if !valid {
|
||||||
t.Error("Valid request should return as valid")
|
t.Error("Valid request should return as valid")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Valid request should not return error, got:", err)
|
t.Error("Valid request should not return error, got:", err)
|
||||||
}
|
}
|
||||||
if state != "99" {
|
if state != "99" {
|
||||||
t.Error("Valid request should return correct state, got:", state)
|
t.Error("Valid request should return correct state, got:", state)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNonce(t *testing.T) {
|
func TestNonce(t *testing.T) {
|
||||||
fw = &ForwardAuth{}
|
fw = &ForwardAuth{}
|
||||||
|
|
||||||
err, nonce1 := fw.Nonce()
|
err, nonce1 := fw.Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error generation nonce:", err)
|
t.Error("Error generation nonce:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err, nonce2 := fw.Nonce()
|
err, nonce2 := fw.Nonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Error generation nonce:", err)
|
t.Error("Error generation nonce:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nonce1) != 32 || len(nonce2) != 32 {
|
if len(nonce1) != 32 || len(nonce2) != 32 {
|
||||||
t.Error("Nonce should be 32 chars")
|
t.Error("Nonce should be 32 chars")
|
||||||
}
|
}
|
||||||
if nonce1 == nonce2 {
|
if nonce1 == nonce2 {
|
||||||
t.Error("Nonce should not be equal")
|
t.Error("Nonce should not be equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCookieDomainMatch(t *testing.T) {
|
func TestCookieDomainMatch(t *testing.T) {
|
||||||
cd := NewCookieDomain("example.com")
|
cd := NewCookieDomain("example.com")
|
||||||
|
|
||||||
// Exact should match
|
// Exact should match
|
||||||
if !cd.Match("example.com") {
|
if !cd.Match("example.com") {
|
||||||
t.Error("Exact domain should match")
|
t.Error("Exact domain should match")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subdomain should match
|
// Subdomain should match
|
||||||
if !cd.Match("test.example.com") {
|
if !cd.Match("test.example.com") {
|
||||||
t.Error("Subdomain should match")
|
t.Error("Subdomain should match")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Derived domain should not match
|
// Derived domain should not match
|
||||||
if cd.Match("testexample.com") {
|
if cd.Match("testexample.com") {
|
||||||
t.Error("Derived domain should not match")
|
t.Error("Derived domain should not match")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Other domain should not match
|
// Other domain should not match
|
||||||
if cd.Match("test.com") {
|
if cd.Match("test.com") {
|
||||||
t.Error("Other domain should not match")
|
t.Error("Other domain should not match")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
48
log.go
Normal file
48
log.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateLogger(logLevel, logFormat string) logrus.FieldLogger {
|
||||||
|
// Setup logger
|
||||||
|
log := logrus.StandardLogger()
|
||||||
|
logrus.SetOutput(os.Stdout)
|
||||||
|
|
||||||
|
// Set logger format
|
||||||
|
switch logFormat {
|
||||||
|
case "pretty":
|
||||||
|
break
|
||||||
|
case "json":
|
||||||
|
logrus.SetFormatter(&logrus.JSONFormatter{})
|
||||||
|
// "text" is the default
|
||||||
|
default:
|
||||||
|
logrus.SetFormatter(&logrus.TextFormatter{
|
||||||
|
DisableColors: true,
|
||||||
|
FullTimestamp: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set logger level
|
||||||
|
switch logLevel {
|
||||||
|
case "trace":
|
||||||
|
logrus.SetLevel(logrus.TraceLevel)
|
||||||
|
case "debug":
|
||||||
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
|
case "info":
|
||||||
|
logrus.SetLevel(logrus.InfoLevel)
|
||||||
|
case "error":
|
||||||
|
logrus.SetLevel(logrus.ErrorLevel)
|
||||||
|
case "fatal":
|
||||||
|
logrus.SetLevel(logrus.FatalLevel)
|
||||||
|
case "panic":
|
||||||
|
logrus.SetLevel(logrus.PanicLevel)
|
||||||
|
// warn is the default
|
||||||
|
default:
|
||||||
|
logrus.SetLevel(logrus.WarnLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return log
|
||||||
|
}
|
368
main.go
368
main.go
@ -1,217 +1,239 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"encoding/json"
|
||||||
"time"
|
"fmt"
|
||||||
"strings"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"net/http"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/namsral/flag"
|
"github.com/namsral/flag"
|
||||||
"github.com/op/go-logging"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Vars
|
// Vars
|
||||||
var fw *ForwardAuth;
|
var fw *ForwardAuth
|
||||||
var log = logging.MustGetLogger("traefik-forward-auth")
|
var log logrus.FieldLogger
|
||||||
|
|
||||||
// Primary handler
|
// Primary handler
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
// Parse uri
|
// Logging setup
|
||||||
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
logger := log.WithFields(logrus.Fields{
|
||||||
if err != nil {
|
"RemoteAddr": r.RemoteAddr,
|
||||||
log.Error("Error parsing url")
|
})
|
||||||
http.Error(w, "Service unavailable", 503)
|
logger.WithFields(logrus.Fields{
|
||||||
return
|
"Headers": r.Header,
|
||||||
}
|
}).Debugf("Handling request")
|
||||||
|
|
||||||
// Direct mode
|
// Parse uri
|
||||||
if fw.Direct {
|
uri, err := url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||||
uri = r.URL
|
if err != nil {
|
||||||
}
|
logger.Errorf("Error parsing X-Forwarded-Uri, %v", err)
|
||||||
|
http.Error(w, "Service unavailable", 503)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Handle callback
|
// Handle callback
|
||||||
if uri.Path == fw.Path {
|
if uri.Path == fw.Path {
|
||||||
handleCallback(w, r, uri.Query())
|
logger.Debugf("Passing request to auth callback")
|
||||||
return
|
handleCallback(w, r, uri.Query(), logger)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c, err := r.Cookie(fw.CookieName)
|
// Get auth cookie
|
||||||
if err != nil {
|
c, err := r.Cookie(fw.CookieName)
|
||||||
// Error indicates no cookie, generate nonce
|
if err != nil {
|
||||||
err, nonce := fw.Nonce()
|
// Error indicates no cookie, generate nonce
|
||||||
if err != nil {
|
err, nonce := fw.Nonce()
|
||||||
log.Error("Error generating nonce")
|
if err != nil {
|
||||||
http.Error(w, "Service unavailable", 503)
|
logger.Errorf("Error generating nonce, %v", err)
|
||||||
return
|
http.Error(w, "Service unavailable", 503)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Set the CSRF cookie
|
// Set the CSRF cookie
|
||||||
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
http.SetCookie(w, fw.MakeCSRFCookie(r, nonce))
|
||||||
log.Debug("Set CSRF cookie and redirecting to google login")
|
logger.Debug("Set CSRF cookie and redirecting to google login")
|
||||||
|
|
||||||
// Forward them on
|
// Forward them on
|
||||||
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
http.Redirect(w, r, fw.GetLoginURL(r, nonce), http.StatusTemporaryRedirect)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate cookie
|
// Validate cookie
|
||||||
valid, email, err := fw.ValidateCookie(r, c)
|
valid, email, err := fw.ValidateCookie(r, c)
|
||||||
if !valid {
|
if !valid {
|
||||||
log.Debugf("Invlaid cookie: %s", err)
|
logger.Errorf("Invalid cookie: %v", err)
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate user
|
// Validate user
|
||||||
valid = fw.ValidateEmail(email)
|
valid = fw.ValidateEmail(email)
|
||||||
if !valid {
|
if !valid {
|
||||||
log.Debugf("Invalid email: %s", email)
|
logger.WithFields(logrus.Fields{
|
||||||
http.Error(w, "Not authorized", 401)
|
"email": email,
|
||||||
return
|
}).Errorf("Invalid email")
|
||||||
}
|
http.Error(w, "Not authorized", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Valid request
|
// Valid request
|
||||||
w.WriteHeader(200)
|
logger.Debugf("Allowing valid request ")
|
||||||
|
w.Header().Set("X-Forwarded-User", email)
|
||||||
|
w.WriteHeader(200)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Authenticate user after they have come back from google
|
// Authenticate user after they have come back from google
|
||||||
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values) {
|
func handleCallback(w http.ResponseWriter, r *http.Request, qs url.Values,
|
||||||
// Check for CSRF cookie
|
logger logrus.FieldLogger) {
|
||||||
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
// Check for CSRF cookie
|
||||||
if err != nil {
|
csrfCookie, err := r.Cookie(fw.CSRFCookieName)
|
||||||
log.Debug("Missing csrf cookie")
|
if err != nil {
|
||||||
http.Error(w, "Not authorized", 401)
|
logger.Warn("Missing csrf cookie")
|
||||||
return
|
http.Error(w, "Not authorized", 401)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Validate state
|
// Validate state
|
||||||
state := qs.Get("state")
|
state := qs.Get("state")
|
||||||
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
valid, redirect, err := fw.ValidateCSRFCookie(csrfCookie, state)
|
||||||
if !valid {
|
if !valid {
|
||||||
log.Debugf("Invalid oauth state, expected '%s', got '%s'\n", csrfCookie.Value, state)
|
logger.WithFields(logrus.Fields{
|
||||||
http.Error(w, "Not authorized", 401)
|
"csrf": csrfCookie.Value,
|
||||||
return
|
"state": state,
|
||||||
}
|
}).Warnf("Error validating csrf cookie: %v", err)
|
||||||
|
http.Error(w, "Not authorized", 401)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Clear CSRF cookie
|
// Clear CSRF cookie
|
||||||
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
http.SetCookie(w, fw.ClearCSRFCookie(r))
|
||||||
|
|
||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
token, err := fw.ExchangeCode(r, qs.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Code exchange failed with: %s\n", err)
|
logger.Errorf("Code exchange failed with: %v", err)
|
||||||
http.Error(w, "Service unavailable", 503)
|
http.Error(w, "Service unavailable", 503)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user
|
// Get user
|
||||||
user, err := fw.GetUser(token)
|
user, err := fw.GetUser(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("Error getting user: %s\n", err)
|
logger.Errorf("Error getting user: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate cookie
|
// Generate cookie
|
||||||
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
http.SetCookie(w, fw.MakeCookie(r, user.Email))
|
||||||
log.Debugf("Generated auth cookie for %s\n", user.Email)
|
logger.WithFields(logrus.Fields{
|
||||||
|
"user": user.Email,
|
||||||
|
}).Infof("Generated auth cookie")
|
||||||
|
|
||||||
// Redirect
|
// Redirect
|
||||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Main
|
// Main
|
||||||
func main() {
|
func main() {
|
||||||
// Parse options
|
// Parse options
|
||||||
flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
|
flag.String(flag.DefaultConfigFlagname, "", "Path to config file")
|
||||||
path := flag.String("url-path", "_oauth", "Callback URL")
|
path := flag.String("url-path", "_oauth", "Callback URL")
|
||||||
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
lifetime := flag.Int("lifetime", 43200, "Session length in seconds")
|
||||||
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
secret := flag.String("secret", "", "*Secret used for signing (required)")
|
||||||
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
authHost := flag.String("auth-host", "", "Central auth login")
|
||||||
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
clientId := flag.String("client-id", "", "*Google Client ID (required)")
|
||||||
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
clientSecret := flag.String("client-secret", "", "*Google Client Secret (required)")
|
||||||
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
cookieName := flag.String("cookie-name", "_forward_auth", "Cookie Name")
|
||||||
cookieSecret := flag.String("cookie-secret", "", "*Cookie secret (required)")
|
cSRFCookieName := flag.String("csrf-cookie-name", "_forward_auth_csrf", "CSRF Cookie Name")
|
||||||
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
cookieDomainList := flag.String("cookie-domains", "", "Comma separated list of cookie domains") //todo
|
||||||
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
cookieSecret := flag.String("cookie-secret", "", "Deprecated")
|
||||||
direct := flag.Bool("direct", false, "Run in direct mode (use own hostname as oppose to X-Forwarded-Host, used for testing/development)")
|
cookieSecure := flag.Bool("cookie-secure", true, "Use secure cookies")
|
||||||
|
domainList := flag.String("domain", "", "Comma separated list of email domains to allow")
|
||||||
|
emailWhitelist := flag.String("whitelist", "", "Comma separated list of emails to allow")
|
||||||
|
prompt := flag.String("prompt", "", "Space separated list of OpenID prompt options")
|
||||||
|
logLevel := flag.String("log-level", "warn", "Log level: trace, debug, info, warn, error, fatal, panic")
|
||||||
|
logFormat := flag.String("log-format", "text", "Log format: text, json, pretty")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
// Check for show stopper errors
|
// Setup logger
|
||||||
err := false
|
log = CreateLogger(*logLevel, *logFormat)
|
||||||
if *clientId == "" {
|
|
||||||
err = true
|
|
||||||
log.Critical("client-id must be set")
|
|
||||||
}
|
|
||||||
if *clientSecret == "" {
|
|
||||||
err = true
|
|
||||||
log.Critical("client-secret must be set")
|
|
||||||
}
|
|
||||||
if *cookieSecret == "" {
|
|
||||||
err = true
|
|
||||||
log.Critical("cookie-secret must be set")
|
|
||||||
}
|
|
||||||
if err {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse lists
|
// Backwards compatibility
|
||||||
var cookieDomains []CookieDomain
|
if *secret == "" && *cookieSecret != "" {
|
||||||
if *cookieDomainList != "" {
|
*secret = *cookieSecret
|
||||||
for _, d := range strings.Split(*cookieDomainList, ",") {
|
}
|
||||||
cookieDomain := NewCookieDomain(d)
|
|
||||||
cookieDomains = append(cookieDomains, *cookieDomain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var domain []string
|
// Check for show stopper errors
|
||||||
if *domainList != "" {
|
if *clientId == "" || *clientSecret == "" || *secret == "" {
|
||||||
domain = strings.Split(*domainList, ",")
|
log.Fatal("client-id, client-secret and secret must all be set")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup
|
// Parse lists
|
||||||
fw = &ForwardAuth{
|
var cookieDomains []CookieDomain
|
||||||
Path: fmt.Sprintf("/%s", *path),
|
if *cookieDomainList != "" {
|
||||||
Lifetime: time.Second * time.Duration(*lifetime),
|
for _, d := range strings.Split(*cookieDomainList, ",") {
|
||||||
|
cookieDomain := NewCookieDomain(d)
|
||||||
|
cookieDomains = append(cookieDomains, *cookieDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ClientId: *clientId,
|
var domain []string
|
||||||
ClientSecret: *clientSecret,
|
if *domainList != "" {
|
||||||
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
domain = strings.Split(*domainList, ",")
|
||||||
LoginURL: &url.URL{
|
}
|
||||||
Scheme: "https",
|
var whitelist []string
|
||||||
Host: "accounts.google.com",
|
if *emailWhitelist != "" {
|
||||||
Path: "/o/oauth2/auth",
|
whitelist = strings.Split(*emailWhitelist, ",")
|
||||||
},
|
}
|
||||||
TokenURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "www.googleapis.com",
|
|
||||||
Path: "/oauth2/v3/token",
|
|
||||||
},
|
|
||||||
UserURL: &url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: "www.googleapis.com",
|
|
||||||
Path: "/oauth2/v2/userinfo",
|
|
||||||
},
|
|
||||||
|
|
||||||
CookieName: *cookieName,
|
// Setup
|
||||||
CSRFCookieName: *cSRFCookieName,
|
fw = &ForwardAuth{
|
||||||
CookieDomains: cookieDomains,
|
Path: fmt.Sprintf("/%s", *path),
|
||||||
CookieSecret: []byte(*cookieSecret),
|
Lifetime: time.Second * time.Duration(*lifetime),
|
||||||
CookieSecure: *cookieSecure,
|
Secret: []byte(*secret),
|
||||||
|
AuthHost: *authHost,
|
||||||
|
|
||||||
Domain: domain,
|
ClientId: *clientId,
|
||||||
|
ClientSecret: *clientSecret,
|
||||||
|
Scope: "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "accounts.google.com",
|
||||||
|
Path: "/o/oauth2/auth",
|
||||||
|
},
|
||||||
|
TokenURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "www.googleapis.com",
|
||||||
|
Path: "/oauth2/v3/token",
|
||||||
|
},
|
||||||
|
UserURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "www.googleapis.com",
|
||||||
|
Path: "/oauth2/v2/userinfo",
|
||||||
|
},
|
||||||
|
|
||||||
Direct: *direct,
|
CookieName: *cookieName,
|
||||||
}
|
CSRFCookieName: *cSRFCookieName,
|
||||||
|
CookieDomains: cookieDomains,
|
||||||
|
CookieSecure: *cookieSecure,
|
||||||
|
|
||||||
// Attach handler
|
Domain: domain,
|
||||||
http.HandleFunc("/", handler)
|
Whitelist: whitelist,
|
||||||
|
|
||||||
log.Notice("Litening on :4181")
|
Prompt: *prompt,
|
||||||
log.Notice(http.ListenAndServe(":4181", nil))
|
}
|
||||||
|
|
||||||
|
// Attach handler
|
||||||
|
http.HandleFunc("/", handler)
|
||||||
|
|
||||||
|
// Start
|
||||||
|
jsonConf, _ := json.Marshal(fw)
|
||||||
|
log.Debugf("Starting with options: %s", string(jsonConf))
|
||||||
|
log.Info("Listening on :4181")
|
||||||
|
log.Info(http.ListenAndServe(":4181", nil))
|
||||||
}
|
}
|
||||||
|
289
main_test.go
289
main_test.go
@ -1,29 +1,31 @@
|
|||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
// "reflect"
|
// "reflect"
|
||||||
"strings"
|
"io/ioutil"
|
||||||
"testing"
|
"net/http"
|
||||||
"net/url"
|
"net/http/httptest"
|
||||||
"net/http"
|
"net/url"
|
||||||
"io/ioutil"
|
"strings"
|
||||||
"net/http/httptest"
|
"testing"
|
||||||
|
|
||||||
"github.com/op/go-logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utilities
|
||||||
|
*/
|
||||||
|
|
||||||
|
type TokenServerHandler struct{}
|
||||||
|
|
||||||
type TokenServerHandler struct {}
|
|
||||||
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
fmt.Fprint(w, `{"access_token":"123456789"}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserServerHandler struct {}
|
type UserServerHandler struct{}
|
||||||
|
|
||||||
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, `{
|
fmt.Fprint(w, `{
|
||||||
"id":"1",
|
"id":"1",
|
||||||
"email":"example@example.com",
|
"email":"example@example.com",
|
||||||
"verified_email":true,
|
"verified_email":true,
|
||||||
@ -32,150 +34,177 @@ func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Remove for debugging
|
log = CreateLogger("panic", "")
|
||||||
logging.SetLevel(logging.INFO, "traefik-forward-auth")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
func httpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Set cookies on recorder
|
// Set cookies on recorder
|
||||||
if c != nil {
|
if c != nil {
|
||||||
http.SetCookie(w, c)
|
http.SetCookie(w, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy into request
|
// Copy into request
|
||||||
for _, c := range w.HeaderMap["Set-Cookie"] {
|
for _, c := range w.HeaderMap["Set-Cookie"] {
|
||||||
r.Header.Add("Cookie", c)
|
r.Header.Add("Cookie", c)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
|
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
body, _ := ioutil.ReadAll(res.Body)
|
body, _ := ioutil.ReadAll(res.Body)
|
||||||
|
|
||||||
return res, string(body)
|
return res, string(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHttpRequest(uri string) *http.Request {
|
func newHttpRequest(uri string) *http.Request {
|
||||||
r := httptest.NewRequest("", "http://example.com", nil)
|
r := httptest.NewRequest("", "http://example.com", nil)
|
||||||
r.Header.Add("X-Forwarded-Uri", uri)
|
r.Header.Add("X-Forwarded-Uri", uri)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func qsDiff(one, two url.Values) {
|
||||||
|
for k := range one {
|
||||||
|
if two.Get(k) == "" {
|
||||||
|
fmt.Printf("Key missing: %s\n", k)
|
||||||
|
}
|
||||||
|
if one.Get(k) != two.Get(k) {
|
||||||
|
fmt.Printf("Value different for %s: expected: '%s' got: '%s'\n", k, one.Get(k), two.Get(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range two {
|
||||||
|
if one.Get(k) == "" {
|
||||||
|
fmt.Printf("Extra key: %s\n", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests
|
||||||
|
*/
|
||||||
|
|
||||||
func TestHandler(t *testing.T) {
|
func TestHandler(t *testing.T) {
|
||||||
fw = &ForwardAuth{
|
fw = &ForwardAuth{
|
||||||
Path: "_oauth",
|
Path: "_oauth",
|
||||||
ClientId: "idtest",
|
ClientId: "idtest",
|
||||||
ClientSecret: "sectest",
|
ClientSecret: "sectest",
|
||||||
Scope: "scopetest",
|
Scope: "scopetest",
|
||||||
LoginURL: &url.URL{
|
LoginURL: &url.URL{
|
||||||
Scheme: "http",
|
Scheme: "http",
|
||||||
Host: "test.com",
|
Host: "test.com",
|
||||||
Path: "/auth",
|
Path: "/auth",
|
||||||
},
|
},
|
||||||
CookieName: "cookie_test",
|
CookieName: "cookie_test",
|
||||||
Lifetime: time.Second * time.Duration(10),
|
Lifetime: time.Second * time.Duration(10),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should redirect vanilla request to login url
|
// Should redirect vanilla request to login url
|
||||||
req := newHttpRequest("foo")
|
req := newHttpRequest("foo")
|
||||||
res, _ := httpRequest(req, nil)
|
res, _ := httpRequest(req, nil)
|
||||||
if res.StatusCode != 307 {
|
if res.StatusCode != 307 {
|
||||||
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
t.Error("Vanilla request should be redirected with 307, got:", res.StatusCode)
|
||||||
}
|
}
|
||||||
fwd, _ := res.Location()
|
fwd, _ := res.Location()
|
||||||
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
if fwd.Scheme != "http" || fwd.Host != "test.com" || fwd.Path != "/auth" {
|
||||||
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
t.Error("Vanilla request should be redirected to login url, got:", fwd)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should catch invalid cookie
|
// Should catch invalid cookie
|
||||||
req = newHttpRequest("foo")
|
req = newHttpRequest("foo")
|
||||||
|
|
||||||
c := fw.MakeCookie(req, "test@example.com")
|
c := fw.MakeCookie(req, "test@example.com")
|
||||||
parts := strings.Split(c.Value, "|")
|
parts := strings.Split(c.Value, "|")
|
||||||
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
c.Value = fmt.Sprintf("bad|%s|%s", parts[1], parts[2])
|
||||||
|
|
||||||
res, _ = httpRequest(req, c)
|
res, _ = httpRequest(req, c)
|
||||||
if res.StatusCode != 401 {
|
if res.StatusCode != 401 {
|
||||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should validate email
|
// Should validate email
|
||||||
req = newHttpRequest("foo")
|
req = newHttpRequest("foo")
|
||||||
|
|
||||||
c = fw.MakeCookie(req, "test@example.com")
|
c = fw.MakeCookie(req, "test@example.com")
|
||||||
fw.Domain = []string{"test.com"}
|
fw.Domain = []string{"test.com"}
|
||||||
|
|
||||||
res, _ = httpRequest(req, c)
|
res, _ = httpRequest(req, c)
|
||||||
if res.StatusCode != 401 {
|
if res.StatusCode != 401 {
|
||||||
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
t.Error("Request with invalid cookie shound't be authorised", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should allow valid request email
|
// Should allow valid request email
|
||||||
req = newHttpRequest("foo")
|
req = newHttpRequest("foo")
|
||||||
|
|
||||||
c = fw.MakeCookie(req, "test@example.com")
|
c = fw.MakeCookie(req, "test@example.com")
|
||||||
fw.Domain = []string{}
|
fw.Domain = []string{}
|
||||||
|
|
||||||
res, _ = httpRequest(req, c)
|
res, _ = httpRequest(req, c)
|
||||||
if res.StatusCode != 200 {
|
if res.StatusCode != 200 {
|
||||||
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
t.Error("Valid request should be allowed, got:", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Should pass through user
|
||||||
|
users := res.Header["X-Forwarded-User"]
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Error("Valid request missing X-Forwarded-User header")
|
||||||
|
} else if users[0] != "test@example.com" {
|
||||||
|
t.Error("X-Forwarded-User should match user, got: ", users)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCallback(t *testing.T) {
|
func TestCallback(t *testing.T) {
|
||||||
fw = &ForwardAuth{
|
fw = &ForwardAuth{
|
||||||
Path: "_oauth",
|
Path: "_oauth",
|
||||||
ClientId: "idtest",
|
ClientId: "idtest",
|
||||||
ClientSecret: "sectest",
|
ClientSecret: "sectest",
|
||||||
Scope: "scopetest",
|
Scope: "scopetest",
|
||||||
LoginURL: &url.URL{
|
LoginURL: &url.URL{
|
||||||
Scheme: "http",
|
Scheme: "http",
|
||||||
Host: "test.com",
|
Host: "test.com",
|
||||||
Path: "/auth",
|
Path: "/auth",
|
||||||
},
|
},
|
||||||
CSRFCookieName: "csrf_test",
|
CSRFCookieName: "csrf_test",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup token server
|
// Setup token server
|
||||||
tokenServerHandler := &TokenServerHandler{}
|
tokenServerHandler := &TokenServerHandler{}
|
||||||
tokenServer := httptest.NewServer(tokenServerHandler)
|
tokenServer := httptest.NewServer(tokenServerHandler)
|
||||||
defer tokenServer.Close()
|
defer tokenServer.Close()
|
||||||
tokenUrl, _ := url.Parse(tokenServer.URL)
|
tokenUrl, _ := url.Parse(tokenServer.URL)
|
||||||
fw.TokenURL = tokenUrl
|
fw.TokenURL = tokenUrl
|
||||||
|
|
||||||
// Setup user server
|
// Setup user server
|
||||||
userServerHandler := &UserServerHandler{}
|
userServerHandler := &UserServerHandler{}
|
||||||
userServer := httptest.NewServer(userServerHandler)
|
userServer := httptest.NewServer(userServerHandler)
|
||||||
defer userServer.Close()
|
defer userServer.Close()
|
||||||
userUrl, _ := url.Parse(userServer.URL)
|
userUrl, _ := url.Parse(userServer.URL)
|
||||||
fw.UserURL = userUrl
|
fw.UserURL = userUrl
|
||||||
|
|
||||||
// Should pass auth response request to callback
|
// Should pass auth response request to callback
|
||||||
req := newHttpRequest("_oauth")
|
req := newHttpRequest("_oauth")
|
||||||
res, _ := httpRequest(req, nil)
|
res, _ := httpRequest(req, nil)
|
||||||
if res.StatusCode != 401 {
|
if res.StatusCode != 401 {
|
||||||
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
t.Error("Auth callback without cookie shound't be authorised, got:", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should catch invalid csrf cookie
|
// Should catch invalid csrf cookie
|
||||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
c := fw.MakeCSRFCookie(req, "nononononononononononononononono")
|
||||||
res, _ = httpRequest(req, c)
|
res, _ = httpRequest(req, c)
|
||||||
if res.StatusCode != 401 {
|
if res.StatusCode != 401 {
|
||||||
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
t.Error("Auth callback with invalid cookie shound't be authorised, got:", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should redirect valid request
|
// Should redirect valid request
|
||||||
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
req = newHttpRequest("_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||||
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
c = fw.MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||||
res, _ = httpRequest(req, c)
|
res, _ = httpRequest(req, c)
|
||||||
if res.StatusCode != 307 {
|
if res.StatusCode != 307 {
|
||||||
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
t.Error("Valid callback should be allowed, got:", res.StatusCode)
|
||||||
}
|
}
|
||||||
fwd, _ := res.Location()
|
fwd, _ := res.Location()
|
||||||
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
if fwd.Scheme != "http" || fwd.Host != "redirect" || fwd.Path != "" {
|
||||||
t.Error("Valid request should be redirected to return url, got:", fwd)
|
t.Error("Valid request should be redirected to return url, got:", fwd)
|
||||||
}
|
}
|
||||||
}
|
}
|
Reference in New Issue
Block a user