Skip to content

Commit

Permalink
Adds worker api key type (#1177)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzelei authored Jan 24, 2024
1 parent 69c6aba commit 336ee68
Show file tree
Hide file tree
Showing 20 changed files with 403 additions and 90 deletions.
1 change: 1 addition & 0 deletions backend/charts/api/templates/api-env-vars.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,4 @@ stringData:
{{- end }}

NEOSYNC_CLOUD: {{ .Values.neosyncCloud.enabled | default "false" | quote }}
NEOSYNC_CLOUD_ALLOWED_WORKER_API_KEYS: {{ join "," .Values.neosyncCloud.workerApiKeys }}
1 change: 1 addition & 0 deletions backend/charts/api/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,4 @@ volumeMounts: []

neosyncCloud:
enabled: false
workerApiKeys: []
31 changes: 29 additions & 2 deletions backend/internal/apikey/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,60 @@ import (
"github.com/google/uuid"
)

type ApiKeyType string

const (
AccountApiKey ApiKeyType = "account"
WorkerApiKey ApiKeyType = "worker"
)

const (
prefix = "neo"
accountTokenId = "at"
workerTokenId = "wt"
v1 = "v1"
separator = "_"

uuidPattern = `[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-4[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}`
)

var (
v1Prefix = strings.Join([]string{prefix, accountTokenId, v1}, separator)
v1AtPrefix = strings.Join([]string{prefix, accountTokenId, v1}, separator)
v1WtPrefix = strings.Join([]string{prefix, workerTokenId, v1}, separator)

v1AccountTokenPattern = fmt.Sprintf(
`^(%s)%s(%s)%sv([\d+])%s%s$`,
prefix, separator, accountTokenId, separator, separator, uuidPattern,
)
v1AccountTokenRegex = regexp.MustCompile(v1AccountTokenPattern)

v1WorkerTokenPattern = fmt.Sprintf(
`^(%s)%s(%s)%sv([\d+])%s%s$`,
prefix, separator, workerTokenId, separator, separator, uuidPattern,
)
v1WorkerTokenRegex = regexp.MustCompile(v1WorkerTokenPattern)
)

func NewV1AccountKey() string {
return v1AccountKey(uuid.NewString())
}

func v1AccountKey(suffix string) string {
return v1Prefix + separator + suffix
return v1AtPrefix + separator + suffix
}

func IsValidV1AccountKey(apikey string) bool {
return v1AccountTokenRegex.MatchString(apikey)
}

func NewV1WorkerKey() string {
return v1WorkerKey(uuid.NewString())
}

func v1WorkerKey(suffix string) string {
return v1WtPrefix + separator + suffix
}

func IsValidV1WorkerKey(apiKey string) bool {
return v1WorkerTokenRegex.MatchString(apiKey)
}
29 changes: 29 additions & 0 deletions backend/internal/apikey/apikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,33 @@ func Test_IsValidV1AccountKey(t *testing.T) {
t,
IsValidV1AccountKey(NewV1AccountKey()),
)
assert.False(
t,
IsValidV1AccountKey(NewV1WorkerKey()),
"worker keys should not pass as valid account keys",
)
}

func Test_NewV1WrokerKey(t *testing.T) {
assert.NotEmpty(t, NewV1WorkerKey())
}

func Test_v1WorkerKey(t *testing.T) {
assert.Equal(
t,
v1WorkerKey("foo-bar"),
"neo_wt_v1_foo-bar",
)
}

func Test_IsValidV1WorkerKey(t *testing.T) {
assert.True(
t,
IsValidV1WorkerKey(NewV1WorkerKey()),
)
assert.False(
t,
IsValidV1WorkerKey(NewV1AccountKey()),
"account keys should not pass as valid worker keys",
)
}
97 changes: 73 additions & 24 deletions backend/internal/auth/apikey/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package auth_apikey

