2019-09-18 17:55:52 +01:00
|
|
|
package provider
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/rand"
|
|
|
|
"crypto/rsa"
|
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
|
|
|
"net/url"
|
|
|
|
"strconv"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
jose "gopkg.in/square/go-jose.v2"
|
|
|
|
)
|
|
|
|
|
|
|
|
// Tests
|
|
|
|
|
|
|
|
func TestOIDCName(t *testing.T) {
|
|
|
|
p := OIDC{}
|
|
|
|
assert.Equal(t, "oidc", p.Name())
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestOIDCSetup(t *testing.T) {
|
|
|
|
assert := assert.New(t)
|
|
|
|
p := OIDC{}
|
|
|
|
|
|
|
|
err := p.Setup()
|
|
|
|
if assert.Error(err) {
|
|
|
|
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestOIDCGetLoginURL(t *testing.T) {
|
|
|
|
assert := assert.New(t)
|
|
|
|
|
|
|
|
provider, server, serverURL, _ := setupOIDCTest(t, nil)
|
|
|
|
defer server.Close()
|
|
|
|
|
|
|
|
// Check url
|
|
|
|
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
|
|
|
|
assert.Nil(err)
|
|
|
|
assert.Equal(serverURL.Scheme, uri.Scheme)
|
|
|
|
assert.Equal(serverURL.Host, uri.Host)
|
|
|
|
assert.Equal("/auth", uri.Path)
|
|
|
|
|
|
|
|
// Check query string
|
|
|
|
qs := uri.Query()
|
|
|
|
expectedQs := url.Values{
|
|
|
|
"client_id": []string{"idtest"},
|
|
|
|
"redirect_uri": []string{"http://example.com/_oauth"},
|
|
|
|
"response_type": []string{"code"},
|
|
|
|
"scope": []string{"openid profile email"},
|
|
|
|
"state": []string{"state"},
|
|
|
|
}
|
|
|
|
assert.Equal(expectedQs, qs)
|
|
|
|
|
|
|
|
// Calling the method should not modify the underlying config
|
|
|
|
assert.Equal("", provider.Config.RedirectURL)
|
2020-06-11 12:24:51 +01:00
|
|
|
|
|
|
|
//
|
|
|
|
// Test with resource config option
|
|
|
|
//
|
|
|
|
provider.Resource = "resourcetest"
|
|
|
|
|
|
|
|
// Check url
|
|
|
|
uri, err = url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
|
|
|
|
assert.Nil(err)
|
|
|
|
assert.Equal(serverURL.Scheme, uri.Scheme)
|
|
|
|
assert.Equal(serverURL.Host, uri.Host)
|
|
|
|
assert.Equal("/auth", uri.Path)
|
|
|
|
|
|
|
|
// Check query string
|
|
|
|
qs = uri.Query()
|
|
|
|
expectedQs = url.Values{
|
|
|
|
"client_id": []string{"idtest"},
|
|
|
|
"redirect_uri": []string{"http://example.com/_oauth"},
|
|
|
|
"response_type": []string{"code"},
|
|
|
|
"scope": []string{"openid profile email"},
|
|
|
|
"state": []string{"state"},
|
|
|
|
"resource": []string{"resourcetest"},
|
|
|
|
}
|
|
|
|
assert.Equal(expectedQs, qs)
|
|
|
|
|
|
|
|
// Calling the method should not modify the underlying config
|
|
|
|
assert.Equal("", provider.Config.RedirectURL)
|
2019-09-18 17:55:52 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func TestOIDCExchangeCode(t *testing.T) {
|
|
|
|
assert := assert.New(t)
|
|
|
|
|
|
|
|
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
|
|
|
|
"token": {
|
|
|
|
"code": "code",
|
|
|
|
"grant_type": "authorization_code",
|
|
|
|
"redirect_uri": "http://example.com/_oauth",
|
|
|
|
},
|
|
|
|
})
|
|
|
|
defer server.Close()
|
|
|
|
|
|
|
|
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
|
|
|
|
assert.Nil(err)
|
|
|
|
assert.Equal("id_123456789", token)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestOIDCGetUser(t *testing.T) {
|
|
|
|
assert := assert.New(t)
|
|
|
|
|
|
|
|
provider, server, serverURL, key := setupOIDCTest(t, nil)
|
|
|
|
defer server.Close()
|
|
|
|
|
|
|
|
// Generate JWT
|
|
|
|
token := key.sign(t, []byte(`{
|
|
|
|
"iss": "`+serverURL.String()+`",
|
|
|
|
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
|
|
|
|
"aud": "idtest",
|
|
|
|
"sub": "1",
|
|
|
|
"email": "example@example.com",
|
|
|
|
"email_verified": true
|
|
|
|
}`))
|
|
|
|
|
|
|
|
// Get user
|
|
|
|
user, err := provider.GetUser(token)
|
|
|
|
assert.Nil(err)
|
|
|
|
assert.Equal("1", user.ID)
|
|
|
|
assert.Equal("example@example.com", user.Email)
|
|
|
|
assert.True(user.Verified)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Utils
|
|
|
|
|
|
|
|
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
|
|
|
|
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
|
|
|
|
// Generate key
|
|
|
|
key, err := newRSAKey()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
body := make(map[string]string)
|
|
|
|
if bodyValues != nil {
|
|
|
|
// URL encode bodyValues into body
|
|
|
|
for method, values := range bodyValues {
|
|
|
|
q := url.Values{}
|
|
|
|
for k, v := range values {
|
|
|
|
q.Set(k, v)
|
|
|
|
}
|
|
|
|
body[method] = q.Encode()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set up oidc server
|
|
|
|
server, serverURL := NewOIDCServer(t, key, body)
|
|
|
|
|
|
|
|
// Setup provider
|
|
|
|
p := OIDC{
|
|
|
|
ClientID: "idtest",
|
|
|
|
ClientSecret: "sectest",
|
|
|
|
IssuerURL: serverURL.String(),
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialise config/verifier
|
|
|
|
err = p.Setup()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &p, server, serverURL, key
|
|
|
|
}
|
|
|
|
|
|
|
|
// OIDCServer is used in the OIDC Tests to mock an OIDC server
|
|
|
|
type OIDCServer struct {
|
|
|
|
t *testing.T
|
|
|
|
url *url.URL
|
|
|
|
body map[string]string // method -> body
|
|
|
|
key *rsaKey
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
|
|
|
|
handler := &OIDCServer{t: t, key: key, body: body}
|
|
|
|
server := httptest.NewServer(handler)
|
|
|
|
handler.url, _ = url.Parse(server.URL)
|
|
|
|
return server, handler.url
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
|
body, _ := ioutil.ReadAll(r.Body)
|
|
|
|
|
|
|
|
if r.URL.Path == "/.well-known/openid-configuration" {
|
|
|
|
// Open id config
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
fmt.Fprint(w, `{
|
|
|
|
"issuer":"`+s.url.String()+`",
|
|
|
|
"authorization_endpoint":"`+s.url.String()+`/auth",
|
|
|
|
"token_endpoint":"`+s.url.String()+`/token",
|
|
|
|
"jwks_uri":"`+s.url.String()+`/jwks"
|
|
|
|
}`)
|
|
|
|
} else if r.URL.Path == "/token" {
|
|
|
|
// Token request
|
|
|
|
// Check body
|
|
|
|
if b, ok := s.body["token"]; ok {
|
|
|
|
if b != string(body) {
|
|
|
|
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
fmt.Fprint(w, `{
|
|
|
|
"access_token":"123456789",
|
|
|
|
"id_token":"id_123456789"
|
|
|
|
}`)
|
|
|
|
} else if r.URL.Path == "/jwks" {
|
|
|
|
// Key request
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
|
|
|
|
} else {
|
|
|
|
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// rsaKey is used in the OIDCServer tests to sign and verify requests
|
|
|
|
type rsaKey struct {
|
|
|
|
key *rsa.PrivateKey
|
|
|
|
alg jose.SignatureAlgorithm
|
|
|
|
jwkPub *jose.JSONWebKey
|
|
|
|
jwkPriv *jose.JSONWebKey
|
|
|
|
}
|
|
|
|
|
|
|
|
func newRSAKey() (*rsaKey, error) {
|
|
|
|
key, err := rsa.GenerateKey(rand.Reader, 1028)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &rsaKey{
|
|
|
|
key: key,
|
|
|
|
alg: jose.RS256,
|
|
|
|
jwkPub: &jose.JSONWebKey{
|
|
|
|
Key: key.Public(),
|
|
|
|
Algorithm: string(jose.RS256),
|
|
|
|
},
|
|
|
|
jwkPriv: &jose.JSONWebKey{
|
|
|
|
Key: key,
|
|
|
|
Algorithm: string(jose.RS256),
|
|
|
|
},
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (k *rsaKey) publicJWK(t *testing.T) string {
|
|
|
|
b, err := k.jwkPub.MarshalJSON()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return string(b)
|
|
|
|
}
|
|
|
|
|
|
|
|
// sign creates a JWS using the private key from the provided payload.
|
|
|
|
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
|
|
|
|
signer, err := jose.NewSigner(jose.SigningKey{
|
|
|
|
Algorithm: k.alg,
|
|
|
|
Key: k.key,
|
|
|
|
}, nil)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
jws, err := signer.Sign(payload)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
data, err := jws.CompactSerialize()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
return data
|
|
|
|
}
|