Multiple provider support + OIDC provider

This commit is contained in:
Thom Seddon
2019-09-18 17:55:52 +01:00
parent 5dfd4f2878
commit c9289d6fc1
16 changed files with 1043 additions and 278 deletions

View File

@ -2,13 +2,15 @@ package provider
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
)
// Google provider
type Google struct {
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
Scope string
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
@ -18,15 +20,48 @@ type Google struct {
UserURL *url.URL
}
func (g *Google) GetLoginURL(redirectUri, state string) string {
// Name returns the name of the provider
func (g *Google) Name() string {
return "google"
}
// Setup performs validation and setup
func (g *Google) Setup() error {
if g.ClientID == "" || g.ClientSecret == "" {
return errors.New("providers.google.client-id, providers.google.client-secret must be set")
}
// Set static values
g.Scope = "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email"
g.LoginURL = &url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}
g.TokenURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}
g.UserURL = &url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}
return nil
}
// GetLoginURL provides the login url for the given redirect uri and state
func (g *Google) GetLoginURL(redirectURI, state string) string {
q := url.Values{}
q.Set("client_id", g.ClientId)
q.Set("client_id", g.ClientID)
q.Set("response_type", "code")
q.Set("scope", g.Scope)
if g.Prompt != "" {
q.Set("prompt", g.Prompt)
}
q.Set("redirect_uri", redirectUri)
q.Set("redirect_uri", redirectURI)
q.Set("state", state)
var u url.URL
@ -36,12 +71,13 @@ func (g *Google) GetLoginURL(redirectUri, state string) string {
return u.String()
}
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
// ExchangeCode exchanges the given redirect uri and code for a token
func (g *Google) ExchangeCode(redirectURI, code string) (string, error) {
form := url.Values{}
form.Set("client_id", g.ClientId)
form.Set("client_id", g.ClientID)
form.Set("client_secret", g.ClientSecret)
form.Set("grant_type", "authorization_code")
form.Set("redirect_uri", redirectUri)
form.Set("redirect_uri", redirectURI)
form.Set("code", code)
res, err := http.PostForm(g.TokenURL.String(), form)
@ -49,13 +85,14 @@ func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
return "", err
}
var token Token
var token token
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(&token)
return token.Token, err
}
// GetUser uses the given token and returns a complete provider.User object
func (g *Google) GetUser(token string) (User, error) {
var user User

View File

@ -0,0 +1,151 @@
package provider
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
// Tests
func TestGoogleName(t *testing.T) {
p := Google{}
assert.Equal(t, "google", p.Name())
}
func TestGoogleSetup(t *testing.T) {
assert := assert.New(t)
p := Google{}
// Check validation
err := p.Setup()
if assert.Error(err) {
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", err.Error())
}
// Check setup
p = Google{
ClientID: "id",
ClientSecret: "secret",
}
err = p.Setup()
assert.Nil(err)
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", p.Scope)
assert.Equal("", p.Prompt)
assert.Equal(&url.URL{
Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
}, p.LoginURL)
assert.Equal(&url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token",
}, p.TokenURL)
assert.Equal(&url.URL{
Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v2/userinfo",
}, p.UserURL)
}
func TestGoogleGetLoginURL(t *testing.T) {
assert := assert.New(t)
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
LoginURL: &url.URL{
Scheme: "https",
Host: "google.com",
Path: "/auth",
},
}
// Check url
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("google.com", uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"scopetest"},
"prompt": []string{"consent select_account"},
"state": []string{"state"},
}
assert.Equal(expectedQs, qs)
}
func TestGoogleExchangeCode(t *testing.T) {
assert := assert.New(t)
// Setup server
expected := url.Values{
"client_id": []string{"idtest"},
"client_secret": []string{"sectest"},
"code": []string{"code"},
"grant_type": []string{"authorization_code"},
"redirect_uri": []string{"http://example.com/_oauth"},
}
server, serverURL := NewOAuthServer(t, map[string]string{
"token": expected.Encode(),
})
defer server.Close()
// Setup provider
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
TokenURL: &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/token",
},
}
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("123456789", token)
}
func TestGoogleGetUser(t *testing.T) {
assert := assert.New(t)
// Setup server
server, serverURL := NewOAuthServer(t, nil)
defer server.Close()
// Setup provider
p := Google{
ClientID: "idtest",
ClientSecret: "sectest",
Scope: "scopetest",
Prompt: "consent select_account",
UserURL: &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
Path: "/userinfo",
},
}
user, err := p.GetUser("123456789")
assert.Nil(err)
assert.Equal("1", user.ID)
assert.Equal("example@example.com", user.Email)
assert.True(user.Verified)
assert.Equal("example.com", user.Hd)
}

