Skip to content

Commit

Permalink
feat(auth): add auth middleware to API
Browse files Browse the repository at this point in the history
  • Loading branch information
eizyc committed Jun 24, 2024
1 parent 8465432 commit 5b89f47
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 33 deletions.
15 changes: 13 additions & 2 deletions api/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ package api

import (
"database/sql"
"errors"
"log"
"net/http"

db "github.com/eizyc/simplebank/db/sqlc"
"github.com/eizyc/simplebank/token"
"github.com/jackc/pgx/v5/pgconn"

"github.com/gin-gonic/gin"
)

type createAccountRequest struct {
Owner string `json:"owner" binding:"required"`
Currency string `json:"currency" binding:"required,currency"`
}

Expand All @@ -23,8 +24,9 @@ func (server *Server) createAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.CreateAccountParams{
Owner: req.Owner,
Owner: authPayload.Username,
Currency: req.Currency,
Balance: 0,
}
Expand Down Expand Up @@ -69,6 +71,13 @@ func (server *Server) getAccount(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
if account.Owner != authPayload.Username {
err := errors.New("account doesn't belong to the authenticated user")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.JSON(http.StatusOK, account)
}

Expand All @@ -84,7 +93,9 @@ func (server *Server) listAccounts(ctx *gin.Context) {
return
}

authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.ListAccountsParams{
Owner: authPayload.Username,
Limit: req.PageSize,
Offset: (req.PageID - 1) * req.PageSize,
}
Expand Down
93 changes: 75 additions & 18 deletions api/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,34 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

mockdb "github.com/eizyc/simplebank/db/mock"
db "github.com/eizyc/simplebank/db/sqlc"
"github.com/eizyc/simplebank/token"
"github.com/eizyc/simplebank/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestGetAccountAPI(t *testing.T) {
account := randomAccount()
user, _ := randomUser(t)
account := randomAccount(user.Username)

testCases := []struct {
name string
accountID int64
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
buildStubs func(store *mockdb.MockStore)
checkResponse func(t *testing.T, recoder *httptest.ResponseRecorder)
}{
{
name: "OK",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Expand All @@ -41,9 +48,42 @@ func TestGetAccountAPI(t *testing.T) {
requireBodyMatchAccount(t, recorder.Body, account)
},
},
{
name: "UnauthorizedUser",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Times(1).
Return(account, nil)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "NoAuthorization",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Any()).
Times(0)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "NotFound",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Expand All @@ -57,6 +97,9 @@ func TestGetAccountAPI(t *testing.T) {
{
name: "InternalError",
accountID: account.ID,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Eq(account.ID)).
Expand All @@ -70,6 +113,9 @@ func TestGetAccountAPI(t *testing.T) {
{
name: "InvalidID",
accountID: 0,
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetAccount(gomock.Any(), gomock.Any()).
Expand Down Expand Up @@ -98,27 +144,32 @@ func TestGetAccountAPI(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)

tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
}
}

func TestCreateAccountAPI(t *testing.T) {
account := randomAccount()
user, _ := randomUser(t)
account := randomAccount(user.Username)

testCases := []struct {
name string
body gin.H
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
buildStubs func(store *mockdb.MockStore)
checkResponse func(recoder *httptest.ResponseRecorder)
}{
{
name: "OK",
body: gin.H{
"owner": account.Owner,
"currency": account.Currency,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
arg := db.CreateAccountParams{
Owner: account.Owner,
Expand All @@ -137,41 +188,46 @@ func TestCreateAccountAPI(t *testing.T) {
},
},
{
name: "InternalError",
name: "NoAuthorization",
body: gin.H{
"owner": account.Owner,
"currency": account.Currency,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
CreateAccount(gomock.Any(), gomock.Any()).
Times(1).
Return(db.Account{}, sql.ErrConnDone)
Times(0)
},
checkResponse: func(recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusInternalServerError, recorder.Code)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "InvalidCurrency",
name: "InternalError",
body: gin.H{
"owner": account.Owner,
"currency": "invalid",
"currency": account.Currency,
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
CreateAccount(gomock.Any(), gomock.Any()).
Times(0)
Times(1).
Return(db.Account{}, sql.ErrConnDone)
},
checkResponse: func(recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
},
},
{
name: "InvalidOwner",
name: "InvalidCurrency",
body: gin.H{
"owner": "",
"currency": account.Currency,
"currency": "invalid",
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, user.Username, time.Minute)
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
Expand Down Expand Up @@ -204,16 +260,17 @@ func TestCreateAccountAPI(t *testing.T) {
request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
require.NoError(t, err)

tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(recorder)
})
}
}

func randomAccount() db.Account {
func randomAccount(owner string) db.Account {
return db.Account{
ID: util.RandomInt(1, 1000),
Owner: util.RandomOwner(),
Owner: owner,
Balance: util.RandomMoney(),
Currency: util.RandomCurrency(),
}
Expand Down
54 changes: 54 additions & 0 deletions api/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package api

import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/eizyc/simplebank/token"
"github.com/gin-gonic/gin"
)

const (
authorizationHeaderKey = "authorization"
authorizationTypeBearer = "bearer"
authorizationPayloadKey = "authorization_payload"
)

// AuthMiddleware creates a gin middleware for authorization
func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc {
return func(ctx *gin.Context) {
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)

if len(authorizationHeader) == 0 {
err := errors.New("authorization header is not provided")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

fields := strings.Fields(authorizationHeader)
if len(fields) < 2 {
err := errors.New("invalid authorization header format")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

authorizationType := strings.ToLower(fields[0])
if authorizationType != authorizationTypeBearer {
err := fmt.Errorf("unsupported authorization type %s", authorizationType)
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

accessToken := fields[1]
payload, err := tokenMaker.VerifyToken(accessToken)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}

ctx.Set(authorizationPayloadKey, payload)
ctx.Next()
}
}
Loading

0 comments on commit 5b89f47

Please sign in to comment.