Skip to content

Commit

Permalink
misc: Clean up generics usage in parse_config builtins (open-policy-a…
Browse files Browse the repository at this point in the history
…gent#898)

The type does not need to be derived from a parameter to the generic function.

Signed-off-by: James Alseth <[email protected]>
  • Loading branch information
jalseth authored Dec 13, 2023
1 parent 8a6e121 commit 392ddaf
Showing 1 changed file with 46 additions and 24 deletions.
70 changes: 46 additions & 24 deletions builtins/parse_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func registerParseCombinedConfigFiles() {
// parsed configuration as a Rego object. This can be used to parse all of the
// configuration formats conftest supports in-line in Rego policies.
func parseConfig(bctx rego.BuiltinContext, op1, op2 *ast.Term) (*ast.Term, error) {
args, err := decodeTypedArgs("", op1, op2)
args, err := decodeArgs[string](op1, op2)
if err != nil {
return nil, fmt.Errorf("decode args: %w", err)
}
Expand All @@ -77,11 +77,11 @@ func parseConfig(bctx rego.BuiltinContext, op1, op2 *ast.Term) (*ast.Term, error
// parseConfigFile takes a config file path, parses the config file, and
// returns the parsed configuration as a Rego object.
func parseConfigFile(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Term, error) {
args, err := decodeTypedArgs("", op1)
path, err := decodeArg[string](op1)
if err != nil {
return nil, fmt.Errorf("decode args: %w", err)
}
filePath := filepath.Join(filepath.Dir(bctx.Location.File), args[0])
filePath := filepath.Join(filepath.Dir(bctx.Location.File), path)

parser, err := parser.NewFromPath(filePath)
if err != nil {
Expand All @@ -103,22 +103,12 @@ func parseConfigFile(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Term, error)
// parseCombinedConfigFiles takes multiple config file paths, parses the configs,
// combines them, and returns that as a Rego object.
func parseCombinedConfigFiles(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Term, error) {
iface, err := ast.ValueToInterface(op1.Value, nil)
paths, err := decodeSliceArg[string](op1)
if err != nil {
return nil, fmt.Errorf("ast.ValueToInterface: %w", err)
return nil, fmt.Errorf("decode args: %w", err)
}
slice, ok := iface.([]any)
if !ok {
return nil, fmt.Errorf("argument is not a slice")
}

var paths []string
for i, s := range slice {
path, ok := s.(string)
if !ok {
return nil, fmt.Errorf("index %d is not expected type string", i)
}
paths = append(paths, filepath.Join(filepath.Dir(bctx.Location.File), path))
for i, p := range paths {
paths[i] = filepath.Join(filepath.Dir(bctx.Location.File), p)
}

cfg, err := parser.ParseConfigurations(paths)
Expand All @@ -134,16 +124,48 @@ func parseCombinedConfigFiles(bctx rego.BuiltinContext, op1 *ast.Term) (*ast.Ter
return toAST(bctx, combined["Combined"], content)
}

func decodeTypedArgs[T any](ty T, args ...*ast.Term) ([]T, error) {
func decodeSliceArg[T any](arg *ast.Term) ([]T, error) {
iface, err := ast.ValueToInterface(arg.Value, nil)
if err != nil {
return nil, fmt.Errorf("decode arg: %w", err)
}
ifaceSlice, ok := iface.([]any)
if !ok {
return nil, fmt.Errorf("decodeSliceArg used with non-slice value: (%T)%v", iface, iface)
}

var t T
slice := make([]T, len(ifaceSlice))
for i, val := range ifaceSlice {
v, ok := val.(T)
if !ok {
return nil, fmt.Errorf("slice index %d is not expected type %T, got %T", i, t, val)
}
slice[i] = v
}

return slice, nil
}

func decodeArg[T any](arg *ast.Term) (T, error) {
iface, err := ast.ValueToInterface(arg.Value, nil)
if err != nil {
return *new(T), fmt.Errorf("ast.ValueToInterface: %w", err)
}
v, ok := iface.(T)
if !ok {
return *new(T), fmt.Errorf("argument is not expected type, have %T", iface)
}

return v, nil
}

func decodeArgs[T any](args ...*ast.Term) ([]T, error) {
decoded := make([]T, len(args))
for i, arg := range args {
iface, err := ast.ValueToInterface(arg.Value, nil)
v, err := decodeArg[T](arg)
if err != nil {
return nil, fmt.Errorf("ast.ValueToInterface: %w", err)
}
v, ok := iface.(T)
if !ok {
return nil, fmt.Errorf("argument %d is not type %T, have %T", i, ty, iface)
return nil, err
}
decoded[i] = v
}
Expand Down

0 comments on commit 392ddaf

Please sign in to comment.