From 2a2692dadd7ba71dd040c9e3bcb4a5527ad6e180 Mon Sep 17 00:00:00 2001 From: Adam Hanna Date: Mon, 1 May 2017 16:00:33 -0700 Subject: [PATCH] finished writing tests --- README.md | 274 ++++++++++++++++ auth/service.go | 22 +- auth/service_unit_test.go | 112 +++++++ auth/service_util_unit_test.go | 2 + service.go | 6 +- service_e2e_test.go | 485 ++++++++++++++++++++++++++++ service_unit_test.go | 380 ++++++++++++++++++++++ service_util_unit_test.go | 29 ++ sessionerrs/sessionerrs.go | 2 +- store/service.go | 8 + store/service_integration_test.go | 267 +++++++++++++++ store/service_unit_test.go | 58 ++++ store/service_util.go | 6 +- store/service_util_unit_test.go | 29 ++ transport/service.go | 4 +- transport/service_interface.go | 2 +- transport/service_unit_test.go | 161 +++++++++ transport/service_util_unit_test.go | 28 ++ user/user_unit_test.go | 55 ++++ 19 files changed, 1915 insertions(+), 15 deletions(-) create mode 100644 auth/service_unit_test.go create mode 100644 service_e2e_test.go create mode 100644 service_unit_test.go create mode 100644 service_util_unit_test.go create mode 100644 store/service_integration_test.go create mode 100644 store/service_unit_test.go create mode 100644 store/service_util_unit_test.go create mode 100644 transport/service_unit_test.go create mode 100644 transport/service_util_unit_test.go create mode 100644 user/user_unit_test.go diff --git a/README.md b/README.md index e69de29..271dfd5 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,274 @@ +# Go Sessions +A dead simple, highly customizable sessions service for go http servers + +## Quickstart + +~~~go +package main + +import ( + ... +) + +var sesh *sessions.Service + +var issueSession = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userSession, seshErr := sesh.IssueUserSession("fakeUserID", "", w) + if seshErr != nil { + log.Printf("Err issuing user session: %v\n", seshErr) + http.Error(w, seshErr.Err.Error(), seshErr.Code) // seshErr is a custom err with an http code + return + } + log.Printf("In issue; user's session: %v\n", userSession) + + w.WriteHeader(http.StatusOK) +}) + +func main() { + seshStore := store.New(store.Options{}) + + // e.g. `$ openssl rand -base64 64` + authKey := "DOZDgBdMhGLImnk0BGYgOUI+h1n7U+OdxcZPctMbeFCsuAom2aFU4JPV4Qj11hbcb5yaM4WDuNP/3B7b+BnFhw==" + authOptions := auth.Options{ + Key: []byte(authKey), + } + seshAuth, err := auth.New(authOptions) + if err != nil { + log.Fatal(err) + } + + transportOptions := transport.Options{ + Secure: false, // note: can't use secure cookies in development! + } + seshTransport := transport.New(transportOptions) + + seshOptions := sessions.Options{} + sesh = sessions.New(seshStore, seshAuth, seshTransport, seshOptions) + + http.HandleFunc("/issue", issueSession) + + log.Println("Listening on localhost:8080") + log.Fatal(http.ListenAndServe("127.0.0.1:8080", nil)) +} +~~~ + +## Testing +Tests are broken down into three categories: unit, integration and e2e. Integration and e2e tests require a connection to a redis server. The connection address can be set in the `REDIS_URL` environment variable. The default is ":6379" + +To run all tests, simply: +~~~ +$ go test -tags="unit integration e2e" ./... +~~~ + +To run only tests from one of the categories: +~~~ +$ go test -tags="integration" ./... +~~~ + +To run only unit and integration tests: +~~~ +$ go test -tags="unit integration" ./... +~~~ + +## Example +The following example is a demonstration of using the session service along with a CSRF code to check for authentication. The CSRF code is stored in the userSession JSON field. + +~~~go +package main + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "io" + "log" + "net/http" + "time" + + "github.com/adam-hanna/sessions" + "github.com/adam-hanna/sessions/auth" + "github.com/adam-hanna/sessions/store" + "github.com/adam-hanna/sessions/transport" +) + +// SessionJSON is used for marshalling and unmarshalling custom session json information. +// We're using it as an opportunity to tie csrf strings to sessions to prevent csrf attacks +type SessionJSON struct { + CSRF string `json:"csrf"` +} + +var sesh *sessions.Service + +var issueSession = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + csrf, err := generateKey() + if err != nil { + log.Printf("Err generating csrf: %v\n", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + myJSON := SessionJSON{ + CSRF: csrf, + } + JSONBytes, err := json.Marshal(myJSON) + if err != nil { + log.Printf("Err generating json: %v\n", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + userSession, seshErr := sesh.IssueUserSession("fakeUserID", string(JSONBytes[:]), w) + if seshErr != nil { + log.Printf("Err issuing user session: %v\n", seshErr) + http.Error(w, seshErr.Err.Error(), seshErr.Code) + return + } + log.Printf("In issue; user's session: %v\n", userSession) + + // note: we set the csrf in a cookie, but look for it in request headers + csrfCookie := http.Cookie{ + Name: "csrf", + Value: csrf, + Expires: userSession.ExpiresAt, + Path: "/", + HttpOnly: false, + Secure: false, // note: can't use secure cookies in development + } + http.SetCookie(w, &csrfCookie) + + w.WriteHeader(http.StatusOK) +}) + +var requiresSession = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userSession, seshErr := sesh.GetUserSession(r) + if seshErr != nil { + log.Printf("Err fetching user session: %v\n", seshErr) + http.Error(w, seshErr.Err.Error(), seshErr.Code) + return + } + log.Printf("In require; user session expiration before extension: %v\n", userSession.ExpiresAt.UTC()) + + myJSON := SessionJSON{} + if err := json.Unmarshal([]byte(userSession.JSON), &myJSON); err != nil { + log.Printf("Err issuing unmarshalling json: %v\n", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + log.Printf("In require; user's custom json: %v\n", myJSON) + + // note: we set the csrf in a cookie, but look for it in request headers + csrf := r.Header.Get("X-CSRF-Token") + if csrf != myJSON.CSRF { + log.Printf("Unauthorized! CSRF token doesn't match user session") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // note that session expiry's need to be manually extended + seshErr = sesh.ExtendUserSession(userSession, r, w) + if seshErr != nil { + log.Printf("Err fetching user session: %v\n", seshErr) + http.Error(w, seshErr.Err.Error(), seshErr.Code) + return + } + log.Printf("In require; users session expiration after extension: %v\n", userSession.ExpiresAt.UTC()) + + // need to extend the csrf cookie, too + csrfCookie := http.Cookie{ + Name: "csrf", + Value: csrf, + Expires: userSession.ExpiresAt, + Path: "/", + HttpOnly: false, + Secure: false, // note: can't use secure cookies in development + } + http.SetCookie(w, &csrfCookie) + + w.WriteHeader(http.StatusOK) +}) + +var clearSession = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userSession, err := sesh.GetUserSession(r) + if err != nil { + log.Printf("Err fetching user session: %v\n", err) + http.Error(w, err.Err.Error(), err.Code) + return + } + + log.Printf("In clear; session: %v\n", userSession) + + myJSON := SessionJSON{} + if err := json.Unmarshal([]byte(userSession.JSON), &myJSON); err != nil { + log.Printf("Err issuing unmarshalling json: %v\n", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + log.Printf("In require; user's custom json: %v\n", myJSON) + + // note: we set the csrf in a cookie, but look for it in request headers + csrf := r.Header.Get("X-CSRF-Token") + if csrf != myJSON.CSRF { + log.Printf("Unauthorized! CSRF token doesn't match user session") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + err = sesh.ClearUserSession(userSession, w) + if err != nil { + log.Printf("Err clearing user session: %v\n", err) + http.Error(w, err.Err.Error(), err.Code) + return + } + + // need to clear the csrf cookie, too + aLongTimeAgo := time.Now().Add(-1000 * time.Hour) + csrfCookie := http.Cookie{ + Name: "csrf", + Value: "", + Expires: aLongTimeAgo, + Path: "/", + HttpOnly: false, + Secure: false, // note: can't use secure cookies in development + } + http.SetCookie(w, &csrfCookie) + + w.WriteHeader(http.StatusOK) +}) + +func main() { + seshStore := store.New(store.Options{}) + + // e.g. `$ openssl rand -base64 64` + authKey := "DOZDgBdMhGLImnk0BGYgOUI+h1n7U+OdxcZPctMbeFCsuAom2aFU4JPV4Qj11hbcb5yaM4WDuNP/3B7b+BnFhw==" + authOptions := auth.Options{ + Key: []byte(authKey), + } + seshAuth, err := auth.New(authOptions) + if err != nil { + log.Fatal(err) + } + + transportOptions := transport.Options{ + Secure: false, // note: can't use secure cookies in development! + } + seshTransport := transport.New(transportOptions) + + seshOptions := sessions.Options{} + sesh = sessions.New(seshStore, seshAuth, seshTransport, seshOptions) + + http.HandleFunc("/issue", issueSession) + http.HandleFunc("/require", requiresSession) + http.HandleFunc("/clear", clearSession) // also requires a valid session + + log.Println("Listening on localhost:3000") + log.Fatal(http.ListenAndServe("127.0.0.1:3000", nil)) +} + +func generateKey() (string, error) { + b := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} +~~~ \ No newline at end of file diff --git a/auth/service.go b/auth/service.go index 22e2797..e72b0f3 100644 --- a/auth/service.go +++ b/auth/service.go @@ -6,6 +6,16 @@ import ( "github.com/adam-hanna/sessions/sessionerrs" ) +// note @adam-hanna: can these be constants? +var ( + // ErrNoSessionKey is thrown when no key was provided for HMAC signing + ErrNoSessionKey = errors.New("no session key") + // ErrMalformedSession is thrown when the session value doesn't conform to expectations + ErrMalformedSession = errors.New("malformed session") + // ErrInvalidSessionSignature the signature included with the session can't be verified with the provided session key + ErrInvalidSessionSignature = errors.New("invalid session signature") +) + // Service performs signing and verification actions using HMAC type Service struct { options Options @@ -21,9 +31,9 @@ type Options struct { func New(options Options) (*Service, *sessionerrs.Custom) { // note @adam-hanna: should we perform other checks like min/max length? if len(options.Key) == 0 { - return &Service{}, &sessionerrs.Custom{ + return nil, &sessionerrs.Custom{ Code: 500, - Err: errors.New("no session key"), + Err: ErrNoSessionKey, } } return &Service{ @@ -31,7 +41,7 @@ func New(options Options) (*Service, *sessionerrs.Custom) { }, nil } -// SignAndBase64Encode signs the sessionID with they key and returns a base64 encoded string +// SignAndBase64Encode signs the sessionID with the key and returns a base64 encoded string func (s *Service) SignAndBase64Encode(sessionID string) (string, *sessionerrs.Custom) { userSessionIDBytes := []byte(sessionID) signedBytes := signHMAC(&userSessionIDBytes, &s.options.Key) @@ -55,12 +65,12 @@ func (s *Service) VerifyAndDecode(signed string) (string, *sessionerrs.Custom) { } } - // note: session uuid's are always 36 bytes long + // note: session uuid's are always 36 bytes long. This will make it difficult to switch to a new uuid algorithm! if len(decodedSessionValueBytes) <= 36 { // note @adam-hanna: is 401 the proper http status code, here? return "", &sessionerrs.Custom{ Code: 401, - Err: errors.New("invalid session"), + Err: ErrMalformedSession, } } sessionIDBytes := decodedSessionValueBytes[:36] @@ -72,7 +82,7 @@ func (s *Service) VerifyAndDecode(signed string) (string, *sessionerrs.Custom) { if !verified { return "", &sessionerrs.Custom{ Code: 401, - Err: errors.New("invalid session signature"), + Err: ErrInvalidSessionSignature, } } diff --git a/auth/service_unit_test.go b/auth/service_unit_test.go new file mode 100644 index 0000000..0887f39 --- /dev/null +++ b/auth/service_unit_test.go @@ -0,0 +1,112 @@ +// +build unit + +package auth + +import ( + "reflect" + "testing" + + "github.com/adam-hanna/sessions/sessionerrs" +) + +var ( + validKey = []byte("DOZDgBdMhGLImnk0BGYgOUI+h1n7U+OdxcZPctMbeFCsuAom2aFU4JPV4Qj11hbcb5yaM4WDuNP/3B7b+BnFhw==") + validService = Service{ + options: Options{ + Key: validKey, + }, + } +) + +// TestNew tests the New function +func TestNew(t *testing.T) { + var tests = []struct { + input Options + expectedErr sessionerrs.Custom + expectedService Service + }{ + { + Options{ + Key: validKey, + }, + sessionerrs.Custom{}, + validService, + }, + { + Options{ + Key: []byte{}, + }, + sessionerrs.Custom{ + Code: 500, + Err: ErrNoSessionKey, + }, + Service{}, + }, + } + + for idx, tt := range tests { + s, e := New(tt.input) + if e == nil { + e = &sessionerrs.Custom{} + } + if s == nil { + s = &Service{} + } + + assertErr := reflect.DeepEqual(tt.expectedErr, *e) + assertService := reflect.DeepEqual(tt.expectedService, *s) + + if !assertErr && !assertService { + t.Errorf("test #%d failed; assertErr: %t, assertService: %t, expectedErr: %v, expectedService: %v, received err: %v, received service: %v", idx+1, assertErr, assertService, tt.expectedErr, tt.expectedService, *s, *e) + } + } +} + +// TestSignAndBase64Encode tests the SignAndBase64Encode function +func TestSignAndBase64Encode(t *testing.T) { + // note: err returned is always nil, so don't need to test for it + var tests = []struct { + input string + expected string + pass bool + }{ + {"5f4cd331-c869-4871-bb41-76b726df9937", "NWY0Y2QzMzEtYzg2OS00ODcxLWJiNDEtNzZiNzI2ZGY5OTM3YGV5KkkGaOaikrAO9qqRa3hocM3OD0JDoXUtJ8LRJKKQw_8H6kAtbps8g4bQHoL--LyxWPesiTvlasxlnnNA7g==", true}, + {"4f4cd331-c869-4871-bb41-76b726df9937", "NWY0Y2QzMzEtYzg2OS00ODcxLWJiNDEtNzZiNzI2ZGY5OTM3YGV5KkkGaOaikrAO9qqRa3hocM3OD0JDoXUtJ8LRJKKQw_8H6kAtbps8g4bQHoL--LyxWPesiTvlasxlnnNA7g=a", false}, + } + + for idx, tt := range tests { + a, _ := validService.SignAndBase64Encode(tt.input) + + if tt.pass && a != tt.expected { + t.Errorf("test #%d failed; input: %s, expected output: %s, received: %s", idx+1, tt.input, tt.expected, a) + } else if !tt.pass && a == tt.expected { + t.Errorf("test #%d failed; input: %s, expected output: %s, received: %s", idx+1, tt.input, tt.expected, a) + } + } +} + +// TestSignAndBase64Encode tests the SignAndBase64Encode function +func TestVerifyAndDecode(t *testing.T) { + // note: err returned is always nil, so don't need to test for it + var tests = []struct { + expectedString string + input string + expectedErr sessionerrs.Custom + }{ + {"5f4cd331-c869-4871-bb41-76b726df9937", "NWY0Y2QzMzEtYzg2OS00ODcxLWJiNDEtNzZiNzI2ZGY5OTM3YGV5KkkGaOaikrAO9qqRa3hocM3OD0JDoXUtJ8LRJKKQw_8H6kAtbps8g4bQHoL--LyxWPesiTvlasxlnnNA7g==", sessionerrs.Custom{}}, + {"", "NWY0Y2QzMzEtYzg2OS00ODcxLWJiNDEtNzZiNzI2ZGY5OTM3YGV5KkkGaOaikrAO9qqRa3hocM3OD0JDoXUtJ8LRJKKQw_8H6kAtbps8g4bQHoL--LyxWPesiTvlasxlnnNA7g=a", sessionerrs.Custom{Code: 500, Err: ErrBase64Decode}}, + {"", "5f4cd331-c869-4871-bb41-76b726df9937", sessionerrs.Custom{Code: 401, Err: ErrMalformedSession}}, + {"", "NAY0Y2QzMzEtYzg2OS00ODcxLWJiNDEtNzZiNzI2ZGY5OTM3YGV5KkkGaOaikrAO9qqRa3hocM3OD0JDoXUtJ8LRJKKQw_8H6kAtbps8g4bQHoL--LyxWPesiTvlasxlnnNA7g==", sessionerrs.Custom{Code: 401, Err: ErrInvalidSessionSignature}}, + } + + for idx, tt := range tests { + a, e := validService.VerifyAndDecode(tt.input) + if e == nil { + e = &sessionerrs.Custom{} + } + + if a != tt.expectedString || *e != tt.expectedErr { + t.Errorf("test #%d failed; input: %s, expected string: %s, expected err: %v, received string: %s, received err: %v", idx+1, tt.input, tt.expectedString, tt.expectedErr, a, *e) + } + } +} diff --git a/auth/service_util_unit_test.go b/auth/service_util_unit_test.go index c26ad48..f5cf0ea 100644 --- a/auth/service_util_unit_test.go +++ b/auth/service_util_unit_test.go @@ -1,3 +1,5 @@ +// +build unit + package auth import ( diff --git a/service.go b/service.go index f7a6f8d..3cb0856 100644 --- a/service.go +++ b/service.go @@ -81,7 +81,7 @@ func (s *Service) ClearUserSession(userSession *user.Session, w http.ResponseWri // sessions that have expired, or that fail signature verification will return a custom session error with code 401 func (s *Service) GetUserSession(r *http.Request) (*user.Session, *sessionerrs.Custom) { // read the session from the request - signedSessionID, err := s.transport.FetchSessionFromRequest(r) + signedSessionID, err := s.transport.FetchSessionIDFromRequest(r) if err != nil { return nil, err } @@ -97,6 +97,8 @@ func (s *Service) GetUserSession(r *http.Request) (*user.Session, *sessionerrs.C } // ExtendUserSession extends the ExpiresAt of a session by the Options.ExpirationDuration +// +// Note that this function must be called, manually! Extension of user session expiry's does not happen automatically! func (s *Service) ExtendUserSession(userSession *user.Session, r *http.Request, w http.ResponseWriter) *sessionerrs.Custom { newExpiresAt := time.Now().Add(s.options.ExpirationDuration).UTC() @@ -110,7 +112,7 @@ func (s *Service) ExtendUserSession(userSession *user.Session, r *http.Request, } // fetch the signed session id from the request - signedSessionID, err := s.transport.FetchSessionFromRequest(r) + signedSessionID, err := s.transport.FetchSessionIDFromRequest(r) if err != nil { return err } diff --git a/service_e2e_test.go b/service_e2e_test.go new file mode 100644 index 0000000..f488ee2 --- /dev/null +++ b/service_e2e_test.go @@ -0,0 +1,485 @@ +// +build e2e + +package sessions + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/adam-hanna/sessions/auth" + "github.com/adam-hanna/sessions/store" + "github.com/adam-hanna/sessions/transport" +) + +// SessionJSON is used for marshalling and unmarshalling custom session json information. +// We're using it as an opportunity to tie csrf strings to sessions to prevent csrf attacks +type SessionJSON struct { + CSRF string `json:"csrf"` +} + +var ( + issuedSessionIDs []string + + sesh *Service + seshStore *store.Service + + issueSession = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + csrf, err := generateKey() + if err != nil { + if testing.Verbose() { + log.Printf("Err generating csrf: %v\n", err) + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + myJSON := SessionJSON{ + CSRF: csrf, + } + JSONBytes, err := json.Marshal(myJSON) + if err != nil { + if testing.Verbose() { + log.Printf("Err generating json: %v\n", err) + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + userSession, seshErr := sesh.IssueUserSession("fakeUserID", string(JSONBytes[:]), w) + if seshErr != nil { + if testing.Verbose() { + log.Printf("Err issuing user session: %v\n", seshErr) + } + http.Error(w, seshErr.Err.Error(), seshErr.Code) + return + } + if testing.Verbose() { + log.Printf("In issue; user's session: %v\n", userSession) + } + + // we need to remove these from redis during testing shutdown + issuedSessionIDs = append(issuedSessionIDs, userSession.ID) + + // note: we set the csrf in a cookie, but look for it in request headers + csrfCookie := http.Cookie{ + Name: "csrf", + Value: csrf, + Expires: userSession.ExpiresAt, + Path: "/", + HttpOnly: false, + Secure: false, // note: can't use secure cookies in development + } + http.SetCookie(w, &csrfCookie) + + w.WriteHeader(http.StatusOK) + }) + + requiresSession = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userSession, seshErr := sesh.GetUserSession(r) + if seshErr != nil { + if testing.Verbose() { + log.Printf("Err fetching user session: %v\n", seshErr) + } + http.Error(w, seshErr.Err.Error(), seshErr.Code) + return + } + if testing.Verbose() { + log.Printf("In require; user session expiration before extension: %v\n", userSession.ExpiresAt.UTC()) + } + + myJSON := SessionJSON{} + if err := json.Unmarshal([]byte(userSession.JSON), &myJSON); err != nil { + if testing.Verbose() { + log.Printf("Err issuing unmarshalling json: %v\n", err) + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if testing.Verbose() { + log.Printf("In require; user's custom json: %v\n", myJSON) + } + + // note: we set the csrf in a cookie, but look for it in request headers + csrf := r.Header.Get("X-CSRF-Token") + if csrf != myJSON.CSRF { + if testing.Verbose() { + log.Println("Unauthorized! CSRF token doesn't match user session") + } + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // note that session expiry's need to be manually extended + seshErr = sesh.ExtendUserSession(userSession, r, w) + if seshErr != nil { + if testing.Verbose() { + log.Printf("Err fetching user session: %v\n", seshErr) + } + http.Error(w, seshErr.Err.Error(), seshErr.Code) + return + } + if testing.Verbose() { + log.Printf("In require; users session expiration after extension: %v\n", userSession.ExpiresAt.UTC()) + } + + // need to extend the csrf cookie, too + csrfCookie := http.Cookie{ + Name: "csrf", + Value: csrf, + Expires: userSession.ExpiresAt, + Path: "/", + HttpOnly: false, + Secure: false, // note: can't use secure cookies in development + } + http.SetCookie(w, &csrfCookie) + + w.WriteHeader(http.StatusOK) + }) + + clearSession = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userSession, err := sesh.GetUserSession(r) + if err != nil { + if testing.Verbose() { + log.Printf("Err fetching user session: %v\n", err) + } + http.Error(w, err.Err.Error(), err.Code) + return + } + if testing.Verbose() { + log.Printf("In clear; session: %v\n", userSession) + } + + myJSON := SessionJSON{} + if err := json.Unmarshal([]byte(userSession.JSON), &myJSON); err != nil { + if testing.Verbose() { + log.Printf("Err issuing unmarshalling json: %v\n", err) + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if testing.Verbose() { + log.Printf("In require; user's custom json: %v\n", myJSON) + } + + // note: we set the csrf in a cookie, but look for it in request headers + csrf := r.Header.Get("X-CSRF-Token") + if csrf != myJSON.CSRF { + if testing.Verbose() { + log.Println("Unauthorized! CSRF token doesn't match user session") + } + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + err = sesh.ClearUserSession(userSession, w) + if err != nil { + if testing.Verbose() { + log.Printf("Err clearing user session: %v\n", err) + } + http.Error(w, err.Err.Error(), err.Code) + return + } + + // need to clear the csrf cookie, too + aLongTimeAgo := time.Now().Add(-1000 * time.Hour) + csrfCookie := http.Cookie{ + Name: "csrf", + Value: "", + Expires: aLongTimeAgo, + Path: "/", + HttpOnly: false, + Secure: false, // note: can't use secure cookies in development + } + http.SetCookie(w, &csrfCookie) + + w.WriteHeader(http.StatusOK) + }) +) + +func recoverHandler(next http.Handler) http.Handler { + // this catches any errors and returns an internal server error to the client + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + if testing.Verbose() { + log.Printf("Recovered! Panic: %+v\n", err) + } + http.Error(w, http.StatusText(500), 500) + } + }() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +func generateKey() (string, error) { + b := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} + +func TestMain(m *testing.M) { + err := setup() + if err != nil { + log.Fatal("Err setting up e2e tests", err) + } + + code := m.Run() + + err = shutdown() + if err != nil { + log.Fatal("Err shutting down e2e tests", err) + } + + os.Exit(code) +} + +func setup() error { + log.Println("setting up e2e tests") + + return nil +} + +func shutdown() error { + log.Println("shutting down redis e2e tests") + + c := seshStore.Pool.Get() + defer c.Close() + + aLongTimeAgo := time.Now().Add(-1000 * time.Hour) + + for idx := range issuedSessionIDs { + _, err := c.Do("EXPIREAT", issuedSessionIDs[idx], aLongTimeAgo.Unix()) + if err != nil { + return errors.New("Could not delete issued session id. Error: " + err.Error()) + } + } + + return nil +} + +// TestE2E tests the entire system +func TestE2E(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + log.Println("running e2e tests") + + // set up the session service + seshStore = store.New(store.Options{}) + + // e.g. `$ openssl rand -base64 64` + authKey := "DOZDgBdMhGLImnk0BGYgOUI+h1n7U+OdxcZPctMbeFCsuAom2aFU4JPV4Qj11hbcb5yaM4WDuNP/3B7b+BnFhw==" + authOptions := auth.Options{ + Key: []byte(authKey), + } + seshAuth, customErr := auth.New(authOptions) + if customErr != nil { + log.Fatal(customErr) + } + + transportOptions := transport.Options{ + Secure: false, // note: can't use secure cookies in development! + } + seshTransport := transport.New(transportOptions) + + seshOptions := Options{ + ExpirationDuration: 5 * time.Second, + } + sesh = New(seshStore, seshAuth, seshTransport, seshOptions) + + // make sure that we can connect + c := seshStore.Pool.Get() + defer c.Close() + + _, err := c.Do("PING") + if err != nil { + t.Errorf("Could not ping redis server. Error: %s\n", err.Error()) + } + + // set up the test servers + issueServer := httptest.NewServer(recoverHandler(issueSession)) + defer issueServer.Close() + + requireServer := httptest.NewServer(recoverHandler(requiresSession)) + defer requireServer.Close() + + clearServer := httptest.NewServer(recoverHandler(clearSession)) + defer clearServer.Close() + + // first, let's send a request to the require server without session. This should err. + res, err := http.Get(requireServer.URL) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if res.StatusCode != 401 { + t.Errorf("Expected unathorized (401), received: %d\n", res.StatusCode) + } + + // not let's get a valid session + res, err = http.Get(issueServer.URL) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if res.StatusCode != 200 { + t.Errorf("Expected unathorized (200), received: %d\n", res.StatusCode) + } + + // now let's send to the require server, without a valid csrf. This should err. + // first, grab the csrf + rc := res.Cookies() + var csrf string + var sessionCookieIndex int + var originalExpiresAt time.Time + for i, cookie := range rc { + if cookie.Name == "csrf" { + csrf = cookie.Value + } + if cookie.Name == "session" { + sessionCookieIndex = i + originalExpiresAt = cookie.Expires + } + } + + req, err := http.NewRequest("GET", requireServer.URL, nil) + if err != nil { + t.Errorf("Couldn't build request; Err: %v\n", err) + } + + req.AddCookie(rc[sessionCookieIndex]) + + // send the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if resp.StatusCode != 401 { + t.Errorf("Expected status code 401, received: %d\n", resp.StatusCode) + } + + // now add the csrf to the header. This should NOT err. + req.Header.Add("X-CSRF-Token", csrf) + + // send the request + // but first sleep so we can tell the difference between the two expiry's + time.Sleep(2 * time.Second) // Pause + resp, err = client.Do(req) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if resp.StatusCode != 200 { + t.Errorf("Expected status code 200, received: %d\n", resp.StatusCode) + } + + // was the expiration extended? + rc = resp.Cookies() + sessionCookieIndex = 0 + for i, cookie := range rc { + if cookie.Name == "session" { + sessionCookieIndex = i + } + } + + if rc[sessionCookieIndex].Expires.Sub(originalExpiresAt) <= 0 { + t.Errorf("Expected the session cookie to be extended; original expire at: %v, new expire at: %v\n", originalExpiresAt, rc[sessionCookieIndex].Expires) + } + + // now let's logout + req, err = http.NewRequest("GET", clearServer.URL, nil) + if err != nil { + t.Errorf("Couldn't build request; Err: %v\n", err) + } + + req.AddCookie(rc[sessionCookieIndex]) + req.Header.Add("X-CSRF-Token", csrf) + + resp, err = client.Do(req) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if resp.StatusCode != 200 { + t.Errorf("Expected status code 200, received: %d\n", resp.StatusCode) + } + + // now test that the session is no longer valid + req, err = http.NewRequest("GET", requireServer.URL, nil) + if err != nil { + t.Errorf("Couldn't build request; Err: %v\n", err) + } + + req.AddCookie(rc[sessionCookieIndex]) + req.Header.Add("X-CSRF-Token", csrf) + + resp, err = client.Do(req) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if resp.StatusCode != 401 { + t.Errorf("Expected status code 401, received: %d\n", resp.StatusCode) + } + + // great! now let's test that we can wait for a token to expire + // not let's get a valid session + res, err = http.Get(issueServer.URL) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if res.StatusCode != 200 { + t.Errorf("Expected unathorized (200), received: %d\n", res.StatusCode) + } + + // now let's send to the require server + rc = res.Cookies() + csrf = "" + sessionCookieIndex = 0 + for i, cookie := range rc { + if cookie.Name == "csrf" { + csrf = cookie.Value + } + if cookie.Name == "session" { + sessionCookieIndex = i + } + } + + req, err = http.NewRequest("GET", requireServer.URL, nil) + if err != nil { + t.Errorf("Couldn't build request; Err: %v\n", err) + } + + req.AddCookie(rc[sessionCookieIndex]) + req.Header.Add("X-CSRF-Token", csrf) + + // send the request after waiting + time.Sleep(sesh.options.ExpirationDuration + 2*time.Second) // Pause + + resp, err = client.Do(req) + if err != nil { + t.Errorf("Couldn't send request to test server; Err: %v\n", err) + } + + if resp.StatusCode != 401 { + t.Errorf("Expected status code 401, received: %d\n", resp.StatusCode) + } +} diff --git a/service_unit_test.go b/service_unit_test.go new file mode 100644 index 0000000..ddcbbdc --- /dev/null +++ b/service_unit_test.go @@ -0,0 +1,380 @@ +// +build unit + +package sessions + +import ( + "errors" + "net/http" + "reflect" + "testing" + "time" + + "github.com/adam-hanna/sessions/sessionerrs" + "github.com/adam-hanna/sessions/user" +) + +var ( + MockedTestErr = sessionerrs.Custom{Code: -1, Err: errors.New("test err")} + + mockedStore = MockedStoreType{} + mockedAuth = MockedAuthType{} + mockedTransport = MockedTransportType{} + + erredStore = ErredStoreType{} + erredAuth = ErredAuthType{} + erredTransport = ErredTransportType{} + + opts = Options{ExpirationDuration: DefaultExpirationDuration} + + inputUserID = "testID" + inputJSON = "testJSON" + userSession = &user.Session{ + ID: "fakeID", + UserID: inputUserID, + JSON: inputJSON, + ExpiresAt: time.Now().Add(opts.ExpirationDuration).UTC(), + } +) + +type MockedAuthType struct { +} + +func (a *MockedAuthType) SignAndBase64Encode(sessionID string) (string, *sessionerrs.Custom) { + return "test", nil +} + +func (a *MockedAuthType) VerifyAndDecode(signed string) (string, *sessionerrs.Custom) { + return "test", nil +} + +type ErredAuthType struct { +} + +func (b *ErredAuthType) SignAndBase64Encode(sessionID string) (string, *sessionerrs.Custom) { + return "", &MockedTestErr +} + +func (b *ErredAuthType) VerifyAndDecode(signed string) (string, *sessionerrs.Custom) { + return "", &MockedTestErr +} + +type MockedStoreType struct { +} + +func (c *MockedStoreType) SaveUserSession(userSession *user.Session) *sessionerrs.Custom { + return nil +} + +func (c *MockedStoreType) DeleteUserSession(sessionID string) *sessionerrs.Custom { + return nil +} + +func (c *MockedStoreType) FetchValidUserSession(sessionID string) (*user.Session, *sessionerrs.Custom) { + return userSession, nil +} + +type ErredStoreType struct { +} + +func (d *ErredStoreType) SaveUserSession(userSession *user.Session) *sessionerrs.Custom { + return &MockedTestErr +} + +func (d *ErredStoreType) DeleteUserSession(sessionID string) *sessionerrs.Custom { + return &MockedTestErr +} + +func (d *ErredStoreType) FetchValidUserSession(sessionID string) (*user.Session, *sessionerrs.Custom) { + return nil, &MockedTestErr +} + +type MockedTransportType struct { +} + +func (e *MockedTransportType) SetSessionOnResponse(session string, userSession *user.Session, w http.ResponseWriter) *sessionerrs.Custom { + return nil +} + +func (e *MockedTransportType) DeleteSessionFromResponse(w http.ResponseWriter) *sessionerrs.Custom { + return nil +} + +func (e *MockedTransportType) FetchSessionIDFromRequest(r *http.Request) (string, *sessionerrs.Custom) { + return "test", nil +} + +type ErredTransportType struct { +} + +func (f *ErredTransportType) SetSessionOnResponse(session string, userSession *user.Session, w http.ResponseWriter) *sessionerrs.Custom { + return &MockedTestErr +} + +func (f *ErredTransportType) DeleteSessionFromResponse(w http.ResponseWriter) *sessionerrs.Custom { + return &MockedTestErr +} + +func (f *ErredTransportType) FetchSessionIDFromRequest(r *http.Request) (string, *sessionerrs.Custom) { + return "test", &MockedTestErr +} + +// TestNew tests the New function +func TestNew(t *testing.T) { + var expectedService = Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &mockedTransport, + options: opts, + } + + actualService := New(&mockedStore, &mockedAuth, &mockedTransport, Options{}) + if actualService == nil { + actualService = &Service{} + } + + assert := reflect.DeepEqual(expectedService, *actualService) + + if !assert { + t.Errorf("test failed; assert: %t, expected: %v, received: %v", assert, expectedService, actualService) + } +} + +// TestIssueUserSession tests the IssueUserSession function +func TestIssueUserSession(t *testing.T) { + var w http.ResponseWriter + + var tests = []struct { + input Service + expectedUserSession *user.Session + expectedErr *sessionerrs.Custom + }{ + { + Service{ + store: &erredStore, + auth: &erredAuth, + transport: &erredTransport, + options: opts, + }, + nil, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &erredAuth, + transport: &erredTransport, + options: opts, + }, + nil, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &erredTransport, + options: opts, + }, + userSession, // note: when transport is erred, the session, as well as an error, get returned + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &mockedTransport, + options: opts, + }, + userSession, + nil, + }, + } + + for idx, tt := range tests { + var assertSession bool + var assertErr bool + a, e := tt.input.IssueUserSession(inputUserID, inputJSON, w) + if a == nil { + assertSession = a == tt.expectedUserSession + a = &user.Session{} + } else { + t1 := time.Now().Add(tt.input.options.ExpirationDuration).UTC() + assertSession = tt.expectedUserSession.UserID == a.UserID && tt.expectedUserSession.JSON == a.JSON && + tt.expectedUserSession.ExpiresAt.Sub(t1) < 1*time.Second + } + if e == nil { + assertErr = e == tt.expectedErr + e = &sessionerrs.Custom{} + } else { + assertErr = reflect.DeepEqual(*tt.expectedErr, *e) + } + + if !assertSession || !assertErr { + t.Errorf("test #%d failed; input service: %v, assertSession: %t, assertErr: %t, expectedSession: %v, expectedErr: %v, received session: %v, received err: %v", idx+1, tt.input, assertSession, assertErr, tt.expectedUserSession, tt.expectedErr, *a, *e) + } + } +} + +// TestClearUserSession tests the ClearUserSession function +func TestClearUserSession(t *testing.T) { + var w http.ResponseWriter + + var tests = []struct { + input Service + expectedErr *sessionerrs.Custom + }{ + { + Service{ + store: &erredStore, + auth: &mockedAuth, + transport: &erredTransport, + options: opts, + }, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &erredTransport, + options: opts, + }, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &mockedTransport, + options: opts, + }, + nil, + }, + } + + for idx, tt := range tests { + e := tt.input.ClearUserSession(userSession, w) + assertErr := e == tt.expectedErr + + if !assertErr { + t.Errorf("test #%d failed; input service: %v, assertErr: %t, expectedErr: %v, received err: %v", idx+1, tt.input, assertErr, tt.expectedErr, e) + } + } +} + +// TestGetUserSession tests the GetUserSession function +func TestGetUserSession(t *testing.T) { + r := &http.Request{} + + var tests = []struct { + input Service + expectedSession *user.Session + expectedErr *sessionerrs.Custom + }{ + { + Service{ + store: &erredStore, + auth: &erredAuth, + transport: &erredTransport, + options: opts, + }, + nil, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &erredAuth, + transport: &erredTransport, + options: opts, + }, + nil, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &erredTransport, + options: opts, + }, + nil, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &mockedTransport, + options: opts, + }, + userSession, + nil, + }, + } + + for idx, tt := range tests { + a, e := tt.input.GetUserSession(r) + assertErr := e == tt.expectedErr + assertSession := a == tt.expectedSession + + if !assertSession || !assertErr { + t.Errorf("test #%d failed; input service: %v, assertSession: %t, assertErr: %t, expected session: %v, expectedErr: %v, received session: %v, received err: %v", idx+1, tt.input, assertSession, assertErr, tt.expectedSession, tt.expectedErr, a, e) + } + } +} + +// TestGetUserSession tests the GetUserSession function +func TestExtendUserSession(t *testing.T) { + r := &http.Request{} + var w http.ResponseWriter + + var tests = []struct { + input Service + expectedErr *sessionerrs.Custom + }{ + { + Service{ + store: &erredStore, + auth: &mockedAuth, + transport: &erredTransport, + options: opts, + }, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &erredTransport, + options: opts, + }, + &MockedTestErr, + }, + { + Service{ + store: &mockedStore, + auth: &mockedAuth, + transport: &mockedTransport, + options: opts, + }, + nil, + }, + } + + for idx, tt := range tests { + // let's use a test user session bc we don't want to mess with the one defined above + testUserSession := &user.Session{ + ExpiresAt: time.Now().UTC(), + } + e := tt.input.ExtendUserSession(testUserSession, r, w) + assertErr := e == tt.expectedErr + + newExpiresAt := time.Now().Add(tt.input.options.ExpirationDuration).UTC() + assertExtension := newExpiresAt.Sub(testUserSession.ExpiresAt) < 1*time.Second + + if !assertExtension || !assertErr { + t.Errorf("test #%d failed; input service: %v, assertSession: %t, assertErr: %t, expected expires at: %v, expectedErr: %v, received expires at: %v, received err: %v", idx+1, tt.input, assertExtension, assertErr, newExpiresAt, tt.expectedErr, testUserSession.ExpiresAt, e) + } + } +} diff --git a/service_util_unit_test.go b/service_util_unit_test.go new file mode 100644 index 0000000..f243521 --- /dev/null +++ b/service_util_unit_test.go @@ -0,0 +1,29 @@ +// +build unit + +package sessions + +import ( + "reflect" + "testing" + "time" +) + +// TestSetDefaultOptions tests the setDefaultOptions function +func TestSetDefaultOptions(t *testing.T) { + var tests = []struct { + input Options + expected Options + }{ + {Options{}, Options{ExpirationDuration: DefaultExpirationDuration}}, + {Options{ExpirationDuration: 1 * time.Second}, Options{ExpirationDuration: 1 * time.Second}}, + } + + for idx, tt := range tests { + setDefaultOptions(&tt.input) + assert := reflect.DeepEqual(tt.expected, tt.input) + + if !assert { + t.Errorf("test #%d failed; assert: %t, expected: %v, received: %v", idx+1, assert, tt.expected, tt.input) + } + } +} diff --git a/sessionerrs/sessionerrs.go b/sessionerrs/sessionerrs.go index 688179d..1710f5d 100644 --- a/sessionerrs/sessionerrs.go +++ b/sessionerrs/sessionerrs.go @@ -3,7 +3,7 @@ package sessionerrs // Custom is the error type returned by this session service package. // This custom error is useful for calling funcs to determine which http status code to return to clients on err type Custom struct { - // Code corresponds to an http status code (e.g. 401 Unauthorized) + // Code corresponds to an http status code (e.g. 401 Unauthorized, or 500 Internal Server Error) Code int // Err is the actual error thrown Err error diff --git a/store/service.go b/store/service.go index bc62b0e..2522bc6 100644 --- a/store/service.go +++ b/store/service.go @@ -132,6 +132,14 @@ func (s *Service) FetchValidUserSession(sessionID string) (*user.Session, *sessi Err: errors.New("error retrieving session data from store"), } } + for idx := range reply { + if reply[idx] == nil { + return nil, &sessionerrs.Custom{ + Code: 500, + Err: errors.New("error retrieving session data from store"), + } + } + } if _, err := redis.Scan(reply, &userID, &json, &expiresAtSeconds); err != nil { return nil, &sessionerrs.Custom{ Code: 500, diff --git a/store/service_integration_test.go b/store/service_integration_test.go new file mode 100644 index 0000000..04af4e7 --- /dev/null +++ b/store/service_integration_test.go @@ -0,0 +1,267 @@ +// +build integration + +package store + +import ( + "errors" + "fmt" + "log" + "os" + "reflect" + "testing" + "time" + + "github.com/adam-hanna/sessions/sessionerrs" + "github.com/adam-hanna/sessions/user" + "github.com/garyburd/redigo/redis" +) + +var ( + service *Service + validUserSession = &user.Session{ + ID: "validSessionID", + UserID: "validUserID", + JSON: "validJSON", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + validUserSessionForSaving = &user.Session{ + ID: "validSessionForSavingID", + UserID: "validSessionForSavingID", + JSON: "validSessionForSavingJSON", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + // session is invalid because it doesn't have json + inValidUserSession = &user.Session{ + ID: "invalidSessionID", + UserID: "invalidUserID", + // JSON: "invalidJSON", + ExpiresAt: time.Now().Add(1 * time.Hour), + } + expiredUserSession = &user.Session{ + ID: "expiredSessionID", + UserID: "expiredUserID", + JSON: "expiredJSON", + ExpiresAt: time.Now().Add(-100 * time.Hour), + } +) + +func TestMain(m *testing.M) { + err := setup() + if err != nil { + log.Fatal("Err setting up integration tests") + } + + code := m.Run() + + err = shutdown() + if err != nil { + log.Fatal("Err shutting down integration tests") + } + + os.Exit(code) +} + +func setup() error { + fmt.Println("setting up redis integration tests") + + options := Options{ + ConnectionAddress: os.Getenv("REDIS_URL"), + } + service = New(options) + c := service.Pool.Get() + defer c.Close() + + // VALID USER + _, err := c.Do("HMSET", validUserSession.ID, "UserID", validUserSession.UserID, "JSON", validUserSession.JSON, "ExpiresAtSeconds", validUserSession.ExpiresAt.Unix()) + if err != nil { + return errors.New("Could not set valid user") + } + _, err = c.Do("EXPIREAT", validUserSession.ID, validUserSession.ExpiresAt.Unix()) + if err != nil { + return errors.New("Could not set expiry for valid user") + } + + // INVALID USER + // note: the invalid user doesn't have JSON! + _, err = c.Do("HMSET", inValidUserSession.ID, "UserID", inValidUserSession.UserID, "ExpiresAtSeconds", inValidUserSession.ExpiresAt.Unix()) + if err != nil { + return errors.New("Could not set valid user") + } + _, err = c.Do("EXPIREAT", inValidUserSession.ID, inValidUserSession.ExpiresAt.Unix()) + if err != nil { + return errors.New("Could not set expiry for valid user") + } + + // EXPIRED USER + _, err = c.Do("HMSET", expiredUserSession.ID, "UserID", expiredUserSession.UserID, "JSON", expiredUserSession.JSON, "ExpiresAtSeconds", expiredUserSession.ExpiresAt.Unix()) + if err != nil { + return errors.New("Could not set valid user") + } + _, err = c.Do("EXPIREAT", expiredUserSession.ID, expiredUserSession.ExpiresAt.Unix()) + if err != nil { + return errors.New("Could not set expiry for valid user") + } + + return nil +} + +func shutdown() error { + fmt.Println("shutting down redis integration tests") + + c := service.Pool.Get() + defer c.Close() + + aLongTimeAgo := time.Now().Add(-1000 * time.Hour) + + // VALID USER + _, err := c.Do("EXPIREAT", validUserSession.ID, aLongTimeAgo.Unix()) + if err != nil { + return errors.New("Could not set EXPIREAT for validUserSession") + } + + // VALID USER + _, err = c.Do("EXPIREAT", validUserSessionForSaving.ID, aLongTimeAgo.Unix()) + if err != nil { + return errors.New("Could not set EXPIREAT for validUserSessionForSaving") + } + + // INVALID USER + _, err = c.Do("EXPIREAT", inValidUserSession.ID, aLongTimeAgo.Unix()) + if err != nil { + return errors.New("Could not set EXPIREAT for invaludUserSession") + } + + // EXPIRED USER + _, err = c.Do("EXPIREAT", expiredUserSession.ID, aLongTimeAgo.Unix()) + if err != nil { + return errors.New("Could not set EXPIREAT for expiredUserSession") + } + + return nil +} + +// TestSaveUserSession tests the SaveUserSession function +func TestSaveUserSession(t *testing.T) { + if testing.Short() { + t.Skip("skipping TestSaveUserSession, an integration test") + } + + tests := []struct { + input *user.Session + expectedErr *sessionerrs.Custom + expectToExist bool + }{ + {validUserSessionForSaving, nil, true}, + {expiredUserSession, nil, false}, + } + + c := service.Pool.Get() + defer c.Close() + + for idx, tt := range tests { + var assertErr bool + var assertExist bool + + e := service.SaveUserSession(tt.input) + if e == nil { + assertErr = e == tt.expectedErr + } else { + reflect.DeepEqual(*e, *tt.expectedErr) + } + + exists, err := redis.Bool(c.Do("EXISTS", tt.input.ID)) + if err != nil { + t.Errorf("Err in test #%d; cannot check if sessionID: %s exists\n", idx, tt.input.ID) + } + + assertExist = exists == tt.expectToExist + + if !assertErr || !assertExist { + t.Errorf("test #%d failed; assert err: %t, assert exists: %t, received err: %v, received exists: %t, expected err: %v, expected exists: %t, input: %v\n", idx+1, assertErr, assertExist, e, exists, tt.expectedErr, tt.expectToExist, tt.input) + } + } +} + +// TestFetchValidUserSession tests the FetchValidUserSession function +func TestFetchValidUserSession(t *testing.T) { + if testing.Short() { + t.Skip("skipping TestFetchValidUserSession, an integration test") + } + + tests := []struct { + input string + expectedUserSession *user.Session + expectedErr *sessionerrs.Custom + }{ + {validUserSession.ID, validUserSession, nil}, + {validUserSessionForSaving.ID, validUserSessionForSaving, nil}, + {expiredUserSession.ID, nil, &sessionerrs.Custom{Code: 401, Err: errors.New("session is expired or sessionID doesn't exist")}}, + {inValidUserSession.ID, nil, &sessionerrs.Custom{Code: 500, Err: errors.New("error retrieving session data from store")}}, + } + + for idx, tt := range tests { + var assertErr bool + var assertUserSession bool + + a, e := service.FetchValidUserSession(tt.input) + if e == nil { + assertErr = e == tt.expectedErr + } else { + if tt.expectedErr != nil { + assertErr = reflect.DeepEqual(*e, *tt.expectedErr) + } + } + if a == nil { + assertUserSession = a == tt.expectedUserSession + } else { + if tt.expectedUserSession != nil { + // note: we can't use deep equal here bc the expiry time might be off by a second or so + assertUserSession = a.ID == tt.expectedUserSession.ID && a.UserID == tt.expectedUserSession.UserID && + a.JSON == tt.expectedUserSession.JSON && a.ExpiresAt.Sub(tt.expectedUserSession.ExpiresAt) < 1 + } + } + + if !assertErr || !assertUserSession { + t.Errorf("test #%d failed; assert err: %t, assert user session: %t, received err: %v, received user session: %v, expected err: %v, expected user session: %v, input: %v\n", idx+1, assertErr, assertUserSession, e, a, tt.expectedErr, tt.expectedUserSession, tt.input) + } + } +} + +// TestDeleteUserSession tests the DeleteUserSession function +func TestDeleteUserSession(t *testing.T) { + if testing.Short() { + t.Skip("skipping TestDeleteUserSession, an integration test") + } + + tests := []struct { + input string + expectToExist bool + }{ + {validUserSession.ID, false}, + {validUserSessionForSaving.ID, false}, + {expiredUserSession.ID, false}, + {inValidUserSession.ID, false}, + } + + c := service.Pool.Get() + defer c.Close() + + for idx, tt := range tests { + var assertExist bool + + e := service.DeleteUserSession(tt.input) + if e != nil { + t.Errorf("Err in test #%d when deleting user session, expected err to be nil, received err: %v, input: %s\n", idx, e, tt.input) + } + + exists, err := redis.Bool(c.Do("EXISTS", tt.input)) + if err != nil { + t.Errorf("Err in test #%d when checking exists, expected err to be nil, received err: %v, input: %s\n", idx, err, tt.input) + } + + assertExist = exists == tt.expectToExist + if err != nil { + t.Errorf("Err in test #%d; assertExist: %t, expected exists: %t, received exists: %t, input: %s\n", idx, assertExist, tt.expectToExist, exists, tt.input) + } + } +} diff --git a/store/service_unit_test.go b/store/service_unit_test.go new file mode 100644 index 0000000..e19ed3d --- /dev/null +++ b/store/service_unit_test.go @@ -0,0 +1,58 @@ +// +build unit + +package store + +import ( + "testing" + "time" + + "github.com/garyburd/redigo/redis" +) + +var testOptions = Options{ConnectionAddress: "test", MaxIdleConnections: 5, MaxActiveConnections: 5, IdleTimeoutDuration: 1 * time.Second} + +// TestNew tests the New function +func TestNew(t *testing.T) { + var tests = []struct { + input Options + expected Service + }{ + { + Options{}, + Service{ + Pool: &redis.Pool{ + MaxActive: 0, + MaxIdle: DefaultMaxIdleConnections, + IdleTimeout: DefaultIdleTimeoutDuration, + Dial: func() (redis.Conn, error) { return redis.Dial("tcp", DefaultConnectionAddress) }, + }, + }, + }, + { + testOptions, + Service{ + Pool: &redis.Pool{ + MaxActive: testOptions.MaxActiveConnections, + MaxIdle: testOptions.MaxIdleConnections, + IdleTimeout: testOptions.IdleTimeoutDuration, + Dial: func() (redis.Conn, error) { return redis.Dial("tcp", testOptions.ConnectionAddress) }, + }, + }, + }, + } + + for idx, tt := range tests { + s := New(tt.input) + if s == nil { + s = &Service{} + } + + // note: how to check for connection address? + assert := tt.expected.Pool.MaxActive == s.Pool.MaxActive && tt.expected.Pool.MaxIdle == s.Pool.MaxIdle && + tt.expected.Pool.IdleTimeout == s.Pool.IdleTimeout + + if !assert { + t.Errorf("test #%d failed; assert: %t, input: %v, expected: %v, received: %v", idx+1, assert, tt.input, tt.expected.Pool, *s.Pool) + } + } +} diff --git a/store/service_util.go b/store/service_util.go index 0811307..09d2cbd 100644 --- a/store/service_util.go +++ b/store/service_util.go @@ -9,9 +9,9 @@ func setDefaultOptions(options *Options) { options.ConnectionAddress = DefaultConnectionAddress } // note @adam-hanna: what if someone sends in a value of 0? This will set it to default! - // if options.MaxIdleConnections == emptyOptions.MaxIdleConnections { - // options.MaxIdleConnections = DefaultMaxIdleConnections - // } + if options.MaxIdleConnections == emptyOptions.MaxIdleConnections { + options.MaxIdleConnections = DefaultMaxIdleConnections + } // note @adam-hanna: what if someone sends in a value of 0? This will set it to default! // if options.MaxActiveConnections == emptyOptions.MaxActiveConnections { // options.MaxActiveConnections = DefaultMaxActiveConnections diff --git a/store/service_util_unit_test.go b/store/service_util_unit_test.go new file mode 100644 index 0000000..6e299ab --- /dev/null +++ b/store/service_util_unit_test.go @@ -0,0 +1,29 @@ +// +build unit + +package store + +import ( + "reflect" + "testing" + "time" +) + +// TestSetDefaultOptions tests the setDefaultOptions function +func TestSetDefaultOptions(t *testing.T) { + var tests = []struct { + input Options + expected Options + }{ + {Options{}, Options{MaxIdleConnections: DefaultMaxIdleConnections, ConnectionAddress: DefaultConnectionAddress, IdleTimeoutDuration: DefaultIdleTimeoutDuration}}, + {Options{ConnectionAddress: "test", MaxIdleConnections: 5, MaxActiveConnections: 5, IdleTimeoutDuration: 1 * time.Second}, Options{ConnectionAddress: "test", MaxIdleConnections: 5, MaxActiveConnections: 5, IdleTimeoutDuration: 1 * time.Second}}, + } + + for idx, tt := range tests { + setDefaultOptions(&tt.input) + assert := reflect.DeepEqual(tt.expected, tt.input) + + if !assert { + t.Errorf("test #%d failed; assert: %t, expected: %v, received: %v\n", idx+1, assert, tt.expected, tt.input) + } + } +} diff --git a/transport/service.go b/transport/service.go index 014d79d..25d3757 100644 --- a/transport/service.go +++ b/transport/service.go @@ -72,8 +72,8 @@ func (s *Service) DeleteSessionFromResponse(w http.ResponseWriter) *sessionerrs. return nil } -// FetchSessionFromRequest retrieves a signed session id from a request -func (s *Service) FetchSessionFromRequest(r *http.Request) (string, *sessionerrs.Custom) { +// FetchSessionIDFromRequest retrieves a signed session id from a request +func (s *Service) FetchSessionIDFromRequest(r *http.Request) (string, *sessionerrs.Custom) { sessionCookie, err := r.Cookie(s.options.CookieName) if err != nil { if err == http.ErrNoCookie { diff --git a/transport/service_interface.go b/transport/service_interface.go index 93d0ceb..04e4436 100644 --- a/transport/service_interface.go +++ b/transport/service_interface.go @@ -11,5 +11,5 @@ import ( type ServiceInterface interface { SetSessionOnResponse(session string, userSession *user.Session, w http.ResponseWriter) *sessionerrs.Custom DeleteSessionFromResponse(w http.ResponseWriter) *sessionerrs.Custom - FetchSessionFromRequest(r *http.Request) (string, *sessionerrs.Custom) + FetchSessionIDFromRequest(r *http.Request) (string, *sessionerrs.Custom) } diff --git a/transport/service_unit_test.go b/transport/service_unit_test.go new file mode 100644 index 0000000..b847590 --- /dev/null +++ b/transport/service_unit_test.go @@ -0,0 +1,161 @@ +// +build unit + +package transport + +import ( + "errors" + "net/http" + "reflect" + "testing" + "time" + + "github.com/adam-hanna/sessions/sessionerrs" + "github.com/adam-hanna/sessions/user" +) + +var ( + testOptions = Options{CookieName: "test", CookiePath: "/", HTTPOnly: true, Secure: true} + testService = Service{options: testOptions} +) + +// Thanks! +// https://gist.github.com/karlseguin/5128461 +type FakeResponse struct { + headers http.Header + body []byte + status int +} + +func (f FakeResponse) Write(body []byte) (int, error) { + f.body = body + return len(body), nil +} + +func (f FakeResponse) WriteHeader(status int) { + f.status = status +} + +func (f FakeResponse) Header() http.Header { + return f.headers +} + +// TestNew tests the New function +func TestNew(t *testing.T) { + var tests = []struct { + input Options + expected Service + }{ + {Options{}, Service{options: Options{CookieName: DefaultCookieName, CookiePath: DefaultCookiePath}}}, + {testOptions, Service{options: testOptions}}, + } + + for idx, tt := range tests { + s := New(tt.input) + if s == nil { + s = &Service{} + } + assert := reflect.DeepEqual(tt.expected, *s) + + if !assert { + t.Errorf("test #%d failed; assert: %t, input: %v, expected: %v, received: %v", idx+1, assert, tt.input, tt.expected, *s) + } + } +} + +// TestSetSessionOnResponse tests the SetSessionOnResponse function +func TestSetSessionOnResponse(t *testing.T) { + u := user.New("testID", "", 1*time.Second) + m := make(map[string][]string, 1) + b := make([]byte, 1) + f := FakeResponse{m, b, 0} + var tests = []struct { + signedSessionID string + userSession *user.Session + w FakeResponse + }{ + {"testSignedSessionID", u, f}, + } + + for idx, tt := range tests { + expectedW := FakeResponse{m, b, 0} + sessionCookie := http.Cookie{ + Name: testService.options.CookieName, + Value: tt.signedSessionID, + Expires: tt.userSession.ExpiresAt, + Path: testService.options.CookiePath, + HttpOnly: testService.options.HTTPOnly, + Secure: testService.options.Secure, + } + http.SetCookie(expectedW, &sessionCookie) + + _ = testService.SetSessionOnResponse(tt.signedSessionID, tt.userSession, tt.w) + + assert := reflect.DeepEqual(tt.w, expectedW) + if !assert { + t.Errorf("test #%d failed; assert: %t, w: %v, expected: %v", idx+1, assert, tt.w, expectedW) + } + } +} + +// TestDeleteSessionFromResponse tests the DeleteSessionFromResponse function +func TestDeleteSessionFromResponse(t *testing.T) { + m := make(map[string][]string, 1) + b := make([]byte, 1) + f := FakeResponse{m, b, 0} + var tests = []struct { + w FakeResponse + }{ + {f}, + } + + for idx, tt := range tests { + expectedW := FakeResponse{m, b, 0} + aLongTimeAgo := time.Now().Add(-1000 * time.Hour) + nullSessionCookie := http.Cookie{ + Name: testService.options.CookieName, + Value: "", + Expires: aLongTimeAgo, + Path: testService.options.CookiePath, + HttpOnly: testService.options.HTTPOnly, + Secure: testService.options.Secure, + } + http.SetCookie(expectedW, &nullSessionCookie) + + _ = testService.DeleteSessionFromResponse(tt.w) + + assert := reflect.DeepEqual(tt.w, expectedW) + if !assert { + t.Errorf("test #%d failed; assert: %t, w: %v, expected: %v", idx+1, assert, tt.w, expectedW) + } + } +} + +// TestFetchSessionIDFromRequest tests the FetchSessionIDFromRequest function +func TestFetchSessionIDFromRequest(t *testing.T) { + var tests = []struct { + input http.Cookie + expectedString string + expectedErr sessionerrs.Custom + }{ + {http.Cookie{Name: testService.options.CookieName, Value: "testValue"}, "testValue", sessionerrs.Custom{}}, + {http.Cookie{Name: "badName"}, "", sessionerrs.Custom{Code: 401, Err: errors.New("no session on request")}}, + } + + for idx, tt := range tests { + m := make(map[string][]string, 1) + r := http.Request{Header: m} + r.AddCookie(&tt.input) + + s, e := testService.FetchSessionIDFromRequest(&r) + assert := true + if e != nil { + assert = tt.expectedErr.Err.Error() == e.Err.Error() + } else { + e = &sessionerrs.Custom{} + } + + if !assert || tt.expectedString != s { + t.Errorf("test #%d failed; assert: %t, expectedErr: %v, err: %v, expectedString: %s, string: %s", idx+1, assert, tt.expectedErr, e, tt.expectedString, s) + } + } +} diff --git a/transport/service_util_unit_test.go b/transport/service_util_unit_test.go new file mode 100644 index 0000000..024209a --- /dev/null +++ b/transport/service_util_unit_test.go @@ -0,0 +1,28 @@ +// +build unit + +package transport + +import ( + "reflect" + "testing" +) + +// TestSetDefaultOptions tests the setDefaultOptions function +func TestSetDefaultOptions(t *testing.T) { + var tests = []struct { + input Options + expected Options + }{ + {Options{}, Options{CookieName: DefaultCookieName, CookiePath: DefaultCookiePath}}, + {Options{CookieName: "test", CookiePath: "/", HTTPOnly: true, Secure: true}, Options{CookieName: "test", CookiePath: "/", HTTPOnly: true, Secure: true}}, + } + + for idx, tt := range tests { + setDefaultOptions(&tt.input) + assert := reflect.DeepEqual(tt.expected, tt.input) + + if !assert { + t.Errorf("test #%d failed; assert: %t, expected: %v, received: %v", idx+1, assert, tt.expected, tt.input) + } + } +} diff --git a/user/user_unit_test.go b/user/user_unit_test.go new file mode 100644 index 0000000..47281de --- /dev/null +++ b/user/user_unit_test.go @@ -0,0 +1,55 @@ +// +build unit + +package user + +import ( + "strings" + "testing" + "time" +) + +// TestNew tests the New func +func TestNew(t *testing.T) { + var tests = []struct { + inputUserID string + inputJSON string + inputDuration time.Duration + expected Session + }{ + { + "testID", + "testJSON", + 10 * time.Second, + Session{ + ID: "", // note: tested elsewhere + UserID: "testID", + ExpiresAt: time.Now().Add(10 * time.Second).UTC(), + JSON: "testJSON", + }, + }, + } + + for idx, tt := range tests { + a := New(tt.inputUserID, tt.inputJSON, tt.inputDuration) + if a == nil { + a = &Session{} + } + + if a.UserID != tt.inputUserID || a.JSON != tt.inputJSON || !testSessionID(a.ID) || !testExpiresAt(tt.inputDuration, a.ExpiresAt) { + t.Errorf("test #%d failed; inputUserID: %s, inputJSON: %s, inputDuration: %v, expected session: %v, received session: %v", idx+1, tt.inputUserID, tt.inputJSON, tt.inputDuration, tt.expected, *a) + } + } +} + +func testExpiresAt(inputDuration time.Duration, actualExpiresAt time.Time) bool { + t1 := time.Now().Add(inputDuration).UTC() + return actualExpiresAt.Sub(t1) < 1*time.Second +} + +func testSessionID(sessionID string) bool { + // eg 5f4cd331-c869-4871-bb41-76b726df9937 + parts := strings.Split(sessionID, "-") + return len([]byte(sessionID)) == 36 && len([]rune(sessionID)) == 36 && strings.Count(sessionID, "-") == 4 && + len(parts) == 5 && len([]rune(parts[0])) == 8 && len([]rune(parts[1])) == 4 && len([]rune(parts[2])) == 4 && + len([]rune(parts[3])) == 4 && len([]rune(parts[4])) == 12 +}