Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor goroutine handling #388

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions cmd/ssh-portal-api/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import (
"github.com/uselagoon/ssh-portal/internal/metrics"
"github.com/uselagoon/ssh-portal/internal/rbac"
"github.com/uselagoon/ssh-portal/internal/sshportalapi"
"golang.org/x/sync/errgroup"
)

const (
metricsPort = ":9911"
)

// ServeCmd represents the serve command.
Expand All @@ -31,10 +36,6 @@ type ServeCmd struct {

// Run the serve command to ssh-portal API requests.
func (cmd *ServeCmd) Run(log *slog.Logger) error {
// metrics needs a separate context because deferred Shutdown() will exit
// immediately the context is done, which is the case for ctx on SIGTERM.
m := metrics.NewServer(log, ":9911")
defer m.Shutdown(context.Background()) //nolint:errcheck
// get main process context, which cancels on SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer stop()
Expand Down Expand Up @@ -65,6 +66,14 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't init keycloak client: %v", err)
}
// start serving NATS requests
return sshportalapi.ServeNATS(ctx, stop, log, p, l, k, cmd.NATSURL)
// set up goroutine handler
eg, ctx := errgroup.WithContext(ctx)
// start the metrics server
metrics.Serve(ctx, eg, metricsPort)
// start serving SSH token requests
eg.Go(func() error {
// start serving NATS requests
return sshportalapi.ServeNATS(ctx, stop, log, p, l, k, cmd.NATSURL)
})
return eg.Wait()
}
22 changes: 16 additions & 6 deletions cmd/ssh-portal/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
"github.com/uselagoon/ssh-portal/internal/k8s"
"github.com/uselagoon/ssh-portal/internal/metrics"
"github.com/uselagoon/ssh-portal/internal/sshserver"
"golang.org/x/sync/errgroup"
)

const (
metricsPort = ":9912"
)

// ServeCmd represents the serve command.
Expand All @@ -26,10 +31,6 @@ type ServeCmd struct {

// Run the serve command to handle SSH connection requests.
func (cmd *ServeCmd) Run(log *slog.Logger) error {
// metrics needs a separate context because deferred Shutdown() will exit
// immediately the context is done, which is the case for ctx on SIGTERM.
m := metrics.NewServer(log, ":9912")
defer m.Shutdown(context.Background()) //nolint:errcheck
// get main process context, which cancels on SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer stop()
Expand Down Expand Up @@ -60,6 +61,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't listen on port %d: %v", cmd.SSHServerPort, err)
}
defer l.Close()
// get kubernetes client
c, err := k8s.NewClient()
if err != nil {
Expand All @@ -72,6 +74,14 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
hostkeys = append(hostkeys, []byte(hk))
}
}
// start serving SSH connection requests
return sshserver.Serve(ctx, log, nc, l, c, hostkeys, cmd.LogAccessEnabled)
// set up goroutine handler
eg, ctx := errgroup.WithContext(ctx)
// start the metrics server
metrics.Serve(ctx, eg, metricsPort)
// start serving SSH token requests
eg.Go(func() error {
// start serving SSH connection requests
return sshserver.Serve(ctx, log, nc, l, c, hostkeys, cmd.LogAccessEnabled)
})
return eg.Wait()
}
21 changes: 15 additions & 6 deletions cmd/ssh-token/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import (
"github.com/uselagoon/ssh-portal/internal/metrics"
"github.com/uselagoon/ssh-portal/internal/rbac"
"github.com/uselagoon/ssh-portal/internal/sshtoken"
"golang.org/x/sync/errgroup"
)

const (
metricsPort = ":9948"
)

// ServeCmd represents the serve command.
Expand All @@ -37,10 +42,6 @@ type ServeCmd struct {

// Run the serve command to ssh-portal API requests.
func (cmd *ServeCmd) Run(log *slog.Logger) error {
// metrics needs a separate context because deferred Shutdown() will exit
// immediately the context is done, which is the case for ctx on SIGTERM.
m := metrics.NewServer(log, ":9948")
defer m.Shutdown(context.Background()) //nolint:errcheck
// get main process context, which cancels on SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer stop()
Expand Down Expand Up @@ -85,6 +86,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't listen on port %d: %v", cmd.SSHServerPort, err)
}
defer l.Close()
// check for persistent host key arguments
var hostkeys [][]byte
for _, hk := range []string{cmd.HostKeyECDSA, cmd.HostKeyED25519,
Expand All @@ -93,7 +95,14 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
hostkeys = append(hostkeys, []byte(hk))
}
}
// set up goroutine handler
eg, ctx := errgroup.WithContext(ctx)
// start the metrics server
metrics.Serve(ctx, eg, metricsPort)
// start serving SSH token requests
return sshtoken.Serve(ctx, log, l, p, ldb, keycloakToken, keycloakPermission,
hostkeys)
eg.Go(func() error {
return sshtoken.Serve(ctx, log, l, p, ldb, keycloakToken, keycloakPermission,
hostkeys)
})
return eg.Wait()
}
45 changes: 31 additions & 14 deletions internal/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,46 @@
package metrics

import (
"log/slog"
"context"
"fmt"
"net/http"
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
)

// NewServer returns a *http.Server serving prometheus metrics in a new
// goroutine.
// Caller should defer Shutdown() for cleanup.
func NewServer(log *slog.Logger, addr string) *http.Server {
const (
metricsReadTimeout = 2 * time.Second
metricsShutdownTimeout = 2 * time.Second
)

// Serve runs a prometheus metrics server in goroutines managed by eg. It will
// gracefully exit with a two second timeout.
// Callers should Wait() on eg before exiting.
func Serve(ctx context.Context, eg *errgroup.Group, metricsPort string) {
// configure metrics server
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
s := http.Server{
Addr: addr,
metricsSrv := http.Server{
Addr: metricsPort,
ReadTimeout: metricsReadTimeout,
WriteTimeout: metricsReadTimeout,
Handler: mux,
ReadTimeout: 16 * time.Second,
WriteTimeout: 16 * time.Second,
}
go func() {
if err := s.ListenAndServe(); err != http.ErrServerClosed {
log.Error("metrics server did not shut down cleanly", slog.Any("error", err))
// start metrics server
eg.Go(func() error {
if err := metricsSrv.ListenAndServe(); err != http.ErrServerClosed {
return fmt.Errorf("metrics server exited with error: %v", err)
}
}()
return &s
return nil
})
// start metrics server shutdown handler for graceful shutdown
eg.Go(func() error {
<-ctx.Done()
timeoutCtx, cancel :=
context.WithTimeout(context.Background(), metricsShutdownTimeout)
defer cancel()
return metricsSrv.Shutdown(timeoutCtx)
})
}
2 changes: 1 addition & 1 deletion internal/sshtoken/authhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func pubKeyAuth(log *slog.Logger, ldb LagoonDBService) ssh.PublicKeyHandler {
authnSuccessTotal.Inc()
ctx.SetValue(userUUID, user.UUID)
log.Info("authentication successful",
slog.String("userID", user.UUID.String()))
slog.String("userUUID", user.UUID.String()))
return true
}
}