108
internal/provider/oidc.go Normal file
View File

@ -0,0 +1,108 @@
package provider
import (
"context"
"errors"
"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
)
// OIDC provider
type OIDC struct {
OAuthProvider
IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"`
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
}
// Name returns the name of the provider
func (o *OIDC) Name() string {
return "oidc"
}
// Setup performs validation and setup
func (o *OIDC) Setup() error {
// Check parms
if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" {
return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set")
}
var err error
o.ctx = context.Background()
// Try to initiate provider
o.provider, err = oidc.NewProvider(o.ctx, o.IssuerURL)
if err != nil {
return err
}
// Create oauth2 config
o.Config = &oauth2.Config{
ClientID: o.ClientID,
ClientSecret: o.ClientSecret,
Endpoint: o.provider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
// Create OIDC verifier
o.verifier = o.provider.Verifier(&oidc.Config{
ClientID: o.ClientID,
})
return nil
}
// GetLoginURL provides the login url for the given redirect uri and state
func (o *OIDC) GetLoginURL(redirectURI, state string) string {
return o.OAuthGetLoginURL(redirectURI, state)
}
// ExchangeCode exchanges the given redirect uri and code for a token
func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
token, err := o.OAuthExchangeCode(redirectURI, code)
if err != nil {
return "", err
}
// Extract ID token
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return "", errors.New("Missing id_token")
}
return rawIDToken, nil
}
// GetUser uses the given token and returns a complete provider.User object
func (o *OIDC) GetUser(token string) (User, error) {
var user User
// Parse & Verify ID Token
idToken, err := o.verifier.Verify(o.ctx, token)
if err != nil {
return user, err
}
// Extract custom claims
var claims struct {
ID string `json:"sub"`
Email string `json:"email"`
Verified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
return user, err
}
user.ID = claims.ID
user.Email = claims.Email
user.Verified = claims.Verified
return user, nil
}

View File

@ -0,0 +1,252 @@
package provider
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
jose "gopkg.in/square/go-jose.v2"
)
// Tests
func TestOIDCName(t *testing.T) {
p := OIDC{}
assert.Equal(t, "oidc", p.Name())
}
func TestOIDCSetup(t *testing.T) {
assert := assert.New(t)
p := OIDC{}
err := p.Setup()
if assert.Error(err) {
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
}
}
func TestOIDCGetLoginURL(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, _ := setupOIDCTest(t, nil)
defer server.Close()
// Check url
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal(serverURL.Scheme, uri.Scheme)
assert.Equal(serverURL.Host, uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"openid profile email"},
"state": []string{"state"},
}
assert.Equal(expectedQs, qs)
// Calling the method should not modify the underlying config
assert.Equal("", provider.Config.RedirectURL)
}
func TestOIDCExchangeCode(t *testing.T) {
assert := assert.New(t)
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
"token": {
"code": "code",
"grant_type": "authorization_code",
"redirect_uri": "http://example.com/_oauth",
},
})
defer server.Close()
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("id_123456789", token)
}
func TestOIDCGetUser(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, key := setupOIDCTest(t, nil)
defer server.Close()
// Generate JWT
token := key.sign(t, []byte(`{
"iss": "`+serverURL.String()+`",
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
"aud": "idtest",
"sub": "1",
"email": "example@example.com",
"email_verified": true
}`))
// Get user
user, err := provider.GetUser(token)
assert.Nil(err)
assert.Equal("1", user.ID)
assert.Equal("example@example.com", user.Email)
assert.True(user.Verified)
}
// Utils
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
// Generate key
key, err := newRSAKey()
if err != nil {
t.Fatal(err)
}
body := make(map[string]string)
if bodyValues != nil {
// URL encode bodyValues into body
for method, values := range bodyValues {
q := url.Values{}
for k, v := range values {
q.Set(k, v)
}
body[method] = q.Encode()
}
}
// Set up oidc server
server, serverURL := NewOIDCServer(t, key, body)
// Setup provider
p := OIDC{
ClientID: "idtest",
ClientSecret: "sectest",
IssuerURL: serverURL.String(),
}
// Initialise config/verifier
err = p.Setup()
if err != nil {
t.Fatal(err)
}
return &p, server, serverURL, key
}
// OIDCServer is used in the OIDC Tests to mock an OIDC server
type OIDCServer struct {
t *testing.T
url *url.URL
body map[string]string // method -> body
key *rsaKey
}
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
handler := &OIDCServer{t: t, key: key, body: body}
server := httptest.NewServer(handler)
handler.url, _ = url.Parse(server.URL)
return server, handler.url
}
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
if r.URL.Path == "/.well-known/openid-configuration" {
// Open id config
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"issuer":"`+s.url.String()+`",
"authorization_endpoint":"`+s.url.String()+`/auth",
"token_endpoint":"`+s.url.String()+`/token",
"jwks_uri":"`+s.url.String()+`/jwks"
}`)
} else if r.URL.Path == "/token" {
// Token request
// Check body
if b, ok := s.body["token"]; ok {
if b != string(body) {
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
}
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"access_token":"123456789",
"id_token":"id_123456789"
}`)
} else if r.URL.Path == "/jwks" {
// Key request
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
} else {
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
}
}
// rsaKey is used in the OIDCServer tests to sign and verify requests
type rsaKey struct {
key *rsa.PrivateKey
alg jose.SignatureAlgorithm
jwkPub *jose.JSONWebKey
jwkPriv *jose.JSONWebKey
}
func newRSAKey() (*rsaKey, error) {
key, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
return nil, err
}
return &rsaKey{
key: key,
alg: jose.RS256,
jwkPub: &jose.JSONWebKey{
Key: key.Public(),
Algorithm: string(jose.RS256),
},
jwkPriv: &jose.JSONWebKey{
Key: key,
Algorithm: string(jose.RS256),
},
}, nil
}
func (k *rsaKey) publicJWK(t *testing.T) string {
b, err := k.jwkPub.MarshalJSON()
if err != nil {
t.Fatal(err)
}
return string(b)
}
// sign creates a JWS using the private key from the provided payload.
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: k.alg,
Key: k.key,
}, nil)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(payload)
if err != nil {
t.Fatal(err)
}
data, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return data
}

