Skip to content

Commit

Permalink
fix: ensure struct value by reflect type check
Browse files Browse the repository at this point in the history
  • Loading branch information
lvlcn-t committed Jul 3, 2024
1 parent 30572e7 commit 0b65658
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 11 deletions.
34 changes: 24 additions & 10 deletions config/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,12 @@ type Fallback func() (string, error)
//
// Note: If the configuration is a pointer to a struct, the experimental feature behind viper.ExperimentalBindStruct() will not be used.
func Load[T Settings](path string, fallbacks ...Fallback) (cfg T, err error) {
k := reflect.TypeOf(cfg).Kind()
if k != reflect.Struct && k != reflect.Pointer {
return cfg, errors.New("configuration must be a struct or a pointer to a struct")
}

var opts []viper.Option
// The feature behind this option only works on direct struct values (not pointers)
if k != reflect.Pointer {
opts = append(opts, viper.ExperimentalBindStruct())
cfg, err = ensureStruct(cfg)
if err != nil {
return cfg, fmt.Errorf("given type is not a struct: %w", err)
}

v := viper.NewWithOptions(opts...)
v := viper.NewWithOptions(viper.ExperimentalBindStruct())
v.SetFs(fsys)
if path == "" {
if len(fallbacks) == 0 {
Expand Down Expand Up @@ -121,3 +115,23 @@ func defaultFallback() (string, error) {
}
return filepath.Join(home, bin, "config.yaml"), nil
}

// ensureStruct ensures that the provided value is a struct or a pointer to a struct.
func ensureStruct[T any](value T) (T, error) {
var empty T
t := reflect.TypeOf(value)

for t.Kind() == reflect.Pointer {
t = t.Elem()
}

if t.Kind() != reflect.Struct {
return empty, errors.New("value must be a struct or a pointer to a struct")
}

if reflect.TypeOf(value).Kind() == reflect.Pointer && reflect.ValueOf(value).IsNil() {
return reflect.New(t).Interface().(T), nil
}

return value, nil
}
46 changes: 45 additions & 1 deletion config/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type config struct {
}

func (c config) IsEmpty() bool {
return c == (config{})
return reflect.DeepEqual(c, config{})
}

func TestLoad(t *testing.T) {
Expand Down Expand Up @@ -108,6 +108,50 @@ func TestLoad_InvalidType(t *testing.T) {
}
}

func TestLoad_Pointer(t *testing.T) {
tests := []struct {
name string
path string
want Settings
wantErr bool
}{
{
name: "pointer",
path: "testdata/config.yaml",
want: &config{
Host: "localhost",
Port: 8080,
},
wantErr: false,
},
{
name: "nil pointer",
path: "testdata/config.yaml",
want: (*config)(nil),
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setup(t, tt.path, tt.want)
if reflect.ValueOf(tt.want).IsNil() {
tt.want = &config{}
}

got, err := Load[*config](tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr)
return
}

if err == nil && !reflect.DeepEqual(got, tt.want) {
t.Errorf("Load() = %v, want %v", got, tt.want)
}
})
}
}

func setup(t *testing.T, path string, cfg Settings, fallbacks ...Fallback) {
t.Helper()
if path == "" {
Expand Down

0 comments on commit 0b65658

Please sign in to comment.