Skip to content

Commit

Permalink
Enrich policy input (#540)
Browse files Browse the repository at this point in the history
* Update SDK

* Add hook to policy input for enrichment
* Fix bug in checking the verdict of policy execution
This currently only considers boolean values, as it is trying to figure
out if it should include `terminal` in the result output.
* Update method signatures to include hook struct
* Fix action runs and skip actions with false verdicts
* Update tests to reflect changes
* Add test case to show how to use hook info
  • Loading branch information
mostafa authored May 24, 2024
1 parent 94a747f commit 1c8196f
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 44 deletions.
16 changes: 9 additions & 7 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

type IRegistry interface {
Add(policy *sdkAct.Policy)
Apply(signals []sdkAct.Signal) []*sdkAct.Output
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
}

Expand Down Expand Up @@ -107,11 +107,11 @@ func (r *Registry) Add(policy *sdkAct.Policy) {
}

// Apply applies the signals to the registry and returns the outputs.
func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
func (r *Registry) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output {
// If there are no signals, apply the default policy.
if len(signals) == 0 {
r.Logger.Debug().Msg("No signals provided, applying default signal")
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
}

// Separate terminal and non-terminal signals to find contradictions.
Expand Down Expand Up @@ -139,7 +139,7 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
}

// Apply the signal and append the output to the list of outputs.
output, err := r.apply(signal)
output, err := r.apply(signal, hook)
if err != nil {
r.Logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal")
// If there is an error evaluating the policy, continue to the next signal.
Expand All @@ -155,14 +155,16 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
}

if len(outputs) == 0 && !evalErr {
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
}

return outputs
}

