From 5b89f47d2a7aa2a27ec29477fc28c8a5d25aa3af Mon Sep 17 00:00:00 2001 From: eizyc Date: Mon, 24 Jun 2024 12:35:43 -0500 Subject: [PATCH] feat(auth): add auth middleware to API --- api/account.go | 15 +++++- api/account_test.go | 93 +++++++++++++++++++++++++++------- api/middleware.go | 54 ++++++++++++++++++++ api/middleware_test.go | 108 ++++++++++++++++++++++++++++++++++++++++ api/server.go | 9 ++-- api/transfer.go | 10 +++- db/query/account.sql | 5 +- db/sqlc/account.sql.go | 12 +++-- db/sqlc/account_test.go | 7 ++- 9 files changed, 280 insertions(+), 33 deletions(-) create mode 100644 api/middleware.go create mode 100644 api/middleware_test.go diff --git a/api/account.go b/api/account.go index d5b5b26..88e5b83 100644 --- a/api/account.go +++ b/api/account.go @@ -2,17 +2,18 @@ package api import ( "database/sql" + "errors" "log" "net/http" db "github.com/eizyc/simplebank/db/sqlc" + "github.com/eizyc/simplebank/token" "github.com/jackc/pgx/v5/pgconn" "github.com/gin-gonic/gin" ) type createAccountRequest struct { - Owner string `json:"owner" binding:"required"` Currency string `json:"currency" binding:"required,currency"` } @@ -23,8 +24,9 @@ func (server *Server) createAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) arg := db.CreateAccountParams{ - Owner: req.Owner, + Owner: authPayload.Username, Currency: req.Currency, Balance: 0, } @@ -69,6 +71,13 @@ func (server *Server) getAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if account.Owner != authPayload.Username { + err := errors.New("account doesn't belong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + ctx.JSON(http.StatusOK, account) } @@ -84,7 +93,9 @@ func (server *Server) listAccounts(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) arg := db.ListAccountsParams{ + Owner: authPayload.Username, Limit: req.PageSize, Offset: (req.PageID - 1) * req.PageSize, } diff --git a/api/account_test.go b/api/account_test.go index 5feefc8..c9b6d94 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -9,9 +9,11 @@ import ( "net/http" "net/http/httptest" "testing" + "time" mockdb "github.com/eizyc/simplebank/db/mock" db "github.com/eizyc/simplebank/db/sqlc" + "github.com/eizyc/simplebank/token" "github.com/eizyc/simplebank/util" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -19,17 +21,22 @@ import ( ) func TestGetAccountAPI(t *testing.T) { - account := randomAccount() + user, _ := randomUser(t) + account := randomAccount(user.Username) testCases := []struct { name string accountID int64 + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(t *testing.T, recoder *httptest.ResponseRecorder) }{ { name: "OK", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -41,9 +48,42 @@ func TestGetAccountAPI(t *testing.T) { requireBodyMatchAccount(t, recorder.Body, account) }, }, + { + name: "UnauthorizedUser", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", time.Minute) + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Eq(account.ID)). + Times(1). + Return(account, nil) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "NoAuthorization", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Any()). + Times(0) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, { name: "NotFound", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -57,6 +97,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InternalError", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -70,6 +113,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InvalidID", accountID: 0, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Any()). @@ -98,6 +144,7 @@ func TestGetAccountAPI(t *testing.T) { request, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(t, recorder) }) @@ -105,20 +152,24 @@ func TestGetAccountAPI(t *testing.T) { } func TestCreateAccountAPI(t *testing.T) { - account := randomAccount() + user, _ := randomUser(t) + account := randomAccount(user.Username) testCases := []struct { name string body gin.H + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(recoder *httptest.ResponseRecorder) }{ { name: "OK", body: gin.H{ - "owner": account.Owner, "currency": account.Currency, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { arg := db.CreateAccountParams{ Owner: account.Owner, @@ -137,41 +188,46 @@ func TestCreateAccountAPI(t *testing.T) { }, }, { - name: "InternalError", + name: "NoAuthorization", body: gin.H{ - "owner": account.Owner, "currency": account.Currency, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccount(gomock.Any(), gomock.Any()). - Times(1). - Return(db.Account{}, sql.ErrConnDone) + Times(0) }, checkResponse: func(recorder *httptest.ResponseRecorder) { - require.Equal(t, http.StatusInternalServerError, recorder.Code) + require.Equal(t, http.StatusUnauthorized, recorder.Code) }, }, { - name: "InvalidCurrency", + name: "InternalError", body: gin.H{ - "owner": account.Owner, - "currency": "invalid", + "currency": account.Currency, + }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccount(gomock.Any(), gomock.Any()). - Times(0) + Times(1). + Return(db.Account{}, sql.ErrConnDone) }, checkResponse: func(recorder *httptest.ResponseRecorder) { - require.Equal(t, http.StatusBadRequest, recorder.Code) + require.Equal(t, http.StatusInternalServerError, recorder.Code) }, }, { - name: "InvalidOwner", + name: "InvalidCurrency", body: gin.H{ - "owner": "", - "currency": account.Currency, + "currency": "invalid", + }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). @@ -204,16 +260,17 @@ func TestCreateAccountAPI(t *testing.T) { request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(recorder) }) } } -func randomAccount() db.Account { +func randomAccount(owner string) db.Account { return db.Account{ ID: util.RandomInt(1, 1000), - Owner: util.RandomOwner(), + Owner: owner, Balance: util.RandomMoney(), Currency: util.RandomCurrency(), } diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..2c17931 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,54 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "github.com/eizyc/simplebank/token" + "github.com/gin-gonic/gin" +) + +const ( + authorizationHeaderKey = "authorization" + authorizationTypeBearer = "bearer" + authorizationPayloadKey = "authorization_payload" +) + +// AuthMiddleware creates a gin middleware for authorization +func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc { + return func(ctx *gin.Context) { + authorizationHeader := ctx.GetHeader(authorizationHeaderKey) + + if len(authorizationHeader) == 0 { + err := errors.New("authorization header is not provided") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + fields := strings.Fields(authorizationHeader) + if len(fields) < 2 { + err := errors.New("invalid authorization header format") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + authorizationType := strings.ToLower(fields[0]) + if authorizationType != authorizationTypeBearer { + err := fmt.Errorf("unsupported authorization type %s", authorizationType) + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken := fields[1] + payload, err := tokenMaker.VerifyToken(accessToken) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + ctx.Set(authorizationPayloadKey, payload) + ctx.Next() + } +} diff --git a/api/middleware_test.go b/api/middleware_test.go new file mode 100644 index 0000000..4a2ca54 --- /dev/null +++ b/api/middleware_test.go @@ -0,0 +1,108 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/eizyc/simplebank/token" + "github.com/eizyc/simplebank/util" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func addAuthorization( + t *testing.T, + request *http.Request, + tokenMaker token.Maker, + authorizationType string, + username string, + duration time.Duration, +) { + token, err := tokenMaker.CreateToken(username, duration) + require.NoError(t, err) + + authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) + request.Header.Set(authorizationHeaderKey, authorizationHeader) +} + +func TestAuthMiddleware(t *testing.T) { + username := util.RandomOwner() + + testCases := []struct { + name string + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) + checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + name: "OK", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, username, time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + { + name: "NoAuthorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "UnsupportedAuthorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, "unsupported", username, time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "InvalidAuthorizationFormat", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, "", username, time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "ExpiredToken", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, username, -time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + server := newTestServer(t, nil) + authPath := "/auth" + server.router.GET( + authPath, + authMiddleware(server.tokenMaker), + func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{}) + }, + ) + + recorder := httptest.NewRecorder() + request, err := http.NewRequest(http.MethodGet, authPath, nil) + require.NoError(t, err) + + tc.setupAuth(t, request, server.tokenMaker) + server.router.ServeHTTP(recorder, request) + tc.checkResponse(t, recorder) + }) + } +} diff --git a/api/server.go b/api/server.go index 84d9dbe..879a7cf 100644 --- a/api/server.go +++ b/api/server.go @@ -49,11 +49,12 @@ func (server *Server) setupRouter() { router.POST("/users", server.createUser) router.POST("/users/login", server.loginUser) - router.POST("/accounts", server.createAccount) - router.GET("/accounts/:id", server.getAccount) - router.GET("/accounts", server.listAccounts) + authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker)) + authRoutes.POST("/accounts", server.createAccount) + authRoutes.GET("/accounts/:id", server.getAccount) + authRoutes.GET("/accounts", server.listAccounts) - router.POST("/transfers", server.createTransfer) + authRoutes.POST("/transfers", server.createTransfer) server.router = router } diff --git a/api/transfer.go b/api/transfer.go index 3c3180c..c769fad 100644 --- a/api/transfer.go +++ b/api/transfer.go @@ -6,6 +6,7 @@ import ( "net/http" db "github.com/eizyc/simplebank/db/sqlc" + "github.com/eizyc/simplebank/token" "github.com/gin-gonic/gin" ) @@ -24,11 +25,18 @@ func (server *Server) createTransfer(ctx *gin.Context) { return } - _, valid := server.validAccount(ctx, req.FromAccountID, req.Currency) + fromAccount, valid := server.validAccount(ctx, req.FromAccountID, req.Currency) if !valid { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if fromAccount.Owner != authPayload.Username { + err := errors.New("from account doesn't belong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + _, valid = server.validAccount(ctx, req.ToAccountID, req.Currency) if !valid { return diff --git a/db/query/account.sql b/db/query/account.sql index 4536302..81bfeb9 100644 --- a/db/query/account.sql +++ b/db/query/account.sql @@ -18,9 +18,10 @@ FOR NO KEY UPDATE; -- name: ListAccounts :many SELECT * FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2; +LIMIT $2 +OFFSET $3; -- name: UpdateAccount :one UPDATE accounts diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index 2b0f16e..9f3e838 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -112,18 +112,20 @@ func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, e const listAccounts = `-- name: ListAccounts :many SELECT id, owner, balance, currency, created_at FROM accounts +WHERE owner = $1 ORDER BY id -LIMIT $1 -OFFSET $2 +LIMIT $2 +OFFSET $3 ` type ListAccountsParams struct { - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` + Owner string `json:"owner"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` } func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) { - rows, err := q.db.Query(ctx, listAccounts, arg.Limit, arg.Offset) + rows, err := q.db.Query(ctx, listAccounts, arg.Owner, arg.Limit, arg.Offset) if err != nil { return nil, err } diff --git a/db/sqlc/account_test.go b/db/sqlc/account_test.go index 7e22b70..5994818 100644 --- a/db/sqlc/account_test.go +++ b/db/sqlc/account_test.go @@ -21,9 +21,11 @@ func createRandomAccount(t *testing.T) Account { account, err := testStore.CreateAccount(context.Background(), arg) require.NoError(t, err) require.NotEmpty(t, account) + require.Equal(t, arg.Owner, account.Owner) require.Equal(t, arg.Balance, account.Balance) require.Equal(t, arg.Currency, account.Currency) + require.NotZero(t, account.ID) require.NotZero(t, account.CreatedAt) return account @@ -77,11 +79,13 @@ func TestDeleteAccount(t *testing.T) { } func TestListAccounts(t *testing.T) { + var lastAccount Account for i := 0; i < 10; i++ { - createRandomAccount(t) + lastAccount = createRandomAccount(t) } arg := ListAccountsParams{ + Owner: lastAccount.Owner, Limit: 5, Offset: 0, } @@ -92,5 +96,6 @@ func TestListAccounts(t *testing.T) { for _, account := range accounts { require.NotEmpty(t, account) + require.Equal(t, lastAccount.Owner, account.Owner) } }