import (
"context"
"crypto/subtle"
"errors"
"net/http"
"time"

"connectrpc.com/connect"
db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db"
"github.com/nucleuscloud/neosync/backend/internal/apikey"
nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors"
Expand All @@ -14,8 +16,9 @@ import (

type TokenContextKey struct{}
type TokenContextData struct {
RawToken string
ApiKey *db_queries.NeosyncApiAccountApiKey
RawToken string
ApiKey *db_queries.NeosyncApiAccountApiKey
ApiKeyType apikey.ApiKeyType
}

var (
Expand All @@ -28,43 +31,61 @@ type Queries interface {
}

type Client struct {
q Queries
db db_queries.DBTX
q Queries
db db_queries.DBTX
allowedWorkerApiKeys []string
allowedWorkerProcedures map[string]any
}

func New(
queries Queries,
db db_queries.DBTX,
allowedWorkerApiKeys []string,
allowedWorkerProcedures []string,
) *Client {
return &Client{q: queries, db: db}
allowedWorkerProcedureSet := map[string]any{}
for _, procedure := range allowedWorkerProcedures {
allowedWorkerProcedureSet[procedure] = struct{}{}
}
return &Client{q: queries, db: db, allowedWorkerApiKeys: allowedWorkerApiKeys, allowedWorkerProcedures: allowedWorkerProcedureSet}
}

func (c *Client) InjectTokenCtx(ctx context.Context, header http.Header) (context.Context, error) {
func (c *Client) InjectTokenCtx(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) {
token, err := utils.GetBearerTokenFromHeader(header, "Authorization")
if err != nil {
return nil, err
}
if !apikey.IsValidV1AccountKey(token) {
return nil, InvalidApiKeyErr
}

hashedKeyValue := utils.ToSha256(
token,
)
apiKey, err := c.q.GetAccountApiKeyByKeyValue(ctx, c.db, hashedKeyValue)
if err != nil {
return nil, err
}
if apikey.IsValidV1AccountKey(token) {
hashedKeyValue := utils.ToSha256(
token,
)
apiKey, err := c.q.GetAccountApiKeyByKeyValue(ctx, c.db, hashedKeyValue)
if err != nil {
return nil, err
}

if time.Now().After(apiKey.ExpiresAt.Time) {
return nil, ApiKeyExpiredErr
}
if time.Now().After(apiKey.ExpiresAt.Time) {
return nil, ApiKeyExpiredErr
}

newctx := context.WithValue(ctx, TokenContextKey{}, &TokenContextData{
RawToken: token,
ApiKey: &apiKey,
})
return newctx, err
newctx := context.WithValue(ctx, TokenContextKey{}, &TokenContextData{
RawToken: token,
ApiKey: &apiKey,
ApiKeyType: apikey.AccountApiKey,
})
return newctx, nil
} else if apikey.IsValidV1WorkerKey(token) &&
isApiKeyAllowed(c.allowedWorkerApiKeys, token) &&
isProcedureAllowed(c.allowedWorkerProcedures, spec.Procedure) {
newctx := context.WithValue(ctx, TokenContextKey{}, &TokenContextData{
RawToken: token,
ApiKey: nil,
ApiKeyType: apikey.WorkerApiKey,
})
return newctx, nil
}
return nil, InvalidApiKeyErr
}

func GetTokenDataFromCtx(ctx context.Context) (*TokenContextData, error) {
Expand All @@ -75,3 +96,31 @@ func GetTokenDataFromCtx(ctx context.Context) (*TokenContextData, error) {

return data, nil
}

func isApiKeyAllowed(allowedKeys []string, key string) bool {
for _, allowedKey := range allowedKeys {
if secureCompare(allowedKey, key) {
return true
}
}
return false
}

func isProcedureAllowed(allowedProcedures map[string]any, procedure string) bool {
_, ok := allowedProcedures[procedure]
return ok
}

func secureCompare(a, b string) bool {
// Convert strings to byte slices for comparison
aBytes := []byte(a)
bBytes := []byte(b)

// Check length first; if they differ, return false immediately
if len(aBytes) != len(bBytes) {
return false
}

// Use ConstantTimeCompare for a timing-attack resistant comparison
return subtle.ConstantTimeCompare(aBytes, bBytes) == 1
}
Loading

0 comments on commit 336ee68

Please sign in to comment.