2017-02-22 08:14:37 +01:00
/ *
Package gothic wraps common behaviour when using Goth . This makes it quick , and easy , to get up
and running with Goth . Of course , if you want complete control over how things flow , in regards
to the authentication process , feel free and use Goth directly .
See https : //github.com/markbates/goth/examples/main.go to see this in action.
* /
package gothic
import (
2018-02-19 07:10:51 +02:00
"bytes"
"compress/gzip"
"encoding/base64"
2017-02-22 08:14:37 +01:00
"errors"
"fmt"
2018-02-19 07:10:51 +02:00
"io/ioutil"
"math/rand"
2017-02-22 08:14:37 +01:00
"net/http"
2018-02-19 07:10:51 +02:00
"net/url"
2017-02-22 08:14:37 +01:00
"os"
2018-02-19 07:10:51 +02:00
"strings"
"time"
2017-02-22 08:14:37 +01:00
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/markbates/goth"
)
// SessionName is the key used to access the session store.
const SessionName = "_gothic_session"
// Store can/should be set by applications using gothic. The default is a cookie store.
var Store sessions . Store
var defaultStore sessions . Store
var keySet = false
2018-02-19 07:10:51 +02:00
var gothicRand * rand . Rand
2017-02-22 08:14:37 +01:00
func init ( ) {
key := [ ] byte ( os . Getenv ( "SESSION_SECRET" ) )
keySet = len ( key ) != 0
2018-02-19 07:10:51 +02:00
cookieStore := sessions . NewCookieStore ( [ ] byte ( key ) )
cookieStore . Options . HttpOnly = true
Store = cookieStore
2017-02-22 08:14:37 +01:00
defaultStore = Store
2018-02-19 07:10:51 +02:00
gothicRand = rand . New ( rand . NewSource ( time . Now ( ) . UnixNano ( ) ) )
2017-02-22 08:14:37 +01:00
}
/ *
2018-02-19 07:10:51 +02:00
BeginAuthHandler is a convenience handler for starting the authentication process .
2017-02-22 08:14:37 +01:00
It expects to be able to get the name of the provider from the query parameters
as either "provider" or ":provider" .
BeginAuthHandler will redirect the user to the appropriate authentication end - point
for the requested provider .
See https : //github.com/markbates/goth/examples/main.go to see this in action.
* /
func BeginAuthHandler ( res http . ResponseWriter , req * http . Request ) {
url , err := GetAuthURL ( res , req )
if err != nil {
res . WriteHeader ( http . StatusBadRequest )
fmt . Fprintln ( res , err )
return
}
http . Redirect ( res , req , url , http . StatusTemporaryRedirect )
}
// SetState sets the state string associated with the given request.
// If no state string is associated with the request, one will be generated.
// This state is sent to the provider and can be retrieved during the
// callback.
var SetState = func ( req * http . Request ) string {
state := req . URL . Query ( ) . Get ( "state" )
if len ( state ) > 0 {
return state
}
2018-02-19 07:10:51 +02:00
// If a state query param is not passed in, generate a random
// base64-encoded nonce so that the state on the auth URL
// is unguessable, preventing CSRF attacks, as described in
//
// https://auth0.com/docs/protocols/oauth2/oauth-state#keep-reading
nonceBytes := make ( [ ] byte , 64 )
for i := 0 ; i < 64 ; i ++ {
nonceBytes [ i ] = byte ( gothicRand . Int63 ( ) % 256 )
}
return base64 . URLEncoding . EncodeToString ( nonceBytes )
2017-02-22 08:14:37 +01:00
}
// GetState gets the state returned by the provider during the callback.
// This is used to prevent CSRF attacks, see
// http://tools.ietf.org/html/rfc6749#section-10.12
var GetState = func ( req * http . Request ) string {
return req . URL . Query ( ) . Get ( "state" )
}
/ *
GetAuthURL starts the authentication process with the requested provided .
It will return a URL that should be used to send users to .
It expects to be able to get the name of the provider from the query parameters
as either "provider" or ":provider" .
I would recommend using the BeginAuthHandler instead of doing all of these steps
yourself , but that ' s entirely up to you .
* /
func GetAuthURL ( res http . ResponseWriter , req * http . Request ) ( string , error ) {
if ! keySet && defaultStore == Store {
fmt . Println ( "goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store." )
}
providerName , err := GetProviderName ( req )
if err != nil {
return "" , err
}
provider , err := goth . GetProvider ( providerName )
if err != nil {
return "" , err
}
sess , err := provider . BeginAuth ( SetState ( req ) )
if err != nil {
return "" , err
}
url , err := sess . GetAuthURL ( )
if err != nil {
return "" , err
}
err = storeInSession ( providerName , sess . Marshal ( ) , req , res )
if err != nil {
return "" , err
}
return url , err
}
/ *
CompleteUserAuth does what it says on the tin . It completes the authentication
process and fetches all of the basic information about the user from the provider .
It expects to be able to get the name of the provider from the query parameters
as either "provider" or ":provider" .
See https : //github.com/markbates/goth/examples/main.go to see this in action.
* /
var CompleteUserAuth = func ( res http . ResponseWriter , req * http . Request ) ( goth . User , error ) {
2018-02-19 07:10:51 +02:00
defer Logout ( res , req )
2017-02-22 08:14:37 +01:00
if ! keySet && defaultStore == Store {
fmt . Println ( "goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store." )
}
providerName , err := GetProviderName ( req )
if err != nil {
return goth . User { } , err
}
provider , err := goth . GetProvider ( providerName )
if err != nil {
return goth . User { } , err
}
value , err := getFromSession ( providerName , req )
if err != nil {
return goth . User { } , err
}
sess , err := provider . UnmarshalSession ( value )
if err != nil {
return goth . User { } , err
}
2018-02-19 07:10:51 +02:00
err = validateState ( req , sess )
if err != nil {
return goth . User { } , err
}
2017-02-22 08:14:37 +01:00
user , err := provider . FetchUser ( sess )
if err == nil {
// user can be found with existing session data
return user , err
}
// get new token and retry fetch
_ , err = sess . Authorize ( provider , req . URL . Query ( ) )
if err != nil {
return goth . User { } , err
}
err = storeInSession ( providerName , sess . Marshal ( ) , req , res )
if err != nil {
return goth . User { } , err
}
2018-02-19 07:10:51 +02:00
gu , err := provider . FetchUser ( sess )
return gu , err
}
// validateState ensures that the state token param from the original
// AuthURL matches the one included in the current (callback) request.
func validateState ( req * http . Request , sess goth . Session ) error {
rawAuthURL , err := sess . GetAuthURL ( )
if err != nil {
return err
}
authURL , err := url . Parse ( rawAuthURL )
if err != nil {
return err
}
originalState := authURL . Query ( ) . Get ( "state" )
if originalState != "" && ( originalState != req . URL . Query ( ) . Get ( "state" ) ) {
return errors . New ( "state token mismatch" )
}
return nil
}
// Logout invalidates a user session.
func Logout ( res http . ResponseWriter , req * http . Request ) error {
session , err := Store . Get ( req , SessionName )
if err != nil {
return err
}
session . Options . MaxAge = - 1
session . Values = make ( map [ interface { } ] interface { } )
err = session . Save ( req , res )
if err != nil {
return errors . New ( "Could not delete user session " )
}
return nil
2017-02-22 08:14:37 +01:00
}
// GetProviderName is a function used to get the name of a provider
// for a given request. By default, this provider is fetched from
// the URL query string. If you provide it in a different way,
// assign your own function to this variable that returns the provider
// name for your request.
var GetProviderName = getProviderName
func getProviderName ( req * http . Request ) ( string , error ) {
2018-02-19 07:10:51 +02:00
// get all the used providers
providers := goth . GetProviders ( )
// loop over the used providers, if we already have a valid session for any provider (ie. user is already logged-in with a provider), then return that provider name
for _ , provider := range providers {
p := provider . Name ( )
session , _ := Store . Get ( req , p + SessionName )
value := session . Values [ p ]
if _ , ok := value . ( string ) ; ok {
2017-02-22 08:14:37 +01:00
return p , nil
}
}
2018-02-19 07:10:51 +02:00
// try to get it from the url param "provider"
if p := req . URL . Query ( ) . Get ( "provider" ) ; p != "" {
return p , nil
2017-02-22 08:14:37 +01:00
}
2018-02-19 07:10:51 +02:00
// try to get it from the url param ":provider"
if p := req . URL . Query ( ) . Get ( ":provider" ) ; p != "" {
return p , nil
}
// try to get it from the context's value of "provider" key
if p , ok := mux . Vars ( req ) [ "provider" ] ; ok {
return p , nil
2017-02-22 08:14:37 +01:00
}
2018-02-19 07:10:51 +02:00
// try to get it from the go-context's value of "provider" key
if p , ok := req . Context ( ) . Value ( "provider" ) . ( string ) ; ok {
return p , nil
}
// if not found then return an empty string with the corresponding error
return "" , errors . New ( "you must select a provider" )
2017-02-22 08:14:37 +01:00
}
func storeInSession ( key string , value string , req * http . Request , res http . ResponseWriter ) error {
2018-02-19 07:10:51 +02:00
session , _ := Store . Get ( req , SessionName )
2017-02-22 08:14:37 +01:00
2018-02-19 07:10:51 +02:00
if err := updateSessionValue ( session , key , value ) ; err != nil {
return err
}
2017-02-22 08:14:37 +01:00
return session . Save ( req , res )
}
func getFromSession ( key string , req * http . Request ) ( string , error ) {
2018-02-19 07:10:51 +02:00
session , _ := Store . Get ( req , SessionName )
value , err := getSessionValue ( session , key )
if err != nil {
return "" , errors . New ( "could not find a matching session for this request" )
}
2017-02-22 08:14:37 +01:00
2018-02-19 07:10:51 +02:00
return value , nil
}
func getSessionValue ( session * sessions . Session , key string ) ( string , error ) {
2017-02-22 08:14:37 +01:00
value := session . Values [ key ]
if value == nil {
2018-02-19 07:10:51 +02:00
return "" , fmt . Errorf ( "could not find a matching session for this request" )
2017-02-22 08:14:37 +01:00
}
2018-02-19 07:10:51 +02:00
rdata := strings . NewReader ( value . ( string ) )
r , err := gzip . NewReader ( rdata )
if err != nil {
return "" , err
}
s , err := ioutil . ReadAll ( r )
if err != nil {
return "" , err
}
return string ( s ) , nil
}
func updateSessionValue ( session * sessions . Session , key , value string ) error {
var b bytes . Buffer
gz := gzip . NewWriter ( & b )
if _ , err := gz . Write ( [ ] byte ( value ) ) ; err != nil {
return err
}
if err := gz . Flush ( ) ; err != nil {
return err
}
if err := gz . Close ( ) ; err != nil {
return err
}
session . Values [ key ] = b . String ( )
return nil
}