diff --git a/internal/server_test.go b/internal/server_test.go index 71449d7..42919bc 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -121,19 +121,19 @@ func TestServerAuthCallback(t *testing.T) { assert := assert.New(t) config = newDefaultConfig() - // Setup token server - tokenServerHandler := &TokenServerHandler{} - tokenServer := httptest.NewServer(tokenServerHandler) - defer tokenServer.Close() - tokenUrl, _ := url.Parse(tokenServer.URL) - config.Providers.Google.TokenURL = tokenUrl - - // Setup user server - userServerHandler := &UserServerHandler{} - userServer := httptest.NewServer(userServerHandler) - defer userServer.Close() - userUrl, _ := url.Parse(userServer.URL) - config.Providers.Google.UserURL = userUrl + // Setup OAuth server + server, serverURL := NewOAuthServer(t) + defer server.Close() + config.Providers.Google.TokenURL = &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/token", + } + config.Providers.Google.UserURL = &url.URL{ + Scheme: serverURL.Scheme, + Host: serverURL.Host, + Path: "/userinfo", + } // Should pass auth response request to callback req := newDefaultHttpRequest("/_oauth") @@ -342,21 +342,30 @@ func TestServerRouteQuery(t *testing.T) { * Utilities */ -type TokenServerHandler struct{} - -func (t *TokenServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, `{"access_token":"123456789"}`) +type OAuthServer struct { + t *testing.T } -type UserServerHandler struct{} +func (s *OAuthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/token" { + fmt.Fprintf(w, `{"access_token":"123456789"}`) + } else if r.URL.Path == "/userinfo" { + fmt.Fprint(w, `{ + "id":"1", + "email":"example@example.com", + "verified_email":true, + "hd":"example.com" + }`) + } else { + s.t.Fatal("Unrecognised request: ", r.Method, r.URL) + } +} -func (t *UserServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, `{ - "id":"1", - "email":"example@example.com", - "verified_email":true, - "hd":"example.com" - }`) +func NewOAuthServer(t *testing.T) (*httptest.Server, *url.URL) { + handler := &OAuthServer{} + server := httptest.NewServer(handler) + serverURL, _ := url.Parse(server.URL) + return server, serverURL } func doHttpRequest(r *http.Request, c *http.Cookie) (*http.Response, string) {