View File

@ -1,16 +1,61 @@
package provider
import (
"context"
// "net/url"
"golang.org/x/oauth2"
)
// Providers contains all the implemented providers
type Providers struct {
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
OIDC OIDC `group:"OIDC Provider" namespace:"oidc" env-namespace:"OIDC"`
}
type Token struct {
// Provider is used to authenticate users
type Provider interface {
Name() string
GetLoginURL(redirectURI, state string) string
ExchangeCode(redirectURI, code string) (string, error)
GetUser(token string) (User, error)
Setup() error
}
type token struct {
Token string `json:"access_token"`
}
// User is the authenticated user
type User struct {
Id string `json:"id"`
ID string `json:"id"`
Email string `json:"email"`
Verified bool `json:"verified_email"`
Hd string `json:"hd"`
}
// OAuthProvider is a provider using the oauth2 library
type OAuthProvider struct {
Config *oauth2.Config
ctx context.Context
}
// ConfigCopy returns a copy of the oauth2 config with the given redirectURI
// which ensures the underlying config is not modified
func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config {
config := *p.Config
config.RedirectURL = redirectURI
return config
}
// OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2
func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string) string {
config := p.ConfigCopy(redirectURI)
return config.AuthCodeURL(state)
}
// OAuthExchangeCode provides a base "ExchangeCode" for proiders using OAauth2
func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string) (*oauth2.Token, error) {
config := p.ConfigCopy(redirectURI)
return config.Exchange(p.ctx, code)
}

View File

@ -0,0 +1,48 @@
package provider
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
// Utilities
type OAuthServer struct {
t *testing.T
url *url.URL
body map[string]string // method -> body
}
func NewOAuthServer(t *testing.T, body map[string]string) (*httptest.Server, *url.URL) {
handler := &OAuthServer{t: t, body: body}
server := httptest.NewServer(handler)
handler.url, _ = url.Parse(server.URL)
return server, handler.url
}
func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
// fmt.Println("Got request:", r.URL, r.Method, string(body))
if r.Method == "POST" && r.URL.Path == "/token" {
if s.body["token"] != string(body) {
s.t.Fatal("Unexpected request body, expected", s.body["token"], "got", string(body))
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"access_token":"123456789"}`)
} else if r.Method == "GET" && r.URL.Path == "/userinfo" {
fmt.Fprint(w, `{
"id":"1",
"email":"example@example.com",
"verified_email":true,
"hd":"example.com"
}`)
} else {
s.t.Fatal("Unrecognised request: ", r.Method, r.URL, string(body))
}
}