Thom Seddon be2b4ba9f4
Remove unused user fields (#141)
These aren't actually used anywhere and can result in a parse error
if the ID field isn't a string
2020-06-29 21:01:59 +01:00

278 lines
6.4 KiB
Go

package provider
import (
"crypto/rand"
"crypto/rsa"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
jose "gopkg.in/square/go-jose.v2"
)
// Tests
func TestOIDCName(t *testing.T) {
p := OIDC{}
assert.Equal(t, "oidc", p.Name())
}
func TestOIDCSetup(t *testing.T) {
assert := assert.New(t)
p := OIDC{}
err := p.Setup()
if assert.Error(err) {
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
}
}
func TestOIDCGetLoginURL(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, _ := setupOIDCTest(t, nil)
defer server.Close()
// Check url
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal(serverURL.Scheme, uri.Scheme)
assert.Equal(serverURL.Host, uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs := uri.Query()
expectedQs := url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"openid profile email"},
"state": []string{"state"},
}
assert.Equal(expectedQs, qs)
// Calling the method should not modify the underlying config
assert.Equal("", provider.Config.RedirectURL)
//
// Test with resource config option
//
provider.Resource = "resourcetest"
// Check url
uri, err = url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
assert.Nil(err)
assert.Equal(serverURL.Scheme, uri.Scheme)
assert.Equal(serverURL.Host, uri.Host)
assert.Equal("/auth", uri.Path)
// Check query string
qs = uri.Query()
expectedQs = url.Values{
"client_id": []string{"idtest"},
"redirect_uri": []string{"http://example.com/_oauth"},
"response_type": []string{"code"},
"scope": []string{"openid profile email"},
"state": []string{"state"},
"resource": []string{"resourcetest"},
}
assert.Equal(expectedQs, qs)
// Calling the method should not modify the underlying config
assert.Equal("", provider.Config.RedirectURL)
}
func TestOIDCExchangeCode(t *testing.T) {
assert := assert.New(t)
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
"token": {
"code": "code",
"grant_type": "authorization_code",
"redirect_uri": "http://example.com/_oauth",
},
})
defer server.Close()
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
assert.Nil(err)
assert.Equal("id_123456789", token)
}
func TestOIDCGetUser(t *testing.T) {
assert := assert.New(t)
provider, server, serverURL, key := setupOIDCTest(t, nil)
defer server.Close()
// Generate JWT
token := key.sign(t, []byte(`{
"iss": "`+serverURL.String()+`",
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
"aud": "idtest",
"sub": "1",
"email": "example@example.com",
"email_verified": true
}`))
// Get user
user, err := provider.GetUser(token)
assert.Nil(err)
assert.Equal("example@example.com", user.Email)
}
// Utils
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
// Generate key
key, err := newRSAKey()
if err != nil {
t.Fatal(err)
}
body := make(map[string]string)
if bodyValues != nil {
// URL encode bodyValues into body
for method, values := range bodyValues {
q := url.Values{}
for k, v := range values {
q.Set(k, v)
}
body[method] = q.Encode()
}
}
// Set up oidc server
server, serverURL := NewOIDCServer(t, key, body)
// Setup provider
p := OIDC{
ClientID: "idtest",
ClientSecret: "sectest",
IssuerURL: serverURL.String(),
}
// Initialise config/verifier
err = p.Setup()
if err != nil {
t.Fatal(err)
}
return &p, server, serverURL, key
}
// OIDCServer is used in the OIDC Tests to mock an OIDC server
type OIDCServer struct {
t *testing.T
url *url.URL
body map[string]string // method -> body
key *rsaKey
}
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
handler := &OIDCServer{t: t, key: key, body: body}
server := httptest.NewServer(handler)
handler.url, _ = url.Parse(server.URL)
return server, handler.url
}
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
if r.URL.Path == "/.well-known/openid-configuration" {
// Open id config
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"issuer":"`+s.url.String()+`",
"authorization_endpoint":"`+s.url.String()+`/auth",
"token_endpoint":"`+s.url.String()+`/token",
"jwks_uri":"`+s.url.String()+`/jwks"
}`)
} else if r.URL.Path == "/token" {
// Token request
// Check body
if b, ok := s.body["token"]; ok {
if b != string(body) {
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
}
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{
"access_token":"123456789",
"id_token":"id_123456789"
}`)
} else if r.URL.Path == "/jwks" {
// Key request
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
} else {
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
}
}
// rsaKey is used in the OIDCServer tests to sign and verify requests
type rsaKey struct {
key *rsa.PrivateKey
alg jose.SignatureAlgorithm
jwkPub *jose.JSONWebKey
jwkPriv *jose.JSONWebKey
}
func newRSAKey() (*rsaKey, error) {
key, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
return nil, err
}
return &rsaKey{
key: key,
alg: jose.RS256,
jwkPub: &jose.JSONWebKey{
Key: key.Public(),
Algorithm: string(jose.RS256),
},
jwkPriv: &jose.JSONWebKey{
Key: key,
Algorithm: string(jose.RS256),
},
}, nil
}
func (k *rsaKey) publicJWK(t *testing.T) string {
b, err := k.jwkPub.MarshalJSON()
if err != nil {
t.Fatal(err)
}
return string(b)
}
// sign creates a JWS using the private key from the provided payload.
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: k.alg,
Key: k.key,
}, nil)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(payload)
if err != nil {
t.Fatal(err)
}
data, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return data
}