Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enrich policy input #540

Merged
merged 7 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

type IRegistry interface {
Add(policy *sdkAct.Policy)
Apply(signals []sdkAct.Signal) []*sdkAct.Output
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
}

Expand Down Expand Up @@ -107,11 +107,11 @@ func (r *Registry) Add(policy *sdkAct.Policy) {
}

// Apply applies the signals to the registry and returns the outputs.
func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
func (r *Registry) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output {
// If there are no signals, apply the default policy.
if len(signals) == 0 {
r.Logger.Debug().Msg("No signals provided, applying default signal")
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
}

// Separate terminal and non-terminal signals to find contradictions.
Expand Down Expand Up @@ -139,7 +139,7 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
}

// Apply the signal and append the output to the list of outputs.
output, err := r.apply(signal)
output, err := r.apply(signal, hook)
if err != nil {
r.Logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal")
// If there is an error evaluating the policy, continue to the next signal.
Expand All @@ -155,14 +155,16 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
}

if len(outputs) == 0 && !evalErr {
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
}

return outputs
}

// apply applies the signal to the registry and returns the output.
func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDError) {
func (r *Registry) apply(
signal sdkAct.Signal, hook sdkAct.Hook,
) (*sdkAct.Output, *gerr.GatewayDError) {
action, exists := r.Actions[signal.Name]
if !exists {
return nil, gerr.ErrActionNotMatched
Expand All @@ -178,12 +180,12 @@ func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDEr
defer cancel()

// Evaluate the policy.
// TODO: Policy should be able to receive other parameters like server and client IPs, etc.
verdict, err := policy.Eval(
ctx, sdkAct.Input{
Name: signal.Name,
Policy: policy.Metadata,
Signal: signal.Metadata,
Hook: hook,
// Action dictates the sync mode, not the signal.
Sync: action.Sync,
},
Expand Down
175 changes: 169 additions & 6 deletions act/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,17 @@ func Test_Apply(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
})
outputs := actRegistry.Apply(
[]sdkAct.Signal{
*sdkAct.Passthrough(),
},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
Expand All @@ -225,7 +233,15 @@ func Test_Apply_NoSignals(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{})
outputs := actRegistry.Apply(
[]sdkAct.Signal{},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -272,7 +288,12 @@ func Test_Apply_ContradictorySignals(t *testing.T) {
assert.NotNil(t, actRegistry)

for _, s := range signals {
outputs := actRegistry.Apply(s)
outputs := actRegistry.Apply(s, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 2)
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -318,6 +339,11 @@ func Test_Apply_ActionNotMatched(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
{Name: "non-existent"},
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -351,6 +377,11 @@ func Test_Apply_PolicyNotMatched(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -399,6 +430,11 @@ func Test_Apply_NonBoolPolicy(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
Expand Down Expand Up @@ -447,6 +483,110 @@ func Test_Apply_BadPolicy(t *testing.T) {
}
}

// Test_Apply_Hook tests the Apply function of the act registry with a policy that
// has the hook info and makes use of it.
func Test_Apply_Hook(t *testing.T) {
buf := bytes.Buffer{}
logger := zerolog.New(&buf)

// Custom policy leveraging the hook info.
policies := map[string]*sdkAct.Policy{
"passthrough": sdkAct.MustNewPolicy(
"passthrough",
"true",
nil,
),
"log": sdkAct.MustNewPolicy(
"log",
`Signal.log == true && Policy.log == "enabled" &&
split(Hook.Params.client.remote, ":")[0] == "192.168.0.1"`,
map[string]any{
"log": "enabled",
},
),
}

actRegistry := NewActRegistry(
Registry{
Signals: BuiltinSignals(),
Policies: policies,
Actions: BuiltinActions(),
DefaultPolicyName: config.DefaultPolicy,
PolicyTimeout: config.DefaultPolicyTimeout,
DefaultActionTimeout: config.DefaultActionTimeout,
Logger: logger,
})
assert.NotNil(t, actRegistry)

hook := sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
// Input parameters for the hook.
Params: map[string]any{
"field": "value",
"server": map[string]any{
"local": "value",
"remote": "value",
},
"client": map[string]any{
"local": "value",
"remote": "192.168.0.1:15432",
},
"request": "Base64EncodedRequest",
"error": "",
},
// Output parameters for the hook.
Result: map[string]any{
"field": "value",
"server": map[string]any{
"local": "value",
"remote": "value",
},
"client": map[string]any{
"local": "value",
"remote": "value",
},
"request": "Base64EncodedRequest",
"error": "",
sdkAct.Signals: []any{
sdkAct.Log("error", "error message", map[string]any{"key": "value"}).ToMap(),
},
"response": "Base64EncodedResponse",
},
}

outputs := actRegistry.Apply(
[]sdkAct.Signal{
*sdkAct.Log(
"error",
"policy matched from incoming address 192.168.0.1, so we are seeing this error message",
map[string]any{"key": "value"},
),
},
hook,
)
assert.NotNil(t, outputs)
assert.Len(t, outputs, 1)
assert.Equal(t, "log", outputs[0].MatchedPolicy)
assert.Equal(t, outputs[0].Metadata, map[string]any{
"key": "value",
"level": "error",
"log": true,
"message": "policy matched from incoming address 192.168.0.1, so we are seeing this error message",
})
assert.False(t, outputs[0].Sync) // Asynchronous action.
assert.True(t, cast.ToBool(outputs[0].Verdict))
assert.False(t, outputs[0].Terminal)

result, err := actRegistry.Run(outputs[0], WithResult(hook.Result))
assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error")
assert.Nil(t, result, "expected nil result")

time.Sleep(time.Millisecond) // wait for async action to complete

assert.Contains(t, buf.String(), `{"level":"error","key":"value","message":"policy matched from incoming address 192.168.0.1, so we are seeing this error message"}`) //nolint:lll
}

// Test_Run tests the Run function of the act registry with a non-terminal action.
func Test_Run(t *testing.T) {
logger := zerolog.Logger{}
Expand All @@ -464,6 +604,11 @@ func Test_Run(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Passthrough(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)

Expand All @@ -489,6 +634,11 @@ func Test_Run_Terminate(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Terminate(),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -522,6 +672,11 @@ func Test_Run_Async(t *testing.T) {

outputs := actRegistry.Apply([]sdkAct.Signal{
*sdkAct.Log("info", "test", map[string]any{"async": true}),
}, sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
})
assert.NotNil(t, outputs)
assert.Equal(t, "log", outputs[0].MatchedPolicy)
Expand Down Expand Up @@ -647,7 +802,15 @@ func Test_Run_Timeout(t *testing.T) {
})
assert.NotNil(t, actRegistry)

outputs := actRegistry.Apply([]sdkAct.Signal{*signals[name]})
outputs := actRegistry.Apply(
[]sdkAct.Signal{*signals[name]},
sdkAct.Hook{
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
Priority: 1000,
Params: map[string]any{},
Result: map[string]any{},
},
)
assert.NotNil(t, outputs)
assert.Equal(t, name, outputs[0].MatchedPolicy)
assert.Equal(t,
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/codingsince1985/checksum v1.3.0
github.com/cybercyst/go-scaffold v0.0.0-20240404115540-744e601147cd
github.com/envoyproxy/protoc-gen-validate v1.0.4
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.13
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.14
github.com/getsentry/sentry-go v0.27.0
github.com/go-co-op/gocron v1.37.0
github.com/google/go-github/v53 v53.2.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 21 additions & 17 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -856,31 +856,35 @@ func (pr *Proxy) shouldTerminate(result map[string]interface{}) (bool, map[strin
// The Terminal field is only present if the action wants to terminate the request,
// that is the `__terminal__` field is set in one of the outputs.
keys := maps.Keys(result)
if slices.Contains(keys, sdkAct.Terminal) {
var actionResult map[string]interface{}
for _, output := range outputs {
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]interface{}); ok {
actionResult = v
}
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
actionResult := make(map[string]interface{})
for _, output := range outputs {
if !cast.ToBool(output.Verdict) {
pr.Logger.Debug().Msg(
"Skipping the action, because the verdict of the policy execution is false")
continue
}
actRes, err := pr.PluginRegistry.ActRegistry.Run(
output, act.WithResult(result))
// If the action is async and we received a sentinel error,
// don't log the error.
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
pr.Logger.Error().Err(err).Msg("Error running policy")
}
// The terminate action should return a map.
if v, ok := actRes.(map[string]interface{}); ok {
actionResult = v
}
}
if terminate {
pr.Logger.Debug().Fields(
map[string]interface{}{
"function": "proxy.passthrough",
"reason": "terminate",
},
).Msg("Terminating request")
return cast.ToBool(result[sdkAct.Terminal]), actionResult
}

return false, result
return terminate, actionResult
}

// getPluginModifiedRequest is a function that retrieves the modified request
Expand Down
Loading