Skip to content

Commit

Permalink
feat: add server restart method & add optional healthz route
Browse files Browse the repository at this point in the history
  • Loading branch information
lvlcn-t committed Jun 22, 2024
1 parent 8d5b72b commit e6e9a40
Show file tree
Hide file tree
Showing 4 changed files with 574 additions and 43 deletions.
11 changes: 11 additions & 0 deletions apimanager/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,14 @@ func (e *ErrAlreadyRunning) Is(target error) bool {
_, ok := target.(*ErrAlreadyRunning)
return ok
}

type ErrNotRunning struct{}

func (e *ErrNotRunning) Error() string {
return "cannot stop the server because it is not running"
}

func (e *ErrNotRunning) Is(target error) bool {
_, ok := target.(*ErrNotRunning)
return ok
}
25 changes: 25 additions & 0 deletions apimanager/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package apimanager

import (
"testing"
)

func TestErrAlreadyRunning(t *testing.T) {
err := &ErrAlreadyRunning{}
if err.Error() == "" {
t.Error("No error message")
}
if !err.Is(&ErrAlreadyRunning{}) {
t.Error("Is() should return true")
}
}

func TestErrNotRunning(t *testing.T) {
err := &ErrNotRunning{}
if err.Error() == "" {
t.Error("No error message")
}
if !err.Is(&ErrNotRunning{}) {
t.Error("Is() should return true")
}
}
170 changes: 131 additions & 39 deletions apimanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,38 @@ import (
)

// shutdownTimeout is the timeout for the server to shut down.
const shutdownTimeout = 5 * time.Second
const shutdownTimeout = 15 * time.Second

type Server interface {
// Run runs the server.
// Runs indefinitely until an error occurs or the server is shut down.
//
// If no health check route was mounted before, a health check route will be mounted.
// Run attaches all previously mounted routes and starts the server.
// Runs indefinitely until an error occurs, the server shuts down, or the provided context is done.
//
// Example setup:
//
// srv := api.NewServer(&Config{Address: ":8080"})
// err := srv.Mount(RouteGroup{
// Path: "/v1",
// App: fiber.New().Get("/hello", func(c fiber.Ctx) error {
// return c.SendString("Hello, World!")
// }),
// })
// if err != nil {
// // handle error
// }
//
// _ = srv.Run(context.Background())
// server := apimanager.New(nil)
// server.Mount(apimanager.Route{
// Path: "/",
// Methods: []string{http.MethodGet},
// Handler: func(c fiber.Ctx) error {
// return c.SendString("Hello, World!")
// },
// })
// // The server will listen on the default address ":8080" and respond with "Hello, World!" on a GET request to "/".
// server.Run(context.Background())
Run(ctx context.Context) error
// Restart restarts the server by shutting it down and starting it again.
// If any routes or groups are provided, they will be added to the server.
// All existing routes and groups will be preserved.
Restart(ctx context.Context, routes []Route, groups []RouteGroup) error
// Shutdown gracefully shuts down the server.
Shutdown(ctx context.Context) error
// Mount adds the provided routes to the server.
Mount(routes ...Route) error
// MountGroup adds the provided route groups to the server.
MountGroup(groups ...RouteGroup) error
// Shutdown gracefully shuts down the server.
Shutdown(ctx context.Context) error
// App returns the fiber app of the server.
App() *fiber.App
// Mounted returns all mounted routes, groups, and global middlewares.
Mounted() (routes []Route, groups []RouteGroup, middlewares []fiber.Handler)
}

// Route is a route to register to the server.
Expand Down Expand Up @@ -73,6 +76,8 @@ type Config struct {
Address string `yaml:"address" mapstructure:"address"`
// BasePath is the base path of the API.
BasePath string `yaml:"basePath" mapstructure:"basePath"`
// UseDefaultHealthz indicates if the default healthz handler should be used.
UseDefaultHealthz bool `yaml:"useDefaultHealthz" mapstructure:"useDefaultHealthz"`
// TLS is the TLS configuration.
TLS TLSConfig `yaml:"tls" mapstructure:"tls"`
}
Expand All @@ -88,7 +93,7 @@ type TLSConfig struct {
}

