Multiple provider support + OIDC provider
This commit is contained in:
@ -2,13 +2,15 @@ package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Google provider
|
||||
type Google struct {
|
||||
ClientId string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
|
||||
Scope string
|
||||
Prompt string `long:"prompt" env:"PROMPT" description:"Space separated list of OpenID prompt options"`
|
||||
@ -18,15 +20,48 @@ type Google struct {
|
||||
UserURL *url.URL
|
||||
}
|
||||
|
||||
func (g *Google) GetLoginURL(redirectUri, state string) string {
|
||||
// Name returns the name of the provider
|
||||
func (g *Google) Name() string {
|
||||
return "google"
|
||||
}
|
||||
|
||||
// Setup performs validation and setup
|
||||
func (g *Google) Setup() error {
|
||||
if g.ClientID == "" || g.ClientSecret == "" {
|
||||
return errors.New("providers.google.client-id, providers.google.client-secret must be set")
|
||||
}
|
||||
|
||||
// Set static values
|
||||
g.Scope = "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email"
|
||||
g.LoginURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
}
|
||||
g.TokenURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
}
|
||||
g.UserURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginURL provides the login url for the given redirect uri and state
|
||||
func (g *Google) GetLoginURL(redirectURI, state string) string {
|
||||
q := url.Values{}
|
||||
q.Set("client_id", g.ClientId)
|
||||
q.Set("client_id", g.ClientID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", g.Scope)
|
||||
if g.Prompt != "" {
|
||||
q.Set("prompt", g.Prompt)
|
||||
}
|
||||
q.Set("redirect_uri", redirectUri)
|
||||
q.Set("redirect_uri", redirectURI)
|
||||
q.Set("state", state)
|
||||
|
||||
var u url.URL
|
||||
@ -36,12 +71,13 @@ func (g *Google) GetLoginURL(redirectUri, state string) string {
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
|
||||
// ExchangeCode exchanges the given redirect uri and code for a token
|
||||
func (g *Google) ExchangeCode(redirectURI, code string) (string, error) {
|
||||
form := url.Values{}
|
||||
form.Set("client_id", g.ClientId)
|
||||
form.Set("client_id", g.ClientID)
|
||||
form.Set("client_secret", g.ClientSecret)
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("redirect_uri", redirectUri)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("code", code)
|
||||
|
||||
res, err := http.PostForm(g.TokenURL.String(), form)
|
||||
@ -49,13 +85,14 @@ func (g *Google) ExchangeCode(redirectUri, code string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var token Token
|
||||
var token token
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
|
||||
return token.Token, err
|
||||
}
|
||||
|
||||
// GetUser uses the given token and returns a complete provider.User object
|
||||
func (g *Google) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
|
151
internal/provider/google_test.go
Normal file
151
internal/provider/google_test.go
Normal file
@ -0,0 +1,151 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Tests
|
||||
|
||||
func TestGoogleName(t *testing.T) {
|
||||
p := Google{}
|
||||
assert.Equal(t, "google", p.Name())
|
||||
}
|
||||
|
||||
func TestGoogleSetup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := Google{}
|
||||
|
||||
// Check validation
|
||||
err := p.Setup()
|
||||
if assert.Error(err) {
|
||||
assert.Equal("providers.google.client-id, providers.google.client-secret must be set", err.Error())
|
||||
}
|
||||
|
||||
// Check setup
|
||||
p = Google{
|
||||
ClientID: "id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
err = p.Setup()
|
||||
assert.Nil(err)
|
||||
assert.Equal("https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", p.Scope)
|
||||
assert.Equal("", p.Prompt)
|
||||
|
||||
assert.Equal(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "accounts.google.com",
|
||||
Path: "/o/oauth2/auth",
|
||||
}, p.LoginURL)
|
||||
|
||||
assert.Equal(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v3/token",
|
||||
}, p.TokenURL)
|
||||
|
||||
assert.Equal(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.googleapis.com",
|
||||
Path: "/oauth2/v2/userinfo",
|
||||
}, p.UserURL)
|
||||
}
|
||||
|
||||
func TestGoogleGetLoginURL(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := Google{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "google.com",
|
||||
Path: "/auth",
|
||||
},
|
||||
}
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("google.com", uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"prompt": []string{"consent select_account"},
|
||||
"state": []string{"state"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
}
|
||||
|
||||
func TestGoogleExchangeCode(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Setup server
|
||||
expected := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"client_secret": []string{"sectest"},
|
||||
"code": []string{"code"},
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
}
|
||||
server, serverURL := NewOAuthServer(t, map[string]string{
|
||||
"token": expected.Encode(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Setup provider
|
||||
p := Google{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
TokenURL: &url.URL{
|
||||
Scheme: serverURL.Scheme,
|
||||
Host: serverURL.Host,
|
||||
Path: "/token",
|
||||
},
|
||||
}
|
||||
|
||||
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
|
||||
assert.Nil(err)
|
||||
assert.Equal("123456789", token)
|
||||
}
|
||||
|
||||
func TestGoogleGetUser(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Setup server
|
||||
server, serverURL := NewOAuthServer(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Setup provider
|
||||
p := Google{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
Scope: "scopetest",
|
||||
Prompt: "consent select_account",
|
||||
UserURL: &url.URL{
|
||||
Scheme: serverURL.Scheme,
|
||||
Host: serverURL.Host,
|
||||
Path: "/userinfo",
|
||||
},
|
||||
}
|
||||
|
||||
user, err := p.GetUser("123456789")
|
||||
assert.Nil(err)
|
||||
|
||||
assert.Equal("1", user.ID)
|
||||
assert.Equal("example@example.com", user.Email)
|
||||
assert.True(user.Verified)
|
||||
assert.Equal("example.com", user.Hd)
|
||||
}
|
108
internal/provider/oidc.go
Normal file
108
internal/provider/oidc.go
Normal file
@ -0,0 +1,108 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OIDC provider
|
||||
type OIDC struct {
|
||||
OAuthProvider
|
||||
|
||||
IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"`
|
||||
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
|
||||
|
||||
provider *oidc.Provider
|
||||
verifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
// Name returns the name of the provider
|
||||
func (o *OIDC) Name() string {
|
||||
return "oidc"
|
||||
}
|
||||
|
||||
// Setup performs validation and setup
|
||||
func (o *OIDC) Setup() error {
|
||||
// Check parms
|
||||
if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" {
|
||||
return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set")
|
||||
}
|
||||
|
||||
var err error
|
||||
o.ctx = context.Background()
|
||||
|
||||
// Try to initiate provider
|
||||
o.provider, err = oidc.NewProvider(o.ctx, o.IssuerURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create oauth2 config
|
||||
o.Config = &oauth2.Config{
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: o.ClientSecret,
|
||||
Endpoint: o.provider.Endpoint(),
|
||||
|
||||
// "openid" is a required scope for OpenID Connect flows.
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
// Create OIDC verifier
|
||||
o.verifier = o.provider.Verifier(&oidc.Config{
|
||||
ClientID: o.ClientID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginURL provides the login url for the given redirect uri and state
|
||||
func (o *OIDC) GetLoginURL(redirectURI, state string) string {
|
||||
return o.OAuthGetLoginURL(redirectURI, state)
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges the given redirect uri and code for a token
|
||||
func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
|
||||
token, err := o.OAuthExchangeCode(redirectURI, code)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Extract ID token
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing id_token")
|
||||
}
|
||||
|
||||
return rawIDToken, nil
|
||||
}
|
||||
|
||||
// GetUser uses the given token and returns a complete provider.User object
|
||||
func (o *OIDC) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
// Parse & Verify ID Token
|
||||
idToken, err := o.verifier.Verify(o.ctx, token)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
// Extract custom claims
|
||||
var claims struct {
|
||||
ID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"email_verified"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
user.ID = claims.ID
|
||||
user.Email = claims.Email
|
||||
user.Verified = claims.Verified
|
||||
|
||||
return user, nil
|
||||
}
|
252
internal/provider/oidc_test.go
Normal file
252
internal/provider/oidc_test.go
Normal file
@ -0,0 +1,252 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// Tests
|
||||
|
||||
func TestOIDCName(t *testing.T) {
|
||||
p := OIDC{}
|
||||
assert.Equal(t, "oidc", p.Name())
|
||||
}
|
||||
|
||||
func TestOIDCSetup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := OIDC{}
|
||||
|
||||
err := p.Setup()
|
||||
if assert.Error(err) {
|
||||
assert.Equal("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCGetLoginURL(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, serverURL, _ := setupOIDCTest(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(provider.GetLoginURL("http://example.com/_oauth", "state"))
|
||||
assert.Nil(err)
|
||||
assert.Equal(serverURL.Scheme, uri.Scheme)
|
||||
assert.Equal(serverURL.Host, uri.Host)
|
||||
assert.Equal("/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"openid profile email"},
|
||||
"state": []string{"state"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
|
||||
// Calling the method should not modify the underlying config
|
||||
assert.Equal("", provider.Config.RedirectURL)
|
||||
}
|
||||
|
||||
func TestOIDCExchangeCode(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, _, _ := setupOIDCTest(t, map[string]map[string]string{
|
||||
"token": {
|
||||
"code": "code",
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": "http://example.com/_oauth",
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
token, err := provider.ExchangeCode("http://example.com/_oauth", "code")
|
||||
assert.Nil(err)
|
||||
assert.Equal("id_123456789", token)
|
||||
}
|
||||
|
||||
func TestOIDCGetUser(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
provider, server, serverURL, key := setupOIDCTest(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Generate JWT
|
||||
token := key.sign(t, []byte(`{
|
||||
"iss": "`+serverURL.String()+`",
|
||||
"exp":`+strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10)+`,
|
||||
"aud": "idtest",
|
||||
"sub": "1",
|
||||
"email": "example@example.com",
|
||||
"email_verified": true
|
||||
}`))
|
||||
|
||||
// Get user
|
||||
user, err := provider.GetUser(token)
|
||||
assert.Nil(err)
|
||||
assert.Equal("1", user.ID)
|
||||
assert.Equal("example@example.com", user.Email)
|
||||
assert.True(user.Verified)
|
||||
}
|
||||
|
||||
// Utils
|
||||
|
||||
// setOIDCTest creates a key, OIDCServer and initilises an OIDC provider
|
||||
func setupOIDCTest(t *testing.T, bodyValues map[string]map[string]string) (*OIDC, *httptest.Server, *url.URL, *rsaKey) {
|
||||
// Generate key
|
||||
key, err := newRSAKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body := make(map[string]string)
|
||||
if bodyValues != nil {
|
||||
// URL encode bodyValues into body
|
||||
for method, values := range bodyValues {
|
||||
q := url.Values{}
|
||||
for k, v := range values {
|
||||
q.Set(k, v)
|
||||
}
|
||||
body[method] = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// Set up oidc server
|
||||
server, serverURL := NewOIDCServer(t, key, body)
|
||||
|
||||
// Setup provider
|
||||
p := OIDC{
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
IssuerURL: serverURL.String(),
|
||||
}
|
||||
|
||||
// Initialise config/verifier
|
||||
err = p.Setup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return &p, server, serverURL, key
|
||||
}
|
||||
|
||||
// OIDCServer is used in the OIDC Tests to mock an OIDC server
|
||||
type OIDCServer struct {
|
||||
t *testing.T
|
||||
url *url.URL
|
||||
body map[string]string // method -> body
|
||||
key *rsaKey
|
||||
}
|
||||
|
||||
func NewOIDCServer(t *testing.T, key *rsaKey, body map[string]string) (*httptest.Server, *url.URL) {
|
||||
handler := &OIDCServer{t: t, key: key, body: body}
|
||||
server := httptest.NewServer(handler)
|
||||
handler.url, _ = url.Parse(server.URL)
|
||||
return server, handler.url
|
||||
}
|
||||
|
||||
func (s *OIDCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := ioutil.ReadAll(r.Body)
|
||||
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
// Open id config
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{
|
||||
"issuer":"`+s.url.String()+`",
|
||||
"authorization_endpoint":"`+s.url.String()+`/auth",
|
||||
"token_endpoint":"`+s.url.String()+`/token",
|
||||
"jwks_uri":"`+s.url.String()+`/jwks"
|
||||
}`)
|
||||
} else if r.URL.Path == "/token" {
|
||||
// Token request
|
||||
// Check body
|
||||
if b, ok := s.body["token"]; ok {
|
||||
if b != string(body) {
|
||||
s.t.Fatal("Unexpected request body, expected", b, "got", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{
|
||||
"access_token":"123456789",
|
||||
"id_token":"id_123456789"
|
||||
}`)
|
||||
} else if r.URL.Path == "/jwks" {
|
||||
// Key request
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"keys":[`+s.key.publicJWK(s.t)+`]}`)
|
||||
} else {
|
||||
s.t.Fatal("Unrecognised request: ", r.URL, string(body))
|
||||
}
|
||||
}
|
||||
|
||||
// rsaKey is used in the OIDCServer tests to sign and verify requests
|
||||
type rsaKey struct {
|
||||
key *rsa.PrivateKey
|
||||
alg jose.SignatureAlgorithm
|
||||
jwkPub *jose.JSONWebKey
|
||||
jwkPriv *jose.JSONWebKey
|
||||
}
|
||||
|
||||
func newRSAKey() (*rsaKey, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1028)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &rsaKey{
|
||||
key: key,
|
||||
alg: jose.RS256,
|
||||
jwkPub: &jose.JSONWebKey{
|
||||
Key: key.Public(),
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
jwkPriv: &jose.JSONWebKey{
|
||||
Key: key,
|
||||
Algorithm: string(jose.RS256),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (k *rsaKey) publicJWK(t *testing.T) string {
|
||||
b, err := k.jwkPub.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// sign creates a JWS using the private key from the provided payload.
|
||||
func (k *rsaKey) sign(t *testing.T, payload []byte) string {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: k.alg,
|
||||
Key: k.key,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jws, err := signer.Sign(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data, err := jws.CompactSerialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
@ -1,16 +1,61 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
// "net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Providers contains all the implemented providers
|
||||
type Providers struct {
|
||||
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
|
||||
OIDC OIDC `group:"OIDC Provider" namespace:"oidc" env-namespace:"OIDC"`
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
// Provider is used to authenticate users
|
||||
type Provider interface {
|
||||
Name() string
|
||||
GetLoginURL(redirectURI, state string) string
|
||||
ExchangeCode(redirectURI, code string) (string, error)
|
||||
GetUser(token string) (User, error)
|
||||
Setup() error
|
||||
}
|
||||
|
||||
type token struct {
|
||||
Token string `json:"access_token"`
|
||||
}
|
||||
|
||||
// User is the authenticated user
|
||||
type User struct {
|
||||
Id string `json:"id"`
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified_email"`
|
||||
Hd string `json:"hd"`
|
||||
}
|
||||
|
||||
// OAuthProvider is a provider using the oauth2 library
|
||||
type OAuthProvider struct {
|
||||
Config *oauth2.Config
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// ConfigCopy returns a copy of the oauth2 config with the given redirectURI
|
||||
// which ensures the underlying config is not modified
|
||||
func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config {
|
||||
config := *p.Config
|
||||
config.RedirectURL = redirectURI
|
||||
return config
|
||||
}
|
||||
|
||||
// OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2
|
||||
func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string) string {
|
||||
config := p.ConfigCopy(redirectURI)
|
||||
return config.AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// OAuthExchangeCode provides a base "ExchangeCode" for proiders using OAauth2
|
||||
func (p *OAuthProvider) OAuthExchangeCode(redirectURI, code string) (*oauth2.Token, error) {
|
||||
config := p.ConfigCopy(redirectURI)
|
||||
return config.Exchange(p.ctx, code)
|
||||
}
|
||||
|
48
internal/provider/providers_test.go
Normal file
48
internal/provider/providers_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Utilities
|
||||
|
||||
type OAuthServer struct {
|
||||
t *testing.T
|
||||
url *url.URL
|
||||
body map[string]string // method -> body
|
||||
}
|
||||
|
||||
func NewOAuthServer(t *testing.T, body map[string]string) (*httptest.Server, *url.URL) {
|
||||
handler := &OAuthServer{t: t, body: body}
|
||||
server := httptest.NewServer(handler)
|
||||
handler.url, _ = url.Parse(server.URL)
|
||||
return server, handler.url
|
||||
}
|
||||
|
||||
func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := ioutil.ReadAll(r.Body)
|
||||
// fmt.Println("Got request:", r.URL, r.Method, string(body))
|
||||
|
||||
if r.Method == "POST" && r.URL.Path == "/token" {
|
||||
if s.body["token"] != string(body) {
|
||||
s.t.Fatal("Unexpected request body, expected", s.body["token"], "got", string(body))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintf(w, `{"access_token":"123456789"}`)
|
||||
} else if r.Method == "GET" && r.URL.Path == "/userinfo" {
|
||||
fmt.Fprint(w, `{
|
||||
"id":"1",
|
||||
"email":"example@example.com",
|
||||
"verified_email":true,
|
||||
"hd":"example.com"
|
||||
}`)
|
||||
} else {
|
||||
s.t.Fatal("Unrecognised request: ", r.Method, r.URL, string(body))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user