From 9f1216c4f33460cfcecd819348dc3e4d0c667ee0 Mon Sep 17 00:00:00 2001 From: Mikolaj Gasior Date: Wed, 1 Jan 2025 23:28:26 +0100 Subject: [PATCH] Return `error` not `*ErrUmbrella` + add a sample app (#3) * Add example application --- .gitignore | 1 + Makefile | 23 ++++++++++++++ README.md | 59 ++++++++++++------------------------ cmd/example1/main.go | 72 ++++++++++++++++++++++++++++++++++++++++++++ err.go | 4 +-- internal.go | 60 ++++++++++++++++++------------------ umbrella.go | 28 ++++++++--------- version.go | 2 +- 8 files changed, 163 insertions(+), 86 deletions(-) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 cmd/example1/main.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..afb0a4d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/cmd/example1/example1 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6b9dbb3 --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +.DEFAULT_GOAL := help + +.PHONY: help test + +test: ## Runs tests + go test + +run-example1: clean ## Runs sample app + @echo "* Creating docker container with PostgreSQL" + docker run --name umbrella-example1 -d -e POSTGRES_PASSWORD=upass -e POSTGRES_USER=uuser -e POSTGRES_DB=udb -p 54321:5432 postgres:13 + @echo "* Sleeping for 10 seconds to give database time to initialize..." + @sleep 10 + @echo "* Building and starting application..." + cd cmd/example1 && go build . + cd cmd/example1 && ./example1 + @echo "* Removing previously created docker container..." + + +clean: ## Removes all created dockers + docker rm -f umbrella-example1 + +help: ## Displays this help + @awk 'BEGIN {FS = ":.*##"; printf "$(MAKEFILE_NAME)\n\nUsage:\n make \033[1;36m\033[0m\n\nTargets:\n"} /^[a-zA-Z0-9_-]+:.*?##/ { printf " \033[1;36m%-25s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST) diff --git a/README.md b/README.md index 780edb7..096f79a 100644 --- a/README.md +++ b/README.md @@ -15,45 +15,26 @@ Package umbrella provides a simple authentication mechanism for an HTTP endpoint 5. [Motivation](#motivation) ## Sample code -The following code snippet shows how the module can be used. - -```go -// database connection -dbConn, _ = sql.Open("postgres", "host=localhost user=myuser password=mypass port=5432 dbname=mydb sslmode=disable") - -// umbrella controller -u := NewUmbrella(dbConn, "tblprefix_", &JWTConfig{ - Key: "SomeSecretKey--.", - Issuer: "SomeIssuer", - ExpirationMinutes: 15, -}, nil) - -// create db tables -_ := u.CreateDBTables() - -// http server -// uri with registration, activation, login (returns auth token), logout endpoint -http.Handle("/umbrella/", u.GetHTTPHandler("/umbrella/")) -// restricted stuff that requires signing in (a token in http header) -http.Handle("/restricted_stuff/", u.GetHTTPHandlerWrapper( - getRestrictedStuffHTTPHandler(), - umbrella.HandlerConfig{}, -)) -http.ListenAndServe(":8001", nil) - -// wrap http handler with a check for logged user -func getRestrictedStuffHTTPHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - userID := umbrella.GetUserIDFromRequest(r) - if userID != 0 { - w.WriteHeader(http.StatusOK) - w.Write([]byte("RestrictedAreaContent")) - } else { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("NoAccess")) - } - }) -} +A working application can be found in the `cmd/example1`. Type `make run-example1` to start an HTTP server and check the endpoints as shown below. jq is used to parse out the token from JSON output, however, it can be done manually as well. + +```bash +# run the application +% make run-example1 +# ...some output... + +# sign in to get a token +% UMB_TOKEN=$(curl -s -X POST -d "email=admin@example.com&password=admin" http://localhost:8001/umbrella/login | jq -r '.data.token') + +# call restricted endpoint without the token +% curl http://localhost:8001/secret_stuff/ +YouHaveToBeLoggedIn + +# call it again with token +% curl -H "Authorization: Bearer $UMB_TOKEN" http://localhost:8001/secret_stuff/ +SecretStuffOnlyForAdmin% + +# remove temporary postgresql docker +make clean ``` ## Database connection diff --git a/cmd/example1/main.go b/cmd/example1/main.go new file mode 100644 index 0000000..55ef97d --- /dev/null +++ b/cmd/example1/main.go @@ -0,0 +1,72 @@ +package main + +import ( + "database/sql" + "log" + "net/http" + + "github.com/go-phings/umbrella" + _ "github.com/lib/pq" +) + +const dbDSN = "host=localhost user=uuser password=upass port=54321 dbname=udb sslmode=disable" +const tblPrefix = "p_" + +func main() { + db, err := sql.Open("postgres", dbDSN) + if err != nil { + log.Fatal("Error connecting to db") + } + + // create umbrella controller + u := *umbrella.NewUmbrella(db, tblPrefix, &umbrella.JWTConfig{ + Key: "someSecretKey", + Issuer: "someIssuer", + ExpirationMinutes: 15, + }, &umbrella.UmbrellaConfig{ + TagName: "ui", + }) + + // create database tables + err = u.CreateDBTables() + if err != nil { + log.Fatalf("error creating database tables: %s", err.Error()) + } + + // create admin user + key, err := u.CreateUser("admin@example.com", "admin", map[string]string{ + "Name": "admin", + }) + if err != nil { + log.Fatalf("error with creating admin: %s", err.Error()) + } + err = u.ConfirmEmail(key) + if err != nil { + log.Fatalf("error with confirming admin email: %s", err.Error()) + } + + // /umbrella/{login,logout,register,confirm} + http.Handle("/umbrella/", u.GetHTTPHandler("/umbrella/")) + + // secret stuff + http.Handle("/secret_stuff/", u.GetHTTPHandlerWrapper(secretStuff(), umbrella.HandlerConfig{})) + + log.Fatal(http.ListenAndServe(":8001", nil)) +} + +func secretStuff() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userID := umbrella.GetUserIDFromRequest(r) + switch userID { + case 1: + w.WriteHeader(http.StatusOK) + w.Write([]byte("SecretStuffOnlyForAdmin")) + case 0: + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("YouHaveToBeLoggedIn")) + default: + w.WriteHeader(http.StatusOK) + w.Write([]byte("SecretStuffForOtherUser")) + } + }) +} diff --git a/err.go b/err.go index 7c6df01..62d9aeb 100644 --- a/err.go +++ b/err.go @@ -7,10 +7,10 @@ type ErrUmbrella struct { Err error } -func (e *ErrUmbrella) Error() string { +func (e ErrUmbrella) Error() string { return e.Err.Error() } -func (e *ErrUmbrella) Unwrap() error { +func (e ErrUmbrella) Unwrap() error { return e.Err } diff --git a/internal.go b/internal.go index fc40303..df353be 100644 --- a/internal.go +++ b/internal.go @@ -85,7 +85,7 @@ func (u Umbrella) handleConfirm(w http.ResponseWriter, r *http.Request) { err2 := u.ConfirmEmail(key) if err2 != nil { - var errUmb *ErrUmbrella + var errUmb ErrUmbrella if errors.As(err2, &errUmb) { if errUmb.Op == "NoRow" || errUmb.Op == "UserInactive" { u.writeErrText(w, http.StatusNotFound, "invalid_key") @@ -146,7 +146,7 @@ func (u Umbrella) handleLogin(w http.ResponseWriter, r *http.Request, setCookie token, expiresAt, err := u.login(email, password) if err != nil { - var errUmb *ErrUmbrella + var errUmb ErrUmbrella if errors.As(err, &errUmb) { if errUmb.Op == "NoRow" || errUmb.Op == "UserInactive" || errUmb.Op == "InvalidPassword" { if failureURI != "" { @@ -216,7 +216,7 @@ func (u Umbrella) handleCheck(w http.ResponseWriter, r *http.Request) { token2, expiresAt, _, err := u.check(token, refresh) if err != nil { - var errUmb *ErrUmbrella + var errUmb ErrUmbrella if errors.As(err, &errUmb) { if errUmb.Op == "InvalidToken" || errUmb.Op == "UserInactive" || errUmb.Op == "Expired" || errUmb.Op == "InvalidSession" || errUmb.Op == "InvalidUser" || errUmb.Op == "ParseToken" { u.writeErrText(w, http.StatusNotFound, "invalid_credentials") @@ -297,7 +297,7 @@ func (u Umbrella) handleLogout(w http.ResponseWriter, r *http.Request, useCookie err := u.logout(token) if err != nil { - var errUmb *ErrUmbrella + var errUmb ErrUmbrella if errors.As(err, &errUmb) { if errUmb.Op == "InvalidToken" || errUmb.Op == "Expired" || errUmb.Op == "ParseToken" || errUmb.Op == "InvalidSession" { if successURI != "" { @@ -369,26 +369,26 @@ func (u Umbrella) writeOK(w http.ResponseWriter, status int, data map[string]int } } -func (u Umbrella) login(email string, password string) (string, int64, *ErrUmbrella) { +func (u Umbrella) login(email string, password string) (string, int64, error) { user := u.Interfaces.User() got, err := user.GetByEmail(email) if !got { if err == nil { - return "", 0, &ErrUmbrella{ + return "", 0, ErrUmbrella{ Op: "NoRow", Err: nil, } } if err != nil { - return "", 0, &ErrUmbrella{ + return "", 0, ErrUmbrella{ Op: "GetFromDB", Err: err, } } } if user.GetFlags()&FlagUserActive == 0 || user.GetFlags()&FlagUserAllowLogin == 0 { - return "", 0, &ErrUmbrella{ + return "", 0, ErrUmbrella{ Op: "UserInactive", Err: err, } @@ -396,14 +396,14 @@ func (u Umbrella) login(email string, password string) (string, int64, *ErrUmbre passwordInDBDecoded, err := base64.StdEncoding.DecodeString(user.GetPassword()) if err != nil { - return "", 0, &ErrUmbrella{ + return "", 0, ErrUmbrella{ Op: "InvalidPassword", Err: err, } } err = bcrypt.CompareHashAndPassword(passwordInDBDecoded, []byte(password)) if err != nil { - return "", 0, &ErrUmbrella{ + return "", 0, ErrUmbrella{ Op: "InvalidPassword", Err: err, } @@ -412,7 +412,7 @@ func (u Umbrella) login(email string, password string) (string, int64, *ErrUmbre sUUID := uuid.New().String() token, expiresAt, err := u.createToken(sUUID) if err != nil { - return "", 0, &ErrUmbrella{ + return "", 0, ErrUmbrella{ Op: "CreateToken", Err: err, } @@ -429,7 +429,7 @@ func (u Umbrella) login(email string, password string) (string, int64, *ErrUmbre errCrud := u.orm.Save(sess) if errCrud != nil { - return "", 0, &ErrUmbrella{ + return "", 0, ErrUmbrella{ Op: "SaveToDB", Err: errCrud, } @@ -451,7 +451,7 @@ func (u *Umbrella) getSession(key string) (*Session, error) { return sessions[0].(*Session), nil } -func (u Umbrella) logout(token string) *ErrUmbrella { +func (u Umbrella) logout(token string) error { sID, errUmbrella := u.parseTokenWithCheck(token) if errUmbrella != nil { return errUmbrella @@ -460,13 +460,13 @@ func (u Umbrella) logout(token string) *ErrUmbrella { session, err := u.getSession(sID) if session == nil { if err == nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "NoRow", Err: nil, } } if err != nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "GetFromDB", Err: err, } @@ -474,7 +474,7 @@ func (u Umbrella) logout(token string) *ErrUmbrella { } if session.Flags&FlagSessionActive == 0 || session.Flags&FlagSessionLoggedOut > 0 { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "InvalidSession", Err: err, } @@ -486,7 +486,7 @@ func (u Umbrella) logout(token string) *ErrUmbrella { } errCrud := u.orm.Save(session) if errCrud != nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "SaveToDB", Err: errCrud, } @@ -495,7 +495,7 @@ func (u Umbrella) logout(token string) *ErrUmbrella { return nil } -func (u Umbrella) check(token string, refresh bool) (string, int64, int64, *ErrUmbrella) { +func (u Umbrella) check(token string, refresh bool) (string, int64, int64, error) { sID, errUmbrella := u.parseTokenWithCheck(token) if errUmbrella != nil { return "", 0, 0, errUmbrella @@ -504,13 +504,13 @@ func (u Umbrella) check(token string, refresh bool) (string, int64, int64, *ErrU session, err := u.getSession(sID) if session == nil { if err == nil { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "InvalidSession", Err: nil, } } if err != nil { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "GetFromDB", Err: err, } @@ -518,7 +518,7 @@ func (u Umbrella) check(token string, refresh bool) (string, int64, int64, *ErrU } if session.Flags&FlagSessionActive == 0 || session.Flags&FlagSessionLoggedOut > 0 { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "InvalidSession", Err: err, } @@ -528,20 +528,20 @@ func (u Umbrella) check(token string, refresh bool) (string, int64, int64, *ErrU got, err := user.GetByID(session.UserID) if !got { if err == nil { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "InvalidUser", Err: err, } } if err != nil { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "GetFromDB", Err: err, } } } if user.GetFlags()&FlagUserActive == 0 || user.GetFlags()&FlagUserAllowLogin == 0 { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "UserInactive", Err: nil, } @@ -550,7 +550,7 @@ func (u Umbrella) check(token string, refresh bool) (string, int64, int64, *ErrU if refresh { token2, expiresAt, err := u.createToken(sID) if err != nil { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "CreateToken", Err: err, } @@ -559,7 +559,7 @@ func (u Umbrella) check(token string, refresh bool) (string, int64, int64, *ErrU session.ExpiresAt = expiresAt errCrud := u.orm.Save(session) if errCrud != nil { - return "", 0, 0, &ErrUmbrella{ + return "", 0, 0, ErrUmbrella{ Op: "SaveToDB", Err: errCrud, } @@ -570,24 +570,24 @@ func (u Umbrella) check(token string, refresh bool) (string, int64, int64, *ErrU return token, 0, session.UserID, nil } -func (u Umbrella) parseTokenWithCheck(token string) (string, *ErrUmbrella) { +func (u Umbrella) parseTokenWithCheck(token string) (string, error) { sID, expired, err := u.parseToken(token) if err != nil { - return "", &ErrUmbrella{ + return "", ErrUmbrella{ Op: "ParseToken", Err: err, } } if expired { - return "", &ErrUmbrella{ + return "", ErrUmbrella{ Op: "Expired", Err: err, } } if !u.isValidSessionID(sID) { - return "", &ErrUmbrella{ + return "", ErrUmbrella{ Op: "InvalidSession", Err: err, } diff --git a/umbrella.go b/umbrella.go index 00b73c4..df92179 100644 --- a/umbrella.go +++ b/umbrella.go @@ -119,11 +119,11 @@ func NewUmbrella(dbConn *sql.DB, tblPrefix string, jwtConfig *JWTConfig, cfg *Um return u } -func (u Umbrella) CreateDBTables() *ErrUmbrella { +func (u Umbrella) CreateDBTables() error { user := u.Interfaces.User() err := user.CreateDBTable() if err != nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "CreateDBTables", Err: err, } @@ -131,7 +131,7 @@ func (u Umbrella) CreateDBTables() *ErrUmbrella { err2 := u.orm.CreateTables(&Session{}) if err2 != nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "CreateDBTables", Err: err2, } @@ -139,7 +139,7 @@ func (u Umbrella) CreateDBTables() *ErrUmbrella { err2 = u.orm.CreateTables(&Permission{}) if err2 != nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "CreateDBTables", Err: err2, } @@ -250,10 +250,10 @@ func (u Umbrella) getURIFromRequest(r *http.Request, uri string) string { return xs[0] } -func (u Umbrella) GeneratePassword(pass string) (string, *ErrUmbrella) { +func (u Umbrella) GeneratePassword(pass string) (string, error) { passEncrypted, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.DefaultCost) if err != nil { - return "", &ErrUmbrella{ + return "", ErrUmbrella{ Op: "GeneratePassword", Err: err, } @@ -262,7 +262,7 @@ func (u Umbrella) GeneratePassword(pass string) (string, *ErrUmbrella) { return base64.StdEncoding.EncodeToString(passEncrypted), nil } -func (u Umbrella) CreateUser(email string, pass string, extraFields map[string]string) (string, *ErrUmbrella) { +func (u Umbrella) CreateUser(email string, pass string, extraFields map[string]string) (string, error) { pass, err := u.GeneratePassword(pass) if err != nil { return "", err @@ -290,7 +290,7 @@ func (u Umbrella) CreateUser(email string, pass string, extraFields map[string]s err2 := user.Save() if err2 != nil { - return "", &ErrUmbrella{ + return "", ErrUmbrella{ Op: "SaveToDB", Err: err2, } @@ -299,26 +299,25 @@ func (u Umbrella) CreateUser(email string, pass string, extraFields map[string]s return key, nil } -func (u Umbrella) ConfirmEmail(key string) *ErrUmbrella { +func (u Umbrella) ConfirmEmail(key string) error { user := u.Interfaces.User() got, err := user.GetByEmailActivationKey(key) - if !got { if err == nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "NoRow", Err: nil, } } if err != nil { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "GetFromDB", Err: err, } } } if user.GetFlags()&FlagUserActive == 0 { - return &ErrUmbrella{ + return ErrUmbrella{ Op: "UserInactive", Err: err, } @@ -328,7 +327,8 @@ func (u Umbrella) ConfirmEmail(key string) *ErrUmbrella { user.SetEmailActivationKey("") err = user.Save() if err != nil { - return &ErrUmbrella{ + log.Print("save to db err") + return ErrUmbrella{ Op: "SaveToDB", Err: err, } diff --git a/version.go b/version.go index b6bc66f..7601a47 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package umbrella -const VERSION = "0.8.0" +const VERSION = "0.8.1"