diff --git a/act/registry.go b/act/registry.go index 1a45a85d..be65dad7 100644 --- a/act/registry.go +++ b/act/registry.go @@ -14,7 +14,7 @@ import ( type IRegistry interface { Add(policy *sdkAct.Policy) - Apply(signals []sdkAct.Signal) []*sdkAct.Output + Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError) } @@ -107,11 +107,11 @@ func (r *Registry) Add(policy *sdkAct.Policy) { } // Apply applies the signals to the registry and returns the outputs. -func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output { +func (r *Registry) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output { // If there are no signals, apply the default policy. if len(signals) == 0 { r.Logger.Debug().Msg("No signals provided, applying default signal") - return r.Apply([]sdkAct.Signal{*r.DefaultSignal}) + return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook) } // Separate terminal and non-terminal signals to find contradictions. @@ -139,7 +139,7 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output { } // Apply the signal and append the output to the list of outputs. - output, err := r.apply(signal) + output, err := r.apply(signal, hook) if err != nil { r.Logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal") // If there is an error evaluating the policy, continue to the next signal. @@ -155,14 +155,16 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output { } if len(outputs) == 0 && !evalErr { - return r.Apply([]sdkAct.Signal{*r.DefaultSignal}) + return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook) } return outputs } // apply applies the signal to the registry and returns the output. -func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDError) { +func (r *Registry) apply( + signal sdkAct.Signal, hook sdkAct.Hook, +) (*sdkAct.Output, *gerr.GatewayDError) { action, exists := r.Actions[signal.Name] if !exists { return nil, gerr.ErrActionNotMatched @@ -178,12 +180,12 @@ func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDEr defer cancel() // Evaluate the policy. - // TODO: Policy should be able to receive other parameters like server and client IPs, etc. verdict, err := policy.Eval( ctx, sdkAct.Input{ Name: signal.Name, Policy: policy.Metadata, Signal: signal.Metadata, + Hook: hook, // Action dictates the sync mode, not the signal. Sync: action.Sync, }, diff --git a/act/registry_test.go b/act/registry_test.go index c5cc3e7c..13abf0bd 100644 --- a/act/registry_test.go +++ b/act/registry_test.go @@ -196,9 +196,17 @@ func Test_Apply(t *testing.T) { }) assert.NotNil(t, actRegistry) - outputs := actRegistry.Apply([]sdkAct.Signal{ - *sdkAct.Passthrough(), - }) + outputs := actRegistry.Apply( + []sdkAct.Signal{ + *sdkAct.Passthrough(), + }, + sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, + }, + ) assert.NotNil(t, outputs) assert.Len(t, outputs, 1) assert.Equal(t, "passthrough", outputs[0].MatchedPolicy) @@ -225,7 +233,15 @@ func Test_Apply_NoSignals(t *testing.T) { }) assert.NotNil(t, actRegistry) - outputs := actRegistry.Apply([]sdkAct.Signal{}) + outputs := actRegistry.Apply( + []sdkAct.Signal{}, + sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, + }, + ) assert.NotNil(t, outputs) assert.Len(t, outputs, 1) assert.Equal(t, "passthrough", outputs[0].MatchedPolicy) @@ -272,7 +288,12 @@ func Test_Apply_ContradictorySignals(t *testing.T) { assert.NotNil(t, actRegistry) for _, s := range signals { - outputs := actRegistry.Apply(s) + outputs := actRegistry.Apply(s, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, + }) assert.NotNil(t, outputs) assert.Len(t, outputs, 2) assert.Equal(t, "terminate", outputs[0].MatchedPolicy) @@ -318,6 +339,11 @@ func Test_Apply_ActionNotMatched(t *testing.T) { outputs := actRegistry.Apply([]sdkAct.Signal{ {Name: "non-existent"}, + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, }) assert.NotNil(t, outputs) assert.Len(t, outputs, 1) @@ -351,6 +377,11 @@ func Test_Apply_PolicyNotMatched(t *testing.T) { outputs := actRegistry.Apply([]sdkAct.Signal{ *sdkAct.Terminate(), + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, }) assert.NotNil(t, outputs) assert.Len(t, outputs, 1) @@ -399,6 +430,11 @@ func Test_Apply_NonBoolPolicy(t *testing.T) { outputs := actRegistry.Apply([]sdkAct.Signal{ *sdkAct.Passthrough(), + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, }) assert.NotNil(t, outputs) assert.Len(t, outputs, 1) @@ -447,6 +483,110 @@ func Test_Apply_BadPolicy(t *testing.T) { } } +// Test_Apply_Hook tests the Apply function of the act registry with a policy that +// has the hook info and makes use of it. +func Test_Apply_Hook(t *testing.T) { + buf := bytes.Buffer{} + logger := zerolog.New(&buf) + + // Custom policy leveraging the hook info. + policies := map[string]*sdkAct.Policy{ + "passthrough": sdkAct.MustNewPolicy( + "passthrough", + "true", + nil, + ), + "log": sdkAct.MustNewPolicy( + "log", + `Signal.log == true && Policy.log == "enabled" && + split(Hook.Params.client.remote, ":")[0] == "192.168.0.1"`, + map[string]any{ + "log": "enabled", + }, + ), + } + + actRegistry := NewActRegistry( + Registry{ + Signals: BuiltinSignals(), + Policies: policies, + Actions: BuiltinActions(), + DefaultPolicyName: config.DefaultPolicy, + PolicyTimeout: config.DefaultPolicyTimeout, + DefaultActionTimeout: config.DefaultActionTimeout, + Logger: logger, + }) + assert.NotNil(t, actRegistry) + + hook := sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + // Input parameters for the hook. + Params: map[string]any{ + "field": "value", + "server": map[string]any{ + "local": "value", + "remote": "value", + }, + "client": map[string]any{ + "local": "value", + "remote": "192.168.0.1:15432", + }, + "request": "Base64EncodedRequest", + "error": "", + }, + // Output parameters for the hook. + Result: map[string]any{ + "field": "value", + "server": map[string]any{ + "local": "value", + "remote": "value", + }, + "client": map[string]any{ + "local": "value", + "remote": "value", + }, + "request": "Base64EncodedRequest", + "error": "", + sdkAct.Signals: []any{ + sdkAct.Log("error", "error message", map[string]any{"key": "value"}).ToMap(), + }, + "response": "Base64EncodedResponse", + }, + } + + outputs := actRegistry.Apply( + []sdkAct.Signal{ + *sdkAct.Log( + "error", + "policy matched from incoming address 192.168.0.1, so we are seeing this error message", + map[string]any{"key": "value"}, + ), + }, + hook, + ) + assert.NotNil(t, outputs) + assert.Len(t, outputs, 1) + assert.Equal(t, "log", outputs[0].MatchedPolicy) + assert.Equal(t, outputs[0].Metadata, map[string]any{ + "key": "value", + "level": "error", + "log": true, + "message": "policy matched from incoming address 192.168.0.1, so we are seeing this error message", + }) + assert.False(t, outputs[0].Sync) // Asynchronous action. + assert.True(t, cast.ToBool(outputs[0].Verdict)) + assert.False(t, outputs[0].Terminal) + + result, err := actRegistry.Run(outputs[0], WithResult(hook.Result)) + assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error") + assert.Nil(t, result, "expected nil result") + + time.Sleep(time.Millisecond) // wait for async action to complete + + assert.Contains(t, buf.String(), `{"level":"error","key":"value","message":"policy matched from incoming address 192.168.0.1, so we are seeing this error message"}`) //nolint:lll +} + // Test_Run tests the Run function of the act registry with a non-terminal action. func Test_Run(t *testing.T) { logger := zerolog.Logger{} @@ -464,6 +604,11 @@ func Test_Run(t *testing.T) { outputs := actRegistry.Apply([]sdkAct.Signal{ *sdkAct.Passthrough(), + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, }) assert.NotNil(t, outputs) @@ -489,6 +634,11 @@ func Test_Run_Terminate(t *testing.T) { outputs := actRegistry.Apply([]sdkAct.Signal{ *sdkAct.Terminate(), + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, }) assert.NotNil(t, outputs) assert.Equal(t, "terminate", outputs[0].MatchedPolicy) @@ -522,6 +672,11 @@ func Test_Run_Async(t *testing.T) { outputs := actRegistry.Apply([]sdkAct.Signal{ *sdkAct.Log("info", "test", map[string]any{"async": true}), + }, sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, }) assert.NotNil(t, outputs) assert.Equal(t, "log", outputs[0].MatchedPolicy) @@ -647,7 +802,15 @@ func Test_Run_Timeout(t *testing.T) { }) assert.NotNil(t, actRegistry) - outputs := actRegistry.Apply([]sdkAct.Signal{*signals[name]}) + outputs := actRegistry.Apply( + []sdkAct.Signal{*signals[name]}, + sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, + }, + ) assert.NotNil(t, outputs) assert.Equal(t, name, outputs[0].MatchedPolicy) assert.Equal(t, diff --git a/go.mod b/go.mod index caed14f2..e062d121 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/codingsince1985/checksum v1.3.0 github.com/cybercyst/go-scaffold v0.0.0-20240404115540-744e601147cd github.com/envoyproxy/protoc-gen-validate v1.0.4 - github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.13 + github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.14 github.com/getsentry/sentry-go v0.27.0 github.com/go-co-op/gocron v1.37.0 github.com/google/go-github/v53 v53.2.0 diff --git a/go.sum b/go.sum index 0666a23e..665e5af0 100644 --- a/go.sum +++ b/go.sum @@ -161,8 +161,8 @@ github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7z github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.13 h1:zjsMK6m/DwaD8vHmPDKhMyhUuWlRzF4Y8FO3hNmujZg= -github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.13/go.mod h1:TN8dII/sN3awR0znv2vY25rhHLN9XyMTNnEIUWjioMk= +github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.14 h1:h1lw4Hw0ugF+dOC2ytbEQA2exfZreLutR79+nBFLYSg= +github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.14/go.mod h1:TN8dII/sN3awR0znv2vY25rhHLN9XyMTNnEIUWjioMk= github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= diff --git a/network/proxy.go b/network/proxy.go index 3a6f47d8..69583fbc 100644 --- a/network/proxy.go +++ b/network/proxy.go @@ -856,31 +856,35 @@ func (pr *Proxy) shouldTerminate(result map[string]interface{}) (bool, map[strin // The Terminal field is only present if the action wants to terminate the request, // that is the `__terminal__` field is set in one of the outputs. keys := maps.Keys(result) - if slices.Contains(keys, sdkAct.Terminal) { - var actionResult map[string]interface{} - for _, output := range outputs { - actRes, err := pr.PluginRegistry.ActRegistry.Run( - output, act.WithResult(result)) - // If the action is async and we received a sentinel error, - // don't log the error. - if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { - pr.Logger.Error().Err(err).Msg("Error running policy") - } - // The terminate action should return a map. - if v, ok := actRes.(map[string]interface{}); ok { - actionResult = v - } + terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal]) + actionResult := make(map[string]interface{}) + for _, output := range outputs { + if !cast.ToBool(output.Verdict) { + pr.Logger.Debug().Msg( + "Skipping the action, because the verdict of the policy execution is false") + continue + } + actRes, err := pr.PluginRegistry.ActRegistry.Run( + output, act.WithResult(result)) + // If the action is async and we received a sentinel error, + // don't log the error. + if err != nil && !errors.Is(err, gerr.ErrAsyncAction) { + pr.Logger.Error().Err(err).Msg("Error running policy") } + // The terminate action should return a map. + if v, ok := actRes.(map[string]interface{}); ok { + actionResult = v + } + } + if terminate { pr.Logger.Debug().Fields( map[string]interface{}{ "function": "proxy.passthrough", "reason": "terminate", }, ).Msg("Terminating request") - return cast.ToBool(result[sdkAct.Terminal]), actionResult } - - return false, result + return terminate, actionResult } // getPluginModifiedRequest is a function that retrieves the modified request diff --git a/plugin/plugin_registry.go b/plugin/plugin_registry.go index 07c7ebc5..d50a0c47 100644 --- a/plugin/plugin_registry.go +++ b/plugin/plugin_registry.go @@ -20,6 +20,7 @@ import ( goplugin "github.com/hashicorp/go-plugin" "github.com/mitchellh/mapstructure" "github.com/rs/zerolog" + "github.com/spf13/cast" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "google.golang.org/grpc" @@ -49,7 +50,7 @@ type IRegistry interface { Shutdown() LoadPlugins(ctx context.Context, plugins []config.Plugin, startTimeout time.Duration) RegisterHooks(ctx context.Context, pluginID sdkPlugin.Identifier) - Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output, bool) + Apply(hook sdkAct.Hook) ([]*sdkAct.Output, bool) // Hook management IHook @@ -329,7 +330,14 @@ func (reg *Registry) Run( continue } - out, terminal := reg.Apply(hookName.String(), result) + out, terminal := reg.Apply( + sdkAct.Hook{ + Name: hookName.String(), + Priority: uint(priority), + Params: params.AsMap(), + Result: result.AsMap(), + }, + ) outputs = append(outputs, out...) if terminal { @@ -352,16 +360,16 @@ func (reg *Registry) Run( } // Apply applies policies to the result. -func (reg *Registry) Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output, bool) { +func (reg *Registry) Apply(hook sdkAct.Hook) ([]*sdkAct.Output, bool) { _, span := otel.Tracer(config.TracerName).Start(reg.ctx, "Apply") defer span.End() // Get signals from the result. - signals := getSignals(result.AsMap()) + signals := getSignals(hook.Result) // Apply policies to the signals. // The outputs contain the verdicts of the policies and their metadata. // And using this list, the caller can take further actions. - outputs := applyPolicies(hookName, signals, reg.Logger, reg.ActRegistry) + outputs := applyPolicies(hook, signals, reg.Logger, reg.ActRegistry) // If no policies are found, return a default output. // Note: this should never happen, as the default policy is always loaded. @@ -373,7 +381,7 @@ func (reg *Registry) Apply(hookName string, result *v1.Struct) ([]*sdkAct.Output // Check if any of the policies have a terminal action. var terminal bool for _, output := range outputs { - if output.Verdict != nil && output.Terminal { + if output.Verdict != nil && cast.ToBool(output.Verdict) && output.Terminal { terminal = true break } diff --git a/plugin/utils.go b/plugin/utils.go index e11bda2f..d117b5f3 100644 --- a/plugin/utils.go +++ b/plugin/utils.go @@ -76,7 +76,10 @@ func getSignals(result map[string]any) []sdkAct.Signal { // applyPolicies applies the policies to the signals and returns the outputs. func applyPolicies( - hookName string, signals []sdkAct.Signal, logger zerolog.Logger, reg act.IRegistry, + hook sdkAct.Hook, + signals []sdkAct.Signal, + logger zerolog.Logger, + reg act.IRegistry, ) []*sdkAct.Output { signalNames := []string{} for _, signal := range signals { @@ -85,15 +88,15 @@ func applyPolicies( logger.Debug().Fields( map[string]interface{}{ - "hook": hookName, + "hook": hook.Name, "signals": signalNames, }, ).Msg("Detected signals from the plugin hook") - outputs := reg.Apply(signals) + outputs := reg.Apply(signals, hook) logger.Debug().Fields( map[string]interface{}{ - "hook": hookName, + "hook": hook.Name, "outputs": outputs, }, ).Msg("Applied policies to signals") diff --git a/plugin/utils_test.go b/plugin/utils_test.go index 690c492e..68e573e8 100644 --- a/plugin/utils_test.go +++ b/plugin/utils_test.go @@ -102,7 +102,16 @@ func Test_applyPolicies(t *testing.T) { }) output := applyPolicies( - "onTrafficFromClient", []sdkAct.Signal{*sdkAct.Passthrough()}, logger, actRegistry) + sdkAct.Hook{ + Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT", + Priority: 1000, + Params: map[string]any{}, + Result: map[string]any{}, + }, + []sdkAct.Signal{*sdkAct.Passthrough()}, + logger, + actRegistry, + ) assert.Len(t, output, 1) assert.Equal(t, "passthrough", output[0].MatchedPolicy) assert.Nil(t, output[0].Metadata)