// apply applies the signal to the registry and returns the output.
func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDError) {
func (r *Registry) apply(
signal sdkAct.Signal, hook sdkAct.Hook,
) (*sdkAct.Output, *gerr.GatewayDError) {
action, exists := r.Actions[signal.Name]
if !exists {
return nil, gerr.ErrActionNotMatched
Expand All @@ -178,12 +180,12 @@ func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDEr
defer cancel()

// Evaluate the policy.
// TODO: Policy should be able to receive other parameters like server and client IPs, etc.
verdict, err := policy.Eval(
ctx, sdkAct.Input{
Name: signal.Name,
Policy: policy.Metadata,
Signal: signal.Metadata,
Hook: hook,
// Action dictates the sync mode, not the signal.
Sync: action.Sync,
},
Expand Down
175 changes: 169 additions & 6 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,17 @@ func Test_Apply(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
})
outputs := actRegistry.Apply(
[]sdkAct.Signal{
*sdkAct.Passthrough(),
},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
Expand All @@ -225,7 +233,15 @@ func Test_Apply_NoSignals(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{})
outputs := actRegistry.Apply(
[]sdkAct.Signal{},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -272,7 +288,12 @@ func Test_Apply_ContradictorySignals(t *testing.T) {
assert.NotNil(t, actRegistry)

for _, s := range signals {
outputs := actRegistry.Apply(s)
outputs := actRegistry.Apply(s, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 2)
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -318,6 +339,11 @@ func Test_Apply_ActionNotMatched(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
{Name: "non-existent"},
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -351,6 +377,11 @@ func Test_Apply_PolicyNotMatched(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -399,6 +430,11 @@ func Test_Apply_NonBoolPolicy(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -447,6 +483,110 @@ func Test_Apply_BadPolicy(t *testing.T) {
}
}

// Test_Apply_Hook tests the Apply function of the act registry with a policy that
// has the hook info and makes use of it.
func Test_Apply_Hook(t *testing.T) {
buf := bytes.Buffer{}
logger := zerolog.New(&buf)

// Custom policy leveraging the hook info.
policies := map[string]*sdkAct.Policy{
"passthrough": sdkAct.MustNewPolicy(
"passthrough",
"true",
nil,
),
"log": sdkAct.MustNewPolicy(
"log",
`Signal.log == true && Policy.log == "enabled" &&
split(Hook.Params.client.remote, ":")[0] == "192.168.0.1"`,
map[string]any{
"log": "enabled",
},
),
}

actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: policies,
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
})
assert.NotNil(t, actRegistry)

hook := sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
// Input parameters for the hook.
Params: map[string]any{
"field": "value",
"server": map[string]any{
"local": "value",
"remote": "value",
},
"client": map[string]any{
"local": "value",
"remote": "192.168.0.1:15432",
},
"request": "Base64EncodedRequest",
"error": "",
},
// Output parameters for the hook.
Result: map[string]any{
"field": "value",
"server": map[string]any{
"local": "value",
"remote": "value",
},
"client": map[string]any{
"local": "value",
"remote": "value",
},
"request": "Base64EncodedRequest",
"error": "",
sdkAct.Signals: []any{
sdkAct.Log("error", "error message", map[string]any{"key": "value"}).ToMap(),
},
"response": "Base64EncodedResponse",
},
}

outputs := actRegistry.Apply(
[]sdkAct.Signal{
*sdkAct.Log(
"error",
"policy matched from incoming address 192.168.0.1, so we are seeing this error message",
map[string]any{"key": "value"},
),
},
hook,
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "log", outputs[0].MatchedPolicy)
assert.Equal(t, outputs[0].Metadata, map[string]any{
"key": "value",
"level": "error",
"log": true,
"message": "policy matched from incoming address 192.168.0.1, so we are seeing this error message",
})
assert.False(t, outputs[0].Sync) // Asynchronous action.
assert.True(t, cast.ToBool(outputs[0].Verdict))
assert.False(t, outputs[0].Terminal)

result, err := actRegistry.Run(outputs[0], WithResult(hook.Result))
assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error")
assert.Nil(t, result, "expected nil result")

time.Sleep(time.Millisecond) // wait for async action to complete

assert.Contains(t, buf.String(), `{"level":"error","key":"value","message":"policy matched from incoming address 192.168.0.1, so we are seeing this error message"}`) //nolint:lll
}

// Test_Run tests the Run function of the act registry with a non-terminal action.
func Test_Run(t *testing.T) {
logger := zerolog.Logger{}
Expand All @@ -464,6 +604,11 @@ func Test_Run(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)

Expand All @@ -489,6 +634,11 @@ func Test_Run_Terminate(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -522,6 +672,11 @@ func Test_Run_Async(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Log("info", "test", map[string]any{"async": true}),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "log", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -647,7 +802,15 @@ func Test_Run_Timeout(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{*signals[name]})
outputs := actRegistry.Apply(
[]sdkAct.Signal{*signals[name]},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Equal(t, name, outputs[0].MatchedPolicy)
assert.Equal(t,
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/codingsince1985/checksum v1.3.0
github.com/cybercyst/go-scaffold v0.0.0-20240404115540-744e601147cd
github.com/envoyproxy/protoc-gen-validate v1.0.4
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.13
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.14
github.com/getsentry/sentry-go v0.27.0
github.com/go-co-op/gocron v1.37.0
github.com/google/go-github/v53 v53.2.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 21 additions & 17 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -856,31 +856,35 @@ func (pr *Proxy) shouldTerminate(result map[string]interface{}) (bool, map[strin
// The Terminal field is only present if the action wants to terminate the request,
// that is the `__terminal__` field is set in one of the outputs.
keys := maps.Keys(result)
if slices.Contains(keys, sdkAct.Terminal) {
var actionResult map[string]interface{}
for _, output := range outputs {
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]interface{}); ok {
actionResult = v
}
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
actionResult := make(map[string]interface{})
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
pr.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]interface{}); ok {
actionResult = v
}
}
if terminate {
pr.Logger.Debug().Fields(
map[string]interface{}{
"function": "proxy.passthrough",
"reason": "terminate",
},
).Msg("Terminating request")
return cast.ToBool(result[sdkAct.Terminal]), actionResult
}

return false, result
return terminate, actionResult
}

// getPluginModifiedRequest is a function that retrieves the modified request
Expand Down
Loading

0 comments on commit 1c8196f

Please sign in to comment.