diff --git a/config/loader.go b/config/loader.go index 7a8fe4d..e0e3a5d 100644 --- a/config/loader.go +++ b/config/loader.go @@ -50,18 +50,12 @@ type Fallback func() (string, error) // // Note: If the configuration is a pointer to a struct, the experimental feature behind viper.ExperimentalBindStruct() will not be used. func Load[T Settings](path string, fallbacks ...Fallback) (cfg T, err error) { - k := reflect.TypeOf(cfg).Kind() - if k != reflect.Struct && k != reflect.Pointer { - return cfg, errors.New("configuration must be a struct or a pointer to a struct") - } - - var opts []viper.Option - // The feature behind this option only works on direct struct values (not pointers) - if k != reflect.Pointer { - opts = append(opts, viper.ExperimentalBindStruct()) + cfg, err = ensureStruct(cfg) + if err != nil { + return cfg, fmt.Errorf("given type is not a struct: %w", err) } - v := viper.NewWithOptions(opts...) + v := viper.NewWithOptions(viper.ExperimentalBindStruct()) v.SetFs(fsys) if path == "" { if len(fallbacks) == 0 { @@ -121,3 +115,23 @@ func defaultFallback() (string, error) { } return filepath.Join(home, bin, "config.yaml"), nil } + +// ensureStruct ensures that the provided value is a struct or a pointer to a struct. +func ensureStruct[T any](value T) (T, error) { + var empty T + t := reflect.TypeOf(value) + + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return empty, errors.New("value must be a struct or a pointer to a struct") + } + + if reflect.TypeOf(value).Kind() == reflect.Pointer && reflect.ValueOf(value).IsNil() { + return reflect.New(t).Interface().(T), nil + } + + return value, nil +} diff --git a/config/loader_test.go b/config/loader_test.go index 7b6c275..c712927 100644 --- a/config/loader_test.go +++ b/config/loader_test.go @@ -16,7 +16,7 @@ type config struct { } func (c config) IsEmpty() bool { - return c == (config{}) + return reflect.DeepEqual(c, config{}) } func TestLoad(t *testing.T) { @@ -108,6 +108,50 @@ func TestLoad_InvalidType(t *testing.T) { } } +func TestLoad_Pointer(t *testing.T) { + tests := []struct { + name string + path string + want Settings + wantErr bool + }{ + { + name: "pointer", + path: "testdata/config.yaml", + want: &config{ + Host: "localhost", + Port: 8080, + }, + wantErr: false, + }, + { + name: "nil pointer", + path: "testdata/config.yaml", + want: (*config)(nil), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setup(t, tt.path, tt.want) + if reflect.ValueOf(tt.want).IsNil() { + tt.want = &config{} + } + + got, err := Load[*config](tt.path) + if (err != nil) != tt.wantErr { + t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err == nil && !reflect.DeepEqual(got, tt.want) { + t.Errorf("Load() = %v, want %v", got, tt.want) + } + }) + } +} + func setup(t *testing.T, path string, cfg Settings, fallbacks ...Fallback) { t.Helper() if path == "" {