Multiple provider support + OIDC provider
This commit is contained in:
252
internal/provider/oidc_test.go
Normal file
252
internal/provider/oidc_test.go
Normal file
@ -0,0 +1,252 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// Tests
|
||||
|
||||
func TestOIDCName(t *testing.T) {
|
||||
p := OIDC{}
|
||||
assert.Equal(t, "oidc", p.Name())
|
||||
}
|
||||
|
||||
func TestOIDCSetup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := OIDC{}
|
||||
|
||||
err := p.Setup()
|
||||
if assert.Error(err) {
|
||||
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCGetLoginURL(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, serverURL, _ := setupOIDCTest(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
|
||||
assert.Nil(err)
|
||||
assert.Equal(serverURL.Scheme, uri.Scheme)
|
||||
assert.Equal(serverURL.Host, uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"openid profile email"},
|
||||
"state": []string{"state"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
|
||||
// Calling the method should not modify the underlying config
|
||||
assert.Equal("", provider.Config.RedirectURL)
|
||||
}
|
||||
|
||||
func TestOIDCExchangeCode(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
|
||||
"token": {
|
||||
"code": "code",
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": "http://example.com/_oauth",
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
|
||||
assert.Nil(err)
|
||||
assert.Equal("id_123456789", token)
|
||||
}
|
||||
|
||||
func TestOIDCGetUser(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, serverURL, key := setupOIDCTest(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Generate JWT
|
||||
token := key.sign(t, []byte(`{
|
||||
"iss": "`+serverURL.String()+`",
|
||||
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
|
||||
"aud": "idtest",
|
||||
"sub": "1",
|
||||
"email": "example@example.com",
|
||||
"email_verified": true
|
||||
}`))
|
||||
|
||||
// Get user
|
||||
user, err := provider.GetUser(token)
|
||||
assert.Nil(err)
|
||||
assert.Equal("1", user.ID)
|
||||
assert.Equal("example@example.com", user.Email)
|
||||
assert.True(user.Verified)
|
||||
}
|
||||
|
||||
// Utils
|
||||
|
||||
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
|
||||
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
|
||||
// Generate key
|
||||
key, err := newRSAKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := make(map[string]string)
|
||||
if bodyValues != nil {
|
||||
// URL encode bodyValues into body
|
||||
for method, values := range bodyValues {
|
||||
q := url.Values{}
|
||||
for k, v := range values {
|
||||
q.Set(k, v)
|
||||
}
|
||||
body[method] = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// Set up oidc server
|
||||
server, serverURL := NewOIDCServer(t, key, body)
|
||||
|
||||
// Setup provider
|
||||
p := OIDC{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
IssuerURL: serverURL.String(),
|
||||
}
|
||||
|
||||
// Initialise config/verifier
|
||||
err = p.Setup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &p, server, serverURL, key
|
||||
}
|
||||
|
||||
// OIDCServer is used in the OIDC Tests to mock an OIDC server
|
||||
type OIDCServer struct {
|
||||
t *testing.T
|
||||
url *url.URL
|
||||
body map[string]string // method -> body
|
||||
key *rsaKey
|
||||
}
|
||||
|
||||
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
|
||||
handler := &OIDCServer{t: t, key: key, body: body}
|
||||
server := httptest.NewServer(handler)
|
||||
handler.url, _ = url.Parse(server.URL)
|
||||
return server, handler.url
|
||||
}
|
||||
|
||||
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := ioutil.ReadAll(r.Body)
|
||||
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
// Open id config
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{
|
||||
"issuer":"`+s.url.String()+`",
|
||||
"authorization_endpoint":"`+s.url.String()+`/auth",
|
||||
"token_endpoint":"`+s.url.String()+`/token",
|
||||
"jwks_uri":"`+s.url.String()+`/jwks"
|
||||
}`)
|
||||
} else if r.URL.Path == "/token" {
|
||||
// Token request
|
||||
// Check body
|
||||
if b, ok := s.body["token"]; ok {
|
||||
if b != string(body) {
|
||||
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{
|
||||
"access_token":"123456789",
|
||||
"id_token":"id_123456789"
|
||||
}`)
|
||||
} else if r.URL.Path == "/jwks" {
|
||||
// Key request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
|
||||
} else {
|
||||
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
// rsaKey is used in the OIDCServer tests to sign and verify requests
|
||||
type rsaKey struct {
|
||||
key *rsa.PrivateKey
|
||||
alg jose.SignatureAlgorithm
|
||||
jwkPub *jose.JSONWebKey
|
||||
jwkPriv *jose.JSONWebKey
|
||||
}
|
||||
|
||||
func newRSAKey() (*rsaKey, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1028)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rsaKey{
|
||||
key: key,
|
||||
alg: jose.RS256,
|
||||
jwkPub: &jose.JSONWebKey{
|
||||
Key: key.Public(),
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
jwkPriv: &jose.JSONWebKey{
|
||||
Key: key,
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *rsaKey) publicJWK(t *testing.T) string {
|
||||
b, err := k.jwkPub.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// sign creates a JWS using the private key from the provided payload.
|
||||
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: k.alg,
|
||||
Key: k.key,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jws, err := signer.Sign(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := jws.CompactSerialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
Reference in New Issue
Block a user