Skip to content

Commit

Permalink
improve route handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
Ajnasz committed Mar 8, 2024
1 parent c06f2c1 commit 68a6c63
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 76 deletions.
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ clean:
test:
go test -v ./...

.PHONY: curl
.PHONY: curl curl-bad
curl:
curl -v --data-binary @go.mod localhost:8080/api/ | xargs -I {} curl localhost:8080{}
curl-bad:
curl -v localhost:8080/api/57c04c70-dd58-11ee-98fc-ebbaf68907f4/8bc419a2de0ccf0b165cd978f8894b77403a2f06019916af0bf48bcade88f518
74 changes: 62 additions & 12 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package api

import (
"database/sql"
"fmt"
"log"
"net/http"
"net/url"
"path"
"strings"

"github.com/Ajnasz/sekret.link/api/middlewares"
"github.com/Ajnasz/sekret.link/internal/api"
"github.com/Ajnasz/sekret.link/internal/models"
"github.com/Ajnasz/sekret.link/internal/parsers"
Expand Down Expand Up @@ -94,17 +98,63 @@ func (s SecretHandler) Options(w http.ResponseWriter, r *http.Request) {
// Your OPTIONS method logic goes here
w.WriteHeader(http.StatusOK)
}
func (s SecretHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
s.Post(w, r)
case http.MethodGet:
s.Get(w, r)
case http.MethodDelete:
s.Delete(w, r)
case http.MethodOptions:
s.Options(w, r)
default:
http.Error(w, "Not found", http.StatusNotFound)

// NotFound handler
func (s SecretHandler) NotFound(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound)
}

func clearApiRoot(apiRoot string) string {
apiRoot = path.Clean(path.Join("/", apiRoot))

if strings.HasSuffix(apiRoot, "/") {
return apiRoot
}

return apiRoot + "/"
}

func (s SecretHandler) RegisterHandlers(mux *http.ServeMux, apiRoot string) {
mux.Handle(
fmt.Sprintf("GET %s", path.Join("/", apiRoot, "{uuid}", "{key}")),
http.StripPrefix(
apiRoot,
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(s.Get)),
),
),
)

mux.Handle(
fmt.Sprintf("POST %s", clearApiRoot(apiRoot)),
http.StripPrefix(
path.Join("/", apiRoot),
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(s.Post)),
),
),
)

mux.Handle(
fmt.Sprintf("DELETE %s", path.Join("/", apiRoot, "{uuid}", "{key}", "{deleteKey}")),
http.StripPrefix(
apiRoot,
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(s.Delete)),
),
),
)

mux.Handle(
fmt.Sprintf("OPTIONS %s", clearApiRoot(apiRoot)),
http.StripPrefix(
apiRoot,
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(s.Options)),
),
),
)