// IsEmpty checks if the configuration is empty.
func (c Config) IsEmpty() bool {
func (c Config) IsEmpty() bool { //nolint:gocritic // To ensure compatibility with viper, no pointer receiver is used.
return reflect.DeepEqual(c, Config{})
}

Expand All @@ -112,35 +117,51 @@ func (c *Config) Validate() error {
return err
}

// server is the server implementation.
type server struct {
mu sync.Mutex
config *Config
app *fiber.App
router fiber.Router
routes []Route
groups []RouteGroup
// mu is the mutex to synchronize access to the server.
mu sync.Mutex
// config is the configuration of the server.
config *Config
// app is the fiber app of the server.
app *fiber.App
// router is the fiber root router of the server.
router fiber.Router
// routes are the routes to mount to the server on startup.
routes []Route
// groups are the route groups to mount to the server on startup.
groups []RouteGroup
// middlewares are the global middlewares to use for the server.
middlewares []fiber.Handler
running bool
// running indicates if the server is running.
running bool
}

// New creates a new server with the provided configuration.
// New creates a new server with the provided configuration and middlewares.
// If no configuration is provided, a default configuration will be used.
// If no middlewares are provided, a default set of middlewares will be used.
func New(c *Config, middlewares ...fiber.Handler) Server {
if c == nil {
c = &Config{
Address: ":8080",
BasePath: "/",
Address: ":8080",
BasePath: "/",
UseDefaultHealthz: false,
}
}

app := fiber.New()
if len(middlewares) == 0 {
middlewares = append(middlewares, middleware.Recover(), middleware.Logger("/healthz"))
if c.Address == "" {
c.Address = ":8080"
}

if c.BasePath == "" {
c.BasePath = "/"
}

app := fiber.New()
if len(middlewares) == 0 {
middlewares = append(middlewares, middleware.Recover(), middleware.Logger())
}

return &server{
mu: sync.Mutex{},
config: c,
Expand All @@ -153,9 +174,8 @@ func New(c *Config, middlewares ...fiber.Handler) Server {
}
}

// Run runs the server.
// It will mount a health check route if no health check route was mounted before.
// Runs indefinitely until an error occurs or the server is shut down.
// Run attaches all previously mounted routes and starts the server.
// Runs indefinitely until an error occurs, the server shuts down, or the provided context is done.
func (s *server) Run(ctx context.Context) error {
err := s.attachRoutes(ctx)
if err != nil {
Expand Down Expand Up @@ -184,6 +204,8 @@ func (s *server) Run(ctx context.Context) error {
}

// Mount adds the provided routes to the server.
//
// Note that mounting routes after the server has started will have no effect and will return an error.
func (s *server) Mount(routes ...Route) error {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down Expand Up @@ -213,6 +235,8 @@ func (s *server) Mount(routes ...Route) error {
}

// MountGroup adds the provided route groups to the server.
//
// Note that mounting route groups after the server has started will have no effect and will return an error.
func (s *server) MountGroup(groups ...RouteGroup) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
Expand All @@ -225,7 +249,6 @@ func (s *server) MountGroup(groups ...RouteGroup) (err error) {
}

// attachRoutes attaches the routes to the server.
// It will mount a health check route if no health check route was mounted before.
func (s *server) attachRoutes(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
Expand All @@ -234,7 +257,9 @@ func (s *server) attachRoutes(ctx context.Context) (err error) {
}

// Always inject the provided context into the request user context.
_ = s.router.Use(middleware.Context(ctx))
// To ensure all routes have access to the same logger a new logger instance is created and
// injected into the context if not already present.
_ = s.router.Use(middleware.Context(logger.IntoContext(ctx, logger.FromContext(ctx))))
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("failed to mount routes: %v", r)
Expand All @@ -256,17 +281,77 @@ func (s *server) attachRoutes(ctx context.Context) (err error) {
_ = s.router.Add(route.Methods, route.Path, route.Handler, route.Middlewares...)
}

if s.config.UseDefaultHealthz {
_ = s.router.Get("/healthz", OkHandler)
}

s.running = true
return nil
}

// Shutdown gracefully shuts down the server.
func (s *server) Shutdown(ctx context.Context) error {
c, cancel := newContextWithTimeout(ctx)
defer cancel()

return errors.Join(ctx.Err(), s.app.ShutdownWithContext(c))
}

// Restart restarts the server by shutting it down and starting it again.
// If any routes or groups are provided, they will be added to the server.
// All existing routes and groups will be preserved.
// Runs indefinitely until an error occurs, the server shuts down, or the provided context is done.
func (s *server) Restart(ctx context.Context, routes []Route, groups []RouteGroup) error {
s.mu.Lock()
if !s.running {
s.mu.Unlock()
return &ErrNotRunning{}
}
s.mu.Unlock()

s.toggleRunning()
defer s.toggleRunning()
if len(routes) > 0 {
err := s.Mount(routes...)
if err != nil {
return err
}
}

if len(groups) > 0 {
err := s.MountGroup(groups...)
if err != nil {
return err
}
}

err := s.Shutdown(ctx)
if err != nil {
return err
}

return s.Run(ctx)
}

// App returns the fiber app of the server.
func (s *server) App() *fiber.App {
s.mu.Lock()
defer s.mu.Unlock()
return s.app
}

// Mounted returns all mounted routes, groups, and global middlewares.
func (s *server) Mounted() (routes []Route, groups []RouteGroup, middlewares []fiber.Handler) {
s.mu.Lock()
defer s.mu.Unlock()
return s.routes, s.groups, s.middlewares
}

// OkHandler is a handler that returns an HTTP 200 OK response.
func OkHandler(c fiber.Ctx) error {
return c.Status(http.StatusOK).SendString("OK")
}

// newContextWithTimeout returns a new context with a timeout.
func newContextWithTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
if deadline, ok := ctx.Deadline(); ok {
Expand Down Expand Up @@ -295,3 +380,10 @@ func isValid(method string) bool {
return false
}
}

// toggleRunning toggles the running state of the server.
func (s *server) toggleRunning() {
s.mu.Lock()
defer s.mu.Unlock()
s.running = !s.running
}
Loading

0 comments on commit e6e9a40

Please sign in to comment.