Skip to content

Commit

Permalink
create database schema on connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Ajnasz committed Mar 9, 2024
1 parent d1142ac commit e22047a
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 21 deletions.
2 changes: 1 addition & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (
"testing"
"time"

"github.com/Ajnasz/sekret.link/internal/durable"
"github.com/Ajnasz/sekret.link/internal/models"
"github.com/Ajnasz/sekret.link/internal/services"
"github.com/Ajnasz/sekret.link/internal/test/durable"
"github.com/Ajnasz/sekret.link/internal/uuid"
)

Expand Down
10 changes: 6 additions & 4 deletions cmd/sekret.link/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func getAPIRoot(webExternalURL *url.URL) string {
return apiRoot
}

func getConfig() (*api.HandlerConfig, error) {
func getConfig(ctx context.Context) (*api.HandlerConfig, error) {
var (
externalURLParam string
expireSeconds int
Expand Down Expand Up @@ -153,20 +153,22 @@ func getConfig() (*api.HandlerConfig, error) {
if err != nil {
return nil, err
}
if err := models.PrepareDatabase(ctx, db); err != nil {
return nil, err
}
handlerConfig.DB = db

return &handlerConfig, nil
}

func main() {
handlerConfig, err := getConfig()
ctx, cancel := context.WithCancel(context.Background())
handlerConfig, err := getConfig(ctx)

if err != nil {
fmt.Fprintf(os.Stderr, "error: %s", err)
os.Exit(1)
}

ctx, cancel := context.WithCancel(context.Background())
go scheduleDeleteExpired(ctx, handlerConfig.DB)
httpServer := listen(*handlerConfig)

Expand Down
13 changes: 0 additions & 13 deletions internal/durable/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ import (
"context"
"database/sql"
"fmt"
"os"

"github.com/Ajnasz/sekret.link/internal/config"
)

type ConnectionInfo struct {
Expand All @@ -22,16 +19,6 @@ func (c ConnectionInfo) String() string {
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s", c.Username, c.Password, c.Host, c.Port, c.Database, c.SslMode)
}

// getPSQLTestConn returns connection string for tests
func getPSQLTestConn() string {
password := os.Getenv("POSTGRES_PASSWORD")

if password == "" {
password = "sekret_link_test"
}
return config.GetConnectionString(fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", "postgres", "password", "localhost", 5432, password))
}

func OpenDatabaseClient(ctx context.Context, connStr string) (*sql.DB, error) {
db, err := sql.Open("postgres", connStr)

Expand Down
2 changes: 1 addition & 1 deletion internal/models/entry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"testing"
"time"

"github.com/Ajnasz/sekret.link/internal/durable"
"github.com/Ajnasz/sekret.link/internal/test/durable"
"github.com/google/uuid"
)

Expand Down
112 changes: 112 additions & 0 deletions internal/models/prepera.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package models

import (
"context"
"database/sql"

"github.com/Ajnasz/sekret.link/internal/key"
)

type dbExec func(context.Context, *sql.DB) error

func createTable(ctx context.Context, db *sql.DB) error {
q, err := db.PrepareContext(ctx, `CREATE TABLE IF NOT EXISTS
entries (
uuid uuid PRIMARY KEY,
data BYTEA,
remaining_reads SMALLINT DEFAULT 1,
delete_key CHAR(256) NOT NULL,
created TIMESTAMPTZ,
accessed TIMESTAMPTZ,
expire TIMESTAMPTZ
);`)

if err != nil {
return err
}
_, err = q.Exec()

return err
}

func addRemainingRead(ctx context.Context, db *sql.DB) error {
alterTable, err := db.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS remaining_reads SMALLINT DEFAULT 1;")

if err != nil {
return err
}

_, err = alterTable.Exec()

return err
}

func addDeleteKey(ctx context.Context, db *sql.DB) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
alterTable, err := db.PrepareContext(ctx, "ALTER TABLE entries ADD COLUMN IF NOT EXISTS delete_key CHAR(256);")

if err != nil {
tx.Rollback()
return err
}

_, err = alterTable.ExecContext(ctx)

if err != nil {
tx.Rollback()
return err
}

rows, err := db.QueryContext(ctx, "SELECT uuid FROM entries WHERE delete_key IS NULL;")
if err != nil {
tx.Rollback()
return err
}

for rows.Next() {
var UUID string
if err := rows.Scan(&UUID); err != nil {
tx.Rollback()
return err
}

k, err := key.NewGeneratedKey()
if err != nil {
tx.Rollback()
return err
}

deleteKey := k.ToHex()

_, err = db.ExecContext(ctx, "UPDATE entries SET delete_key=$2 WHERE uuid=$1", UUID, deleteKey)
if err != nil {
tx.Rollback()
return err
}
}
_, err = db.ExecContext(ctx, "ALTER TABLE entries ALTER COLUMN delete_key SET NOT NULL;")
if err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}

func PrepareDatabase(ctx context.Context, db *sql.DB) error {
actions := []dbExec{
createTable,
addRemainingRead,
addDeleteKey,
}

for _, action := range actions {
if err := action(ctx, db); err != nil {
return err
}
}

return nil
}
6 changes: 4 additions & 2 deletions internal/durable/test.go → internal/test/durable/durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@ package durable
import (
"context"
"database/sql"

"github.com/Ajnasz/sekret.link/internal/durable"
)

func TestConnection(ctx context.Context) (*sql.DB, error) {
config := ConnectionInfo{
config := durable.ConnectionInfo{
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "password",
Database: "sekret_link_test",
SslMode: "disable",
}
return OpenDatabaseClient(ctx, config.String())
return durable.OpenDatabaseClient(ctx, config.String())
}

0 comments on commit e22047a

Please sign in to comment.