Skip to content

Commit

Permalink
refactor: misc (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliottness authored Oct 2, 2024
1 parent c26efc5 commit 24e43d9
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 72 deletions.
18 changes: 11 additions & 7 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,18 @@ func decodeErrors(obj *bindings.WafObject) (map[string][]string, error) {
return wafErrors, nil
}

func decodeDiagnostics(obj *bindings.WafObject) (*Diagnostics, error) {
func decodeDiagnostics(obj *bindings.WafObject) (Diagnostics, error) {
if !obj.IsMap() {
return nil, errors.ErrInvalidObjectType
return Diagnostics{}, errors.ErrInvalidObjectType
}
if obj.Value == 0 && obj.NbEntries > 0 {
return nil, errors.ErrNilObjectPtr
return Diagnostics{}, errors.ErrNilObjectPtr
}

var diags Diagnostics
var err error
var (
diags Diagnostics
err error
)
for i := uint64(0); i < obj.NbEntries; i++ {
objElem := unsafe.CastWithOffset[bindings.WafObject](obj.Value, i)
key := unsafe.GostringSized(unsafe.Cast[byte](objElem.ParameterName), objElem.ParameterNameLength)
Expand All @@ -62,6 +64,8 @@ func decodeDiagnostics(obj *bindings.WafObject) (*Diagnostics, error) {
diags.Rules, err = decodeDiagnosticsEntry(objElem)
case "rules_data":
diags.RulesData, err = decodeDiagnosticsEntry(objElem)
case "exclusion_data":
diags.RulesData, err = decodeDiagnosticsEntry(objElem)
case "rules_override":
diags.RulesOverrides, err = decodeDiagnosticsEntry(objElem)
case "processors":
Expand All @@ -74,11 +78,11 @@ func decodeDiagnostics(obj *bindings.WafObject) (*Diagnostics, error) {
// ignore?
}
if err != nil {
return nil, err
return Diagnostics{}, err
}
}

return &diags, nil
return diags, nil
}

func decodeDiagnosticsEntry(obj *bindings.WafObject) (*DiagnosticEntry, error) {
Expand Down
27 changes: 13 additions & 14 deletions errors/waf.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,23 @@ const (
ErrEmptyRuleAddresses
)

var errorStrMap = map[RunError]string{
ErrInternal: "internal waf error",
ErrInvalidObject: "invalid waf object",
ErrInvalidArgument: "invalid waf argument",
ErrTimeout: "waf timeout",
ErrOutOfMemory: "out of memory",
ErrEmptyRuleAddresses: "empty rule addresses",
}

// Error returns the string representation of the RunError.
func (e RunError) Error() string {
switch e {
case ErrInternal:
return "internal waf error"
case ErrTimeout:
return "waf timeout"
case ErrInvalidObject:
return "invalid waf object"
case ErrInvalidArgument:
return "invalid waf argument"
case ErrOutOfMemory:
return "out of memory"
case ErrEmptyRuleAddresses:
return "empty rule addresses"
default:
description, ok := errorStrMap[e]
if !ok {
return fmt.Sprintf("unknown waf error %d", e)
}

return description
}

// PanicError is an error type wrapping a recovered panic value that happened
Expand Down
52 changes: 16 additions & 36 deletions handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,35 @@ func NewHandle(rules any, keyObfuscatorRegex string, valueObfuscatorRegex string
diagnosticsWafObj := new(bindings.WafObject)
defer wafLib.WafObjectFree(diagnosticsWafObj)

cHandle := wafLib.WafInit(obj, config, diagnosticsWafObj)
// Upon failure, the WAF may have produced some diagnostics to help signal what went wrong...
var (
diags = new(Diagnostics)
diagsErr error
)
if !diagnosticsWafObj.IsInvalid() {
diags, diagsErr = decodeDiagnostics(diagnosticsWafObj)
unsafe.KeepAlive(encoder.cgoRefs)

return newHandle(wafLib.WafInit(obj, config, diagnosticsWafObj), diagnosticsWafObj)
}

// newHandle creates a new Handle from a C handle (nullable) and a diagnostics object.
// and it handles the multiple ways a WAF initialization can fail.
func newHandle(cHandle bindings.WafHandle, diagnosticsWafObj *bindings.WafObject) (*Handle, error) {
diags, diagsErr := decodeDiagnostics(diagnosticsWafObj)
if cHandle == 0 && diagsErr != nil { // WAF Failed initialization and we manage to decode the diagnostics, return the diagnostics error
if err := diags.TopLevelError(); err != nil {
return nil, fmt.Errorf("could not instantiate the WAF: %w", err)
}
}

if cHandle == 0 {
// WAF Failed initialization, report the best possible error...
if diags != nil && diagsErr == nil {
// We were able to parse out some diagnostics from the WAF!
err = diags.TopLevelError()
if err != nil {
return nil, fmt.Errorf("could not instantiate the WAF: %w", err)
}
}
return nil, errors.New("could not instantiate the WAF")
}

// The WAF successfully initialized at this stage...
// The WAF successfully initialized at this stage but if the diagnostics decoding failed, we still need to cleanup
if diagsErr != nil {
wafLib.WafDestroy(cHandle)
return nil, fmt.Errorf("could not decode the WAF diagnostics: %w", diagsErr)
}

unsafe.KeepAlive(encoder.cgoRefs)

handle := &Handle{
cHandle: cHandle,
diagnostics: *diags,
diagnostics: diags,
}

handle.refCounter.Store(1) // We count the handle itself in the counter
Expand Down Expand Up @@ -170,25 +166,9 @@ func (handle *Handle) Update(newRules any) (*Handle, error) {
}

diagnosticsWafObj := new(bindings.WafObject)

cHandle := wafLib.WafUpdate(handle.cHandle, obj, diagnosticsWafObj)
unsafe.KeepAlive(encoder.cgoRefs)
if cHandle == 0 {
return nil, errors.New("could not update the WAF instance")
}

defer wafLib.WafObjectFree(diagnosticsWafObj)

if err != nil { // Something is very wrong
return nil, fmt.Errorf("could not decode the WAF ruleset errors: %w", err)
}

newHandle := &Handle{
cHandle: cHandle,
}

newHandle.refCounter.Store(1) // We count the handle itself in the counter
return newHandle, nil
return newHandle(wafLib.WafUpdate(handle.cHandle, obj, diagnosticsWafObj), diagnosticsWafObj)
}

// Close puts the handle in termination state, when all the contexts are closed the handle will be destroyed
Expand Down
7 changes: 4 additions & 3 deletions internal/bindings/safe.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/pkg/errors"
)

func newPanicError(in func() error, err error) *wafErrors.PanicError {
func newPanicError(in any, err error) *wafErrors.PanicError {
return &wafErrors.PanicError{
In: runtime.FuncForPC(reflect.ValueOf(in).Pointer()).Name(),
Err: err,
Expand All @@ -24,7 +24,7 @@ func newPanicError(in func() error, err error) *wafErrors.PanicError {

// tryCall calls function `f` and recovers from any panic occurring while it
// executes, returning it in a `PanicError` object type.
func tryCall(f func() error) (err error) {
func tryCall[T any](f func() T) (res T, err error) {
defer func() {
r := recover()
if r == nil {
Expand All @@ -43,5 +43,6 @@ func tryCall(f func() error) (err error) {

err = newPanicError(f, err)
}()
return f()
res = f()
return
}
13 changes: 7 additions & 6 deletions internal/bindings/safe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ package bindings

import (
"errors"
wafErrors "github.com/DataDog/go-libddwaf/v3/errors"
"strconv"
"testing"

wafErrors "github.com/DataDog/go-libddwaf/v3/errors"

"github.com/stretchr/testify/require"
)

Expand All @@ -21,7 +22,7 @@ func TestTryCall(t *testing.T) {
t.Run("panic", func(t *testing.T) {
t.Run("error", func(t *testing.T) {
// panic called with an error
err := tryCall(func() error {
_, err := tryCall(func() error {
panic(myPanicErr)
})
require.Error(t, err)
Expand All @@ -33,7 +34,7 @@ func TestTryCall(t *testing.T) {
t.Run("string", func(t *testing.T) {
// panic called with a string
str := "woops"
err := tryCall(func() error {
_, err := tryCall(func() error {
panic(str)
})
require.Error(t, err)
Expand All @@ -45,7 +46,7 @@ func TestTryCall(t *testing.T) {
t.Run("int", func(t *testing.T) {
// panic called with an int to cover the default fallback in tryCall
var i int64 = 42
err := tryCall(func() error {
_, err := tryCall(func() error {
panic(i)
})
require.Error(t, err)
Expand All @@ -56,14 +57,14 @@ func TestTryCall(t *testing.T) {
})

t.Run("error", func(t *testing.T) {
err := tryCall(func() error {
err, _ := tryCall(func() error {
return myErr
})
require.Equal(t, myErr, err)
})

t.Run("no error", func(t *testing.T) {
err := tryCall(func() error {
err, _ := tryCall(func() error {
return nil
})
require.NoError(t, err)
Expand Down
8 changes: 2 additions & 6 deletions internal/bindings/waf_dl.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/ebitengine/purego"
)

// wafDl is the type wrapper for all C calls to the waf
// WafDl is the type wrapper for all C calls to the waf
// It uses `libwaf` to make C calls
// All calls must go through this one-liner to be type safe
// since purego calls are not type safe
Expand Down Expand Up @@ -71,11 +71,7 @@ func NewWafDl() (dl *WafDl, err error) {
dl = &WafDl{symbols, handle}

// Try calling the waf to make sure everything is fine
err = tryCall(func() error {
dl.WafGetVersion()
return nil
})
if err != nil {
if _, err = tryCall(dl.WafGetVersion); err != nil {
if closeErr := purego.Dlclose(handle); closeErr != nil {
err = errors.Join(err, fmt.Errorf("error released the shared libddwaf library: %w", closeErr))
}
Expand Down
2 changes: 2 additions & 0 deletions waf.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Diagnostics struct {
Exclusions *DiagnosticEntry
RulesOverrides *DiagnosticEntry
RulesData *DiagnosticEntry
ExclusionData *DiagnosticEntry
Processors *DiagnosticEntry
Scanners *DiagnosticEntry
Version string
Expand All @@ -44,6 +45,7 @@ func (d *Diagnostics) TopLevelError() error {
"exclusions": d.Exclusions,
"rules_override": d.RulesOverrides,
"rules_data": d.RulesData,
"exclusion_data": d.ExclusionData,
"processors": d.Processors,
"scanners": d.Scanners,
}
Expand Down

0 comments on commit 24e43d9

Please sign in to comment.