diff --git a/builtins/parse_config.go b/builtins/parse_config.go index 44b3a20207..f2018aa650 100644 --- a/builtins/parse_config.go +++ b/builtins/parse_config.go @@ -56,7 +56,7 @@ func registerParseCombinedConfigFiles() { // parsed configuration as a Rego object. This can be used to parse all of the // configuration formats conftest supports in-line in Rego policies. func parseConfig(bctx rego.BuiltinContext, op1, op2 *ast.Term) (*ast.Term, error) { - args, err := decodeTypedArgs("", op1, op2) + args, err := decodeArgs[string](op1, op2) if err != nil { return nil, fmt.Errorf("decode args: %w", err) } @@ -77,11 +77,11 @@ func parseConfig(bctx rego.BuiltinContext, op1, op2 *ast.Term) (*ast.Term, error // parseConfigFile takes a config file path, parses the config file, and // returns the parsed configuration as a Rego object. func parseConfigFile(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Term, error) { - args, err := decodeTypedArgs("", op1) + path, err := decodeArg[string](op1) if err != nil { return nil, fmt.Errorf("decode args: %w", err) } - filePath := filepath.Join(filepath.Dir(bctx.Location.File), args[0]) + filePath := filepath.Join(filepath.Dir(bctx.Location.File), path) parser, err := parser.NewFromPath(filePath) if err != nil { @@ -103,22 +103,12 @@ func parseConfigFile(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Term, error) // parseCombinedConfigFiles takes multiple config file paths, parses the configs, // combines them, and returns that as a Rego object. func parseCombinedConfigFiles(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Term, error) { - iface, err := ast.ValueToInterface(op1.Value, nil) + paths, err := decodeSliceArg[string](op1) if err != nil { - return nil, fmt.Errorf("ast.ValueToInterface: %w", err) + return nil, fmt.Errorf("decode args: %w", err) } - slice, ok := iface.([]any) - if !ok { - return nil, fmt.Errorf("argument is not a slice") - } - - var paths []string - for i, s := range slice { - path, ok := s.(string) - if !ok { - return nil, fmt.Errorf("index %d is not expected type string", i) - } - paths = append(paths, filepath.Join(filepath.Dir(bctx.Location.File), path)) + for i, p := range paths { + paths[i] = filepath.Join(filepath.Dir(bctx.Location.File), p) } cfg, err := parser.ParseConfigurations(paths) @@ -134,16 +124,48 @@ func parseCombinedConfigFiles(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Ter return toAST(bctx, combined["Combined"], content) } -func decodeTypedArgs[T any](ty T, args ...*ast.Term) ([]T, error) { +func decodeSliceArg[T any](arg *ast.Term) ([]T, error) { + iface, err := ast.ValueToInterface(arg.Value, nil) + if err != nil { + return nil, fmt.Errorf("decode arg: %w", err) + } + ifaceSlice, ok := iface.([]any) + if !ok { + return nil, fmt.Errorf("decodeSliceArg used with non-slice value: (%T)%v", iface, iface) + } + + var t T + slice := make([]T, len(ifaceSlice)) + for i, val := range ifaceSlice { + v, ok := val.(T) + if !ok { + return nil, fmt.Errorf("slice index %d is not expected type %T, got %T", i, t, val) + } + slice[i] = v + } + + return slice, nil +} + +func decodeArg[T any](arg *ast.Term) (T, error) { + iface, err := ast.ValueToInterface(arg.Value, nil) + if err != nil { + return *new(T), fmt.Errorf("ast.ValueToInterface: %w", err) + } + v, ok := iface.(T) + if !ok { + return *new(T), fmt.Errorf("argument is not expected type, have %T", iface) + } + + return v, nil +} + +func decodeArgs[T any](args ...*ast.Term) ([]T, error) { decoded := make([]T, len(args)) for i, arg := range args { - iface, err := ast.ValueToInterface(arg.Value, nil) + v, err := decodeArg[T](arg) if err != nil { - return nil, fmt.Errorf("ast.ValueToInterface: %w", err) - } - v, ok := iface.(T) - if !ok { - return nil, fmt.Errorf("argument %d is not type %T, have %T", i, ty, iface) + return nil, err } decoded[i] = v }