diff --git a/gateway/db/cache.go b/gateway/db/cache.go index 2f368c6..c7ee5bd 100644 --- a/gateway/db/cache.go +++ b/gateway/db/cache.go @@ -1,25 +1,26 @@ -package db +package db import ( - "context" - "errors" - "fmt" - log "log/slog" - "strconv" - "sync" - "time" - - "github.com/google/uuid" - "github.com/redis/go-redis/v9" - rl "github.com/go-redis/redis_rate/v10" - - "porters/common" + "context" + "errors" + "fmt" + log "log/slog" + "strconv" + "strings" + "sync" + "time" + + rl "github.com/go-redis/redis_rate/v10" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + + "porters/common" ) const ( ACCOUNT_SET = "VALID_ACCOUNTS" REDIS = "redis" - MISSED_FALSE = "0001-01-01T00:00:00Z" + MISSED_FALSE = "0001-01-01T00:00:00Z" ) // access redis functions through this object @@ -28,380 +29,413 @@ type Cache struct { } type RefreshTask struct { - ref refreshable + ref refreshable } type refreshable interface { - refreshAt() time.Time - refresh(ctx context.Context) error + refreshAt() time.Time + refresh(ctx context.Context) error } type Incrementable interface { - Key() string - Field() string + Key() string + Field() string } type Decrementable interface { - Key() string - Field() string + Key() string + Field() string } var client *redis.Client var redisMutex sync.Once func getCache() *redis.Client { - redisMutex.Do(func() { - opts, err := redis.ParseURL(common.GetConfig(common.REDIS_URL)) - if err != nil { - log.Warn("valid REDIS_URL not provided", "err", err) - opts = &redis.Options{ + redisMutex.Do(func() { + opts, err := redis.ParseURL(common.GetConfig(common.REDIS_URL)) + if err != nil { + log.Warn("valid REDIS_URL not provided", "err", err) + opts = &redis.Options{ Addr: common.GetConfig(common.REDIS_ADDR), - Username: common.GetConfig(common.REDIS_USER), - Password: common.GetConfig(common.REDIS_PASSWORD), + Username: common.GetConfig(common.REDIS_USER), + Password: common.GetConfig(common.REDIS_PASSWORD), DB: 0, - } - } - client = redis.NewClient(opts) - }) - return client + } + } + client = redis.NewClient(opts) + }) + return client } func needsRefresh(r refreshable) bool { - return r.refreshAt().Compare(time.Now()) > 0 + return r.refreshAt().Compare(time.Now()) > 0 } func (c *Cache) Healthcheck() *common.HealthCheckStatus { - hcs := common.NewHealthCheckStatus() - client := getCache() - ctx := context.Background() - status, err := client.Ping(ctx).Result() - if err != nil { - hcs.AddError(REDIS, err) - } else { - hcs.AddHealthy(REDIS, status) - } + hcs := common.NewHealthCheckStatus() + client := getCache() + ctx := context.Background() + status, err := client.Ping(ctx).Result() + if err != nil { + hcs.AddError(REDIS, err) + } else { + hcs.AddHealthy(REDIS, status) + } - return hcs + return hcs } func (t *Tenant) cache(ctx context.Context) error { - cached := time.Now() - err := getCache().HSet(ctx, t.Key(), - "active", t.Active, - "balance", t.Balance, - "cached", cached).Err() - if err != nil { - return err - } - return nil + cached := time.Now() + err := getCache().HSet(ctx, t.Key(), + "active", t.Active, + "balance", t.Balance, + "cached", cached).Err() + if err != nil { + return err + } + return nil } func (a *App) cache(ctx context.Context) error { - cached := time.Now() - err := getCache().HSet(ctx, a.Key(), - "active", a.Active, - "tenant", a.Tenant.Id, - "cached", cached, - "missedAt", a.MissedAt).Err() - if err != nil { - return err - } - return nil + cached := time.Now() + err := getCache().HSet(ctx, a.Key(), + "active", a.Active, + "tenant", a.Tenant.Id, + "cached", cached, + "missedAt", a.MissedAt).Err() + if err != nil { + return err + } + return nil } func (ar *Apprule) cache(ctx context.Context) error { - err := getCache().HSet(ctx, ar.Key(), - "active", ar.Active, - "value", ar.Value, - "appId", ar.App.Id, - "ruleType", ar.RuleType, - "cachedAt", time.Now()).Err() - if err != nil { - return err - } - return nil + err := getCache().HSet(ctx, ar.Key(), + "active", ar.Active, + "value", ar.Value, + "appId", ar.App.Id, + "ruleType", ar.RuleType, + "cachedAt", time.Now()).Err() + if err != nil { + return err + } + return nil } func (p *Product) cache(ctx context.Context) error { - err := getCache().HSet(ctx, p.Key(), - "poktId", p.PoktId, - "weight", p.Weight, - "active", p.Active, - "cachedAt", time.Now(), - "missedAt", p.MissedAt).Err() - if err != nil { - return err - } - return nil + err := getCache().HSet(ctx, p.Key(), + "poktId", p.PoktId, + "weight", p.Weight, + "active", p.Active, + "cachedAt", time.Now(), + "missedAt", p.MissedAt).Err() + if err != nil { + return err + } + return nil } func (t *Tenant) Lookup(ctx context.Context) error { - fromContext, ok := common.FromContext(ctx, TENANT) - if ok { - *t = *fromContext.(*Tenant) - } else { - key := t.Key() - result, err := getCache().HGetAll(ctx, key).Result() - if err != nil || len(result) == 0 { - log.Debug("tenant cache missing", "key", key) - t.refresh(ctx) - } else { - t.Active, _ = strconv.ParseBool(result["active"]) - t.Balance, _ = strconv.Atoi(result["balance"]) - t.CachedAt, _ = time.Parse(time.RFC3339, result["cachedAt"]) - } - - common.UpdateContext(ctx, t) - - if expired(t) { - common.GetTaskQueue().Add(&RefreshTask{ - ref: t, - }) - } - } - return nil + fromContext, ok := common.FromContext(ctx, TENANT) + if ok { + *t = *fromContext.(*Tenant) + } else { + key := t.Key() + result, err := getCache().HGetAll(ctx, key).Result() + if err != nil || len(result) == 0 { + log.Debug("tenant cache missing", "key", key) + t.refresh(ctx) + } else { + t.Active, _ = strconv.ParseBool(result["active"]) + t.Balance, _ = strconv.Atoi(result["balance"]) + t.CachedAt, _ = time.Parse(time.RFC3339, result["cachedAt"]) + } + + common.UpdateContext(ctx, t) + + if expired(t) { + common.GetTaskQueue().Add(&RefreshTask{ + ref: t, + }) + } + } + return nil } func (a *App) Lookup(ctx context.Context) error { - fromContext, ok := common.FromContext(ctx, APP) - if ok { - *a = *fromContext.(*App) - } else { - key := a.Key() - result, err := getCache().HGetAll(ctx, key).Result() - if err != nil || len(result) == 0 { - log.Debug("missed app", "appkey", key) - a.refresh(ctx) - } else if result["missedAt"] != MISSED_FALSE { - if backoff(result["missedAt"]) { - // NOOP - } else { - a.refresh(ctx) - } - } else { - log.Debug("got app from cache", "app", a.HashId()) - a.Active, _ = strconv.ParseBool(result["active"]) - a.Tenant.Id = result["tenant"] - a.Tenant.Lookup(ctx) - } - common.UpdateContext(ctx, a) - - if expired(a) { - common.GetTaskQueue().Add(&RefreshTask{ - ref: a, - }) - } - } - return nil + fromContext, ok := common.FromContext(ctx, APP) + if ok { + *a = *fromContext.(*App) + } else { + key := a.Key() + result, err := getCache().HGetAll(ctx, key).Result() + if err != nil || len(result) == 0 { + log.Debug("missed app", "appkey", key) + a.refresh(ctx) + } else if result["missedAt"] != MISSED_FALSE { + if backoff(result["missedAt"]) { + // NOOP + } else { + a.refresh(ctx) + } + } else { + log.Debug("got app from cache", "app", a.HashId()) + a.Active, _ = strconv.ParseBool(result["active"]) + a.Tenant.Id = result["tenant"] + a.Tenant.Lookup(ctx) + } + common.UpdateContext(ctx, a) + + if expired(a) { + common.GetTaskQueue().Add(&RefreshTask{ + ref: a, + }) + } + } + return nil } func (a *App) Rules(ctx context.Context) (Apprules, error) { - rules := make([]Apprule, 0) - pattern := fmt.Sprintf("%s:%s", APPRULE, a.Id) - - iter := ScanKeys(ctx, pattern) - for iter.Next(ctx) { - key := iter.Val() - result, err := getCache().HGetAll(ctx, key).Result() - if err != nil { - log.Error("error during scan", "err", err) + rules := make([]Apprule, 0) + pattern := fmt.Sprintf("%s:%s", APPRULE, a.Id) + + iter := ScanKeys(ctx, pattern) + for iter.Next(ctx) { + key := iter.Val() + result, err := getCache().HGetAll(ctx, key).Result() + if err != nil { + log.Error("error during scan", "err", err) + continue + } + + // Extract the actual ID from the Redis key + parts := strings.Split(key, ":") + if len(parts) != 3 || parts[0] != APPRULE { + log.Error("Invalid key format", "key", key) continue } - id := key - active, _ := strconv.ParseBool(result["active"]) - cachedAt, _ := time.Parse(time.RFC3339, result["cachedAt"]) - ar := Apprule{ - Id: id, + appId := parts[1] + appRuleId := parts[2] + + active, _ := strconv.ParseBool(result["active"]) + cachedAt, _ := time.Parse(time.RFC3339, result["cachedAt"]) + ar := Apprule{ + Id: appRuleId, Active: active, Value: result["value"], - RuleType: result["ruleType"], - CachedAt: cachedAt, - } - rules = append(rules, ar) - } - return rules, nil + RuleType: result["ruleType"], + CachedAt: cachedAt, + App: App{Id: appId}, + } + + // Check if the Apprule needs to be refreshed + if expired(&ar) { + log.Debug("Apprule is expired, adding refresh task", "apprule", ar.Id) + common.GetTaskQueue().Add(&RefreshTask{ + ref: &ar, + }) + } + + rules = append(rules, ar) + } + return rules, nil } // Lookup by name, p should have a valid "Name" set before lookup func (p *Product) Lookup(ctx context.Context) error { - fromContext, ok := common.FromContext(ctx, PRODUCT) - if ok { - *p = *fromContext.(*Product) - } else { - key := p.Key() - log.Debug("finding product from cache", "prodkey", key) - result, err := getCache().HGetAll(ctx, key).Result() - if err != nil || len(result) == 0 { - log.Debug("missed product", "prodkey", key) - p.refresh(ctx) - } else if result["missedAt"] != MISSED_FALSE { - if backoff(result["missedAt"]) { - // NOOP - } else { - p.refresh(ctx) - } - } else { - p.PoktId, _ = result["poktId"] - p.Weight, _ = strconv.Atoi(result["weight"]) - p.Active, _ = strconv.ParseBool(result["active"]) - } - - common.UpdateContext(ctx, p) - - if expired(p) { - common.GetTaskQueue().Add(&RefreshTask{ - ref: p, - }) - } - } - return nil + fromContext, ok := common.FromContext(ctx, PRODUCT) + if ok { + *p = *fromContext.(*Product) + } else { + key := p.Key() + log.Debug("finding product from cache", "prodkey", key) + result, err := getCache().HGetAll(ctx, key).Result() + if err != nil || len(result) == 0 { + log.Debug("missed product", "prodkey", key) + p.refresh(ctx) + } else if result["missedAt"] != MISSED_FALSE { + if backoff(result["missedAt"]) { + // NOOP + } else { + p.refresh(ctx) + } + } else { + p.PoktId, _ = result["poktId"] + p.Weight, _ = strconv.Atoi(result["weight"]) + p.Active, _ = strconv.ParseBool(result["active"]) + } + + common.UpdateContext(ctx, p) + + if expired(p) { + common.GetTaskQueue().Add(&RefreshTask{ + ref: p, + }) + } + } + return nil } func RelaytxFromKey(ctx context.Context, key string) (*Relaytx, bool) { - relaycount := GetIntVal(ctx, key) - rtx := reverseRelaytxKey(key) - if relaycount > 0 && rtx.AppId != "" && rtx.ProductName != "" { - uuid := uuid.New() - rtx.Id = uuid.String() - rtx.Reference = uuid.String() - rtx.Amount = relaycount - rtx.TxType = Credit - return rtx, true - } - return rtx, false + relaycount := GetIntVal(ctx, key) + rtx := reverseRelaytxKey(key) + if relaycount > 0 && rtx.AppId != "" && rtx.ProductName != "" { + uuid := uuid.New() + rtx.Id = uuid.String() + rtx.Reference = uuid.String() + rtx.Amount = relaycount + rtx.TxType = Credit + return rtx, true + } + return rtx, false } // Refresh does the psql calls to build cache func (t *Tenant) refresh(ctx context.Context) error { - err := t.fetch(ctx) - if err != nil { - log.Error("something's wrong", "tenant", t.Id, "err", err) - return err - } else { - err := t.canonicalBalance(ctx) - if err != nil { - log.Error("error getting balance", "tenant", t.Id, "err", err) - } - t.cache(ctx) - } - return nil + err := t.fetch(ctx) + if err != nil { + log.Error("something's wrong", "tenant", t.Id, "err", err) + return err + } else { + err := t.canonicalBalance(ctx) + if err != nil { + log.Error("error getting balance", "tenant", t.Id, "err", err) + } + t.cache(ctx) + } + return nil } func (a *App) refresh(ctx context.Context) error { - err := a.fetch(ctx) - if err != nil { - log.Error("err seen refreshing app", "app", a.HashId(), "err", err) - a.MissedAt = time.Now() - } else { - a.Tenant.Lookup(ctx) - } - a.cache(ctx) - - rules, err := a.fetchRules(ctx) - if err != nil { - log.Error("error accessing rules", "app", a.HashId(), "err", err) - return err - } - for _, r := range rules { - r.cache(ctx) - } - return nil + err := a.fetch(ctx) + if err != nil { + log.Error("err seen refreshing app", "app", a.HashId(), "err", err) + a.MissedAt = time.Now() + } else { + a.Tenant.Lookup(ctx) + } + a.cache(ctx) + + rules, err := a.fetchRules(ctx) + if err != nil { + log.Error("error accessing rules", "app", a.HashId(), "err", err) + return err + } + for _, r := range rules { + r.cache(ctx) + } + return nil +} + +func (ar *Apprule) refresh(ctx context.Context) error { + err := ar.fetch(ctx) + if err != nil { + log.Error("error seen refreshing apprule", "apprule", ar.Id, "err", err) + return err + } + ar.cache(ctx) + return nil } func (p *Product) refresh(ctx context.Context) error { - err := p.fetch(ctx) - if err != nil { - log.Error("err getting product", "product", p.Name, "err", err) - p.MissedAt = time.Now() - } - p.cache(ctx) - return nil + err := p.fetch(ctx) + if err != nil { + log.Error("err getting product", "product", p.Name, "err", err) + p.MissedAt = time.Now() + } + p.cache(ctx) + return nil } func (t *RefreshTask) Run() { - ctx := context.Background() - err := t.ref.refresh(ctx) - if err != nil { - common.GetTaskQueue().ReportError(errors.New(t.Error())) - } + ctx := context.Background() + err := t.ref.refresh(ctx) + if err != nil { + common.GetTaskQueue().ReportError(errors.New(t.Error())) + } } func (t *RefreshTask) Error() string { - return "error processing refresh" + return "error processing refresh" } func (t *Tenant) refreshAt() time.Time { - return t.CachedAt.Add(1 * time.Minute) + return t.CachedAt.Add(1 * time.Minute) } func (a *App) refreshAt() time.Time { - return a.CachedAt.Add(1 * time.Minute) + return a.CachedAt.Add(1 * time.Minute) +} + +func (ar *Apprule) refreshAt() time.Time { + return ar.CachedAt.Add(1 * time.Minute) } // Products rarely change, hourly is ok func (p *Product) refreshAt() time.Time { - return p.CachedAt.Add(1 * time.Hour) + return p.CachedAt.Add(1 * time.Hour) } func backoff(missedAt string) bool { - missedTime, err := time.Parse(time.RFC3339, missedAt) - if err != nil { - return false // something is wrong with time format, refresh to fix - } - return time.Now().Before(missedTime.Add(5 * time.Minute)) + missedTime, err := time.Parse(time.RFC3339, missedAt) + if err != nil { + return false // something is wrong with time format, refresh to fix + } + return time.Now().Before(missedTime.Add(5 * time.Minute)) } func expired(ref refreshable) bool { - return time.Now().After(ref.refreshAt()) + return time.Now().After(ref.refreshAt()) } // utility function func GetIntVal(ctx context.Context, name string) int { - result, err := getCache().Get(ctx, name).Result() - if err != nil { - return 0 - } - intval, err := strconv.Atoi(result) - if err != nil { - return 0 - } - return intval + result, err := getCache().Get(ctx, name).Result() + if err != nil { + return 0 + } + intval, err := strconv.Atoi(result) + if err != nil { + return 0 + } + return intval } func IncrementField(ctx context.Context, incr Incrementable, amount int) int { - incrBy := int64(amount) - newVal, err := getCache().HIncrBy(ctx, incr.Key(), incr.Field(), incrBy).Result() - if err != nil { - return -1 - } - return int(newVal) + incrBy := int64(amount) + newVal, err := getCache().HIncrBy(ctx, incr.Key(), incr.Field(), incrBy).Result() + if err != nil { + return -1 + } + return int(newVal) } func DecrementField(ctx context.Context, decr Decrementable, amount int) int { - decrBy := -int64(amount) - newVal, err := getCache().HIncrBy(ctx, decr.Key(), decr.Field(), decrBy).Result() - if err != nil { - return -1 - } - return int(newVal) + decrBy := -int64(amount) + newVal, err := getCache().HIncrBy(ctx, decr.Key(), decr.Field(), decrBy).Result() + if err != nil { + return -1 + } + return int(newVal) } func IncrementCounter(ctx context.Context, key string, amount int) int { - incrBy := int64(amount) - newVal, err := getCache().IncrBy(ctx, key, incrBy).Result() - if err != nil { - return -1 - } - return int(newVal) + incrBy := int64(amount) + newVal, err := getCache().IncrBy(ctx, key, incrBy).Result() + if err != nil { + return -1 + } + return int(newVal) } func DecrementCounter(ctx context.Context, key string, amount int) int { - decrBy := int64(amount) - newVal, err := getCache().DecrBy(ctx, key, decrBy).Result() - if err != nil { - return -1 - } - return int(newVal) + decrBy := int64(amount) + newVal, err := getCache().DecrBy(ctx, key, decrBy).Result() + if err != nil { + return -1 + } + return int(newVal) } // returns false if counter already exists @@ -410,31 +444,31 @@ func InitCounter(ctx context.Context, key string, initValue int) (bool, error) { } func ReconcileRelays(ctx context.Context, rtx *Relaytx) (func() bool, error) { - // Can ignore new value - _, err := getCache().DecrBy(ctx, rtx.Key(), int64(rtx.Amount)).Result() - if err != nil { + // Can ignore new value + _, err := getCache().DecrBy(ctx, rtx.Key(), int64(rtx.Amount)).Result() + if err != nil { return func() bool {return false}, err - } - - updateFunc := func() bool { - background := context.Background() - pgerr := rtx.write(background) - if pgerr != nil { - log.Error("couldn't write relaytx", "pgerr", pgerr) - return false - } - return true - } - return updateFunc, nil + } + + updateFunc := func() bool { + background := context.Background() + pgerr := rtx.write(background) + if pgerr != nil { + log.Error("couldn't write relaytx", "pgerr", pgerr) + return false + } + return true + } + return updateFunc, nil } func ScanKeys(ctx context.Context, key string) *redis.ScanIterator { - scankey := fmt.Sprintf("%s:*", key) - iter := getCache().Scan(ctx, 0, scankey, 0).Iterator() - return iter + scankey := fmt.Sprintf("%s:*", key) + iter := getCache().Scan(ctx, 0, scankey, 0).Iterator() + return iter } func Limiter() *rl.Limiter { - rdb := getCache() - return rl.NewLimiter(rdb) + rdb := getCache() + return rl.NewLimiter(rdb) } diff --git a/gateway/db/canonical.go b/gateway/db/canonical.go index 8098fcc..ea590be 100644 --- a/gateway/db/canonical.go +++ b/gateway/db/canonical.go @@ -1,28 +1,29 @@ package db import ( - "context" - "database/sql" - "errors" - log "log/slog" - "sync" + "context" + "database/sql" + "fmt" + "errors" + log "log/slog" + "sync" - "github.com/lib/pq" + "github.com/lib/pq" - "porters/common" + "porters/common" ) // Cache sits in front of canonical db (postgres) // Protect access to this to reduce it to minimal traffic type DB interface { - common.HealthCheck - Conn(ctx context.Context) (*sql.Conn, error) - Report() sql.DBStats + common.HealthCheck + Conn(ctx context.Context) (*sql.Conn, error) + Report() sql.DBStats } type Canonical struct { - // any instance specific stuff here - // singleton holds + // any instance specific stuff here + // singleton holds } type DBFunc func(*sql.Conn) error @@ -31,45 +32,45 @@ var postgresPool *sql.DB var postgresMutex sync.Once func getCanonicalDB() *sql.DB { - postgresMutex.Do(func() { - connStr := common.GetConfig(common.DATABASE_URL) - connector, err := pq.NewConnector(connStr) - if err != nil { - log.Error("Cannot connect to postgres", "err", err) - panic("database required") - } - postgresPool = sql.OpenDB(connector) - }) - return postgresPool + postgresMutex.Do(func() { + connStr := common.GetConfig(common.DATABASE_URL) + connector, err := pq.NewConnector(connStr) + if err != nil { + log.Error("Cannot connect to postgres", "err", err) + panic("database required") + } + postgresPool = sql.OpenDB(connector) + }) + return postgresPool } // Wrapping conn function to not expose the DB outside package func (c *Canonical) Conn(ctx context.Context) (*sql.Conn, error) { - db := getCanonicalDB() - return db.Conn(ctx) + db := getCanonicalDB() + return db.Conn(ctx) } func (c *Canonical) Healthcheck() *common.HealthCheckStatus { - hc := common.NewHealthCheckStatus() - db := getCanonicalDB() - err := db.Ping() - if err != nil { - hc.AddError("postgres", err) - } else { - hc.AddHealthy("postgres", "connected") - } - return hc + hc := common.NewHealthCheckStatus() + db := getCanonicalDB() + err := db.Ping() + if err != nil { + hc.AddError("postgres", err) + } else { + hc.AddHealthy("postgres", "connected") + } + return hc } func (t *Tenant) fetch(ctx context.Context) error { - db := getCanonicalDB() - query := `SELECT id, active FROM "Tenant" WHERE id = $1 AND "deletedAt" IS NULL` - row := db.QueryRowContext(ctx, query, t.Id) - err := row.Scan(&t.Id, &t.Active) - if err != nil { - return err - } - return nil + db := getCanonicalDB() + query := `SELECT id, active FROM "Tenant" WHERE id = $1 AND "deletedAt" IS NULL` + row := db.QueryRowContext(ctx, query, t.Id) + err := row.Scan(&t.Id, &t.Active) + if err != nil { + return err + } + return nil } // Special function on tenant to update the "official" balance @@ -77,8 +78,8 @@ func (t *Tenant) fetch(ctx context.Context) error { // cached balance is incremented on new CREDIT txns and and needs to track last // "createdAt" func (t *Tenant) canonicalBalance(ctx context.Context) error { - db := getCanonicalDB() - query := `SELECT payment.balance - relay.usage as net FROM + db := getCanonicalDB() + query := `SELECT payment.balance - relay.usage as net FROM (SELECT COALESCE(SUM(case when "transactionType"='CREDIT' then amount else 0 end) - SUM(case when "transactionType"='DEBIT' then amount else 0 end), 0) @@ -86,73 +87,84 @@ func (t *Tenant) canonicalBalance(ctx context.Context) error { (SELECT COALESCE(SUM(case when "transactionType"='CREDIT' then amount else 0 end) - SUM(case when "transactionType"='DEBIT' then amount else 0 end), 0) - AS usage FROM "RelayLedger" WHERE "tenantId" = $1) as relay` - row := db.QueryRowContext(ctx, query, t.Id) - err := row.Scan(&t.Balance) - if err != nil { - return err - } - return nil + AS usage FROM "RelayLedger" WHERE "tenantId" = $1) as relay` + row := db.QueryRowContext(ctx, query, t.Id) + err := row.Scan(&t.Balance) + if err != nil { + return err + } + return nil } func (a *App) fetch(ctx context.Context) error { - db := getCanonicalDB() - row := db.QueryRowContext(ctx, `SELECT id, active, "tenantId" FROM "App" WHERE id = $1 AND "deletedAt" IS NULL`, a.Id) - err := row.Scan(&a.Id, &a.Active, &a.Tenant.Id) - if err != nil { - return err - } - return nil + db := getCanonicalDB() + row := db.QueryRowContext(ctx, `SELECT id, active, "tenantId" FROM "App" WHERE id = $1 AND "deletedAt" IS NULL`, a.Id) + err := row.Scan(&a.Id, &a.Active, &a.Tenant.Id) + if err != nil { + return err + } + return nil } func (a *App) fetchRules(ctx context.Context) (Apprules, error) { - rules := make([]Apprule, 0) - db := getCanonicalDB() - rows, err := db.QueryContext(ctx, `SELECT "AppRule".id, "AppRule".value, "AppRule".active, "RuleType".name FROM "AppRule", "RuleType" WHERE "AppRule"."appId" = $1 AND "AppRule"."deletedAt" IS NULL AND "RuleType".active = '1' AND "AppRule"."ruleId" = "RuleType"."id"`, a.Id) - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - apprule := Apprule{} - err := rows.Scan(&apprule.Id, &apprule.Value, &apprule.Active, &apprule.RuleType) - if err != nil { - return nil, err - } - apprule.App = *a - rules = append(rules, apprule) - } - return rules, nil + rules := make([]Apprule, 0) + db := getCanonicalDB() + rows, err := db.QueryContext(ctx, `SELECT "AppRule".id, "AppRule".value, "AppRule".active, "RuleType".name FROM "AppRule", "RuleType" WHERE "AppRule"."appId" = $1 AND "AppRule"."deletedAt" IS NULL AND "RuleType".active = '1' AND "AppRule"."ruleId" = "RuleType"."id"`, a.Id) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + apprule := Apprule{} + err := rows.Scan(&apprule.Id, &apprule.Value, &apprule.Active, &apprule.RuleType) + if err != nil { + return nil, err + } + apprule.App = *a + rules = append(rules, apprule) + } + return rules, nil } func (p *Product) fetch(ctx context.Context) error { - db := getCanonicalDB() - row := db.QueryRowContext(ctx, `SELECT id, name, "poktId", weight, active FROM "Products" WHERE name = $1`, p.Name) - err := row.Scan(&p.Id, &p.Name, &p.PoktId, &p.Weight, &p.Active) - if err != nil { - return err - } - return nil + db := getCanonicalDB() + row := db.QueryRowContext(ctx, `SELECT id, name, "poktId", weight, active FROM "Products" WHERE name = $1`, p.Name) + err := row.Scan(&p.Id, &p.Name, &p.PoktId, &p.Weight, &p.Active) + if err != nil { + return err + } + return nil +} + +func (ar *Apprule) fetch(ctx context.Context) error { + db := getCanonicalDB() + row := db.QueryRowContext(ctx, `SELECT "AppRule".id, "AppRule".value, "AppRule".active, "RuleType".name FROM "AppRule", "RuleType" WHERE "AppRule".id = $1 AND "AppRule"."ruleId" = "RuleType".id`, ar.Id) + err := row.Scan(&ar.Id, &ar.Value, &ar.Active, &ar.RuleType) + if err != nil { + log.Error("Error during row.Scan", "err", err, "appruleId", ar.Id) + return fmt.Errorf("failed to fetch Apprule with ID %s: %w", ar.Id, err) + } + return nil } func (rtx *Relaytx) write(ctx context.Context) error { - db := getCanonicalDB() - res, err := db.ExecContext(ctx, `INSERT INTO "RelayLedger" + db := getCanonicalDB() + res, err := db.ExecContext(ctx, `INSERT INTO "RelayLedger" ("id", "tenantId", "referenceId", "amount", "productId", "transactionType") VALUES - ($1, (SELECT "tenantId" FROM "App" WHERE id = $2), $3, $4, (SELECT id FROM "Products" WHERE name = $5), 'CREDIT')`, rtx.Id, rtx.AppId, rtx.Reference, rtx.Amount, rtx.ProductName) - if err != nil { - return err - } else { - rows, err := res.RowsAffected() - if err != nil { - return err - } - if rows != 1 { - return errors.New("unable to insert to RelayLedger") - } else { - return nil - } - } + ($1, (SELECT "tenantId" FROM "App" WHERE id = $2), $3, $4, (SELECT id FROM "Products" WHERE name = $5), 'CREDIT')`, rtx.Id, rtx.AppId, rtx.Reference, rtx.Amount, rtx.ProductName) + if err != nil { + return err + } else { + rows, err := res.RowsAffected() + if err != nil { + return err + } + if rows != 1 { + return errors.New("unable to insert to RelayLedger") + } else { + return nil + } + } } diff --git a/gateway/main.go b/gateway/main.go index 0920890..4e07d3c 100644 --- a/gateway/main.go +++ b/gateway/main.go @@ -17,6 +17,7 @@ func main() { proxy.Register(&plugins.ApiKeyAuth{"X-API"}) proxy.Register(&plugins.BalanceTracker{}) proxy.Register(&plugins.LeakyBucketPlugin{"APP"}) + proxy.Register(&plugins.ProductFilter{}) proxy.Register(&plugins.UserAgentFilter{}) proxy.Register(&plugins.AllowedOriginFilter{}) proxy.Register(proxy.NewReconciler(300)) // seconds diff --git a/gateway/plugins/origin.go b/gateway/plugins/origin.go index a761e5c..eda2608 100644 --- a/gateway/plugins/origin.go +++ b/gateway/plugins/origin.go @@ -1,99 +1,105 @@ package plugins import ( - "context" - log "log/slog" - "net/http" - "strings" + "context" + log "log/slog" + "net/http" + "strings" - "porters/db" - "porters/proxy" + "porters/db" + "porters/proxy" ) const ( - ORIGIN_HEADER string = "Origin" - ALLOWED_ORIGIN = "allowed-origins" + ORIGIN_HEADER string = "Origin" + ALLOWED_ORIGIN = "allowed-origins" ) type AllowedOriginFilter struct { - } func (a *AllowedOriginFilter) Name() string { - return "Allowed Origin Filter" + return "Allowed Origin Filter" } func (a *AllowedOriginFilter) Key() string { - return "ORIGIN" + return "ORIGIN" } func (a *AllowedOriginFilter) Load() { - log.Debug("loading plugin", "plugin", a.Name()) + log.Debug("loading plugin", "plugin", a.Name()) } func (a *AllowedOriginFilter) HandleRequest(req *http.Request) error { - ctx := req.Context() - origin := req.Header.Get(ORIGIN_HEADER) - app := &db.App{ - Id: proxy.PluckAppId(req), - } - err := app.Lookup(ctx) - if err != nil { - return proxy.NewHTTPError(http.StatusNotFound) - } - - rules := a.getRulesForScope(ctx, app) - allow := a.matchesRules(origin, rules) - - if !allow { - return proxy.NewHTTPError(http.StatusUnauthorized) - } - - return nil + ctx := req.Context() + origin := req.Header.Get(ORIGIN_HEADER) + app := &db.App{ + Id: proxy.PluckAppId(req), + } + err := app.Lookup(ctx) + if err != nil { + return proxy.NewHTTPError(http.StatusNotFound) + } + + rules := a.getRulesForScope(ctx, app) + allow := a.matchesRules(origin, rules) + + if !allow { + return proxy.NewHTTPError(http.StatusUnauthorized) + } + + return nil } func (a *AllowedOriginFilter) HandleResponse(resp *http.Response) error { - ctx := resp.Request.Context() - app := &db.App{ - Id: proxy.PluckAppId(resp.Request), - } - err := app.Lookup(ctx) - if err != nil { - return nil // don't modify header - } - - rules := a.getRulesForScope(ctx, app) - if len(rules) > 0 { - allowedOrigins := strings.Join(rules, ",") - resp.Header.Set("Access-Control-Allow-Origin", allowedOrigins) - } - return nil + ctx := resp.Request.Context() + app := &db.App{ + Id: proxy.PluckAppId(resp.Request), + } + err := app.Lookup(ctx) + if err != nil { + return nil // don't modify header + } + + rules := a.getRulesForScope(ctx, app) + resp.Header.Set("Access-Control-Allow-Headers", "authorization, content-type, server") + resp.Header.Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + + var allowedOrigins string + if len(rules) > 0 { + allowedOrigins = strings.Join(rules, ",") + } else { + allowedOrigins = "*" // default value if no rules are found + } + resp.Header.Set("Access-Control-Allow-Origin", allowedOrigins) + + return nil } func (a *AllowedOriginFilter) getRulesForScope(ctx context.Context, app *db.App) []string { - origins := make([]string, 0) - rules, err := app.Rules(ctx) - if err != nil { - log.Error("couldn't get rules", "app", app.HashId(), "err", err) - } else { - for _, rule := range rules { - if rule.RuleType != ALLOWED_ORIGIN || !rule.Active { - continue - } - origins = append(origins, rule.Value) - } - } - return origins + origins := make([]string, 0) + rules, err := app.Rules(ctx) + if err != nil { + log.Error("couldn't get rules", "app", app.HashId(), "err", err) + } else { + for _, rule := range rules { + if rule.RuleType != ALLOWED_ORIGIN || !rule.Active { + continue + } + origins = append(origins, rule.Value) + } + } + return origins } func (a *AllowedOriginFilter) matchesRules(origin string, rules []string) bool { - if len(rules) == 0 { - return true - } - for _, rule := range rules { - if strings.EqualFold(rule, origin) { - return true - } - } - return false + if len(rules) == 0 { + return true + } + for _, rule := range rules { + if strings.EqualFold(rule, origin) { + return true + } + } + return false } diff --git a/gateway/plugins/productfilter.go b/gateway/plugins/productfilter.go index 336b05c..f4179e4 100644 --- a/gateway/plugins/productfilter.go +++ b/gateway/plugins/productfilter.go @@ -1,74 +1,81 @@ package plugins import ( - "context" - log "log/slog" - "net/http" + "context" + log "log/slog" + "net/http" - "porters/db" - "porters/proxy" + "porters/db" + "porters/proxy" ) const ( - ALLOWED_PRODUCTS string = "approved-chains" + ALLOWED_PRODUCTS string = "approved-chains" ) type ProductFilter struct { - } func (p *ProductFilter) Name() string { - return "Approved Chains" + return "Approved Chains" } func (p *ProductFilter) Key() string { - return "ALLOWEDPRODUCT" + return "ALLOWEDPRODUCT" } func (p *ProductFilter) Load() { - log.Debug("loading plugin", "plugin", p.Name()) + log.Debug("loading plugin", "plugin", p.Name()) } func (p *ProductFilter) HandleRequest(req *http.Request) error { - ctx := req.Context() - product := proxy.PluckProductName(req) - app := &db.App{ - Id: proxy.PluckAppId(req), - } - err := app.Lookup(ctx) - if err != nil { - return proxy.NewHTTPError(http.StatusNotFound) - } - - rules := p.getRulesForScope(ctx, app) - allow := (len(rules) == 0) - - for _, rule := range rules { - if rule == product { - allow = true - break - } - } - - if !allow { - return proxy.NewHTTPError(http.StatusUnauthorized) - } - return nil + ctx := req.Context() + product := proxy.PluckProductName(req) + + app := &db.App{ + Id: proxy.PluckAppId(req), + } + + err := app.Lookup(ctx) + if err != nil { + return proxy.NewHTTPError(http.StatusNotFound) + } + + rules := p.getRulesForScope(ctx, app) + + allow := (len(rules) == 0) + + for _, rule := range rules { + log.Debug("Checking rule against product", "rule", rule, "product", product) + if rule == product { + allow = true + break + } + } + + if !allow { + log.Error("Unauthorized access attempt", "product", product) + return proxy.NewHTTPError(http.StatusUnauthorized) + } + + log.Debug("Request allowed", "product", product) + return nil } func (p *ProductFilter) getRulesForScope(ctx context.Context, app *db.App) []string { - products := make([]string, 0) - rules, err := app.Rules(ctx) - if err != nil { - log.Error("couldn't read rules", "app", app.HashId(), "err", err) - } else { - for _, rule := range rules { - if rule.RuleType != ALLOWED_PRODUCTS || !rule.Active { - continue - } - log.Debug("allowing product", "product", rule.Value) - products = append(products, rule.Value) - } - } - return products + products := make([]string, 0) + rules, err := app.Rules(ctx) + if err != nil { + log.Error("couldn't read rules", "app", app.HashId(), "err", err) + } else { + for _, rule := range rules { + if rule.RuleType != ALLOWED_PRODUCTS || !rule.Active { + log.Debug("blocking product", "product", rule.Value) + continue + } + log.Debug("allowing product", "product", rule.Value) + products = append(products, rule.Value) + } + } + return products } diff --git a/gateway/proxy/proxy.go b/gateway/proxy/proxy.go index 0fb9758..f6b0470 100644 --- a/gateway/proxy/proxy.go +++ b/gateway/proxy/proxy.go @@ -1,213 +1,215 @@ package proxy import ( - "context" - "errors" - "fmt" - log "log/slog" - "net/url" - "net/http" - "net/http/httputil" - "time" - - "github.com/gorilla/mux" - - "porters/common" - "porters/db" - "porters/utils" + "context" + "errors" + "fmt" + log "log/slog" + "net/http" + "net/http/httputil" + "net/url" + "time" + + "github.com/gorilla/mux" + + "porters/common" + "porters/db" + "porters/utils" ) var server *http.Server func Start() { - proxyUrl := common.GetConfig(common.PROXY_TO) - remote, err := url.Parse(proxyUrl) - if err != nil { - log.Error("unable to parse proxy to", "err", err) - panic("unable to start with invalid remote url") - } - log.Debug("proxying to remote", "url", remote) - - handler := func(proxy *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) { - return func(resp http.ResponseWriter, req *http.Request) { - setupContext(req) - proxy.ServeHTTP(resp, req) - } - } - - revProxy := setupProxy(remote) - router := mux.NewRouter() - - proxyRouter := addProxyRoutes(router) - proxyRouter.HandleFunc(fmt.Sprintf(`/{%s}`, APP_PATH), handler(revProxy)) - - _ = addHealthcheckRoute(router) - _ = addMetricsRoute(router) - - port := fmt.Sprintf(":%d", common.GetConfigInt(common.PORT)) - server = &http.Server{Addr: port, Handler: router} - go func() { - err := server.ListenAndServe() - if err != nil { - log.Error("server error encountered", "err", err) - } - }() + proxyUrl := common.GetConfig(common.PROXY_TO) + remote, err := url.Parse(proxyUrl) + if err != nil { + log.Error("unable to parse proxy to", "err", err) + panic("unable to start with invalid remote url") + } + log.Debug("proxying to remote", "url", remote) + + handler := func(proxy *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) { + return func(resp http.ResponseWriter, req *http.Request) { + setupContext(req) + proxy.ServeHTTP(resp, req) + } + } + + revProxy := setupProxy(remote) + router := mux.NewRouter() + + proxyRouter := addProxyRoutes(router) + proxyRouter.HandleFunc(fmt.Sprintf(`/{%s}`, APP_PATH), handler(revProxy)) + + _ = addHealthcheckRoute(router) + _ = addMetricsRoute(router) + + port := fmt.Sprintf(":%d", common.GetConfigInt(common.PORT)) + server = &http.Server{Addr: port, Handler: router} + go func() { + err := server.ListenAndServe() + if err != nil { + log.Error("server error encountered", "err", err) + } + }() } func Stop() { - // 5 second shutdown - shutdownTime := time.Duration(common.GetConfigInt(common.SHUTDOWN_DELAY)) * time.Second - ctx, cancel := context.WithTimeout(context.Background(), shutdownTime) - defer cancel() - - err := server.Shutdown(ctx) - if err != nil { - log.Error("error shutting down", "err", err) - } else { - log.Info("shutdown successful") - } + // 5 second shutdown + shutdownTime := time.Duration(common.GetConfigInt(common.SHUTDOWN_DELAY)) * time.Second + ctx, cancel := context.WithTimeout(context.Background(), shutdownTime) + defer cancel() + + err := server.Shutdown(ctx) + if err != nil { + log.Error("error shutting down", "err", err) + } else { + log.Info("shutdown successful") + } } func RequestCanceler(req *http.Request) context.CancelCauseFunc { - ctx, cancel := context.WithCancelCause(req.Context()) - *req = *req.WithContext(ctx) - return cancel + ctx, cancel := context.WithCancelCause(req.Context()) + *req = *req.WithContext(ctx) + return cancel } func setupProxy(remote *url.URL) *httputil.ReverseProxy { - revProxy := httputil.NewSingleHostReverseProxy(remote) - reg := GetRegistry() - - defaultDirector := revProxy.Director - revProxy.Director = func(req *http.Request) { - defaultDirector(req) - - cancel := RequestCanceler(req) - req.Host = remote.Host - - poktId, ok := lookupPoktId(req) - if !ok { - cancel(ChainNotSupportedError) - } - target := utils.NewTarget(remote, poktId) - req.URL = target.URL() - - - for _, p := range (*reg).plugins { - h, ok := p.(PreHandler) - if ok { - select { - case <-req.Context().Done(): - return - default: - err := h.HandleRequest(req) - if err != nil { - cancel(err) - } - } - } - } - - // Cancel if necessary lifecycle stages not completed - lifecycle := lifecycleFromContext(req.Context()) - if !lifecycle.checkComplete() { - err := LifecycleIncompleteError - log.Debug("lifecycle incomplete", "mask", lifecycle) - cancel(err) - } - - if common.Enabled(common.INSTRUMENT_ENABLED) { - ctx := req.Context() - instr, ok := common.FromContext(ctx, common.INSTRUMENT) - if ok { - start := instr.(*common.Instrument).Timestamp - elapsed := time.Now().Sub(start) - common.LatencyHistogram.WithLabelValues("setup").Observe(float64(elapsed)) - - ctx = common.UpdateContext(ctx, common.StartInstrument()) - *req = *req.WithContext(ctx) - } - } - } - - revProxy.ModifyResponse = func(resp *http.Response) error { - ctx := resp.Request.Context() - defaultHeaders(resp) - - if common.Enabled(common.INSTRUMENT_ENABLED) { - instr, ok := common.FromContext(ctx, common.INSTRUMENT) - if ok { - start := instr.(*common.Instrument).Timestamp - elapsed := time.Now().Sub(start) - common.LatencyHistogram.WithLabelValues("serve").Observe(float64(elapsed)) - } - } - - var err error - for _, p := range (*reg).plugins { - h, ok := p.(PostHandler) - if ok { - newerr := h.HandleResponse(resp) - if newerr != nil { - err = errors.Join(err, newerr) - } - } - } - - if resp.StatusCode < 400 && err == nil { - updater := db.NewUsageUpdater(ctx, "success") - common.GetTaskQueue().Add(updater) - } - - return err - } - - revProxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { - ctx := req.Context() - var httpErr *HTTPError - cause := context.Cause(ctx) - - updater := db.NewUsageUpdater(ctx, "failure") - common.GetTaskQueue().Add(updater) - - log.Debug("cancel cause", "cause", cause) - if errors.As(cause, &httpErr) { - status := httpErr.code - http.Error(resp, http.StatusText(status), status) - } else if err != nil { - status := http.StatusBadGateway - http.Error(resp, http.StatusText(status), status) - } - } - - return revProxy + revProxy := httputil.NewSingleHostReverseProxy(remote) + reg := GetRegistry() + + defaultDirector := revProxy.Director + revProxy.Director = func(req *http.Request) { + defaultDirector(req) + + cancel := RequestCanceler(req) + req.Host = remote.Host + + poktId, ok := lookupPoktId(req) + if !ok { + cancel(ChainNotSupportedError) + } + target := utils.NewTarget(remote, poktId) + req.URL = target.URL() + + for _, p := range (*reg).plugins { + h, ok := p.(PreHandler) + if ok { + select { + case <-req.Context().Done(): + return + default: + err := h.HandleRequest(req) + if err != nil { + log.Error("Failed running HandleRequest proxy filter", "filter", p.Name()) + cancel(err) + } + } + } + } + + // Cancel if necessary lifecycle stages not completed + lifecycle := lifecycleFromContext(req.Context()) + if !lifecycle.checkComplete() { + err := LifecycleIncompleteError + log.Debug("lifecycle incomplete", "mask", lifecycle) + cancel(err) + } + + if common.Enabled(common.INSTRUMENT_ENABLED) { + ctx := req.Context() + instr, ok := common.FromContext(ctx, common.INSTRUMENT) + if ok { + start := instr.(*common.Instrument).Timestamp + elapsed := time.Now().Sub(start) + common.LatencyHistogram.WithLabelValues("setup").Observe(float64(elapsed)) + + ctx = common.UpdateContext(ctx, common.StartInstrument()) + *req = *req.WithContext(ctx) + } + } + } + + revProxy.ModifyResponse = func(resp *http.Response) error { + ctx := resp.Request.Context() + defaultHeaders(resp) + + if common.Enabled(common.INSTRUMENT_ENABLED) { + instr, ok := common.FromContext(ctx, common.INSTRUMENT) + if ok { + start := instr.(*common.Instrument).Timestamp + elapsed := time.Now().Sub(start) + common.LatencyHistogram.WithLabelValues("serve").Observe(float64(elapsed)) + } + } + + var err error + for _, p := range (*reg).plugins { + h, ok := p.(PostHandler) + if ok { + newerr := h.HandleResponse(resp) + if newerr != nil { + err = errors.Join(err, newerr) + } + } + } + + if resp.StatusCode < 400 && err == nil { + updater := db.NewUsageUpdater(ctx, "success") + common.GetTaskQueue().Add(updater) + } + + return err + } + + revProxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { + ctx := req.Context() + var httpErr *HTTPError + cause := context.Cause(ctx) + + updater := db.NewUsageUpdater(ctx, "failure") + common.GetTaskQueue().Add(updater) + + log.Debug("cancel cause", "cause", cause) + if errors.As(cause, &httpErr) { + status := httpErr.code + http.Error(resp, http.StatusText(status), status) + } else if err != nil { + status := http.StatusBadGateway + http.Error(resp, http.StatusText(status), status) + } + } + + return revProxy } func setupContext(req *http.Request) { - ctx := req.Context() - ctx = common.UpdateContext(ctx, &Lifecycle{}) - if common.Enabled(common.INSTRUMENT_ENABLED) { - ctx = common.UpdateContext(ctx, common.StartInstrument()) - } - *req = *req.WithContext(ctx) + ctx := req.Context() + ctx = common.UpdateContext(ctx, &Lifecycle{}) + if common.Enabled(common.INSTRUMENT_ENABLED) { + ctx = common.UpdateContext(ctx, common.StartInstrument()) + } + *req = *req.WithContext(ctx) } // Add or remove headers on response // Dealing with CORS mostly func defaultHeaders(resp *http.Response) { - resp.Header.Set("Access-Control-Allow-Origin", "*") + resp.Header.Set("Access-Control-Allow-Headers", "authorization, content-type, server") + resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + resp.Header.Set("Access-Control-Allow-Origin", "*") } func lookupPoktId(req *http.Request) (string, bool) { - ctx := req.Context() - name := PluckProductName(req) - product := &db.Product{Name: name} - err := product.Lookup(ctx) - if err != nil { - log.Error("product not found", "product", product.Name, "err", err) - return "", false - } - productCtx := common.UpdateContext(ctx, product) - *req = *req.WithContext(productCtx) - return product.PoktId, true + ctx := req.Context() + name := PluckProductName(req) + product := &db.Product{Name: name} + err := product.Lookup(ctx) + if err != nil { + log.Error("product not found", "product", product.Name, "err", err) + return "", false + } + productCtx := common.UpdateContext(ctx, product) + *req = *req.WithContext(productCtx) + return product.PoktId, true } diff --git a/gateway/proxy/reconciler.go b/gateway/proxy/reconciler.go index e544314..e7ab642 100644 --- a/gateway/proxy/reconciler.go +++ b/gateway/proxy/reconciler.go @@ -1,69 +1,69 @@ package proxy import ( - "context" - log "log/slog" - "time" + "context" + log "log/slog" + "time" - "porters/common" - "porters/db" + "porters/common" + "porters/db" ) type Reconciler struct { - runEvery time.Duration - ticker *time.Ticker + runEvery time.Duration + ticker *time.Ticker } type reconcileTask struct { - *common.RetryTask - relaytx *db.Relaytx + *common.RetryTask + relaytx *db.Relaytx } func NewReconciler(seconds int) *Reconciler { - return &Reconciler{ - runEvery: time.Duration(seconds) * time.Second, - } + return &Reconciler{ + runEvery: time.Duration(seconds) * time.Second, + } } func (r *Reconciler) Name() string { - return "Usage Reconciliation" + return "Usage Reconciliation" } func (r *Reconciler) Key() string { - return "RECONCILE" + return "RECONCILE" } func (r *Reconciler) Load() { - r.ticker = time.NewTicker(r.runEvery) - go r.spawnTasks() + r.ticker = time.NewTicker(r.runEvery) + go r.spawnTasks() } func (r *Reconciler) spawnTasks() { - queue := common.GetTaskQueue() - ctx := context.Background() - for range r.ticker.C { - iter := db.ScanKeys(ctx, "RELAYTX") - for iter.Next(ctx) { - rtxkey := iter.Val() // use for building relaytx - - rtx, ok := db.RelaytxFromKey(ctx, rtxkey) - if ok { - task := &reconcileTask{ - relaytx: rtx, - } - queue.Add(task) - } - } - } + queue := common.GetTaskQueue() + ctx := context.Background() + for range r.ticker.C { + iter := db.ScanKeys(ctx, "RELAYTX") + for iter.Next(ctx) { + rtxkey := iter.Val() // use for building relaytx + + rtx, ok := db.RelaytxFromKey(ctx, rtxkey) + if ok { + task := &reconcileTask{ + relaytx: rtx, + } + queue.Add(task) + } + } + } } func (t *reconcileTask) Run() { - ctx := context.Background() - replayfunc, err := db.ReconcileRelays(ctx, t.relaytx) - if err != nil { - log.Error("unable to access cached relay use", "err", err) - } + ctx := context.Background() + replayfunc, err := db.ReconcileRelays(ctx, t.relaytx) + if err != nil { + log.Error("unable to access cached relay use", "err", err) + } - t.RetryTask = common.NewRetryTask(replayfunc, 5, 1 * time.Minute) - t.RetryTask.Run() + t.RetryTask = common.NewRetryTask(replayfunc, 5, 1*time.Minute) + t.RetryTask.Run() } diff --git a/gateway/proxy/registry.go b/gateway/proxy/registry.go index 2ca4318..83eb316 100644 --- a/gateway/proxy/registry.go +++ b/gateway/proxy/registry.go @@ -4,47 +4,48 @@ package proxy // Uses singleton so different parts of the lifecycle have access to plugins import ( - "errors" - log "log/slog" - "sync" + "errors" + log "log/slog" + "sync" ) type registry struct { - plugins []Plugin - keySet map[string]Plugin + plugins []Plugin + keySet map[string]Plugin } var pluginRegistry *registry = nil var registryMutex sync.Once func GetRegistry() *registry { - registryMutex.Do(func() { - pluginRegistry = ®istry{ - plugins: make([]Plugin, 0), - keySet: make(map[string]Plugin), - } - }) - return pluginRegistry + registryMutex.Do(func() { + pluginRegistry = ®istry{ + plugins: make([]Plugin, 0), + keySet: make(map[string]Plugin), + } + }) + return pluginRegistry } func Register(plugin Plugin) { - _ = GetRegistry() // init singleton - err := avoidCollision(plugin) - if err != nil { - log.Error("unable to load plugin", "plugin", plugin.Name(), "err", err.Error()) - return - } - plugin.Load() - - pluginRegistry.plugins = append(pluginRegistry.plugins, plugin) + _ = GetRegistry() // init singleton + err := avoidCollision(plugin) + if err != nil { + log.Error("unable to load plugin", "plugin", plugin.Name(), "err", err.Error()) + return + } + plugin.Load() + + pluginRegistry.plugins = append(pluginRegistry.plugins, plugin) + log.Info("Registered plugin successfully", "plugin", plugin.Name()) } func avoidCollision(plugin Plugin) error { - _ = GetRegistry() // just to make sure - key := plugin.Key() - if pluginRegistry.keySet[key] != nil { - return errors.New("another plugin uses same key") - } - pluginRegistry.keySet[key] = plugin - return nil + _ = GetRegistry() // just to make sure + key := plugin.Key() + if pluginRegistry.keySet[key] != nil { + return errors.New("another plugin uses same key") + } + pluginRegistry.keySet[key] = plugin + return nil } diff --git a/web-portal/backend/src/apps/apps.service.ts b/web-portal/backend/src/apps/apps.service.ts index 09d4440..22c4a96 100644 --- a/web-portal/backend/src/apps/apps.service.ts +++ b/web-portal/backend/src/apps/apps.service.ts @@ -271,6 +271,7 @@ export class AppsService { }, }, data: { + active: false, deletedAt: new Date(), }, });