mux.Handle("/", middlewares.SetupLogging(middlewares.SetupHeaders(http.HandlerFunc(s.NotFound))))

}
44 changes: 34 additions & 10 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func TestCreateEntry(t *testing.T) {

resp := w.Result()
body, err := io.ReadAll(resp.Body)
fmt.Println("BODYDYDYDDY", body, err)

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -370,9 +369,18 @@ func TestGetEntry(t *testing.T) {

req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://example.com/%s/%s", testCase.UUID, hex.EncodeToString(rsakey)), nil)
w := httptest.NewRecorder()
NewSecretHandler(NewHandlerConfig(connection, connection.GetDB())).ServeHTTP(w, req)

mux := http.NewServeMux()
secretHandler := NewSecretHandler(NewHandlerConfig(connection, connection.GetDB()))
secretHandler.RegisterHandlers(mux, "")

mux.ServeHTTP(w, req)

resp := w.Result()

if resp.StatusCode != http.StatusOK {
t.Fatalf("expected statuscode %d got %d", http.StatusOK, resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)

actual := string(body)
Expand Down Expand Up @@ -420,7 +428,11 @@ func TestGetEntryJSON(t *testing.T) {
req := httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%s/%s", testCase.UUID, hex.EncodeToString(rsakey)), nil)
req.Header.Add("Accept", "application/json")
w := httptest.NewRecorder()
NewSecretHandler(NewHandlerConfig(connection, connection.GetDB())).ServeHTTP(w, req)

mux := http.NewServeMux()
secretHandler := NewSecretHandler(NewHandlerConfig(connection, connection.GetDB()))
secretHandler.RegisterHandlers(mux, "")
mux.ServeHTTP(w, req)

resp := w.Result()
if resp.StatusCode != 200 {
Expand Down Expand Up @@ -457,12 +469,19 @@ func TestSetAndGetEntry(t *testing.T) {
connection.Close()
})

req := httptest.NewRequest("POST", "http://example.com", bytes.NewReader([]byte(testCase)))
req := httptest.NewRequest("POST", "http://example.com/", bytes.NewReader([]byte(testCase)))
w := httptest.NewRecorder()

NewSecretHandler(NewHandlerConfig(connection, connection.GetDB())).ServeHTTP(w, req)
mux := http.NewServeMux()
secretHandler := NewSecretHandler(NewHandlerConfig(connection, connection.GetDB()))
secretHandler.RegisterHandlers(mux, "")
mux.ServeHTTP(w, req)

resp := w.Result()

if resp.StatusCode != 200 {
t.Fatalf("expected statuscode %d got %d", 200, resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)

responseURL := string(body)
Expand All @@ -474,7 +493,8 @@ func TestSetAndGetEntry(t *testing.T) {

req = httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%s/%s", savedUUID, keyString), nil)
w = httptest.NewRecorder()
NewSecretHandler(NewHandlerConfig(connection, connection.GetDB())).ServeHTTP(w, req)

mux.ServeHTTP(w, req)

resp = w.Result()
body, _ = io.ReadAll(resp.Body)
Expand Down Expand Up @@ -694,10 +714,14 @@ func FuzzSetAndGetEntry(f *testing.F) {
t.Log("empty")
return
}
req := httptest.NewRequest("POST", "http://example.com", bytes.NewReader([]byte(testCase)))
w := httptest.NewRecorder()

NewSecretHandler(NewHandlerConfig(connection, connection.GetDB())).ServeHTTP(w, req)
mux := http.NewServeMux()
secretHandler := NewSecretHandler(NewHandlerConfig(connection, connection.GetDB()))
secretHandler.RegisterHandlers(mux, "")

req := httptest.NewRequest("POST", "http://example.com/", bytes.NewReader([]byte(testCase)))
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)

resp := w.Result()
body, _ := io.ReadAll(resp.Body)
Expand All @@ -711,7 +735,7 @@ func FuzzSetAndGetEntry(f *testing.F) {

req = httptest.NewRequest("GET", fmt.Sprintf("http://example.com/%s/%s", savedUUID, keyString), nil)
w = httptest.NewRecorder()
NewSecretHandler(NewHandlerConfig(connection, connection.GetDB())).ServeHTTP(w, req)
mux.ServeHTTP(w, req)

resp = w.Result()
body, _ = io.ReadAll(resp.Body)
Expand Down
6 changes: 5 additions & 1 deletion api/middlewares/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import (

func SetupLogging(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println(fmt.Sprintf("%s: %s", r.Method, r.URL.Path))
if r.Method == http.MethodGet {
log.Println(fmt.Sprintf("%s: %s", r.Method, "/***"))
} else {
log.Println(fmt.Sprintf("%s: %s", r.Method, r.URL.Path))
}
h.ServeHTTP(w, r)
})
}
Expand Down
41 changes: 2 additions & 39 deletions cmd/sekret.link/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"time"

"github.com/Ajnasz/sekret.link/api"
"github.com/Ajnasz/sekret.link/api/middlewares"
"github.com/Ajnasz/sekret.link/config"
"github.com/Ajnasz/sekret.link/storage"
"github.com/Ajnasz/sekret.link/storage/postgresql"
Expand Down Expand Up @@ -64,47 +63,11 @@ func scheduleDeleteExpired(ctx context.Context, entryStorage storage.Writer) {

func listen(handlerConfig api.HandlerConfig) *http.Server {
mux := http.NewServeMux()
secretHandler := api.NewSecretHandler(handlerConfig)

apiRoot := getAPIRoot(handlerConfig.WebExternalURL)
mux.Handle(
fmt.Sprintf("GET %s", apiRoot),
http.StripPrefix(
apiRoot,
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(secretHandler.Get)),
),
),
)
mux.Handle(
fmt.Sprintf("POST %s", apiRoot),
http.StripPrefix(
apiRoot,
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(secretHandler.Post)),
),
),
)

mux.Handle(
fmt.Sprintf("DELETE %s", apiRoot),
http.StripPrefix(
apiRoot,
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(secretHandler.Delete)),
),
),
)

mux.Handle(
fmt.Sprintf("OPTIONS %s", apiRoot),
http.StripPrefix(
apiRoot,
middlewares.SetupLogging(
middlewares.SetupHeaders(http.HandlerFunc(secretHandler.Options)),
),
),
)
secretHandler := api.NewSecretHandler(handlerConfig)
secretHandler.RegisterHandlers(mux, apiRoot)

httpServer := &http.Server{
Addr: ":8080",
Expand Down
2 changes: 1 addition & 1 deletion internal/api/getentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (g GetHandler) handle(w http.ResponseWriter, r *http.Request) error {
request, err := g.parser.Parse(r)

if err != nil {
return errors.Join(ErrRequestParseError, err)
return err
}

ctx, cancel := context.WithCancel(context.Background())
Expand Down
3 changes: 3 additions & 0 deletions internal/parsers/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ var ErrInvalidURL = errors.New("invalid URL")

// ErrInvalidUUID is returned when the UUID is invalid
var ErrInvalidUUID = errors.New("invalid UUID")

// ErrInvalidKey is returned when the key is invalid
var ErrInvalidKey = errors.New("invalid key")
17 changes: 9 additions & 8 deletions internal/parsers/getentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package parsers
import (
"encoding/hex"
"errors"
"fmt"
"net/http"
"path"

"github.com/google/uuid"
)
Expand All @@ -21,14 +21,15 @@ type GetEntryRequestData struct {
Key []byte
}

func (g GetEntryParser) Parse(u *http.Request) (GetEntryRequestData, error) {
urlPath := u.URL.Path
pathDir, keyString := path.Split(urlPath)
func (g GetEntryParser) Parse(req *http.Request) (GetEntryRequestData, error) {
var reqData GetEntryRequestData
if len(pathDir) < 1 {
return reqData, ErrInvalidURL
keyString := req.PathValue("key")
if keyString == "" {
fmt.Println("EMPTY KEY", req.URL.Path)
return reqData, ErrInvalidKey
}
_, uuidFromPath := path.Split(pathDir[0 : len(pathDir)-1])

uuidFromPath := req.PathValue("uuid")
UUID, err := uuid.Parse(uuidFromPath)

if err != nil {
Expand All @@ -37,7 +38,7 @@ func (g GetEntryParser) Parse(u *http.Request) (GetEntryRequestData, error) {
key, err := hex.DecodeString(keyString)

if err != nil {
return reqData, err
return reqData, errors.Join(ErrInvalidKey, err)
}
return GetEntryRequestData{UUID: UUID.String(), Key: key, KeyString: keyString}, nil
}
14 changes: 10 additions & 4 deletions internal/views/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"crypto/aes"
"encoding/hex"
"errors"
"log"
"net/http"
"strings"

Expand All @@ -18,7 +17,6 @@ import (
var ErrCreateKey = errors.New("create key failed")

func (e EntryView) RenderCreateEntryErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
log.Println("CREATE ENTRY ERROR", err)
if errors.Is(err, parsers.ErrInvalidExpirationDate) {
http.Error(w, "Invalid expiration", http.StatusBadRequest)
return
Expand All @@ -35,7 +33,6 @@ func (e EntryView) RenderCreateEntryErrorResponse(w http.ResponseWriter, r *http
}

func (e EntryView) RenderReadEntryError(w http.ResponseWriter, r *http.Request, err error) {
log.Println(err)
if errors.Is(err, services.ErrEntryExpired) {
http.Error(w, "Gone", http.StatusGone)
return
Expand All @@ -46,6 +43,16 @@ func (e EntryView) RenderReadEntryError(w http.ResponseWriter, r *http.Request,
return
}

if errors.Is(err, parsers.ErrInvalidUUID) {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}

if errors.Is(err, parsers.ErrInvalidKey) {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}

if errors.Is(err, hex.ErrLength) {
http.Error(w, "Bad request", http.StatusBadRequest)
return
Expand All @@ -61,7 +68,6 @@ func (e EntryView) RenderReadEntryError(w http.ResponseWriter, r *http.Request,
}

func (e EntryView) RenderDeleteEntryError(w http.ResponseWriter, r *http.Request, err error) {
log.Println(err)

if errors.Is(err, entries.ErrEntryNotFound) || errors.Is(err, models.ErrEntryNotFound) {
http.Error(w, "Not Found", http.StatusNotFound)
Expand Down

0 comments on commit 68a6c63

Please sign in to comment.