Add Generic OAuth Provider (#138)
This commit is contained in:
@ -316,6 +316,11 @@ func TestMakeState(t *testing.T) {
|
||||
p2 := provider.OIDC{}
|
||||
state = MakeState(r, &p2, "nonce")
|
||||
assert.Equal("nonce:oidc:http://example.com/hello", state)
|
||||
|
||||
// Test with Generic OAuth
|
||||
p3 := provider.GenericOAuth{}
|
||||
state = MakeState(r, &p3, "nonce")
|
||||
assert.Equal("nonce:generic-oauth:http://example.com/hello", state)
|
||||
}
|
||||
|
||||
func TestAuthNonce(t *testing.T) {
|
||||
|
@ -31,7 +31,7 @@ type Config struct {
|
||||
CookieName string `long:"cookie-name" env:"COOKIE_NAME" default:"_forward_auth" description:"Cookie Name"`
|
||||
CSRFCookieName string `long:"csrf-cookie-name" env:"CSRF_COOKIE_NAME" default:"_forward_auth_csrf" description:"CSRF Cookie Name"`
|
||||
DefaultAction string `long:"default-action" env:"DEFAULT_ACTION" default:"auth" choice:"auth" choice:"allow" description:"Default action"`
|
||||
DefaultProvider string `long:"default-provider" env:"DEFAULT_PROVIDER" default:"google" choice:"google" choice:"oidc" description:"Default provider"`
|
||||
DefaultProvider string `long:"default-provider" env:"DEFAULT_PROVIDER" default:"google" choice:"google" choice:"oidc" choice:"generic-oauth" description:"Default provider"`
|
||||
Domains CommaSeparatedList `long:"domain" env:"DOMAIN" env-delim:"," description:"Only allow given email domains, can be set multiple times"`
|
||||
LifetimeString int `long:"lifetime" env:"LIFETIME" default:"43200" description:"Lifetime in seconds"`
|
||||
LogoutRedirect string `long:"logout-redirect" env:"LOGOUT_REDIRECT" description:"URL to redirect to following logout"`
|
||||
@ -275,6 +275,8 @@ func (c *Config) GetProvider(name string) (provider.Provider, error) {
|
||||
return &c.Providers.Google, nil
|
||||
case "oidc":
|
||||
return &c.Providers.OIDC, nil
|
||||
case "generic-oauth":
|
||||
return &c.Providers.GenericOAuth, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown provider: %s", name)
|
||||
|
@ -366,6 +366,11 @@ func TestConfigGetProvider(t *testing.T) {
|
||||
assert.Nil(err)
|
||||
assert.Equal(&c.Providers.OIDC, p)
|
||||
|
||||
// Should be able to get "generic-oauth" provider
|
||||
p, err = c.GetProvider("generic-oauth")
|
||||
assert.Nil(err)
|
||||
assert.Equal(&c.Providers.GenericOAuth, p)
|
||||
|
||||
// Should catch unknown provider
|
||||
p, err = c.GetProvider("bad")
|
||||
if assert.Error(err) {
|
||||
|
96
internal/provider/generic_oauth.go
Normal file
96
internal/provider/generic_oauth.go
Normal file
@ -0,0 +1,96 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// GenericOAuth provider
|
||||
type GenericOAuth struct {
|
||||
AuthURL string `long:"auth-url" env:"AUTH_URL" description:"Auth/Login URL"`
|
||||
TokenURL string `long:"token-url" env:"TOKEN_URL" description:"Token URL"`
|
||||
UserURL string `long:"user-url" env:"USER_URL" description:"URL used to retrieve user info"`
|
||||
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
|
||||
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
|
||||
Scopes []string `long:"scope" env:"SCOPE" env-delim:"," default:"profile" default:"email" description:"Scopes"`
|
||||
TokenStyle string `long:"token-style" env:"TOKEN_STYLE" default:"header" choice:"header" choice:"query" description:"How token is presented when querying the User URL"`
|
||||
|
||||
OAuthProvider
|
||||
}
|
||||
|
||||
// Name returns the name of the provider
|
||||
func (o *GenericOAuth) Name() string {
|
||||
return "generic-oauth"
|
||||
}
|
||||
|
||||
// Setup performs validation and setup
|
||||
func (o *GenericOAuth) Setup() error {
|
||||
// Check parmas
|
||||
if o.AuthURL == "" || o.TokenURL == "" || o.UserURL == "" || o.ClientID == "" || o.ClientSecret == "" {
|
||||
return errors.New("providers.generic-oauth.auth-url, providers.generic-oauth.token-url, providers.generic-oauth.user-url, providers.generic-oauth.client-id, providers.generic-oauth.client-secret must be set")
|
||||
}
|
||||
|
||||
// Create oauth2 config
|
||||
o.Config = &oauth2.Config{
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: o.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: o.AuthURL,
|
||||
TokenURL: o.TokenURL,
|
||||
},
|
||||
Scopes: o.Scopes,
|
||||
}
|
||||
|
||||
o.ctx = context.Background()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginURL provides the login url for the given redirect uri and state
|
||||
func (o *GenericOAuth) GetLoginURL(redirectURI, state string) string {
|
||||
return o.OAuthGetLoginURL(redirectURI, state)
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges the given redirect uri and code for a token
|
||||
func (o *GenericOAuth) ExchangeCode(redirectURI, code string) (string, error) {
|
||||
token, err := o.OAuthExchangeCode(redirectURI, code)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// GetUser uses the given token and returns a complete provider.User object
|
||||
func (o *GenericOAuth) GetUser(token string) (User, error) {
|
||||
var user User
|
||||
|
||||
req, err := http.NewRequest("GET", o.UserURL, nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
if o.TokenStyle == "header" {
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
} else if o.TokenStyle == "query" {
|
||||
q := req.URL.Query()
|
||||
q.Add("access_token", token)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(&user)
|
||||
|
||||
return user, err
|
||||
}
|
140
internal/provider/generic_oauth_test.go
Normal file
140
internal/provider/generic_oauth_test.go
Normal file
@ -0,0 +1,140 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Tests
|
||||
|
||||
func TestGenericOAuthName(t *testing.T) {
|
||||
p := GenericOAuth{}
|
||||
assert.Equal(t, "generic-oauth", p.Name())
|
||||
}
|
||||
|
||||
func TestGenericOAuthSetup(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := GenericOAuth{}
|
||||
|
||||
// Check validation
|
||||
err := p.Setup()
|
||||
if assert.Error(err) {
|
||||
assert.Equal("providers.generic-oauth.auth-url, providers.generic-oauth.token-url, providers.generic-oauth.user-url, providers.generic-oauth.client-id, providers.generic-oauth.client-secret must be set", err.Error())
|
||||
}
|
||||
|
||||
// Check setup
|
||||
p = GenericOAuth{
|
||||
AuthURL: "https://provider.com/oauth2/auth",
|
||||
TokenURL: "https://provider.com/oauth2/token",
|
||||
UserURL: "https://provider.com/oauth2/user",
|
||||
ClientID: "id",
|
||||
ClientSecret: "secret",
|
||||
}
|
||||
err = p.Setup()
|
||||
assert.Nil(err)
|
||||
}
|
||||
|
||||
func TestGenericOAuthGetLoginURL(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
p := GenericOAuth{
|
||||
AuthURL: "https://provider.com/oauth2/auth",
|
||||
TokenURL: "https://provider.com/oauth2/token",
|
||||
UserURL: "https://provider.com/oauth2/user",
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "secret",
|
||||
Scopes: []string{"scopetest"},
|
||||
}
|
||||
err := p.Setup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check url
|
||||
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
|
||||
assert.Nil(err)
|
||||
assert.Equal("https", uri.Scheme)
|
||||
assert.Equal("provider.com", uri.Host)
|
||||
assert.Equal("/oauth2/auth", uri.Path)
|
||||
|
||||
// Check query string
|
||||
qs := uri.Query()
|
||||
expectedQs := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"scopetest"},
|
||||
"state": []string{"state"},
|
||||
}
|
||||
assert.Equal(expectedQs, qs)
|
||||
}
|
||||
|
||||
func TestGenericOAuthExchangeCode(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Setup server
|
||||
expected := url.Values{
|
||||
"client_id": []string{"idtest"},
|
||||
"client_secret": []string{"sectest"},
|
||||
"code": []string{"code"},
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"redirect_uri": []string{"http://example.com/_oauth"},
|
||||
}
|
||||
server, serverURL := NewOAuthServer(t, map[string]string{
|
||||
"token": expected.Encode(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Setup provider
|
||||
p := GenericOAuth{
|
||||
AuthURL: "https://provider.com/oauth2/auth",
|
||||
TokenURL: serverURL.String() + "/token",
|
||||
UserURL: "https://provider.com/oauth2/user",
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
}
|
||||
err := p.Setup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We force AuthStyleInParams to prevent the test failure when the
|
||||
// AuthStyleInHeader is attempted
|
||||
p.Config.Endpoint.AuthStyle = oauth2.AuthStyleInParams
|
||||
|
||||
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
|
||||
assert.Nil(err)
|
||||
assert.Equal("123456789", token)
|
||||
}
|
||||
|
||||
func TestGenericOAuthGetUser(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Setup server
|
||||
server, serverURL := NewOAuthServer(t, nil)
|
||||
defer server.Close()
|
||||
|
||||
// Setup provider
|
||||
p := GenericOAuth{
|
||||
AuthURL: "https://provider.com/oauth2/auth",
|
||||
TokenURL: "https://provider.com/oauth2/token",
|
||||
UserURL: serverURL.String() + "/userinfo",
|
||||
ClientID: "idtest",
|
||||
ClientSecret: "sectest",
|
||||
}
|
||||
err := p.Setup()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We force AuthStyleInParams to prevent the test failure when the
|
||||
// AuthStyleInHeader is attempted
|
||||
p.Config.Endpoint.AuthStyle = oauth2.AuthStyleInParams
|
||||
|
||||
user, err := p.GetUser("123456789")
|
||||
assert.Nil(err)
|
||||
|
||||
assert.Equal("example@example.com", user.Email)
|
||||
}
|
@ -9,8 +9,9 @@ import (
|
||||
|
||||
// Providers contains all the implemented providers
|
||||
type Providers struct {
|
||||
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
|
||||
OIDC OIDC `group:"OIDC Provider" namespace:"oidc" env-namespace:"OIDC"`
|
||||
Google Google `group:"Google Provider" namespace:"google" env-namespace:"GOOGLE"`
|
||||
OIDC OIDC `group:"OIDC Provider" namespace:"oidc" env-namespace:"OIDC"`
|
||||
GenericOAuth GenericOAuth `group:"Generic OAuth2 Provider" namespace:"generic-oauth" env-namespace:"GENERIC_OAUTH"`
|
||||
}
|
||||
|
||||
// Provider is used to authenticate users
|
||||
|
Reference in New Issue
Block a user