Skip to content

Commit

Permalink
feat: add content type to entry creation and reading
Browse files Browse the repository at this point in the history
The commit introduces the ability to specify a content type when
creating an entry.
The content type is stored in the database and returned in the content
type header if the client doesn't request application/json content
explicitly, but in that case the returned value is already different
  • Loading branch information
Ajnasz committed May 21, 2024
1 parent 59f71d4 commit 4b56708
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 66 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ BUILD_ARGS=-trimpath -ldflags '-w -s'
all: clean linux

run:
@cd cmd/sekret.link && POSTGRES_URL="postgres://postgres:password@localhost:5432/sekret_link_test?sslmode=disable" go run . -webExternalURL=/api -base62
@cd cmd/sekret.link && POSTGRES_URL="postgres://postgres:password@localhost:5432/sekret_link_test?sslmode=disable" go run . -webExternalURL=/api

.PHONY: build
build: build/${BINARY_NAME}.linux.amd64
Expand Down
4 changes: 2 additions & 2 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ func TestGetEntry(t *testing.T) {

keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter)
entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager)
meta, encKey, err := entryManager.CreateEntry(ctx, []byte(testCase.Value), 1, time.Second*10)
meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), 1, time.Second*10)

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -444,7 +444,7 @@ func TestGetEntryJSON(t *testing.T) {

keyManager := services.NewEntryKeyManager(db, &models.EntryKeyModel{}, hasher.NewSHA256Hasher(), encrypter)
entryManager := services.NewEntryManager(db, &models.EntryModel{}, encrypter, keyManager)
meta, encKey, err := entryManager.CreateEntry(ctx, []byte(testCase.Value), 1, time.Second*10)
meta, encKey, err := entryManager.CreateEntry(ctx, "text/plain", []byte(testCase.Value), 1, time.Second*10)
if err != nil {
t.Error(err)
}
Expand Down
6 changes: 4 additions & 2 deletions internal/api/createentry.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type CreateEntryParser interface {

// CreateEntryManager is an interface for creating entries
type CreateEntryManager interface {
CreateEntry(ctx context.Context, body []byte, maxReads int, expiration time.Duration) (*services.EntryMeta, key.Key, error)
CreateEntry(ctx context.Context, contentType string, body []byte, maxReads int, expiration time.Duration) (*services.EntryMeta, key.Key, error)
}

// CreateEntryView is an interface for rendering the create entry response
Expand Down Expand Up @@ -60,9 +60,11 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error {
return errors.Join(ErrRequestParseError, err)
}

contentType := r.Header.Get("Content-Type")

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
entry, key, err := c.entryManager.CreateEntry(ctx, data.Body, data.MaxReads, data.Expiration)
entry, key, err := c.entryManager.CreateEntry(ctx, contentType, data.Body, data.MaxReads, data.Expiration)

if err != nil {
return err
Expand Down
9 changes: 6 additions & 3 deletions internal/api/createentry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ type MockEntryManager struct {

func (m *MockEntryManager) CreateEntry(
ctx context.Context,
contentType string,
body []byte,
maxReads int,
expiration time.Duration,
) (*services.EntryMeta, key.Key, error) {
args := m.Called(ctx, body, maxReads, expiration)
args := m.Called(ctx, contentType, body, maxReads, expiration)

if args.Get(1) == nil {
return args.Get(0).(*services.EntryMeta), nil, args.Error(2)
Expand Down Expand Up @@ -68,6 +69,7 @@ func Test_CreateEntryHandle(t *testing.T) {
view := new(MockEntryView)

request := httptest.NewRequest("POST", "http://example.com/foo", data)
request.Header.Set("Content-Type", "text/plain")
response := httptest.NewRecorder()

parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil)
Expand All @@ -76,7 +78,7 @@ func Test_CreateEntryHandle(t *testing.T) {
if err != nil {
t.Fatal(err)
}
entryManager.On("CreateEntry", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, *retKey, nil)
entryManager.On("CreateEntry", mock.Anything, "text/plain", mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, *retKey, nil)
view.On("Render", mock.Anything, mock.Anything, mock.Anything).Return()

handler := NewCreateHandler(10, parser, entryManager, view)
Expand Down Expand Up @@ -121,12 +123,13 @@ func Test_CreateEntryHandleError(t *testing.T) {
view := new(MockEntryView)

request := httptest.NewRequest("POST", "http://example.com/foo", data)
request.Header.Set("Content-Type", "text/plain")
response := httptest.NewRecorder()

parser.On("Parse", request).Return(&parsers.CreateEntryRequestData{}, nil)
k, err := key.NewGeneratedKey()
assert.NoError(t, err)
entryManager.On("CreateEntry", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, *k, errors.New("error"))
entryManager.On("CreateEntry", mock.Anything, "text/plain", mock.Anything, mock.Anything, mock.Anything).Return(&services.EntryMeta{}, *k, errors.New("error"))
view.On("RenderError", mock.Anything, mock.Anything, mock.Anything).Return()

handler := NewCreateHandler(10, parser, entryManager, view)
Expand Down
16 changes: 9 additions & 7 deletions internal/api/getentry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ func TestGetHandle(t *testing.T) {

managerMock.On("ReadEntry", mock.Anything, "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb", *k).
Return(&services.Entry{
UUID: "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb",
Data: []byte{18, 18, 18, 18, 174, 173, 15},
RemainingReads: 0,
DeleteKey: "12121212aeadf",
Created: time.Now().Add(time.Minute * -1),
Accessed: time.Now(),
Expire: time.Now().Add(time.Minute * 1),
Data: []byte{18, 18, 18, 18, 174, 173, 15},
EntryMeta: services.EntryMeta{
UUID: "a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb",
RemainingReads: 0,
DeleteKey: "12121212aeadf",
Created: time.Now().Add(time.Minute * -1),
Accessed: time.Now(),
Expire: time.Now().Add(time.Minute * 1),
},
}, nil)

request := httptest.NewRequest("GET", "http://example.com/foo/a6a9d8cc-db7f-11ee-8f4f-3b41146b31eb/12121212aeadf", nil)
Expand Down
13 changes: 7 additions & 6 deletions internal/models/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type EntryMeta struct {
Created time.Time
Accessed sql.NullTime
Expire time.Time
ContentType string
}

// uuid uuid PRIMARY KEY,
Expand All @@ -50,14 +51,14 @@ func (e *EntryModel) getDeleteKey() (string, error) {
}

// CreateEntry creates a new entry into the database
func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, error) {
func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, contenType string, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, error) {
deleteKey, err := e.getDeleteKey()
if err != nil {
return nil, errors.Join(err, ErrCreateEntry)
}

now := time.Now()
res, err := tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key) VALUES ($1, $2, $3, $4, $5, $6) RETURNING uuid, delete_key;`, uuid, data, now, now.Add(expire), remainingReads, deleteKey)
res, err := tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key, content_type) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING uuid, delete_key;`, uuid, data, now, now.Add(expire), remainingReads, deleteKey, contenType)

if err != nil {
return nil, errors.Join(err, ErrCreateEntry)
Expand Down Expand Up @@ -89,9 +90,9 @@ func (e *EntryModel) Use(ctx context.Context, tx *sql.Tx, uuid string) error {
// ReadEntry reads a entry from the database
// and updates the read count
func (e *EntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, uuid string) (*Entry, error) {
row := tx.QueryRow("SELECT uuid, data, remaining_reads, delete_key, created, accessed, expire FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid)
row := tx.QueryRow("SELECT uuid, data, remaining_reads, delete_key, created, accessed, expire, content_type FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid)
var s Entry
err := row.Scan(&s.UUID, &s.Data, &s.RemainingReads, &s.DeleteKey, &s.Created, &s.Accessed, &s.Expire)
err := row.Scan(&s.UUID, &s.Data, &s.RemainingReads, &s.DeleteKey, &s.Created, &s.Accessed, &s.Expire, &s.ContentType)
if err != nil {
if err == sql.ErrNoRows {
return nil, ErrEntryNotFound
Expand All @@ -103,9 +104,9 @@ func (e *EntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, uuid string) (*E
}

func (e *EntryModel) ReadEntryMeta(ctx context.Context, tx *sql.Tx, uuid string) (*EntryMeta, error) {
row := tx.QueryRow("SELECT created, accessed, expire, remaining_reads, delete_key FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid)
row := tx.QueryRow("SELECT created, accessed, expire, remaining_reads, delete_key, content_type FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid)
var s EntryMeta
err := row.Scan(&s.Created, &s.Accessed, &s.Expire, &s.RemainingReads, &s.DeleteKey)
err := row.Scan(&s.Created, &s.Accessed, &s.Expire, &s.RemainingReads, &s.DeleteKey, &s.ContentType)
if err != nil {
if err == sql.ErrNoRows {
return nil, ErrEntryNotFound
Expand Down
4 changes: 2 additions & 2 deletions internal/models/entry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func Test_EntryModel_CreateEntry(t *testing.T) {

model := &EntryModel{}

meta, err := model.CreateEntry(ctx, tx, uid, data, remainingReads, expire)
meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data, remainingReads, expire)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -86,7 +86,7 @@ func Test_EntryModel_Use(t *testing.T) {

model := &EntryModel{}

meta, err := model.CreateEntry(ctx, tx, uid, data, remainingReads, expire)
meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data, remainingReads, expire)
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/models/entrykey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func createTestEntryKey(ctx context.Context, tx *sql.Tx) (string, string, error)

entryModel := &EntryModel{}

_, err := entryModel.CreateEntry(ctx, tx, uid, []byte("test data"), 2, 3600)
_, err := entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600)

if err != nil {
return "", "", err
Expand Down Expand Up @@ -66,7 +66,7 @@ func Test_EntryKeyModel_Create(t *testing.T) {
uid := uuid.New().String()

entryModel := &EntryModel{}
_, err = entryModel.CreateEntry(ctx, tx, uid, []byte("test data"), 2, 3600)
_, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -124,7 +124,7 @@ func Test_EntryKeyModel_Get(t *testing.T) {
uid := uuid.New().String()

entryModel := &EntryModel{}
_, err = entryModel.CreateEntry(ctx, tx, uid, []byte("test data"), 2, 3600)
_, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600)
if err != nil {
if err := tx.Rollback(); err != nil {
t.Error(err)
Expand Down
20 changes: 20 additions & 0 deletions internal/models/migrate/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func (e *EntryMigration) Alter(ctx context.Context, tx *sql.Tx) error {
return err
}

if err := e.addContentType(ctx, tx); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -103,3 +107,19 @@ func (*EntryMigration) addDeleteKey(ctx context.Context, tx *sql.Tx) error {

return nil
}

func (e *EntryMigration) addContentType(ctx context.Context, tx *sql.Tx) error {
alterTable, err := tx.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS content_type VARCHAR(256) NOT NULL DEFAULT '';")

if err != nil {
return fmt.Errorf("failed to add delete_key column: %w", err)
}

_, err = alterTable.Exec()

if err != nil {
return fmt.Errorf("failed to add remaining_reads column: %w", err)
}

return nil
}
1 change: 1 addition & 0 deletions internal/models/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func (m *MockEntryModel) CreateEntry(
ctx context.Context,
tx *sql.Tx,
UUID string,
contentType string,
data []byte,
remainingReads int,
expire time.Duration) (*EntryMeta, error) {
Expand Down
32 changes: 16 additions & 16 deletions internal/services/entrymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,12 @@ type EntryMeta struct {
Created time.Time
Accessed time.Time
Expire time.Time
ContentType string
}

type Entry struct {
UUID string
Data []byte
RemainingReads int
DeleteKey string
Created time.Time
Accessed time.Time
Expire time.Time
EntryMeta
Data []byte
}

type EntryKeyData struct {
Expand Down Expand Up @@ -71,7 +67,7 @@ func NewEntryManager(db *sql.DB, model EntryModel, crypto EncrypterFactory, keyM
// It stores the encrypted data in the database
// It stores the key in the key manager
// It returns the meta data of the entry and the key
func (e *EntryManager) CreateEntry(ctx context.Context, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, key.Key, error) {
func (e *EntryManager) CreateEntry(ctx context.Context, contentType string, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, key.Key, error) {
uid := uuid.NewUUIDString()

tx, err := e.db.Begin()
Expand All @@ -96,7 +92,7 @@ func (e *EntryManager) CreateEntry(ctx context.Context, data []byte, remainingRe
}
return nil, nil, errors.Join(ErrCreateEntryFailed, err)
}
meta, err := e.model.CreateEntry(ctx, tx, uid, encryptedData, remainingReads, expire)
meta, err := e.model.CreateEntry(ctx, tx, uid, contentType, encryptedData, remainingReads, expire)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return nil, nil, errors.Join(ErrCreateEntryFailed, err, rollbackErr)
Expand Down Expand Up @@ -125,6 +121,7 @@ func (e *EntryManager) CreateEntry(ctx context.Context, data []byte, remainingRe
Created: meta.Created,
Accessed: meta.Accessed.Time,
Expire: meta.Expire,
ContentType: meta.ContentType,
}, kek, nil
}

Expand Down Expand Up @@ -219,13 +216,16 @@ func (e *EntryManager) ReadEntry(ctx context.Context, UUID string, k key.Key) (*
}

return &Entry{
UUID: entry.UUID,
Data: decryptedData,
RemainingReads: entry.RemainingReads - 1,
DeleteKey: entry.DeleteKey,
Created: entry.Created,
Accessed: entry.Accessed.Time,
Expire: entry.Expire,
EntryMeta: EntryMeta{
UUID: entry.UUID,
RemainingReads: entry.RemainingReads - 1,
DeleteKey: entry.DeleteKey,
Created: entry.Created,
Accessed: entry.Accessed.Time,
Expire: entry.Expire,
ContentType: entry.ContentType,
},
Data: decryptedData,
}, nil
}

Expand Down
21 changes: 12 additions & 9 deletions internal/services/entrymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func Test_EntryService_Create(t *testing.T) {
keyManager.On("CreateWithTx", ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&EntryKey{}, *kek, nil)

service := NewEntryManager(db, entryModel, crypto, keyManager)
meta, key, err := service.CreateEntry(ctx, data, 1, time.Minute)
meta, key, err := service.CreateEntry(ctx, "text/plain", data, 1, time.Minute)

assert.NoError(t, err)
assert.NotNil(t, meta)
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestCreateError(t *testing.T) {
keyManager := new(MockEntryKeyer)

service := NewEntryManager(db, entryModel, crypto, keyManager)
meta, key, err := service.CreateEntry(ctx, data, 1, time.Minute)
meta, key, err := service.CreateEntry(ctx, "text/plain", data, 1, time.Minute)

assert.Error(t, err)
assert.Nil(t, meta)
Expand Down Expand Up @@ -189,13 +189,16 @@ func TestReadEntry(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, data)
assert.Equal(t, Entry{
UUID: entry.UUID,
Data: []byte("data"),
RemainingReads: 0,
DeleteKey: entry.DeleteKey,
Created: entry.Created,
Accessed: entry.Accessed.Time,
Expire: entry.Expire,
Data: []byte("data"),
EntryMeta: EntryMeta{
UUID: entry.UUID,
RemainingReads: 0,
DeleteKey: entry.DeleteKey,
Created: entry.Created,
Accessed: entry.Accessed.Time,
Expire: entry.Expire,
ContentType: entry.ContentType,
},
}, *data)

entryModel.AssertExpectations(t)
Expand Down
2 changes: 1 addition & 1 deletion internal/services/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
// EntryModel is the interface for the entry model
// It is used to create, read and access entries
type EntryModel interface {
CreateEntry(ctx context.Context, tx *sql.Tx, UUID string, data []byte, remainingReads int, expire time.Duration) (*models.EntryMeta, error)
CreateEntry(ctx context.Context, tx *sql.Tx, UUID string, contentType string, data []byte, remainingReads int, expire time.Duration) (*models.EntryMeta, error)
ReadEntry(ctx context.Context, tx *sql.Tx, UUID string) (*models.Entry, error)
Use(ctx context.Context, tx *sql.Tx, UUID string) error
DeleteEntry(ctx context.Context, tx *sql.Tx, UUID string, deleteKey string) error
Expand Down
Loading

0 comments on commit 4b56708

Please sign in to comment.