From e6e9a40385d0da7624d3cd3d648df444b619fd3f Mon Sep 17 00:00:00 2001 From: lvlcn-t <75443136+lvlcn-t@users.noreply.github.com> Date: Sat, 22 Jun 2024 18:22:44 +0200 Subject: [PATCH] feat: add server restart method & add optional healthz route --- apimanager/errors.go | 11 + apimanager/errors_test.go | 25 +++ apimanager/manager.go | 170 +++++++++++---- apimanager/manager_test.go | 411 ++++++++++++++++++++++++++++++++++++- 4 files changed, 574 insertions(+), 43 deletions(-) create mode 100644 apimanager/errors_test.go diff --git a/apimanager/errors.go b/apimanager/errors.go index b6d4109..331c389 100644 --- a/apimanager/errors.go +++ b/apimanager/errors.go @@ -10,3 +10,14 @@ func (e *ErrAlreadyRunning) Is(target error) bool { _, ok := target.(*ErrAlreadyRunning) return ok } + +type ErrNotRunning struct{} + +func (e *ErrNotRunning) Error() string { + return "cannot stop the server because it is not running" +} + +func (e *ErrNotRunning) Is(target error) bool { + _, ok := target.(*ErrNotRunning) + return ok +} diff --git a/apimanager/errors_test.go b/apimanager/errors_test.go new file mode 100644 index 0000000..2a1e723 --- /dev/null +++ b/apimanager/errors_test.go @@ -0,0 +1,25 @@ +package apimanager + +import ( + "testing" +) + +func TestErrAlreadyRunning(t *testing.T) { + err := &ErrAlreadyRunning{} + if err.Error() == "" { + t.Error("No error message") + } + if !err.Is(&ErrAlreadyRunning{}) { + t.Error("Is() should return true") + } +} + +func TestErrNotRunning(t *testing.T) { + err := &ErrNotRunning{} + if err.Error() == "" { + t.Error("No error message") + } + if !err.Is(&ErrNotRunning{}) { + t.Error("Is() should return true") + } +} diff --git a/apimanager/manager.go b/apimanager/manager.go index 22f7e4f..63a2f2a 100644 --- a/apimanager/manager.go +++ b/apimanager/manager.go @@ -16,35 +16,38 @@ import ( ) // shutdownTimeout is the timeout for the server to shut down. -const shutdownTimeout = 5 * time.Second +const shutdownTimeout = 15 * time.Second type Server interface { - // Run runs the server. - // Runs indefinitely until an error occurs or the server is shut down. - // - // If no health check route was mounted before, a health check route will be mounted. + // Run attaches all previously mounted routes and starts the server. + // Runs indefinitely until an error occurs, the server shuts down, or the provided context is done. // // Example setup: - // - // srv := api.NewServer(&Config{Address: ":8080"}) - // err := srv.Mount(RouteGroup{ - // Path: "/v1", - // App: fiber.New().Get("/hello", func(c fiber.Ctx) error { - // return c.SendString("Hello, World!") - // }), - // }) - // if err != nil { - // // handle error - // } - // - // _ = srv.Run(context.Background()) + // server := apimanager.New(nil) + // server.Mount(apimanager.Route{ + // Path: "/", + // Methods: []string{http.MethodGet}, + // Handler: func(c fiber.Ctx) error { + // return c.SendString("Hello, World!") + // }, + // }) + // // The server will listen on the default address ":8080" and respond with "Hello, World!" on a GET request to "/". + // server.Run(context.Background()) Run(ctx context.Context) error + // Restart restarts the server by shutting it down and starting it again. + // If any routes or groups are provided, they will be added to the server. + // All existing routes and groups will be preserved. + Restart(ctx context.Context, routes []Route, groups []RouteGroup) error + // Shutdown gracefully shuts down the server. + Shutdown(ctx context.Context) error // Mount adds the provided routes to the server. Mount(routes ...Route) error // MountGroup adds the provided route groups to the server. MountGroup(groups ...RouteGroup) error - // Shutdown gracefully shuts down the server. - Shutdown(ctx context.Context) error + // App returns the fiber app of the server. + App() *fiber.App + // Mounted returns all mounted routes, groups, and global middlewares. + Mounted() (routes []Route, groups []RouteGroup, middlewares []fiber.Handler) } // Route is a route to register to the server. @@ -73,6 +76,8 @@ type Config struct { Address string `yaml:"address" mapstructure:"address"` // BasePath is the base path of the API. BasePath string `yaml:"basePath" mapstructure:"basePath"` + // UseDefaultHealthz indicates if the default healthz handler should be used. + UseDefaultHealthz bool `yaml:"useDefaultHealthz" mapstructure:"useDefaultHealthz"` // TLS is the TLS configuration. TLS TLSConfig `yaml:"tls" mapstructure:"tls"` } @@ -88,7 +93,7 @@ type TLSConfig struct { } // IsEmpty checks if the configuration is empty. -func (c Config) IsEmpty() bool { +func (c Config) IsEmpty() bool { //nolint:gocritic // To ensure compatibility with viper, no pointer receiver is used. return reflect.DeepEqual(c, Config{}) } @@ -112,35 +117,51 @@ func (c *Config) Validate() error { return err } +// server is the server implementation. type server struct { - mu sync.Mutex - config *Config - app *fiber.App - router fiber.Router - routes []Route - groups []RouteGroup + // mu is the mutex to synchronize access to the server. + mu sync.Mutex + // config is the configuration of the server. + config *Config + // app is the fiber app of the server. + app *fiber.App + // router is the fiber root router of the server. + router fiber.Router + // routes are the routes to mount to the server on startup. + routes []Route + // groups are the route groups to mount to the server on startup. + groups []RouteGroup + // middlewares are the global middlewares to use for the server. middlewares []fiber.Handler - running bool + // running indicates if the server is running. + running bool } -// New creates a new server with the provided configuration. +// New creates a new server with the provided configuration and middlewares. +// If no configuration is provided, a default configuration will be used. +// If no middlewares are provided, a default set of middlewares will be used. func New(c *Config, middlewares ...fiber.Handler) Server { if c == nil { c = &Config{ - Address: ":8080", - BasePath: "/", + Address: ":8080", + BasePath: "/", + UseDefaultHealthz: false, } } - app := fiber.New() - if len(middlewares) == 0 { - middlewares = append(middlewares, middleware.Recover(), middleware.Logger("/healthz")) + if c.Address == "" { + c.Address = ":8080" } if c.BasePath == "" { c.BasePath = "/" } + app := fiber.New() + if len(middlewares) == 0 { + middlewares = append(middlewares, middleware.Recover(), middleware.Logger()) + } + return &server{ mu: sync.Mutex{}, config: c, @@ -153,9 +174,8 @@ func New(c *Config, middlewares ...fiber.Handler) Server { } } -// Run runs the server. -// It will mount a health check route if no health check route was mounted before. -// Runs indefinitely until an error occurs or the server is shut down. +// Run attaches all previously mounted routes and starts the server. +// Runs indefinitely until an error occurs, the server shuts down, or the provided context is done. func (s *server) Run(ctx context.Context) error { err := s.attachRoutes(ctx) if err != nil { @@ -184,6 +204,8 @@ func (s *server) Run(ctx context.Context) error { } // Mount adds the provided routes to the server. +// +// Note that mounting routes after the server has started will have no effect and will return an error. func (s *server) Mount(routes ...Route) error { s.mu.Lock() defer s.mu.Unlock() @@ -213,6 +235,8 @@ func (s *server) Mount(routes ...Route) error { } // MountGroup adds the provided route groups to the server. +// +// Note that mounting route groups after the server has started will have no effect and will return an error. func (s *server) MountGroup(groups ...RouteGroup) (err error) { s.mu.Lock() defer s.mu.Unlock() @@ -225,7 +249,6 @@ func (s *server) MountGroup(groups ...RouteGroup) (err error) { } // attachRoutes attaches the routes to the server. -// It will mount a health check route if no health check route was mounted before. func (s *server) attachRoutes(ctx context.Context) (err error) { s.mu.Lock() defer s.mu.Unlock() @@ -234,7 +257,9 @@ func (s *server) attachRoutes(ctx context.Context) (err error) { } // Always inject the provided context into the request user context. - _ = s.router.Use(middleware.Context(ctx)) + // To ensure all routes have access to the same logger a new logger instance is created and + // injected into the context if not already present. + _ = s.router.Use(middleware.Context(logger.IntoContext(ctx, logger.FromContext(ctx)))) defer func() { if r := recover(); r != nil { err = fmt.Errorf("failed to mount routes: %v", r) @@ -256,10 +281,15 @@ func (s *server) attachRoutes(ctx context.Context) (err error) { _ = s.router.Add(route.Methods, route.Path, route.Handler, route.Middlewares...) } + if s.config.UseDefaultHealthz { + _ = s.router.Get("/healthz", OkHandler) + } + s.running = true return nil } +// Shutdown gracefully shuts down the server. func (s *server) Shutdown(ctx context.Context) error { c, cancel := newContextWithTimeout(ctx) defer cancel() @@ -267,6 +297,61 @@ func (s *server) Shutdown(ctx context.Context) error { return errors.Join(ctx.Err(), s.app.ShutdownWithContext(c)) } +// Restart restarts the server by shutting it down and starting it again. +// If any routes or groups are provided, they will be added to the server. +// All existing routes and groups will be preserved. +// Runs indefinitely until an error occurs, the server shuts down, or the provided context is done. +func (s *server) Restart(ctx context.Context, routes []Route, groups []RouteGroup) error { + s.mu.Lock() + if !s.running { + s.mu.Unlock() + return &ErrNotRunning{} + } + s.mu.Unlock() + + s.toggleRunning() + defer s.toggleRunning() + if len(routes) > 0 { + err := s.Mount(routes...) + if err != nil { + return err + } + } + + if len(groups) > 0 { + err := s.MountGroup(groups...) + if err != nil { + return err + } + } + + err := s.Shutdown(ctx) + if err != nil { + return err + } + + return s.Run(ctx) +} + +// App returns the fiber app of the server. +func (s *server) App() *fiber.App { + s.mu.Lock() + defer s.mu.Unlock() + return s.app +} + +// Mounted returns all mounted routes, groups, and global middlewares. +func (s *server) Mounted() (routes []Route, groups []RouteGroup, middlewares []fiber.Handler) { + s.mu.Lock() + defer s.mu.Unlock() + return s.routes, s.groups, s.middlewares +} + +// OkHandler is a handler that returns an HTTP 200 OK response. +func OkHandler(c fiber.Ctx) error { + return c.Status(http.StatusOK).SendString("OK") +} + // newContextWithTimeout returns a new context with a timeout. func newContextWithTimeout(ctx context.Context) (context.Context, context.CancelFunc) { if deadline, ok := ctx.Deadline(); ok { @@ -295,3 +380,10 @@ func isValid(method string) bool { return false } } + +// toggleRunning toggles the running state of the server. +func (s *server) toggleRunning() { + s.mu.Lock() + defer s.mu.Unlock() + s.running = !s.running +} diff --git a/apimanager/manager_test.go b/apimanager/manager_test.go index 0995e52..af66197 100644 --- a/apimanager/manager_test.go +++ b/apimanager/manager_test.go @@ -37,7 +37,7 @@ func TestNewServer(t *testing.T) { groups: []RouteGroup{}, middlewares: []fiber.Handler{ middleware.Recover(), - middleware.Logger("/healthz"), + middleware.Logger(), }, } }, @@ -122,6 +122,29 @@ func TestNewServer(t *testing.T) { } }, }, + { + name: "New with empty config", + config: &Config{}, + middlewares: nil, + want: func(t *testing.T) *server { + app := fiber.New() + return &server{ + mu: sync.Mutex{}, + config: &Config{ + Address: ":8080", + BasePath: "/", + }, + app: app, + router: app.Group("/"), + routes: []Route{}, + groups: []RouteGroup{}, + middlewares: []fiber.Handler{ + middleware.Recover(), + middleware.Logger(), + }, + } + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -145,6 +168,8 @@ func TestNewServer(t *testing.T) { } func TestServer_Run(t *testing.T) { + const defaultRoutes = (1 * 9) // 1 route * 9 methods + tests := []struct { name string server Server @@ -231,7 +256,7 @@ func TestServer_Run(t *testing.T) { t.Errorf("server.running = %v, want %v", s.running, true) } - routes := len(tt.routes) + 9 + routes := len(tt.routes) + defaultRoutes for _, route := range tt.routes { if route.Handler == nil { routes-- @@ -408,6 +433,325 @@ func TestServer_Mount(t *testing.T) { } } +func TestServer_Shutdown(t *testing.T) { + tests := []struct { + name string + server Server + running bool + wantErr bool + }{ + { + name: "Shutdown without running server", + server: New(nil, nil), + running: false, + wantErr: false, + }, + { + name: "Shutdown with running server", + server: New(nil, nil), + running: true, + // Want false because we want indempotent behavior + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := tt.server.(*server) + s.running = tt.running + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + if !tt.wantErr { + err := s.Shutdown(ctx) + if (err != nil) != tt.wantErr { + t.Errorf("server.Shutdown() error = %v, wantErr %v", err, tt.wantErr) + } + } + }) + } +} + +func TestServer_Restart(t *testing.T) { + tests := []struct { + name string + server Server + running bool + routes []Route + groups []RouteGroup + wantErr bool + }{ + { + name: "Restart without routes", + server: New(nil, nil), + running: true, + routes: nil, + groups: nil, + wantErr: false, + }, + { + name: "Restart with routes", + server: New(nil, nil), + running: true, + routes: []Route{ + { + Methods: []string{http.MethodGet}, + Path: "/", + Handler: func(c fiber.Ctx) error { + return c.Status(http.StatusOK).SendString("Hello, World!") + }, + }, + }, + groups: nil, + wantErr: false, + }, + { + name: "Restart with invalid route", + server: New(nil, nil), + running: true, + routes: []Route{ + { + Methods: []string{http.MethodGet}, + Path: "", + Handler: nil, + }, + }, + groups: nil, + wantErr: true, + }, + { + name: "Restart with invalid method", + server: New(nil, nil), + running: true, + routes: []Route{ + { + Methods: []string{"INVALID"}, + Path: "/", + Handler: func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }, + }, + }, + groups: nil, + wantErr: true, + }, + { + name: "Restart with no methods", + server: New(nil, nil), + running: true, + routes: []Route{ + { + Methods: nil, + Path: "/", + Handler: func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }, + }, + }, + groups: nil, + wantErr: true, + }, + { + name: "Restart with groups", + server: New(nil, nil), + running: true, + routes: nil, + groups: []RouteGroup{ + { + Path: "/api", + App: fiber.New().Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }), + }, + }, + wantErr: false, + }, + { + name: "Restart without running server", + server: New(nil, nil), + running: false, + routes: nil, + groups: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := tt.server.(*server) + s.running = tt.running + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := s.Restart(ctx, tt.routes, tt.groups) + if (err != nil) != tt.wantErr { + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("server.Restart() error = %v, wantErr %v", err, tt.wantErr) + } + } + }) + } +} + +func TestServer_App(t *testing.T) { + app := fiber.New() + tests := []struct { + name string + server Server + want *fiber.App + }{ + { + name: "Get app", + server: &server{ + app: app, + }, + want: app, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.server.App(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("server.App() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestServer_Mounted(t *testing.T) { + tests := []struct { + name string + server Server + wantRoutes int + wantGroups int + wantMiddlewares int + }{ + { + name: "No routes, groups or middleware mounted", + server: &server{ + routes: nil, + groups: nil, + middlewares: nil, + }, + }, + { + name: "Routes mounted", + server: &server{ + routes: []Route{ + { + Methods: []string{http.MethodGet}, + Path: "/", + Handler: func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }, + }, + }, + }, + wantRoutes: 1, + }, + { + name: "Groups mounted", + server: &server{ + groups: []RouteGroup{ + { + Path: "/api", + App: fiber.New().Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }), + }, + }, + }, + wantGroups: 1, + }, + { + name: "Middleware mounted", + server: &server{ + middlewares: []fiber.Handler{ + middleware.Recover(), + middleware.Logger(), + }, + }, + wantMiddlewares: 2, + }, + { + name: "Routes, groups and middleware mounted", + server: &server{ + routes: []Route{ + { + Methods: []string{http.MethodGet}, + Path: "/", + Handler: func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }, + }, + }, + groups: []RouteGroup{ + { + Path: "/api", + App: fiber.New().Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello, World!") + }), + }, + }, + middlewares: []fiber.Handler{ + middleware.Recover(), + middleware.Logger(), + }, + }, + wantRoutes: 1, + wantGroups: 1, + wantMiddlewares: 2, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := tt.server.(*server) + + routes, groups, mw := s.Mounted() + if len(routes) != tt.wantRoutes { + t.Errorf("server.Mounted() routes = %v, want %v", routes, tt.wantRoutes) + } + + if len(groups) != tt.wantGroups { + t.Errorf("server.Mounted() groups = %v, want %v", groups, tt.wantGroups) + } + + if len(mw) != tt.wantMiddlewares { + t.Errorf("server.Mounted() middlewares = %v, want %v", mw, tt.wantMiddlewares) + } + }) + } +} + +func TestConfig_IsEmpty(t *testing.T) { + tests := []struct { + name string + config *Config + want bool + }{ + { + name: "Empty config", + config: &Config{}, + want: true, + }, + { + name: "Non-empty config", + config: &Config{ + Address: ":8080", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.config.IsEmpty(); got != tt.want { + t.Errorf("Config.IsEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + func TestConfig_Validate(t *testing.T) { tests := []struct { name string @@ -415,8 +759,25 @@ func TestConfig_Validate(t *testing.T) { wantErr bool }{ { - name: "Valid config", - config: &Config{Address: ":8080"}, + name: "Valid config", + config: &Config{ + Address: ":8080", + TLS: TLSConfig{ + Enabled: false, + }, + }, + wantErr: false, + }, + { + name: "Valid config with TLS", + config: &Config{ + Address: ":8080", + TLS: TLSConfig{ + Enabled: true, + CertFile: "cert.pem", + CertKeyFile: "key.pem", + }, + }, wantErr: false, }, { @@ -424,6 +785,11 @@ func TestConfig_Validate(t *testing.T) { config: &Config{Address: ""}, wantErr: true, }, + { + name: "Invalid config with TLS", + config: &Config{Address: ":8080", TLS: TLSConfig{Enabled: true}}, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -434,3 +800,40 @@ func TestConfig_Validate(t *testing.T) { }) } } + +func TestOkHandler(t *testing.T) { + tests := []struct { + name string + want int + }{ + { + name: "OK handler", + want: http.StatusOK, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + app := fiber.New() + app.Get("/", OkHandler) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", http.NoBody) + if err != nil { + t.Fatalf("http.NewRequestWithContext() error = %v", err) + } + + resp, err := app.Test(req) + if err != nil { + t.Fatalf("app.Test() error = %v", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Fatalf("resp.Body.Close() error = %v", err) + } + }() + + if resp.StatusCode != tt.want { + t.Errorf("OkHandler() = %v, want %v", resp.StatusCode, http.StatusOK) + } + }) + } +}